概述
存在两个pb模型的输入相同,现想把两个pb模型合并成一个模型并共用一个输入节点,借鉴合并多个tensorflow模型的办法
代码如下:
import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
IMAGE_SIZE = 48 #假设模型输入是48x48
origin_model = "1.pb" #第一个模型
new_model = "2.pb" #第二个模型
def load_graphdef(filename):
with tf.gfile.GFile(filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
def load_graph(graph_def, prefix):
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name=prefix)
return graph
graph1 = load_graphdef(origin_model)
graph2 = load_graphdef(new_model)
x = tf.placeholder(dtype=tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3], name='input')
graph1_out, = tf.import_graph_def(graph1, input_map={"input:0":x}, return_elements=['output:0'], name="model_origin") #其中“input:0”为模型输入节点名, return_elements返回的是输出节点名
graph2_out, = tf.import_graph_def(graph2, input_map={"input:0":x}, return_elements=['output:0'], name="model_new3class") #同上
z = tf.concat([graph1_out[:,:-3], graph2_out], 1) #假设最后三类用新的模型重新训练进行拼接
tf.identity(z, "output")
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
graph = convert_variables_to_constants(sess, sess.graph_def, ["output"])
tf.train.write_graph(graph, '.', 'merge.pb', as_text=False)
通过以上代码便可以将两个同样输入的pb模型合并
参考:https://www.cnblogs.com/th3Bear/p/11438310.html
最后
以上就是自信背包为你收集整理的Tensorflow 合并多个pb模型并共用一个输入的方法的全部内容,希望文章能够帮你解决Tensorflow 合并多个pb模型并共用一个输入的方法所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复