概述
准备材料保存好的ckpt模型。
import tensorflow as tf
from model import mobilenetv2
import numpy as np
import scipy.io as si
checkpoint_dir ='./old_models/scene/'
data = si.loadmat('./datasets/test_32x32.mat')['X'].transpose((0,2,3,1))
label = si.loadmat('./datasets/test_32x32.mat')['y'][0]
img = np.expand_dims(data[0],axis=0)
targ = label[0]
#重新定义模型图
tf.reset_default_graph()
x = tf.placeholder(tf.float32,shape=[1,32,32,3],name='input')#重新定义输入
out = mobilenetv2(x,num_classes=11,is_train=False,reuse = False)#网络输出
ckpt=tf.train.get_checkpoint_state(checkpoint_dir)
aver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess,ckpt.model_checkpoint_path)
out1 = sess.run(out,feed_dict={x:img})
#将变量变为常量
output_grap_def = tf.graph_util.convert_variables_to_constants(sess,
sess.graph_def,
output_node_names=['mobilenetv2/Flatten/flatten/Reshape'])
#将模型以二进制的方式保存为pb文件
with tf.gfile.FastGFile('./models/scene_new.pb',mode='wb') as f1:
f1.write(output_grap_def.SerializeToString())
上面的代码对图像的输入节点进行了修改,而且我们是通过自己训练用的模型文件恢复图的结构的。事实上我们也可以通过ckpt文件的直接恢复图,这与我之前博客中提到模型调用的方法一致的。
import tensorflow as tf
output_node_names = 'Mean'
saver = tf.train.import_meta_graph('./models/model3_2/model.ckpt.meta')
grah = tf.get_default_graph()
input_graph_def = grah.as_graph_def()
with tf.Session() as sess:
saver.restore(sess,'./models/model3_2/model.ckpt' ) #恢复图并得到数据
output_graph_def = tf.graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=input_graph_def,# 等于:sess.graph_def
output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
with tf.gfile.GFile('./models/freeze.pb', "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
print("%d ops in the final graph." % len(output_graph_def.node))
之后我们生成的网络可以通过python调用也可以直接部署到安卓端
最后
以上就是重要路灯为你收集整理的将tensorflow保存的ckpt文件转换成冻结的pb文件的全部内容,希望文章能够帮你解决将tensorflow保存的ckpt文件转换成冻结的pb文件所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复