我是靠谱客的博主 自信背包,最近开发中收集的这篇文章主要介绍Tensorflow 合并多个pb模型并共用一个输入的方法,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

存在两个pb模型的输入相同,现想把两个pb模型合并成一个模型并共用一个输入节点,借鉴合并多个tensorflow模型的办法

代码如下:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants


IMAGE_SIZE = 48  #假设模型输入是48x48
origin_model = "1.pb" #第一个模型
new_model = "2.pb"    #第二个模型


def load_graphdef(filename):
    with tf.gfile.GFile(filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    return graph_def


def load_graph(graph_def, prefix):
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name=prefix)

    return graph


graph1 = load_graphdef(origin_model)
graph2 = load_graphdef(new_model)
x = tf.placeholder(dtype=tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3], name='input')
graph1_out, = tf.import_graph_def(graph1, input_map={"input:0":x}, return_elements=['output:0'], name="model_origin")  #其中“input:0”为模型输入节点名, return_elements返回的是输出节点名
graph2_out, = tf.import_graph_def(graph2, input_map={"input:0":x}, return_elements=['output:0'], name="model_new3class")   #同上

z = tf.concat([graph1_out[:,:-3], graph2_out], 1)  #假设最后三类用新的模型重新训练进行拼接
tf.identity(z, "output")

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    graph = convert_variables_to_constants(sess, sess.graph_def, ["output"])
    tf.train.write_graph(graph, '.', 'merge.pb', as_text=False)

通过以上代码便可以将两个同样输入的pb模型合并

参考:https://www.cnblogs.com/th3Bear/p/11438310.html

最后

以上就是自信背包为你收集整理的Tensorflow 合并多个pb模型并共用一个输入的方法的全部内容,希望文章能够帮你解决Tensorflow 合并多个pb模型并共用一个输入的方法所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(65)

评论列表共有 0 条评论

立即
投稿
返回
顶部