我是靠谱客的博主 重要路灯,最近开发中收集的这篇文章主要介绍将tensorflow保存的ckpt文件转换成冻结的pb文件,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

准备材料保存好的ckpt模型。

import tensorflow as tf
from model import mobilenetv2
import numpy as np
import scipy.io as si

checkpoint_dir ='./old_models/scene/'
data = si.loadmat('./datasets/test_32x32.mat')['X'].transpose((0,2,3,1))
label = si.loadmat('./datasets/test_32x32.mat')['y'][0]
img = np.expand_dims(data[0],axis=0)
targ = label[0]

#重新定义模型图
tf.reset_default_graph()

x = tf.placeholder(tf.float32,shape=[1,32,32,3],name='input')#重新定义输入
out = mobilenetv2(x,num_classes=11,is_train=False,reuse = False)#网络输出

ckpt=tf.train.get_checkpoint_state(checkpoint_dir)

aver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess,ckpt.model_checkpoint_path)
    out1 = sess.run(out,feed_dict={x:img})
    #将变量变为常量
    output_grap_def = tf.graph_util.convert_variables_to_constants(sess,
                                                                   sess.graph_def,
                                                                   output_node_names=['mobilenetv2/Flatten/flatten/Reshape'])
    #将模型以二进制的方式保存为pb文件
    with tf.gfile.FastGFile('./models/scene_new.pb',mode='wb') as f1:
        f1.write(output_grap_def.SerializeToString())

上面的代码对图像的输入节点进行了修改,而且我们是通过自己训练用的模型文件恢复图的结构的。事实上我们也可以通过ckpt文件的直接恢复图,这与我之前博客中提到模型调用的方法一致的。

import tensorflow as tf

output_node_names = 'Mean'

saver = tf.train.import_meta_graph('./models/model3_2/model.ckpt.meta')
grah = tf.get_default_graph()
input_graph_def = grah.as_graph_def()

with tf.Session() as sess:
        saver.restore(sess,'./models/model3_2/model.ckpt' ) #恢复图并得到数据
        output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,# 等于:sess.graph_def
            output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
        with tf.gfile.GFile('./models/freeze.pb', "wb") as f: #保存模型
            f.write(output_graph_def.SerializeToString()) #序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node)) 

之后我们生成的网络可以通过python调用也可以直接部署到安卓端

最后

以上就是重要路灯为你收集整理的将tensorflow保存的ckpt文件转换成冻结的pb文件的全部内容,希望文章能够帮你解决将tensorflow保存的ckpt文件转换成冻结的pb文件所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(56)

评论列表共有 0 条评论

立即
投稿
返回
顶部