keras是一个比较适合初学者上手的高级神经网络API,它能够以TensorFlow, CNTK, 或者 Theano作为后端运行。而keras训练完的模型是.h5文件,如果想要在移动端运行模型需要tflite模型文件
# coding: utf-8
Input arguments:
num_output: this value has nothing to do with the number of classes, batch_size, etc.,
and it is mostly equal to 1. If the network is a **multi-stream network**
(forked network with multiple outputs), set the value to the number of outputs.
quantize: if set to True, use the quantize feature of Tensorflow
(https://www.tensorflow.org/performance/quantization) [default: False]
use_theano: Thaeno and Tensorflow implement convolution in different ways.
When using Keras with Theano backend, the order is set to 'channels_first'.
This feature is not fully tested, and doesn't work with quantizization [default: False]
input_fld: directory holding the keras weights file [default: .]
output_fld: destination directory to save the tensorflow files [default: .]
input_model_file: name of the input weight file [default: 'model.h5']
output_model_file: name of the output weight file [default: args.input_model_file + '.pb']
graph_def: if set to True, will write the graph definition as an ascii file [default: False]
output_graphdef_file: if graph_def is set to True, the file name of the
graph definition [default: model.ascii]
output_node_prefix: the prefix to use for output nodes. [default: output_node]
# Parse input arguments
import argparse
parser = argparse.ArgumentParser(description='set input arguments')
parser.add_argument('-input_fld', action="store",
dest='input_fld', type=str, default='.')
parser.add_argument('-output_fld', action="store",
dest='output_fld', type=str, default='')
parser.add_argument('-input_model_file', action="store",
dest='input_model_file', type=str, default='model.h5')
parser.add_argument('-output_model_file', action="store",
dest='output_model_file', type=str, default='')
parser.add_argument('-output_graphdef_file', action="store",
dest='output_graphdef_file', type=str, default='model.ascii')
parser.add_argument('-num_outputs', action="store",
dest='num_outputs', type=int, default=1)
parser.add_argument('-graph_def', action="store",
dest='graph_def', type=bool, default=False)
parser.add_argument('-output_node_prefix', action="store",
dest='output_node_prefix', type=str, default='output_node')
parser.add_argument('-quantize', action="store",
dest='quantize', type=bool, default=False)
parser.add_argument('-theano_backend', action="store",
dest='theano_backend', type=bool, default=False)
args = parser.parse_args()
print('input args: ', args)
if args.theano_backend is True and args.quantize is True:
raise ValueError("Quantize feature does not work with theano backend.")
# initialize
from keras.models import load_model
import tensorflow as tf
from pathlib import Path
from keras import backend as K
from keras.applications import mobilenet
from keras.utils.generic_utils import CustomObjectScope
output_fld = args.input_fld if args.output_fld == '' else args.output_fld
if args.output_model_file == '':
args.output_model_file = str(Path(args.input_model_file).name) + '.pb'
Path(output_fld).mkdir(parents=True, exist_ok=True)
weight_file_path = str(Path(args.input_fld) / args.input_model_file)
# Load keras model and rename output
if args.theano_backend:
# try:
# 主要修改在这里,需要加上这行,否则会报错
with CustomObjectScope({'relu6': mobilenet.relu6, 'DepthwiseConv2D': mobilenet.DepthwiseConv2D}):
net_model = load_model(weight_file_path)
# except ValueError as err:
# print('''Input file specified ({}) only holds the weights, and not the model defenition.
# Save the model using mode.save(filename.h5) which will contain the network architecture
# as well as its weights.
# If the model is saved using model.save_weights(filename.h5), the model architecture is
# expected to be saved separately in a json format and loaded prior to loading the weights.
# Check the keras documentation for more details (https://keras.io/getting-started/faq/)'''
# .format(weight_file_path))
# raise err
num_output = args.num_outputs
pred = [None]*num_output
pred_node_names = [None]*num_output
for i in range(num_output):
pred_node_names[i] = args.output_node_prefix+str(i)
pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i])
print('output nodes names are: ', pred_node_names)
# [optional] write graph definition in ascii
sess = K.get_session()
if args.graph_def:
f = args.output_graphdef_file
tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)
print('saved the graph definition in ascii format at: ', str(Path(output_fld) / f))
# convert variables to constants and save
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
if args.quantize:
from tensorflow.tools.graph_transforms import TransformGraph
transforms = ["quantize_weights", "quantize_nodes"]
transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms)
constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names)
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
graph_io.write_graph(constant_graph, output_fld, args.output_model_file, as_text=False)
print('saved the freezed graph (ready for inference) at: ', str(Path(output_fld) / args.output_model_file))
- 如果你的tensorflow是1.8的话,先要将tensorflow升级到1.9或者降级到1.7,因为1.8的toco命令不好使。升级方法就是pip3 install -U tensorflow 或者pip3 install --upgrade tensorflow
- 升级完成后就可以使用toco命令了,注意:如果之前你是用virtualenv安装的整个环境,那么先source ./bin/activate激活环境,在环境下才能使用
toco --graph_def_file mobilenet_v1_1.0_224_frozen.pb
- 1.9的toco命令已经用参数--graph_def_file代替了--input_file
- 1.9的toco命令已经将参数--input_type取消掉
- 上面的命令运行成功后,就可以将自己的pb文件转化成tflite文件了,只要替换graph_def_file后面的pb文件名字和output_file后面的输出文件名字,然后重点是知道你训练的模型的input层的name和output层的name,至于怎么找到这两个层的name,最好用tensorflow中的load_graph函数load一下你的pb模型,遍历graph找到对应层的name既可。分别用input层name和output层name替换input_arrays和output_arrays参数后面的值
