我是靠谱客的博主 玩命中心,这篇文章主要介绍Keras的H5模型转换为TensorFlow的pb文件,现在分享给大家,希望可以做个参考。

事实上说成TensorFlow的Pb文件并不准确,TensorFlow的文件保存为ckpt文件而且参数与网络结构是分开存储,神经网络的模型如果要在嵌入式终端设备下进行部署需要转换为tflite文件,本文中使用的是勘智K210芯片,所以需要在芯片内的KPU处理器上进行转换为Kmodel文件
根据GitHub上的Keras_to_tensorflow.py文件进行格式的转换
整个文件的代码核心也很简单

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
flags.DEFINE_string('input_model', None, 'Path to the input model.') flags.DEFINE_string('input_model_json', None, 'Path to the input model ' 'architecture in json format.') flags.DEFINE_string('input_model_yaml', None, 'Path to the input model ' 'architecture in yaml format.') flags.DEFINE_string('output_model', None, 'Path where the converted model will ' 'be stored.') flags.DEFINE_boolean('save_graph_def', False, 'Whether to save the graphdef.pbtxt file which contains ' 'the graph definition in ASCII format.') flags.DEFINE_string('output_nodes_prefix', None, 'If set, the output nodes will be renamed to ' '`output_nodes_prefix`+i, where `i` will numerate the ' 'number of of output nodes of the network.') flags.DEFINE_boolean('quantize', False, 'If set, the resultant TensorFlow graph weights will be ' 'converted from float into eight-bit equivalents. See ' 'documentation here: ' 'https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms') flags.DEFINE_boolean('channels_first', False, 'Whether channels are the first dimension of a tensor. ' 'The default is TensorFlow behaviour where channels are ' 'the last dimension.') flags.DEFINE_boolean('output_meta_ckpt', False, 'If set to True, exports the model as .meta, .index, and ' '.data files, with a checkpoint file. These can be later ' 'loaded in TensorFlow to continue training.') flags.mark_flag_as_required('input_model') flags.mark_flag_as_required('output_model')

使用flag在代码中定义了一些API,在进行终端执行时只要进行指定输入输出的模型位置即可,也可以指定脚本文件run.sh
执行例如

复制代码
1
2
3
cd 文件目录 python keras_to_tensorflow.py --input_model model.h5 --output_model model.pb

