概述
转自:https://blog.csdn.net/amanfromearth/article/details/79155926#commentBox
在使用Tensorflow做读取并finetune的时候,发现在读取官方给的inception_v3预训练模型总是出现各种错误,现记录其正确的读取方式和各种错误做法:
关键代码如下:
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets import inception_v3
from research.slim.preprocessing import inception_preprocessing
Pretrained_model_dir = '/Users/apple/tensorflow_model/models-master/research/slim/pre_train/inception_v3.ckpt'
image_size = 299
# 读取网络
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
imgPath = 'test.jpg'
testImage_string = tf.gfile.FastGFile(imgPath, 'rb').read()
testImage = tf.image.decode_jpeg(testImage_string, channels=3)
processed_image = inception_preprocessing.preprocess_image(testImage, image_size, image_size, is_training=False)
processed_images = tf.expand_dims(processed_image, 0)
logits, end_points = inception_v3.inception_v3(processed_images, num_classes=128, is_training=False)
w1 = tf.Variable(tf.truncated_normal([128, 5], stddev=tf.sqrt(1/128)))
b1 = tf.Variable(tf.zeros([5]))
logits = tf.nn.leaky_relu(tf.matmul(logits, w1) + b1)
with tf.Session() as sess:
# 先初始化所有变量,避免有些变量未读取而产生错误
init = tf.global_variables_initializer()
sess.run(init)
#加载预训练模型
print('Loading model check point from {:s}'.format(Pretrained_model_dir))
#这里的exclusions是不需要读取预训练模型中的Logits,因为默认的类别数目是1000,当你的类别数目不是1000的时候,如果还要读取的话,就会报错
exclusions = ['InceptionV3/Logits',
'InceptionV3/AuxLogits']
#创建一个列表,包含除了exclusions之外所有需要读取的变量
inception_except_logits = slim.get_variables_to_restore(exclude=exclusions)
#建立一个从预训练模型checkpoint中读取上述列表中的相应变量的参数的函数
init_fn = slim.assign_from_checkpoint_fn(Pretrained_model_dir, inception_except_logits,ignore_missing_vars=True)
#运行该函数
init_fn(sess)
print('Loaded.')
out = sess.run(logits)
print(out.shape)
print(out)
其中可能会出现的错误如下:
错误1
- 1
- 2
- 3
原因:
预训练模型中的类别数class_num=1000,这里输入的class_num=5,当读取完整模型的时候当然会出错。
解决方案:
选择不读取包含类别数的Logits层和AuxLogits层:
- 1
- 2
错误2
Tensor name “xxxx” not found in checkpoint files
- 1
- 2
- 3
- 4
这里的Tensor name可以是所有inception_v3中变量的名字,出现这种情况的各种原因和解决方案是:
1.创建图的时候没有用arg_scope,是这样创建的:
- 1
解决方案:
在这里加上arg_scope,里面调用的是库中自带的inception_v3_arg_scope
- 1
- 2
2.在读取checkpoint的时候未初始化所有变量,即未运行
- 1
- 2
这样会导致有一些checkpoint中不存在的变量未被初始化,比如使用Momentum时的每一层的Momentum参数等。
3.使用slim.assign_from_checkpoint_fn()
函数时,没有添加ignore_missing_vars=True
属性,由于默认ignore_missing_vars=False,所以,当使用非SGD的optimizer的时候(如Momentum、RMSProp等)时,会提示Momentum或者RMSProp的参数在checkpoint中无法找到,如:
使用Momentum时:
- 1
- 2
- 3
- 4
使用RMSProp时:
- 1
- 2
- 3
- 4
解决方法很简单,就是把ignore_missing_vars=True
- 1
注意:一定要在之前的步骤都完成之后才能设成True,不然如果变量名称全部出错的话,会忽视掉checkpoint中所有的变量,从而不读取任何参数。
以上就是我碰见的问题,希望有所帮助。
最后
以上就是繁荣黑猫为你收集整理的slim 读取并使用预训练模型 inception_v3 迁移学习的全部内容,希望文章能够帮你解决slim 读取并使用预训练模型 inception_v3 迁移学习所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复