背景
keras是一个比较适合初学者上手的高级神经网络API,它能够以TensorFlow, CNTK, 或者 Theano作为后端运行。而keras训练完的模型是.h5文件,如果想要在移动端运行模型需要tflite模型文件
实现
附上从github上找到的一段转换代码,但是要稍作修改
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# coding: utf-8
# In[ ]:
'''
Input arguments:
num_output: this value has nothing to do with the number of classes, batch_size, etc.,
and it is mostly equal to 1. If the network is a **multi-stream network**
(forked network with multiple outputs), set the value to the number of outputs.
quantize: if set to True, use the quantize feature of Tensorflow
(https://www.tensorflow.org/performance/quantization) [default: False]
use_theano: Thaeno and Tensorflow implement convolution in different ways.
When using Keras with Theano backend, the order is set to 'channels_first'.
This feature is not fully tested, and doesn't work with quantizization [default: False]
input_fld: directory holding the keras weights file [default: .]
output_fld: destination directory to save the tensorflow files [default: .]
input_model_file: name of the input weight file [default: 'model.h5']
output_model_file: name of the output weight file [default: args.input_model_file + '.pb']
graph_def: if set to True, will write the graph definition as an ascii file [default: False]
output_graphdef_file: if graph_def is set to True, the file name of the
graph definition [default: model.ascii]
output_node_prefix: the prefix to use for output nodes. [default: output_node]
'''
# Parse input arguments
# In[ ]:
import argparse
parser = argparse.ArgumentParser(description='set input arguments')
parser.add_argument('-input_fld', action="store",
dest='input_fld', type=str, default='.')
parser.add_argument('-output_fld', action="store",
dest='output_fld', type=str, default='')
parser.add_argument('-input_model_file', action="store",
dest='input_model_file', type=str, default='model.h5')
parser.add_argument('-output_model_file', action="store",
dest='output_model_file', type=str, default='')
parser.add_argument('-output_graphdef_file', action="store",
dest='output_graphdef_file', type=str, default='model.ascii')
parser.add_argument('-num_outputs', action="store",
dest='num_outputs', type=int, default=1)
parser.add_argument('-graph_def', action="store",
dest='graph_def', type=bool, default=False)
parser.add_argument('-output_node_prefix', action="store",
dest='output_node_prefix', type=str, default='output_node')
parser.add_argument('-quantize', action="store",
dest='quantize', type=bool, default=False)
parser.add_argument('-theano_backend', action="store",
dest='theano_backend', type=bool, default=False)
parser.add_argument('-f')
args = parser.parse_args()
parser.print_help()
print('input args: ', args)
if args.theano_backend is True and args.quantize is True:
raise ValueError("Quantize feature does not work with theano backend.")
# initialize
# In[ ]:
from keras.models import load_model
import tensorflow as tf
from pathlib import Path
from keras import backend as K
from keras.applications import mobilenet
from keras.utils.generic_utils import CustomObjectScope
output_fld = args.input_fld if args.output_fld == '' else args.output_fld
if args.output_model_file == '':
args.output_model_file = str(Path(args.input_model_file).name) + '.pb'
Path(output_fld).mkdir(parents=True, exist_ok=True)
weight_file_path = str(Path(args.input_fld) / args.input_model_file)
# Load keras model and rename output
# In[ ]:
K.set_learning_phase(0)
if args.theano_backend:
K.set_image_data_format('channels_first')
else:
K.set_image_data_format('channels_last')
# try:
# 主要修改在这里,需要加上这行,否则会报错
with CustomObjectScope({'relu6': mobilenet.relu6, 'DepthwiseConv2D': mobilenet.DepthwiseConv2D}):
net_model = load_model(weight_file_path)
# except ValueError as err:
# print('''Input file specified ({}) only holds the weights, and not the model defenition.
# Save the model using mode.save(filename.h5) which will contain the network architecture
# as well as its weights.
# If the model is saved using model.save_weights(filename.h5), the model architecture is
# expected to be saved separately in a json format and loaded prior to loading the weights.
# Check the keras documentation for more details (https://keras.io/getting-started/faq/)'''
# .format(weight_file_path))
# raise err
num_output = args.num_outputs
pred = [None]*num_output
pred_node_names = [None]*num_output
for i in range(num_output):
pred_node_names[i] = args.output_node_prefix+str(i)
pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i])
print('output nodes names are: ', pred_node_names)
# [optional] write graph definition in ascii
# In[ ]:
sess = K.get_session()
if args.graph_def:
f = args.output_graphdef_file
tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)
print('saved the graph definition in ascii format at: ', str(Path(output_fld) / f))
# convert variables to constants and save
# In[ ]:
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
if args.quantize:
from tensorflow.tools.graph_transforms import TransformGraph
transforms = ["quantize_weights", "quantize_nodes"]
transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms)
constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names)
else:
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
graph_io.write_graph(constant_graph, output_fld, args.output_model_file, as_text=False)
print('saved the freezed graph (ready for inference) at: ', str(Path(output_fld) / args.output_model_file))
复制代码
keras转tensorflow完成后,接下来我们就要将.pb文件转化为.tflite文件。这里查阅了很多资料,记录一下坑的地方
- 如果你的tensorflow是1.8的话,先要将tensorflow升级到1.9或者降级到1.7,因为1.8的toco命令不好使。升级方法就是pip3 install -U tensorflow 或者pip3 install --upgrade tensorflow
- 升级完成后就可以使用toco命令了,注意:如果之前你是用virtualenv安装的整个环境,那么先source ./bin/activate激活环境,在环境下才能使用
复制代码
1
2
3
4
5
6
7
8toco --graph_def_file mobilenet_v1_1.0_224_frozen.pb --output_format=TFLITE --output_file=mobilenet_v1_1.0_224_test.tflite --inference_type=FLOAT --input_arrays=input --output_arrays=MobilenetV1/Predictions/Reshape_1 --input_shapes=1,224,224,3 复制代码
这里注意,千万不要按照教程里的命令进行,因为这里有几个坑点:
-
- 1.9的toco命令已经用参数--graph_def_file代替了--input_file
-
- 1.9的toco命令已经将参数--input_type取消掉
所以最后可以运行成功的命令,如上
- 上面的命令运行成功后,就可以将自己的pb文件转化成tflite文件了,只要替换graph_def_file后面的pb文件名字和output_file后面的输出文件名字,然后重点是知道你训练的模型的input层的name和output层的name,至于怎么找到这两个层的name,最好用tensorflow中的load_graph函数load一下你的pb模型,遍历graph找到对应层的name既可。分别用input层name和output层name替换input_arrays和output_arrays参数后面的值
最后
以上就是如意裙子最近收集整理的关于如何将keras训练的模型转换成tensorflow lite模型的全部内容,更多相关如何将keras训练的模型转换成tensorflow内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复