我是靠谱客的博主 危机大象,最近开发中收集的这篇文章主要介绍将tf1训练的模型导入tf2进行推理,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

最近做比赛遇到一个问题,tf1训练的模型提交后因为环境问题导致线上无法运行,故尝试线上用tf2进行推理。

步骤需要1.将tf1的.pd模型结构和ckpt存档导出。2.将模型结构和ckpt存档转化为.pd静态图(frozen graph)。3.使用tf2读取.pd静态图(frozen graph)进行推理

首先定义基于tf1的模型

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()


with tf.Session() as sess:
    a = tf.placeholder(tf.float32, [1])
    b = tf.placeholder(tf.float32, [1])
    c = tf.get_variable("w", [1])
    d = a*c
    out = tf.add(d, b)
    # 初始化变量
    sess.run(tf.variables_initializer(tf.global_variables()))
    # 保存图结构
    tf.train.write_graph(sess.graph_def, './', 'graph_define.pb', as_text=True)
    # 保存参数存档
    saver = tf.train.Saver()
    saver.save(sess, 'checkpoint.ckpt')

之后转化模型为静态图(frozen graph)

import tensorflow as tf
from tensorflow.python.tools import freeze_graph

# 转化模型
with tf.compat.v1.Session() as sess:
    freeze_graph.freeze_graph(
        input_graph='./graph_define.pb',
        input_saver='',
        input_binary=False,
        input_checkpoint='./checkpoint.ckpt',
        output_node_names='Add',
        restore_op_name='save/restore_all',
        filename_tensor_name='save/Const:0',
        output_graph='./frozen_model.pb',
        clear_devices=False,
        initializer_nodes=''
    )

其中input_graph是图结构文件,input_checkpoint是参数文件,output_node_names是输出节点名

需要注意的是模型训练时必须是基于原生tf1的,如果训练时使用tf.compat.v1而不引入tf.disable_v2_behavior(),会在转换时报错。

最后使用tf2加载静态图进行推理

import tensorflow as tf


def wrap_frozen_graph(graph_def, inputs, outputs):
    def _imports_graph_def():
        tf.compat.v1.import_graph_def(graph_def, name="")

    wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
    import_graph = wrapped_import.graph

    return wrapped_import.prune(
        tf.nest.map_structure(import_graph.as_graph_element, inputs),
        tf.nest.map_structure(import_graph.as_graph_element, outputs))


def load_pb(filename):
    # Load frozen graph using TensorFlow 1.x functions
    with tf.io.gfile.GFile(filename, "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        loaded = graph_def.ParseFromString(f.read())

    # Wrap frozen graph to ConcreteFunctions
    frozen_func = wrap_frozen_graph(
        graph_def=graph_def,
        inputs=['Placeholder:0', 'Placeholder_1:0'],  # input tensor name of your model
        outputs="Add:0"  # output tensor name of your model
    )
    return frozen_func


model = load_pb('frozen_model.pb')

print(model(tf.constant(3, tf.float32), tf.constant(4, tf.float32)))

实测比赛代码也可以在线上运行了

参考资料:

TF 保存模型为 .pb格式 - 静悟生慧 - 博客园

tensorflow的三种保存格式总结-1(.ckpt) - 知乎

[深度学习] TensorFlow中模型的freeze_graph - 知乎

如何在TF2中使用TF1.x的.pb模型_MD笔记-CSDN博客

最后

以上就是危机大象为你收集整理的将tf1训练的模型导入tf2进行推理的全部内容,希望文章能够帮你解决将tf1训练的模型导入tf2进行推理所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(41)

评论列表共有 0 条评论

立即
投稿
返回
顶部