首先keras训练好的模型通过自带的model.save()保存下来是 .model (.h5) 格式的文件
模型载入是通过 my_model = keras . models . load_model( filepath )
要将该模型转换为.pb 格式的TensorFlow 模型,代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
591 # -*- coding: utf-8 -*- 2 from keras.layers.core import Activation, Dense, Flatten 3 from keras.layers.embeddings import Embedding 4 from keras.layers.recurrent import LSTM 5 from keras.layers import Dropout 6 from keras.layers.wrappers import Bidirectional 7 from keras.models import Sequential,load_model 8 from keras.preprocessing import sequence 9 from sklearn.model_selection import train_test_split 10 import collections 11 from collections import defaultdict 12 import jieba 13 import numpy as np 14 import sys 15 reload(sys) 16 sys.setdefaultencoding('utf-8') 17 import tensorflow as tf 18 import os 19 import os.path as osp 20 from keras import backend as K 21 def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): 22 from tensorflow.python.framework.graph_util import convert_variables_to_constants 23 graph = session.graph 24 with graph.as_default(): 25 freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) 26 output_names = output_names or [] 27 output_names += [v.op.name for v in tf.global_variables()] 28 input_graph_def = graph.as_graph_def() 29 if clear_devices: 30 for node in input_graph_def.node: 31 node.device = "" 32 frozen_graph = convert_variables_to_constants(session, input_graph_def, 33 output_names, freeze_var_names) 34 return frozen_graph 37 input_fld = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/' 38 weight_file = 'biLSTM_brand_recognize.model' 39 output_graph_name = 'tensor_model_v3.pb' 40 41 output_fld = input_fld + '/tensorflow_model/' 42 if not os.path.isdir(output_fld): 43 os.mkdir(output_fld) 44 weight_file_path = osp.join(input_fld, weight_file) 45 46 K.set_learning_phase(0) 47 net_model = load_model(weight_file_path) 48 49 50 print('input is :', net_model.input.name) 51 print ('output is:', net_model.output.name) 52 53 sess = K.get_session() 54 55 frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name]) 57 from tensorflow.python.framework import graph_io 58 59 graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=True) 60 61 62 print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))
然后模型就存成了.pb格式的文件
问题就来了,这样存下来的.pb格式的文件是frozen model
如果通过TensorFlow serving 启用模型的话,会报错:
E tensorflow_serving/core/aspired_versions_manager.cc:358] Servable {name: mnist version: 1} cannot be loaded: Not found: Could not find meta graph def matching supplied tags: { serve }. To inspect available tag-sets in the SavedModel, please use the SavedModel CLI: `saved_model_cli`
因为TensorFlow serving 希望读取的是saved model
于是需要将frozen model 转化为 saved model 格式,解决方案如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
3164 from tensorflow.python.saved_model import signature_constants 65 from tensorflow.python.saved_model import tag_constants 66 67 export_dir = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/saved_model' 68 graph_pb = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/tensorflow_model/tensor_model.pb' 69 70 builder = tf.saved_model.builder.SavedModelBuilder(export_dir) 71 72 with tf.gfile.GFile(graph_pb, "rb") as f: 73 graph_def = tf.GraphDef() 74 graph_def.ParseFromString(f.read()) 75 76 sigs = {} 77 78 with tf.Session(graph=tf.Graph()) as sess: 79 # name="" is important to ensure we don't get spurious prefixing 80 tf.import_graph_def(graph_def, name="") 81 g = tf.get_default_graph() 82 inp = g.get_tensor_by_name(net_model.input.name) 83 out = g.get_tensor_by_name(net_model.output.name) 84 85 sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = 86 tf.saved_model.signature_def_utils.predict_signature_def( 87 {"in": inp}, {"out": out}) 88 89 builder.add_meta_graph_and_variables(sess, 90 [tag_constants.SERVING], 91 signature_def_map=sigs) 92 93 builder.save()
于是保存下来的saved model 文件夹下就有两个文件:
saved_model.pb variables
其中variables 可以为空
于是将.pb 模型导入serving再读取,成功!
最后
以上就是冷静苗条最近收集整理的关于如何将keras训练好的模型转换成tensorflow的.pb的文件并在TensorFlow serving环境调用的全部内容,更多相关如何将keras训练好的模型转换成tensorflow的.pb的文件并在TensorFlow内容请搜索靠谱客的其他文章。
发表评论 取消回复