我是靠谱客的博主 繁荣黑猫,最近开发中收集的这篇文章主要介绍slim 读取并使用预训练模型 inception_v3 迁移学习,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

 

转自:https://blog.csdn.net/amanfromearth/article/details/79155926#commentBox

在使用Tensorflow做读取并finetune的时候,发现在读取官方给的inception_v3预训练模型总是出现各种错误,现记录其正确的读取方式和各种错误做法: 

关键代码如下:

import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets import inception_v3
from research.slim.preprocessing import inception_preprocessing
Pretrained_model_dir = '/Users/apple/tensorflow_model/models-master/research/slim/pre_train/inception_v3.ckpt'
image_size = 299
# 读取网络
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
imgPath = 'test.jpg'
testImage_string = tf.gfile.FastGFile(imgPath, 'rb').read()
testImage = tf.image.decode_jpeg(testImage_string, channels=3)
processed_image = inception_preprocessing.preprocess_image(testImage, image_size, image_size, is_training=False)
processed_images = tf.expand_dims(processed_image, 0)
logits, end_points = inception_v3.inception_v3(processed_images, num_classes=128, is_training=False)
w1 = tf.Variable(tf.truncated_normal([128, 5], stddev=tf.sqrt(1/128)))
b1 = tf.Variable(tf.zeros([5]))
logits = tf.nn.leaky_relu(tf.matmul(logits, w1) + b1)
with tf.Session() as sess:
# 先初始化所有变量,避免有些变量未读取而产生错误
init = tf.global_variables_initializer()
sess.run(init)
#加载预训练模型
print('Loading model check point from {:s}'.format(Pretrained_model_dir))
#这里的exclusions是不需要读取预训练模型中的Logits,因为默认的类别数目是1000,当你的类别数目不是1000的时候,如果还要读取的话,就会报错
exclusions = ['InceptionV3/Logits',
'InceptionV3/AuxLogits']
#创建一个列表,包含除了exclusions之外所有需要读取的变量
inception_except_logits = slim.get_variables_to_restore(exclude=exclusions)
#建立一个从预训练模型checkpoint中读取上述列表中的相应变量的参数的函数
init_fn = slim.assign_from_checkpoint_fn(Pretrained_model_dir, inception_except_logits,ignore_missing_vars=True)
#运行该函数
init_fn(sess)
print('Loaded.')
out = sess.run(logits)
print(out.shape)
print(out)

 

其中可能会出现的错误如下: 
错误1

 
  • 1
  • 2
  • 3

原因: 
预训练模型中的类别数class_num=1000,这里输入的class_num=5,当读取完整模型的时候当然会出错。 
解决方案: 
选择不读取包含类别数的Logits层和AuxLogits层:

 
  • 1
  • 2

错误2 
Tensor name “xxxx” not found in checkpoint files 

 
  • 1
  • 2
  • 3
  • 4

这里的Tensor name可以是所有inception_v3中变量的名字,出现这种情况的各种原因和解决方案是: 
1.创建图的时候没有用arg_scope,是这样创建的:

 
  • 1

解决方案: 
在这里加上arg_scope,里面调用的是库中自带的inception_v3_arg_scope

 
  • 1
  • 2

2.在读取checkpoint的时候未初始化所有变量,即未运行

 
  • 1
  • 2

这样会导致有一些checkpoint中不存在的变量未被初始化,比如使用Momentum时的每一层的Momentum参数等。

3.使用slim.assign_from_checkpoint_fn()函数时,没有添加ignore_missing_vars=True属性,由于默认ignore_missing_vars=False,所以,当使用非SGD的optimizer的时候(如Momentum、RMSProp等)时,会提示Momentum或者RMSProp的参数在checkpoint中无法找到,如: 
使用Momentum时:

 
  • 1
  • 2
  • 3
  • 4

使用RMSProp时:

 
  • 1
  • 2
  • 3
  • 4

解决方法很简单,就是把ignore_missing_vars=True

 
  • 1

注意:一定要在之前的步骤都完成之后才能设成True,不然如果变量名称全部出错的话,会忽视掉checkpoint中所有的变量,从而不读取任何参数。

以上就是我碰见的问题,希望有所帮助。

最后

以上就是繁荣黑猫为你收集整理的slim 读取并使用预训练模型 inception_v3 迁移学习的全部内容,希望文章能够帮你解决slim 读取并使用预训练模型 inception_v3 迁移学习所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(48)

评论列表共有 0 条评论

立即
投稿
返回
顶部