概述
一、背景
为了快速的搭建神经网络、训练模型,使用了Keras框架来搭建网络并进行训练,得到训练的h5模型文件后,需要将模型部署成服务,而pb格式的文件一般比较适合部署,pb模型文件的大小要比h5文件小一点,同时pb文件也适用于在TensorFlow Serving,所以需要把Keras保存的h5模型文件转成TensorFlow加载的pb格式来使用。同时本人也参考了网上几乎所有的模型格式转换的文章,经过一番尝试后,终于成功了,现将模型格式转换方法和分别使用tf1.x和tf2.x加载转换后的pb文件的总结如下。
二、h5文件转pb文件的方法
首先声明一下,这里的h5文件都是用keras框架中的save()方法保存的
方法一:
使用大佬写好的keras_to_tensorflow.py程序进行转化文件格式,项目地址:https://github.com/amir-abdi/keras_to_tensorflow
该作者提供了一份很好模型格式转换工具,能够满足绝大多数人的需求了。原理很简单:首先用Keras读取.h5模型文件,然后用 tensorflow的convert_variables_to_constants函数将所有变量转换成常量,最后再write_graph就是一个包含了网络以及参数值的 .pb文件了。
如果你的Keras模型是一个包含了网络结构和权重的h5文件,那么使用下面的命令就可以了:
python keras_to_tensorflow.py
--input_model="h5_model_path/model.h5"
--output_model="save_pb_model_path/model.pb"
以上命令包含两个参数,第一个是模型输入路径,第二个模型输出路径。输出路径即使你没创建好,代码也会帮你创建。建议使用绝对路径。
注:该工具支持Tensorflow1.x版本
方法二:
使用下面的函数进行转换*(注:该函数支持Tensorflow1.x版本)*
def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
"""
.h5模型文件转换成pb模型文件
:param h5_model: .h5模型
:param output_dir: pb模型文件保存路径
:param model_name: pb模型文件名称
:param out_prefix: 根据训练,需要修改
:param log_tensorboard: 是否生成日志文件,默认为True
:return: pb模型文件
"""
if not os.path.exists(output_dir):
os.mkdir(output_dir)
out_nodes = []
for i in range(len(h5_model.outputs)):
out_nodes.append(out_prefix + str(i + 1))
tf.identity(h5_model.output[i], out_prefix + str(i + 1))
sess = backend.get_session()
from tensorflow.python.framework import graph_util, graph_io
# 写入pb模型文件
init_graph = sess.graph.as_graph_def()
main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
# 输出日志文件
if log_tensorboard:
from tensorflow.python.tools import import_pb_to_tensorboard
import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)
使用该函数前首先要加载h5模型文件,转换pb格式就只有两行代码
load_h5_model = load_model(h5_file_path, custom_objects=get_custom_objects())
h5_to_pb(load_h5_model, output_dir=pb_model_path, model_name=pb_model_name)
如果模型有自定义层,加载时要在custom_objects中写明,如
load_h5_model = load_model(h5_file_path, custom_objects={'CRF': CRF, 'crf_loss': crf_loss, 'crf_viterbi_accuracy': crf_viterbi_accuracy})
方法三:
利用tf2.x版本框架冻结图结构将h5文件转pb文件,具体参考下面函数
def frozen_graph(h5_file_path, pb_model_path):
"""
冻结模型,可以将训练好的.h5模型文件转成.pb文件
:param h5_file_path: h5模型文件路径
:param pb_model_path: pb模型文件保存路径
:return:
"""
# 加载模型,如有自定义层请参考方法二末尾处如何加载
model = tf.keras.models.load_model(h5_file_path, compile=False)
model.summary()
full_model = tf.function(lambda input_1: model(input_1))
full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
layers = [op.name for op in frozen_func.graph.get_operations()]
# print("-" * 50)
# print("Frozen model layers: ")
# for layer in layers:
# print(layer)
# print("-" * 50)
# print("Frozen model inputs: ")
# print(frozen_func.inputs)
# print("Frozen model outputs: ")
# print(frozen_func.outputs)
# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir=pb_model_path,
name="model_name.pb",
as_text=False)
print('model has been saved')
使用该函数时可能会遇到的错误:“AttributeError: module 'tensorflow.python.framework.ops' has no attribute '_TensorLike'”
错误原因:tensorflow和keras版本不匹配
解决方法:升级keras或降级tensorflow,但必须保证tensorflow版本大于2.0
以上方法所有代码在下面框架版本中测试通过:
方法一和方法二:
keras==2.2.4
tensorflow-gpu==1.15.0
方法三:
keras==2.4.3
tensorflow-gpu==2.3.1
三、利用Tensorflow加载pb模型
我们已经将h5文件转换成pb文件了,那现在就要测试一下文件是否能加载成功以及预测情况
(一)使用tf1.x框架加载由h5转成的pb文件
加载及预测的主要代码如下
with tf.Session() as sess:
with gfile.FastGFile(pb_file_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def)
# print all operation names
for op in sess.graph.get_operations():
print(op.name)
# 输入(此处get_tensor_by_name方法中的参数为model的第一层op的name)
input_x = sess.graph.get_tensor_by_name('import/input_1:0')
# 输出(此处get_tensor_by_name方法中的参数为model的最后一层op的name)
output = sess.graph.get_tensor_by_name('import/crf_1/one_hot:0')
# 预测结果,input_data为向量化后的预测输入,注意输入shape要与模型保持一致
ret = sess.run(output, {input_x: input_data})
(二)使用tf2.x框架加载由h5转成的pb文件
加载及预测的主要代码如下
with tf.compat.v1.Session() as sess:
with tf.io.gfile.GFile(pb_file_path, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.compat.v1.import_graph_def(graph_def)
# print all operation names
# for op in sess.graph.get_operations():
# print(op.name)
# 输入
input_x = sess.graph.get_tensor_by_name('import/input_1:0')
# 输出
output = sess.graph.get_tensor_by_name('import/model_1/crf_1/one_hot:0')
# 预测结果
ret = sess.run(output, {input_x: input_data})
整体逻辑与tf1.x版本一致,只是把tf1.x的代码改为tf2.x的写法。
最后我们将keras训练好的命名实体识别模型h5文件转成pb文件,通过tensorflow框架加载pb格式的模型文件,成功识别了新句子中的命名实体,结果如下图:
最后
以上就是兴奋黑裤为你收集整理的keras冻结_Keras训练的h5文件转pb文件并用Tensorflow加载的全部内容,希望文章能够帮你解决keras冻结_Keras训练的h5文件转pb文件并用Tensorflow加载所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复