概述
本文采用tensorflow的slim库进行迁移学习,网站为:github-slim
参考:TensorFlow下slim库函数的使用以及使用VGG网络进行预训练、迁移学习(附代码)
源代码涉及了多个.py文件,对于初学者来说不便于阅读,对于不同的训练对象要修改的参数遍布较多,不太方便,因此这里将整个迁移学习分为三个.py,其中creat_tfrecord.py用于将样本转化为tensorflow的tfrecord格式;input_data.py用于读取生成的tfrecord格式数据并以队列的形式提供样本;finetune_mydata.py是主要的demo,其中调用有上述两个.py文件,要修改的一些参数都已经放在py文件的前端。
迁移学习主代码
根据自己的数据库要修改的参数已经放在了代码的最前端。
from nets import vgg
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import input_data
import os
import creat_tfrecord
slim = tf.contrib.slim
DATA_DIR = './datasets/data/Sample'#数据集的路径
NUM_CLASSES = 2
#输出类别
NUM_TRAIN = 3600 #训练集的总数
NUM_VAL = 1200
#验证集的总数
IMAGE_SIZE = vgg.vgg_16.default_image_size #获取图片大小
checkpoint_file = './model/vgg_16.ckpt' #官方下载的检查点文件路径
save_dir = './result/vgg16/fine_tune' #训练后模型保存路径
trained_file = 'ships_fine_tune.ckpt'
#训练后模型名称
log_dir = './logs/train'
#训练日志路径
#设置训练参数
batch_size
= 64
learning_rate = 0.0001
training_epochs = 10
#迭代轮数
display_epoch = 1
#每回合显示一次
train_num_batch = int(np.ceil(NUM_TRAIN / batch_size)) #batch的数目
val_num_batch = int(np.ceil(NUM_VAL / batch_size)) #batch的数目
def Ships_fine_tuning():
'''
演示一个VGG16的例子
微调 这里只调整VGG16最后一层全连接层,把1000类改为5类
对网络进行训练
'''
'''
1.设置参数,并加载数据
'''
if not tf.gfile.Exists(save_dir):
tf.gfile.MakeDirs(save_dir)
#调用creat_tfrecord.py对数据划分为训练集和验证集,分别生成TF格式数据
creat_tfrecord.run(DATA_DIR, NUM_VAL)
#生成batch
train_images, train_labels = input_data.get_batch_images_and_label(DATA_DIR, batch_size, NUM_CLASSES,
True, IMAGE_SIZE, IMAGE_SIZE)
test_images, test_labels = input_data.get_batch_images_and_label(DATA_DIR, batch_size, NUM_CLASSES,
False, IMAGE_SIZE, IMAGE_SIZE)
#获取模型参数的命名空间
arg_scope = vgg.vgg_arg_scope()
#arg_scope = resnet_v1.resnet_arg_scope()
#创建网络
with
slim.arg_scope(arg_scope):
'''
2.定义占位符和网络结构
'''
#输入图片
input_images = tf.placeholder(dtype=tf.float32,shape = [None,IMAGE_SIZE,IMAGE_SIZE,3])
#图片标签
input_labels = tf.placeholder(dtype=tf.float32,shape = [None,NUM_CLASSES])
#训练还是测试?测试的时候弃权参数会设置为1.0
is_training = tf.placeholder(dtype = tf.bool)
#创建vgg16网络
如果想冻结所有层,可以指定slim.conv2d中的 trainable=False
logits,end_points =
vgg.vgg_16(input_images, is_training=is_training,num_classes = NUM_CLASSES)
#print(end_points)
#每个元素都是以vgg_16/xx命名
# Restore only the convolutional layers: 从检查点载入当前图除了fc8层之外所有变量的参数
params = slim.get_variables_to_restore(exclude=['vgg_16/fc8'])
#用于恢复模型
如果使用这个保存或者恢复的话,只会保存或者恢复指定的变量
restorer = tf.train.Saver(params)
'''
#从当前图中搜索指定scope的变量,然后从检查点文件中恢复这些变量(即vgg_16网络中定义的部分变量)
#如果指定了恢复检查点文件中不存在的变量,则会报错 如果不知道检查点文件有哪些变量,我们可以打印检查点文件查看变量名
params = []
conv1 = slim.get_variables(scope="vgg_16/conv1")
params.extend(conv1)
conv2 = slim.get_variables(scope="vgg_16/conv2")
params.extend(conv2)
conv3 = slim.get_variables(scope="vgg_16/conv3")
params.extend(conv3)
conv4 = slim.get_variables(scope="vgg_16/conv4")
params.extend(conv4)
conv5 = slim.get_variables(scope="vgg_16/conv5")
params.extend(conv5)
fc6 = slim.get_variables(scope="vgg_16/fc6")
params.extend(fc6)
fc7 = slim.get_variables(scope="vgg_16/fc7")
params.extend(fc7)
'''
'''
3 定义代价函数和优化器
'''
#代价函数
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=input_labels,logits=logits))
loss_summary = tf.summary.scalar('loss',cost)
#设置优化器
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
#求准确率
pred = tf.argmax(logits,axis=1)
#预测标签
correct = tf.equal(pred,tf.argmax(input_labels,1))
#返回一个数组 表示统计预测正确或者错误
accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))
acc_summary = tf.summary.scalar('accuracy',accuracy)
#用于保存检查点文件
save = tf.train.Saver(max_to_keep=training_epochs)
#恢复模型
with tf.Session() as sess:
merged = tf.summary.merge([loss_summary, acc_summary])
#合并
train_writer = tf.summary.FileWriter(log_dir,sess.graph) #将训练日志写入到logs文件夹下
sess.run(tf.global_variables_initializer())
#检查最近的检查点文件
ckpt = tf.train.latest_checkpoint(save_dir)
if ckpt != None:
save.restore(sess,ckpt)
print('从上次训练保存后的模型继续训练!')
else:
restorer.restore(sess, checkpoint_file)
print('从官方模型加载训练!')
coord = tf.train.Coordinator()
#创建一个协调器,管理线程
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
#启动QueueRunner, 此时文件名才开始进队。
'''
#4 查看预处理之后的图片
imgs, labs = sess.run([train_images, train_labels])
print('原始训练图片信息:',imgs.shape,labs.shape)
show_img = np.array(imgs[0],dtype=np.uint8)
plt.imshow(show_img)
plt.title('Original train image')
plt.show()
imgs, labs = sess.run([test_images, test_labels])
print('原始测试图片信息:',imgs.shape,labs.shape)
show_img = np.array(imgs[0],dtype=np.uint8)
plt.imshow(show_img)
plt.title('Original test image')
plt.show()
'''
print('开始训练!')
for epoch in range(training_epochs):
train_loss = 0.0
for i in range(train_num_batch):
imgs, labs, = sess.run([train_images, train_labels])
_, loss, train_summaries = sess.run([optimizer, cost, merged],feed_dict={input_images:imgs,input_labels:labs,is_training:True})
train_writer.add_summary(train_summaries, (i+1)+epoch*batch_size)
#将每一个batch的训练结果保存至日志文件
train_loss += loss
#打印信息
if epoch % display_epoch == 0:
train_accuracy = sess.run(accuracy,feed_dict={input_images:imgs,input_labels:labs,is_training:False})
print('Epoch {}/{}
average cost {:.9f}
train accuracy {:.2f}'.format(epoch+1, training_epochs, train_loss/train_num_batch, train_accuracy))
#进行测试
val_accuracy = 0.0
val_loss = 0.0
for j in range(val_num_batch):
imgs, labs = sess.run([test_images, test_labels])
cost_values,accuracy_values = sess.run([cost, accuracy],feed_dict = {input_images:imgs,input_labels:labs,is_training:False})
val_accuracy
+= accuracy_values
val_loss
+= cost_values
print('Epoch {}/{}
Test cost {:.9f} Test accuracy {:.2f}'.format(epoch+1,training_epochs,val_loss/val_num_batch, val_accuracy/val_num_batch))
#保存模型
save.save(sess,os.path.join(save_dir,trained_file),global_step = epoch)
print('Epoch {}/{}
模型保存成功'.format(epoch+1,training_epochs))
print('训练完成')
#终止线程
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
tf.reset_default_graph()
#移除存在的图
Ships_fine_tuning()
读取样本生成tfrecord文件
当你的样本数量过大时,要调高_NUM_SHARDS
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os
import random
import sys
import pdb
import tensorflow as tf
from datasets import dataset_utils
# Seed for repeatability.
_RANDOM_SEED = 0
# The number of shards per dataset split.
# 如果你的样本数量过大,要调高该参数
_NUM_SHARDS = 2
class ImageReader(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self):
# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
def read_image_dims(self, sess, image_data):
image = self.decode_jpeg(sess, image_data)
return image.shape[0], image.shape[1]
def decode_jpeg(self, sess, image_data):
image = sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data})
#pdb.set_trace()
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
def _get_filenames_and_classes(dataset_dir):
"""Returns a list of filenames and inferred class names.
Args:
dataset_dir:包括多个子文件夹,每一个子文件夹是一个类,以类名命名,
其中存放该类的样本.
Returns:
A list of image file paths, relative to `dataset_dir` and the list of
subdirectories, representing class names.
"""
directories = []
class_names = []
for filename in os.listdir(dataset_dir):
path = os.path.join(dataset_dir, filename)
if os.path.isdir(path):
directories.append(path)
#directories里面是每一类文件夹路径
class_names.append(filename)
photo_filenames = []
for directory in directories:
for filename in os.listdir(directory):
path = os.path.join(directory, filename)
photo_filenames.append(path)
return photo_filenames, sorted(class_names)
def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = 'myimage_%s_%05d-of-%05d.tfrecord' % (
split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename)
def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
"""Converts the given filenames to a TFRecord dataset.
Args:
split_name: The name of the dataset, either 'train' or 'validation'.
filenames: A list of absolute paths to png or jpg images.
class_names_to_ids: A dictionary from class names (strings) to ids
(integers).
dataset_dir: The directory where the converted datasets are stored.
"""
assert split_name in ['train', 'validation']
num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))
with tf.Graph().as_default():
image_reader = ImageReader()
with tf.Session('') as sess:
for shard_id in range(_NUM_SHARDS):
output_filename = _get_dataset_filename(
dataset_dir, split_name, shard_id)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_ndx = shard_id * num_per_shard
end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
for i in range(start_ndx, end_ndx):
sys.stdout.write('r>> Converting image %d/%d shard %d' % (
i+1, len(filenames), shard_id))
sys.stdout.flush()
# Read the filename:
image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
print(filenames[i])
height, width = image_reader.read_image_dims(sess, image_data)
class_name = os.path.basename(os.path.dirname(filenames[i]))
class_id = class_names_to_ids[class_name]
example = dataset_utils.image_to_tfexample(
image_data, b'jpg', height, width, class_id)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('n')
sys.stdout.flush()
def _dataset_exists(dataset_dir):
for split_name in ['train', 'validation']:
for shard_id in range(_NUM_SHARDS):
output_filename = _get_dataset_filename(
dataset_dir, split_name, shard_id)
if not tf.gfile.Exists(output_filename):
return False
return True
def run(dataset_dir,_NUM_VALIDATION):
"""
读取样本,划分为训练集和验证集并转换为TF格式.
Args:
dataset_dir: The dataset directory where the dataset is stored.
"""
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
if _dataset_exists(dataset_dir):
print('Dataset files already exist. Exiting without re-creating them.')
return
photo_filenames, class_names = _get_filenames_and_classes(dataset_dir)
class_names_to_ids = dict(zip(class_names, range(len(class_names))))
# Divide into train and test:
random.seed(_RANDOM_SEED)
random.shuffle(photo_filenames)
training_filenames = photo_filenames[_NUM_VALIDATION:]
validation_filenames = photo_filenames[:_NUM_VALIDATION]
# First, convert the training and validation sets.
_convert_dataset('train', training_filenames, class_names_to_ids,
dataset_dir)
_convert_dataset('validation', validation_filenames, class_names_to_ids,
dataset_dir)
# Finally, write the labels file:
labels_to_class_names = dict(zip(range(len(class_names)), class_names))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
print('nFinished converting the dataset!')
读取tfrecord文件并以队列形式供给
要调整的参数已放在最前端。
另外,在对数据进行预处理时,我选用的是tensorflow默认的图片放缩函数以及标准化函数,用这两个函数处理过的训练结果较好,见后面。如果采用slim库的预处理函数反而训练结果不好。
import tensorflow as tf
import os
from preprocessing import vgg_preprocessing
from datasets import dataset_utils
slim = tf.contrib.slim
_FILE_PATTERN = 'myimage_%s_*.tfrecord'
#
SPLITS_TO_SIZES = {'train': 3600, 'validation': 1200} #修改为你的数据库的样本大小,这里3600代表3600张样本
_NUM_CLASSES = 2
#修改为你的类别数
_ITEMS_TO_DESCRIPTIONS = {
'image': 'A color image of varying size.',
'label': 'A single integer between 0 and 4',
}
def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
"""Gets a dataset tuple with instructions for reading flowers.
Args:
split_name: A train/validation split name.
dataset_dir: The base directory of the dataset sources.
file_pattern: The file pattern to use when matching the dataset sources.
It is assumed that the pattern contains a '%s' string so that the split
name can be inserted.
reader: The TensorFlow reader type.
Returns:
A `Dataset` namedtuple.
Raises:
ValueError: if `split_name` is not a valid train/validation split.
"""
if split_name not in SPLITS_TO_SIZES:
raise ValueError('split name %s was not recognized.' % split_name)
if not file_pattern:
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
# Allowing None in the signature so that dataset_factory can use the default.
if reader is None:
reader = tf.TFRecordReader
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(),
'label': slim.tfexample_decoder.Tensor('image/class/label'),
}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
labels_to_names = None
if dataset_utils.has_labels(dataset_dir):
labels_to_names = dataset_utils.read_label_file(dataset_dir)
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=SPLITS_TO_SIZES[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
num_classes=_NUM_CLASSES,
labels_to_names=labels_to_names)
def read_image_and_label(dataset_dir,is_training=False):
'''
读取tf格式数据
args:
dataset_dir:数据集所在的目录
is_training:设置为TRue,表示加载训练数据集,否则加载验证集
return:
image,label:返回随机读取的一张图片,和对应的标签
'''
#选择数据集train
if is_training:
dataset = get_split(split_name = 'train',dataset_dir=dataset_dir)
else:
dataset = get_split(split_name = 'validation',dataset_dir=dataset_dir)
#创建一个数据provider
provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
#通过provider的get随机获取一条样本数据 返回的是两个张量
[image,label] = provider.get(['image','label'])
return image,label
def get_batch_images_and_label(dataset_dir,batch_size,num_classes,is_training=False,output_height=224, output_width=224,num_threads=10):
'''
每次取出batch_size个样本
注意:这里预处理调用的是slim库图片预处理的函数,例如:如果你使用的vgg网络,就调用vgg网络的图像预处理函数
如果你使用的是自己定义的网络,则可以自己写适合自己图像的预处理函数,比如归一化处理也可以使用其他网络已经写好的预处理函数
args:
dataset_dir:数据集所在的目录
batch_size:一次取出的样本数量
num_classes:输出的类别 用于对标签one_hot编码
is_training:设置为TRue,表示加载训练数据集,否则加载验证集
output_height:输出图片高度
output_width:输出图片宽
return:
images,labels:返回随机读取的batch_size张图片,和对应的标签one_hot编码
'''
#获取单张图像和标签
image,label = read_image_and_label(dataset_dir, is_training)
#image = vgg_preprocessing.preprocess_image(image, output_height, output_width,is_training=is_training)
#这里没有采用slim库图片预处理的函数,而采用tensorflow原始的放缩的方式以及标准化方式进行预处理,method=0为双线性插值
crop_image = tf.image.resize_images(image, [output_width,output_height],method=0)
image = tf.image.per_image_standardization(crop_image)
# 标准化数据
#缩放处理
#image = tf.image.convert_image_dtype(image, dtype=tf.float32)
#image = tf.image.resize_image_with_crop_or_pad(image, output_height, output_width)
#
shuffle_batch 函数会将数据顺序打乱
#
bacth 函数不会将数据顺序打乱
images, labels = tf.train.batch(
[image, label],
batch_size = batch_size,
capacity=5 * batch_size,
num_threads = num_threads)
#one-hot编码
labels = slim.one_hot_encoding(labels,num_classes)
return images,labels
运行
这里我训练10个epoch,结果如下,可以说训练集的准确率达到了1,验证集的准确率也已达到95%,取得了不错的结果。
最后用tensorboard可以查看训练的loss和accuracy图,代码如下
tensorboard --logdir=logs/train
测试单张图片
def test_on_image_tf():
'''
使用微调好的网络测试单张图片(原始tensorflow形式)
'''
TEST_DIR =
'./test/6.jpg'
#数据路径
org_image = tf.image.decode_jpeg(tf.read_file(TEST_DIR), channels=3) #加载数据
crop_image = tf.image.resize_images(org_image, [IMAGE_SIZE, IMAGE_SIZE],method=0)
image = tf.image.per_image_standardization(crop_image)
# 标准化数据
image = tf.reshape(image,[1,IMAGE_SIZE, IMAGE_SIZE, 3])
#reshape以满足输入要求
#原始tensotflow,需要占位符
input_images = tf.placeholder(dtype=tf.float32,shape = [None,IMAGE_SIZE,IMAGE_SIZE,3])
#输入图片
is_training = tf.placeholder(dtype = tf.bool)
#训练还是测试?测试的时候弃权参数会设置为1.0
#获取模型参数的命名空间
arg_scope = vgg.vgg_arg_scope()
#创建网络
with
slim.arg_scope(arg_scope):
logits,end_points =
vgg.vgg_16(input_images, is_training=is_training,num_classes = NUM_CLASSES)
#预测标签
pred = tf.argmax(logits,axis=1)
restorer = tf.train.Saver()
#恢复模型
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ckpt = tf.train.latest_checkpoint(save_dir)
if ckpt != None:
#恢复模型
restorer.restore(sess,ckpt)
print("Model restored.")
org_imgs, standard_imgs = sess.run([org_image, image])
pred_value = sess.run(pred, feed_dict = {input_images:standard_imgs,is_training:False})
plt.imshow(org_imgs)
plt.title('Original test image')
plt.show()
if pred_value == 0:
print('预测结果为:非船舶')
else:
print('预测结果为:船舶')
def test_on_image_slim():
'''
使用微调好的网络测试单张图片,利用slim,无需占位符送入数据,更简便
'''
TEST_DIR =
'./test/6.jpg'
#数据路径
org_image = tf.image.decode_jpeg(tf.read_file(TEST_DIR), channels=3) #加载数据
crop_image = tf.image.resize_images(org_image, [IMAGE_SIZE, IMAGE_SIZE],method=0)
image = tf.image.per_image_standardization(crop_image)
# 标准化数据
image = tf.reshape(image,[1,IMAGE_SIZE, IMAGE_SIZE, 3])
#reshape以满足输入要求
#获取模型参数的命名空间
arg_scope = vgg.vgg_arg_scope()
#创建网络并送入数据
with
slim.arg_scope(arg_scope):
logits,end_points =
vgg.vgg_16(image, is_training=False, num_classes = NUM_CLASSES)
#预测标签
pred = tf.argmax(logits,axis=1)
restorer = tf.train.Saver()
#恢复模型
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ckpt = tf.train.latest_checkpoint(save_dir)
if ckpt != None:
#恢复模型
restorer.restore(sess,ckpt)
print("Model restored.")
org_imgs, logit, pred_value= sess.run([org_image, pred])
plt.imshow(org_imgs)
plt.title('Original test image')
plt.show()
if pred_value == 0:
print('预测结果为:非船舶')
else:
print('预测结果为:船舶')
上述两种代码皆可
最后
以上就是现实铃铛为你收集整理的【tensorflow】利用slim进行迁移学习的全部内容,希望文章能够帮你解决【tensorflow】利用slim进行迁移学习所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复