概述
tensorflow2.x 训练后的模型文件为 .data-00000-of-00001 和 .index,然后一般转化都是转化为pb模型和variables和assets文件,如下图所示
但是可能在某种场景,需要我们使用的是freeze graph,也就是以上三个文件冻结为一个pb文件。看到有人说训练时要转成.h5文件,这个方法我没试过,但是我的模型已经训完了,再训练太麻烦了。探索出的冻结方法如下。
单输入
首先通过 saved_model_cli 查看图中pb文件信息
根据输出信息可得,是单输入的模型,输入的shape为[batch, height, width, channel]
因此,可以通过以下代码进行freeze
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import pdb
# 只有一个input的写法
try:
new_model = tf.keras.models.load_model('det')
except Exception as e:
print(e)
print("load model fail")
exit()
full_model = tf.function(lambda x: new_model(x))
# 下面写成None是因为batch不固定
full_model = full_model.get_concrete_function(x=tf.TensorSpec(shape=(None, 640, 640, 3), dtype=tf.float32, name='img'))
forzen_func = convert_variables_to_constants_v2(full_model)
forzen_func.graph.as_graph_def()
layers = [op.name for op in forzen_func.graph.get_operations()]
print("-"*50)
print("Frozen model layers:")
for layer in layers:
print(layer)
print("*"*50)
print("Frozen model input:")
print(forzen_func.inputs)
print("Frozen model output:")
print(forzen_func.outputs)
tf.io.write_graph(
graph_or_graph_def=forzen_func.graph,
logdir="./",
name="det.pb",
as_text=False
)
多输入
然而,有时候模型比较复杂,需要两个输入或者多个输入,那么上面的代码就需要改动,还是通过 saved_model_cli 查看
根据输出信息可得,是双输入的模型,-1是不固定的意思,转换代码如下:
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import pdb
try:
new_model = tf.saved_model.load('rec')
infer = new_model.signatures["serving_default"]
except Exception as e:
print(e)
print("load model fail")
exit()
x_tensor_spec = tf.TensorSpec(shape=[None, 32, None, 1], dtype=tf.float32)
y_tensor_spec = tf.TensorSpec(shape=[None, None], dtype=tf.int32)
# inputs和mask可以通过saved_model_cli得到
full_model = tf.function(infer).get_concrete_function(inputs=x_tensor_spec, mask=y_tensor_spec)
# 下面都一样
forzen_func = convert_variables_to_constants_v2(full_model)
forzen_func.graph.as_graph_def()
layers = [op.name for op in forzen_func.graph.get_operations()]
print("-"*50)
print("Frozen model layers:")
for layer in layers:
print(layer)
print("*"*50)
print("Frozen model input:")
print(forzen_func.inputs)
print("Frozen model output:")
print(forzen_func.outputs)
tf.io.write_graph(
graph_or_graph_def=forzen_func.graph,
logdir="./",
name="rec.pb",
as_text=False
)
参考:
https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/
不过参考的上面链接只有单input,多input的还是和作者说的不同,需要用我的代码转换。
我也在作者github的issues下作了回答
https://github.com/leimao/Frozen-Graph-TensorFlow/issues/5
最后
以上就是单薄水池为你收集整理的tf2 freeze冻结为pb(支持单输入和多输入)的全部内容,希望文章能够帮你解决tf2 freeze冻结为pb(支持单输入和多输入)所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复