概述
最近做比赛遇到一个问题,tf1训练的模型提交后因为环境问题导致线上无法运行,故尝试线上用tf2进行推理。
步骤需要1.将tf1的.pd模型结构和ckpt存档导出。2.将模型结构和ckpt存档转化为.pd静态图(frozen graph)。3.使用tf2读取.pd静态图(frozen graph)进行推理
首先定义基于tf1的模型
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
with tf.Session() as sess:
a = tf.placeholder(tf.float32, [1])
b = tf.placeholder(tf.float32, [1])
c = tf.get_variable("w", [1])
d = a*c
out = tf.add(d, b)
# 初始化变量
sess.run(tf.variables_initializer(tf.global_variables()))
# 保存图结构
tf.train.write_graph(sess.graph_def, './', 'graph_define.pb', as_text=True)
# 保存参数存档
saver = tf.train.Saver()
saver.save(sess, 'checkpoint.ckpt')
之后转化模型为静态图(frozen graph)
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
# 转化模型
with tf.compat.v1.Session() as sess:
freeze_graph.freeze_graph(
input_graph='./graph_define.pb',
input_saver='',
input_binary=False,
input_checkpoint='./checkpoint.ckpt',
output_node_names='Add',
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
output_graph='./frozen_model.pb',
clear_devices=False,
initializer_nodes=''
)
其中input_graph是图结构文件,input_checkpoint是参数文件,output_node_names是输出节点名
需要注意的是模型训练时必须是基于原生tf1的,如果训练时使用tf.compat.v1而不引入tf.disable_v2_behavior(),会在转换时报错。
最后使用tf2加载静态图进行推理
import tensorflow as tf
def wrap_frozen_graph(graph_def, inputs, outputs):
def _imports_graph_def():
tf.compat.v1.import_graph_def(graph_def, name="")
wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
import_graph = wrapped_import.graph
return wrapped_import.prune(
tf.nest.map_structure(import_graph.as_graph_element, inputs),
tf.nest.map_structure(import_graph.as_graph_element, outputs))
def load_pb(filename):
# Load frozen graph using TensorFlow 1.x functions
with tf.io.gfile.GFile(filename, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
loaded = graph_def.ParseFromString(f.read())
# Wrap frozen graph to ConcreteFunctions
frozen_func = wrap_frozen_graph(
graph_def=graph_def,
inputs=['Placeholder:0', 'Placeholder_1:0'], # input tensor name of your model
outputs="Add:0" # output tensor name of your model
)
return frozen_func
model = load_pb('frozen_model.pb')
print(model(tf.constant(3, tf.float32), tf.constant(4, tf.float32)))
实测比赛代码也可以在线上运行了
参考资料:
TF 保存模型为 .pb格式 - 静悟生慧 - 博客园
tensorflow的三种保存格式总结-1(.ckpt) - 知乎
[深度学习] TensorFlow中模型的freeze_graph - 知乎
如何在TF2中使用TF1.x的.pb模型_MD笔记-CSDN博客
最后
以上就是危机大象为你收集整理的将tf1训练的模型导入tf2进行推理的全部内容,希望文章能够帮你解决将tf1训练的模型导入tf2进行推理所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复