我是靠谱客的博主 如意裙子,这篇文章主要介绍如何将keras训练的模型转换成tensorflow lite模型,现在分享给大家,希望可以做个参考。

背景

keras是一个比较适合初学者上手的高级神经网络API,它能够以TensorFlow, CNTK, 或者 Theano作为后端运行。而keras训练完的模型是.h5文件,如果想要在移动端运行模型需要tflite模型文件

实现

附上从github上找到的一段转换代码,但是要稍作修改

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# coding: utf-8 # In[ ]: ''' Input arguments: num_output: this value has nothing to do with the number of classes, batch_size, etc., and it is mostly equal to 1. If the network is a **multi-stream network** (forked network with multiple outputs), set the value to the number of outputs. quantize: if set to True, use the quantize feature of Tensorflow (https://www.tensorflow.org/performance/quantization) [default: False] use_theano: Thaeno and Tensorflow implement convolution in different ways. When using Keras with Theano backend, the order is set to 'channels_first'. This feature is not fully tested, and doesn't work with quantizization [default: False] input_fld: directory holding the keras weights file [default: .] output_fld: destination directory to save the tensorflow files [default: .] input_model_file: name of the input weight file [default: 'model.h5'] output_model_file: name of the output weight file [default: args.input_model_file + '.pb'] graph_def: if set to True, will write the graph definition as an ascii file [default: False] output_graphdef_file: if graph_def is set to True, the file name of the graph definition [default: model.ascii] output_node_prefix: the prefix to use for output nodes. [default: output_node] ''' # Parse input arguments # In[ ]: import argparse parser = argparse.ArgumentParser(description='set input arguments') parser.add_argument('-input_fld', action="store", dest='input_fld', type=str, default='.') parser.add_argument('-output_fld', action="store", dest='output_fld', type=str, default='') parser.add_argument('-input_model_file', action="store", dest='input_model_file', type=str, default='model.h5') parser.add_argument('-output_model_file', action="store", dest='output_model_file', type=str, default='') parser.add_argument('-output_graphdef_file', action="store", dest='output_graphdef_file', type=str, default='model.ascii') parser.add_argument('-num_outputs', action="store", dest='num_outputs', type=int, default=1) parser.add_argument('-graph_def', action="store", dest='graph_def', type=bool, default=False) parser.add_argument('-output_node_prefix', action="store", dest='output_node_prefix', type=str, default='output_node') parser.add_argument('-quantize', action="store", dest='quantize', type=bool, default=False) parser.add_argument('-theano_backend', action="store", dest='theano_backend', type=bool, default=False) parser.add_argument('-f') args = parser.parse_args() parser.print_help() print('input args: ', args) if args.theano_backend is True and args.quantize is True: raise ValueError("Quantize feature does not work with theano backend.") # initialize # In[ ]: from keras.models import load_model import tensorflow as tf from pathlib import Path from keras import backend as K from keras.applications import mobilenet from keras.utils.generic_utils import CustomObjectScope output_fld = args.input_fld if args.output_fld == '' else args.output_fld if args.output_model_file == '': args.output_model_file = str(Path(args.input_model_file).name) + '.pb' Path(output_fld).mkdir(parents=True, exist_ok=True) weight_file_path = str(Path(args.input_fld) / args.input_model_file) # Load keras model and rename output # In[ ]: K.set_learning_phase(0) if args.theano_backend: K.set_image_data_format('channels_first') else: K.set_image_data_format('channels_last') # try: # 主要修改在这里,需要加上这行,否则会报错 with CustomObjectScope({'relu6': mobilenet.relu6, 'DepthwiseConv2D': mobilenet.DepthwiseConv2D}): net_model = load_model(weight_file_path) # except ValueError as err: # print('''Input file specified ({}) only holds the weights, and not the model defenition. # Save the model using mode.save(filename.h5) which will contain the network architecture # as well as its weights. # If the model is saved using model.save_weights(filename.h5), the model architecture is # expected to be saved separately in a json format and loaded prior to loading the weights. # Check the keras documentation for more details (https://keras.io/getting-started/faq/)''' # .format(weight_file_path)) # raise err num_output = args.num_outputs pred = [None]*num_output pred_node_names = [None]*num_output for i in range(num_output): pred_node_names[i] = args.output_node_prefix+str(i) pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i]) print('output nodes names are: ', pred_node_names) # [optional] write graph definition in ascii # In[ ]: sess = K.get_session() if args.graph_def: f = args.output_graphdef_file tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True) print('saved the graph definition in ascii format at: ', str(Path(output_fld) / f)) # convert variables to constants and save # In[ ]: from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_io if args.quantize: from tensorflow.tools.graph_transforms import TransformGraph transforms = ["quantize_weights", "quantize_nodes"] transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms) constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names) else: constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names) graph_io.write_graph(constant_graph, output_fld, args.output_model_file, as_text=False) print('saved the freezed graph (ready for inference) at: ', str(Path(output_fld) / args.output_model_file)) 复制代码

keras转tensorflow完成后,接下来我们就要将.pb文件转化为.tflite文件。这里查阅了很多资料,记录一下坑的地方

  1. 如果你的tensorflow是1.8的话,先要将tensorflow升级到1.9或者降级到1.7,因为1.8的toco命令不好使。升级方法就是pip3 install -U tensorflow 或者pip3 install --upgrade tensorflow
  2. 升级完成后就可以使用toco命令了,注意:如果之前你是用virtualenv安装的整个环境,那么先source ./bin/activate激活环境,在环境下才能使用
复制代码
1
2
3
4
5
6
7
8
toco --graph_def_file mobilenet_v1_1.0_224_frozen.pb --output_format=TFLITE --output_file=mobilenet_v1_1.0_224_test.tflite --inference_type=FLOAT --input_arrays=input --output_arrays=MobilenetV1/Predictions/Reshape_1 --input_shapes=1,224,224,3 复制代码

这里注意,千万不要按照教程里的命令进行,因为这里有几个坑点:

    1. 1.9的toco命令已经用参数--graph_def_file代替了--input_file
    1. 1.9的toco命令已经将参数--input_type取消掉

所以最后可以运行成功的命令,如上

  1. 上面的命令运行成功后,就可以将自己的pb文件转化成tflite文件了,只要替换graph_def_file后面的pb文件名字和output_file后面的输出文件名字,然后重点是知道你训练的模型的input层的name和output层的name,至于怎么找到这两个层的name,最好用tensorflow中的load_graph函数load一下你的pb模型,遍历graph找到对应层的name既可。分别用input层name和output层name替换input_arrays和output_arrays参数后面的值

最后

以上就是如意裙子最近收集整理的关于如何将keras训练的模型转换成tensorflow lite模型的全部内容,更多相关如何将keras训练的模型转换成tensorflow内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(61)

评论列表共有 0 条评论

立即
投稿
返回
顶部