概述
首先keras训练好的模型通过自带的model.save()保存下来是 .model (.h5) 格式的文件
模型载入是通过 my_model = keras . models . load_model( filepath )
要将该模型转换为.pb 格式的TensorFlow 模型,代码如下:
1 # -*- coding: utf-8 -*-
2 from keras.layers.core import Activation, Dense, Flatten
3 from keras.layers.embeddings import Embedding
4 from keras.layers.recurrent import LSTM
5 from keras.layers import Dropout
6 from keras.layers.wrappers import Bidirectional
7 from keras.models import Sequential,load_model
8 from keras.preprocessing import sequence
9 from sklearn.model_selection import train_test_split
10 import collections
11 from collections import defaultdict
12 import jieba
13 import numpy as np
14 import sys
15 reload(sys)
16 sys.setdefaultencoding('utf-8')
17 import tensorflow as tf
18 import os
19 import os.path as osp
20 from keras import backend as K
21 def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
22 from tensorflow.python.framework.graph_util import convert_variables_to_constants
23 graph = session.graph
24 with graph.as_default():
25 freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
26 output_names = output_names or []
27 output_names += [v.op.name for v in tf.global_variables()]
28 input_graph_def = graph.as_graph_def()
29 if clear_devices:
30 for node in input_graph_def.node:
31 node.device = ""
32 frozen_graph = convert_variables_to_constants(session, input_graph_def,
33 output_names, freeze_var_names)
34 return frozen_graph
37 input_fld = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/'
38 weight_file = 'biLSTM_brand_recognize.model'
39 output_graph_name = 'tensor_model_v3.pb'
40
41 output_fld = input_fld + '/tensorflow_model/'
42 if not os.path.isdir(output_fld):
43 os.mkdir(output_fld)
44 weight_file_path = osp.join(input_fld, weight_file)
45
46 K.set_learning_phase(0)
47 net_model = load_model(weight_file_path)
48
49
50 print('input is :', net_model.input.name)
51 print ('output is:', net_model.output.name)
52
53 sess = K.get_session()
54
55 frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
57 from tensorflow.python.framework import graph_io
58
59 graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=True)
60
61
62 print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))
然后模型就存成了.pb格式的文件
问题就来了,这样存下来的.pb格式的文件是frozen model
如果通过TensorFlow serving 启用模型的话,会报错:
E tensorflow_serving/core/aspired_versions_manager.cc:358] Servable {name: mnist version: 1} cannot be loaded: Not found: Could not find meta graph def matching supplied tags: { serve }. To inspect available tag-sets in the SavedModel, please use the SavedModel CLI: `saved_model_cli`
因为TensorFlow serving 希望读取的是saved model
于是需要将frozen model 转化为 saved model 格式,解决方案如下:
64 from tensorflow.python.saved_model import signature_constants
65 from tensorflow.python.saved_model import tag_constants
66
67 export_dir = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/saved_model'
68 graph_pb = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/tensorflow_model/tensor_model.pb'
69
70 builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
71
72 with tf.gfile.GFile(graph_pb, "rb") as f:
73 graph_def = tf.GraphDef()
74 graph_def.ParseFromString(f.read())
75
76 sigs = {}
77
78 with tf.Session(graph=tf.Graph()) as sess:
79 # name="" is important to ensure we don't get spurious prefixing
80 tf.import_graph_def(graph_def, name="")
81 g = tf.get_default_graph()
82 inp = g.get_tensor_by_name(net_model.input.name)
83 out = g.get_tensor_by_name(net_model.output.name)
84
85 sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] =
86 tf.saved_model.signature_def_utils.predict_signature_def(
87 {"in": inp}, {"out": out})
88
89 builder.add_meta_graph_and_variables(sess,
90 [tag_constants.SERVING],
91 signature_def_map=sigs)
92
93 builder.save()
于是保存下来的saved model 文件夹下就有两个文件:
saved_model.pb variables
其中variables 可以为空
于是将.pb 模型导入serving再读取,成功!
最后
以上就是冷静苗条为你收集整理的如何将keras训练好的模型转换成tensorflow的.pb的文件并在TensorFlow serving环境调用的全部内容,希望文章能够帮你解决如何将keras训练好的模型转换成tensorflow的.pb的文件并在TensorFlow serving环境调用所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复