在文件下即可找到转换为.pb格式的模型文件
Keras_to_tensorflow.py完整代码

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import tensorflow as tf from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_io from pathlib import Path from absl import app from absl import flags from absl import logging import keras from keras import backend as K from keras.models import model_from_json, model_from_yaml K.set_learning_phase(0) FLAGS = flags.FLAGS flags.DEFINE_string('input_model', None, 'Path to the input model.') flags.DEFINE_string('input_model_json', None, 'Path to the input model ' 'architecture in json format.') flags.DEFINE_string('input_model_yaml', None, 'Path to the input model ' 'architecture in yaml format.') flags.DEFINE_string('output_model', None, 'Path where the converted model will ' 'be stored.') flags.DEFINE_boolean('save_graph_def', False, 'Whether to save the graphdef.pbtxt file which contains ' 'the graph definition in ASCII format.') flags.DEFINE_string('output_nodes_prefix', None, 'If set, the output nodes will be renamed to ' '`output_nodes_prefix`+i, where `i` will numerate the ' 'number of of output nodes of the network.') flags.DEFINE_boolean('quantize', False, 'If set, the resultant TensorFlow graph weights will be ' 'converted from float into eight-bit equivalents. See ' 'documentation here: ' 'https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms') flags.DEFINE_boolean('channels_first', False, 'Whether channels are the first dimension of a tensor. ' 'The default is TensorFlow behaviour where channels are ' 'the last dimension.') flags.DEFINE_boolean('output_meta_ckpt', False, 'If set to True, exports the model as .meta, .index, and ' '.data files, with a checkpoint file. These can be later ' 'loaded in TensorFlow to continue training.') flags.mark_flag_as_required('input_model') flags.mark_flag_as_required('output_model') def load_model(input_model_path, input_json_path=None, input_yaml_path=None): if not Path(input_model_path).exists(): raise FileNotFoundError( 'Model file `{}` does not exist.'.format(input_model_path)) try: model = keras.models.load_model(input_model_path) return model except FileNotFoundError as err: logging.error('Input mode file (%s) does not exist.', FLAGS.input_model) raise err except ValueError as wrong_file_err: if input_json_path: if not Path(input_json_path).exists(): raise FileNotFoundError( 'Model description json file `{}` does not exist.'.format( input_json_path)) try: model = model_from_json(open(str(input_json_path)).read()) model.load_weights(input_model_path) return model except Exception as err: logging.error("Couldn't load model from json.") raise err elif input_yaml_path: if not Path(input_yaml_path).exists(): raise FileNotFoundError( 'Model description yaml file `{}` does not exist.'.format( input_yaml_path)) try: model = model_from_yaml(open(str(input_yaml_path)).read()) model.load_weights(input_model_path) return model except Exception as err: logging.error("Couldn't load model from yaml.") raise err else: logging.error( 'Input file specified only holds the weights, and not ' 'the model definition. Save the model using ' 'model.save(filename.h5) which will contain the network ' 'architecture as well as its weights. ' 'If the model is saved using the ' 'model.save_weights(filename) function, either ' 'input_model_json or input_model_yaml flags should be set to ' 'to import the network architecture prior to loading the ' 'weights. n' 'Check the keras documentation for more details ' '(https://keras.io/getting-started/faq/)') raise wrong_file_err def main(args): # If output_model path is relative and in cwd, make it absolute from root output_model = FLAGS.output_model if str(Path(output_model).parent) == '.': output_model = str((Path.cwd() / output_model)) output_fld = Path(output_model).parent output_model_name = Path(output_model).name output_model_stem = Path(output_model).stem output_model_pbtxt_name = output_model_stem + '.pbtxt' # Create output directory if it does not exist Path(output_model).parent.mkdir(parents=True, exist_ok=True) if FLAGS.channels_first: K.set_image_data_format('channels_first') else: K.set_image_data_format('channels_last') model = load_model(FLAGS.input_model, FLAGS.input_model_json, FLAGS.input_model_yaml) # TODO(amirabdi): Support networks with multiple inputs orig_output_node_names = [node.op.name for node in model.outputs] if FLAGS.output_nodes_prefix: num_output = len(orig_output_node_names) pred = [None] * num_output converted_output_node_names = [None] * num_output # Create dummy tf nodes to rename output for i in range(num_output): converted_output_node_names[i] = '{}{}'.format( FLAGS.output_nodes_prefix, i) pred[i] = tf.identity(model.outputs[i], name=converted_output_node_names[i]) else: converted_output_node_names = orig_output_node_names logging.info('Converted output node names are: %s', str(converted_output_node_names)) sess = K.get_session() if FLAGS.output_meta_ckpt: saver = tf.train.Saver() saver.save(sess, str(output_fld / output_model_stem)) if FLAGS.save_graph_def: tf.train.write_graph(sess.graph.as_graph_def(), str(output_fld), output_model_pbtxt_name, as_text=True) logging.info('Saved the graph definition in ascii format at %s', str(Path(output_fld) / output_model_pbtxt_name)) if FLAGS.quantize: from tensorflow.tools.graph_transforms import TransformGraph transforms = ["quantize_weights", "quantize_nodes"] transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], converted_output_node_names, transforms) constant_graph = graph_util.convert_variables_to_constants( sess, transformed_graph_def, converted_output_node_names) else: constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(), converted_output_node_names) graph_io.write_graph(constant_graph, str(output_fld), output_model_name, as_text=False) logging.info('Saved the freezed graph at %s', str(Path(output_fld) / output_model_name)) if __name__ == "__main__": app.run(main)

最后

以上就是玩命中心最近收集整理的关于Keras的H5模型转换为TensorFlow的pb文件的全部内容,更多相关Keras内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(82)

评论列表共有 0 条评论

立即
投稿
返回
顶部