我是靠谱客的博主 冷静苗条,最近开发中收集的这篇文章主要介绍如何将keras训练好的模型转换成tensorflow的.pb的文件并在TensorFlow serving环境调用,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

首先keras训练好的模型通过自带的model.save()保存下来是 .model (.h5) 格式的文件

模型载入是通过 my_model = keras . models . load_model( filepath )

要将该模型转换为.pb 格式的TensorFlow 模型,代码如下:

  1 # -*- coding: utf-8 -*-
  2 from keras.layers.core import Activation, Dense, Flatten
  3 from keras.layers.embeddings import Embedding
  4 from keras.layers.recurrent import LSTM
  5 from keras.layers import Dropout
  6 from keras.layers.wrappers import Bidirectional
  7 from keras.models import Sequential,load_model
  8 from keras.preprocessing import sequence
  9 from sklearn.model_selection import train_test_split
 10 import collections
 11 from collections import defaultdict
 12 import jieba
 13 import numpy as np
 14 import sys
 15 reload(sys)
 16 sys.setdefaultencoding('utf-8')
 17 import tensorflow as tf
 18 import os
 19 import os.path as osp
 20 from keras import backend as K
 21 def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 22     from tensorflow.python.framework.graph_util import convert_variables_to_constants
 23     graph = session.graph
 24     with graph.as_default():
 25         freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
 26         output_names = output_names or []
 27         output_names += [v.op.name for v in tf.global_variables()]
 28         input_graph_def = graph.as_graph_def()
 29         if clear_devices:
 30             for node in input_graph_def.node:
 31                 node.device = ""
 32         frozen_graph = convert_variables_to_constants(session, input_graph_def,
 33                                                       output_names, freeze_var_names)
 34         return frozen_graph
 37 input_fld = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/'
 38 weight_file = 'biLSTM_brand_recognize.model'
 39 output_graph_name = 'tensor_model_v3.pb'
 40 
 41 output_fld = input_fld + '/tensorflow_model/'
 42 if not os.path.isdir(output_fld):
 43     os.mkdir(output_fld)
 44 weight_file_path = osp.join(input_fld, weight_file)
 45 
 46 K.set_learning_phase(0)
 47 net_model = load_model(weight_file_path)
 48 
 49 
 50 print('input is :', net_model.input.name)
 51 print ('output is:', net_model.output.name)
 52 
 53 sess = K.get_session()
 54 
 55 frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
 57 from tensorflow.python.framework import graph_io
 58 
 59 graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=True)
 60 
 61 
 62 print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

然后模型就存成了.pb格式的文件

问题就来了,这样存下来的.pb格式的文件是frozen model

如果通过TensorFlow serving 启用模型的话,会报错:

 E tensorflow_serving/core/aspired_versions_manager.cc:358] Servable {name: mnist version: 1} cannot be loaded: Not found: Could not find meta graph def matching supplied tags: { serve }. To inspect available tag-sets in the SavedModel, please use the SavedModel CLI: `saved_model_cli`

因为TensorFlow serving 希望读取的是saved model

于是需要将frozen model 转化为 saved model 格式,解决方案如下:

 64 from tensorflow.python.saved_model import signature_constants
 65 from tensorflow.python.saved_model import tag_constants
 66 
 67 export_dir = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/saved_model'
 68 graph_pb = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/tensorflow_model/tensor_model.pb'
 69 
 70 builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
 71 
 72 with tf.gfile.GFile(graph_pb, "rb") as f:
 73     graph_def = tf.GraphDef()
 74     graph_def.ParseFromString(f.read())
 75 
 76 sigs = {}
 77 
 78 with tf.Session(graph=tf.Graph()) as sess:
 79     # name="" is important to ensure we don't get spurious prefixing
 80     tf.import_graph_def(graph_def, name="")
 81     g = tf.get_default_graph()
 82     inp = g.get_tensor_by_name(net_model.input.name)
 83     out = g.get_tensor_by_name(net_model.output.name)
 84 
 85     sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = 
 86         tf.saved_model.signature_def_utils.predict_signature_def(
 87             {"in": inp}, {"out": out})
 88 
 89     builder.add_meta_graph_and_variables(sess,
 90                                          [tag_constants.SERVING],
 91                                          signature_def_map=sigs)
 92 
 93 builder.save()
                   

于是保存下来的saved model 文件夹下就有两个文件:

saved_model.pb   variables

其中variables 可以为空

于是将.pb 模型导入serving再读取,成功!

 

最后

以上就是冷静苗条为你收集整理的如何将keras训练好的模型转换成tensorflow的.pb的文件并在TensorFlow serving环境调用的全部内容,希望文章能够帮你解决如何将keras训练好的模型转换成tensorflow的.pb的文件并在TensorFlow serving环境调用所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(38)

评论列表共有 0 条评论

立即
投稿
返回
顶部