我是靠谱客的博主 单薄水池,最近开发中收集的这篇文章主要介绍tf2 freeze冻结为pb(支持单输入和多输入),觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

tensorflow2.x 训练后的模型文件为 .data-00000-of-00001 和 .index,然后一般转化都是转化为pb模型和variables和assets文件,如下图所示
在这里插入图片描述

但是可能在某种场景,需要我们使用的是freeze graph,也就是以上三个文件冻结为一个pb文件。看到有人说训练时要转成.h5文件,这个方法我没试过,但是我的模型已经训完了,再训练太麻烦了。探索出的冻结方法如下。

单输入

首先通过 saved_model_cli 查看图中pb文件信息在这里插入图片描述
根据输出信息可得,是单输入的模型,输入的shape为[batch, height, width, channel]

因此,可以通过以下代码进行freeze

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import pdb

# 只有一个input的写法
try:
    new_model = tf.keras.models.load_model('det')

except Exception as e:
    print(e)
    print("load model fail")
    exit()

full_model = tf.function(lambda x: new_model(x))
# 下面写成None是因为batch不固定
full_model = full_model.get_concrete_function(x=tf.TensorSpec(shape=(None, 640, 640, 3), dtype=tf.float32, name='img'))

forzen_func = convert_variables_to_constants_v2(full_model)
forzen_func.graph.as_graph_def()

layers = [op.name for op in forzen_func.graph.get_operations()]
print("-"*50)
print("Frozen model layers:")
for layer in layers:
    print(layer)

print("*"*50)
print("Frozen model input:")
print(forzen_func.inputs)
print("Frozen model output:")
print(forzen_func.outputs)

tf.io.write_graph(
    graph_or_graph_def=forzen_func.graph,
    logdir="./",
    name="det.pb",
    as_text=False
)

多输入

然而,有时候模型比较复杂,需要两个输入或者多个输入,那么上面的代码就需要改动,还是通过 saved_model_cli 查看
在这里插入图片描述
根据输出信息可得,是双输入的模型,-1是不固定的意思,转换代码如下:

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import pdb

try:
    new_model = tf.saved_model.load('rec')
    infer = new_model.signatures["serving_default"]

except Exception as e:
    print(e)
    print("load model fail")
    exit()

x_tensor_spec = tf.TensorSpec(shape=[None, 32, None, 1], dtype=tf.float32)
y_tensor_spec = tf.TensorSpec(shape=[None, None], dtype=tf.int32)    

# inputs和mask可以通过saved_model_cli得到
full_model = tf.function(infer).get_concrete_function(inputs=x_tensor_spec, mask=y_tensor_spec)

# 下面都一样
forzen_func = convert_variables_to_constants_v2(full_model)
forzen_func.graph.as_graph_def()

layers = [op.name for op in forzen_func.graph.get_operations()]
print("-"*50)
print("Frozen model layers:")
for layer in layers:
    print(layer)

print("*"*50)
print("Frozen model input:")
print(forzen_func.inputs)
print("Frozen model output:")
print(forzen_func.outputs)

tf.io.write_graph(
    graph_or_graph_def=forzen_func.graph,
    logdir="./",
    name="rec.pb",
    as_text=False
)

参考:

https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/

不过参考的上面链接只有单input,多input的还是和作者说的不同,需要用我的代码转换。

我也在作者github的issues下作了回答

https://github.com/leimao/Frozen-Graph-TensorFlow/issues/5

最后

以上就是单薄水池为你收集整理的tf2 freeze冻结为pb(支持单输入和多输入)的全部内容,希望文章能够帮你解决tf2 freeze冻结为pb(支持单输入和多输入)所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(50)

评论列表共有 0 条评论

立即
投稿
返回
顶部