概述
在工作中需要用到c++调用keras训练的模型,因为keras没有提供c++接口,因此需要先将keras生的.h5模型文件转换成TensorFlow的.pb文件。
-
利用keras训练模型
如果已经有keras训练好的模型的话,这一步可以跳过。但是要注意的是,在保存模型的时候使用的是model.save(‘ keras .h5’)进行保存,因为save保存的是模型的结构和权重,如果使用的是model.save_weights('keras.h5’)保存的是模型的权重,后面转换将会出现问题。若果还没有训练好的模型的话,参考windows+TensorFlow/keras+vgg16训练自己的数据集 ,先训练模型。
-
.h5转换.pb文件
创建h5_to_pb.py文件,复制一下代码,不需要修改仍和地方。# In[ ]: """ Copyright (c) 2017, by the Authors: Amir H. Abdi This software is freely available under the MIT Public License. Please see the License file in the root for details. The following code snippet will convert the keras model file, which is saved using model.save('kerasmodel_weight_file'), to the freezed .pb tensorflow weight file which holds both the network architecture and its associated weights. """; # In[ ]: ''' Input arguments: num_output: this value has nothing to do with the number of classes, batch_size, etc., and it is mostly equal to 1. If the network is a **multi-stream network** (forked network with multiple outputs), set the value to the number of outputs. quantize: if set to True, use the quantize feature of Tensorflow (https://www.tensorflow.org/performance/quantization) [default: False] use_theano: Thaeno and Tensorflow implement convolution in different ways. When using Keras with Theano backend, the order is set to 'channels_first'. This feature is not fully tested, and doesn't work with quantizization [default: False] input_fld: directory holding the keras weights file [default: .] output_fld: destination directory to save the tensorflow files [default: .] input_model_file: name of the input weight file [default: 'model.h5'] output_model_file: name of the output weight file [default: args.input_model_file + '.pb'] graph_def: if set to True, will write the graph definition as an ascii file [default: False] output_graphdef_file: if graph_def is set to True, the file name of the graph definition [default: model.ascii] output_node_prefix: the prefix to use for output nodes. [default: output_node] ''' # Parse input arguments # In[ ]: import argparse parser = argparse.ArgumentParser(description='set input arguments') parser.add_argument('-input_fld', action="store", dest='input_fld', type=str, default='.') parser.add_argument('-output_fld', action="store", dest='output_fld', type=str, default='') parser.add_argument('-input_model_file', action="store", dest='input_model_file', type=str, default='model.h5') parser.add_argument('-output_model_file', action="store", dest='output_model_file', type=str, default='') parser.add_argument('-output_graphdef_file', action="store", dest='output_graphdef_file', type=str, default='model.ascii') parser.add_argument('-num_outputs', action="store", dest='num_outputs', type=int, default=1) parser.add_argument('-graph_def', action="store", dest='graph_def', type=bool, default=False) parser.add_argument('-output_node_prefix', action="store", dest='output_node_prefix', type=str, default='output_node') parser.add_argument('-quantize', action="store", dest='quantize', type=bool, default=False) parser.add_argument('-theano_backend', action="store", dest='theano_backend', type=bool, default=False) parser.add_argument('-f') args = parser.parse_args() parser.print_help() print('input args: ', args) if args.theano_backend is True and args.quantize is True: raise ValueError("Quantize feature does not work with theano backend.") # initialize # In[ ]: from keras.models import load_model import tensorflow as tf from pathlib import Path from keras import backend as K output_fld = args.input_fld if args.output_fld == '' else args.output_fld if args.output_model_file == '': args.output_model_file = str(Path(args.input_model_file).name) + '.pb' Path(output_fld).mkdir(parents=True, exist_ok=True) weight_file_path = str(Path(args.input_fld) / args.input_model_file) # Load keras model and rename output # In[ ]: K.set_learning_phase(0) if args.theano_backend: K.set_image_data_format('channels_first') else: K.set_image_data_format('channels_last') try: net_model = load_model(weight_file_path) except ValueError as err: print('''Input file specified ({}) only holds the weights, and not the model defenition. Save the model using mode.save(filename.h5) which will contain the network architecture as well as its weights. If the model is saved using model.save_weights(filename.h5), the model architecture is expected to be saved separately in a json format and loaded prior to loading the weights. Check the keras documentation for more details (https://keras.io/getting-started/faq/)''' .format(weight_file_path)) raise err # num_output = args.num_outputs # pred = [None]*num_output # pred_node_names = [None]*num_output # for i in range(num_output): # pred_node_names[i] = args.output_node_prefix+str(i) # pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i]) # num_output = len(net_model.output_names) # pred_node_names = [None]*num_output # pred = [None]*num_output # # pred_node_names = net_model.output_names # for i in range(num_output): # pred_node_names[i] = args.output_node_prefix+str(i) # pred[i] = tf.identity(net_model.outputs[i], name=pred_node_names[i]) input_node_names = [node.op.name for node in net_model.inputs] print('Input nodes names are: ', input_node_names) pred_node_names = [node.op.name for node in net_model.outputs] print('Output nodes names are: ', pred_node_names) # print("net_model.input.op.name:", net_model.input.op.name) # print("net_model.output.op.name:", net_model.output.op.name) # print("net_model.input_names:", net_model.input_names) # print("net_model.output_names:", net_model.output_names) # [optional] write graph definition in ascii # In[ ]: sess = K.get_session() if args.graph_def: f = args.output_graphdef_file tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True) print('saved the graph definition in ascii format at: ', str(Path(output_fld) / f)) # convert variables to constants and save # In[ ]: from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_io if args.quantize: from tensorflow.tools.graph_transforms import TransformGraph transforms = ["quantize_weights", "quantize_nodes"] transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms) constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names) else: constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names) graph_io.write_graph(constant_graph, output_fld, args.output_model_file, as_text=False) print('saved the freezed graph (ready for inference) at: ', str(Path(output_fld) / args.output_model_file))
保存好打开command命令,cd到对应的路径,输入命令
python h5_to_pb.py -input_model_file models/vgg16_use.h5 -output_model_file models/vgg16_use.h5.pb
其中models是存放生成的keras模型的文件夹,如下图所示
如果生成了对应的.pb文件,代表转换成功。 -
测试转换的.pb文件是否正确
这一步是验证转换的模型是否正确,可以忽略跳过。
同样建立load_pb_test.py文件,复制一下代码import tensorflow as tf from tensorflow.python.framework import graph_util import argparse tf.reset_default_graph() # 重置计算图 def network_structure(args): model_path = args.model+'.pb' with tf.Session() as sess: tf.global_variables_initializer().run() output_graph_def = tf.GraphDef() # 获得默认的图 graph = tf.get_default_graph() with open(model_path, "rb") as f: output_graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(output_graph_def, name="") # 得到当前图有几个操作节点 print("%d ops in the final graph." % len(output_graph_def.node)) tensor_name = [tensor.name for tensor in output_graph_def.node] print(tensor_name) print('---------------------------') # 在log_graph文件夹下生产日志文件,可以在tensorboard中可视化模型 summaryWriter = tf.summary.FileWriter('log_graph_'+args.model, graph) cnt = 0 for op in graph.get_operations(): # print出tensor的name和值 print(op.name, op.values()) cnt += 1 if args.n: if cnt == args.n: break """ 可视化 tensorboard --logdir="log_graph/" """ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, help="model name to look") parser.add_argument('--n', type=int, help='the number of first several tensor name to look') # 当tensor_name过多 args = parser.parse_args() network_structure(args)
保存后,在command中输入命令:
python load_pb_test.py --model models/vgg16_use.h5 --n 100
其中–n 100代表的是输出多少条网络的层,如果你网络层数很多的话,数字可以写大点。如果输出如下图所示,代表输出正确。
最后
以上就是故意身影为你收集整理的keras模型转换成TensorFlow的模型格式的全部内容,希望文章能够帮你解决keras模型转换成TensorFlow的模型格式所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复