概述
Tensorflow踩过的坑——Tf.data数据加载篇
- 简要
- 问题一:
- 问题描述:
- 问题解决:
- 问题二:
- 问题描述:
- 问题解决:
简要
本系列主要记录在用Tensorflow框架训练深度学习模型时遇到过的一些问题。
问题一:
问题描述:
以Tf.data的方式进行多GPU训练,训练中采用tf.summary.image发现:所有的GPU中每个Batchsize的图片一样?
部分代码如下:
def train_model_multi_gpu(sess):
inf_dir = os.path.join("./experiments/logs/", cfg.MODEL_TYPE)
if not os.path.exists(inf_dir):
os.makedirs(inf_dir)
with tf.device('/cpu:0'):
global_step = tf.get_variable('global_step', [],
initializer=tf.constant_initializer(0),
trainable=False)
learning_rate = configure_learning_rate(global_step)
optimizer = configure_optimizer(learning_rate)
""" bucket by sequence length"""
images_batch_bucket, labels_batch_bucket, im_width_batch_bucket = data_input_bucket()
tower_grads = []
with tf.variable_scope(tf.get_variable_scope()):
for i in range(cfg.TRAIN.NUM_GPUS):
with tf.device('/gpu:%d' % i):
""" create a dataset and iterator per GPU"""
with tf.name_scope('%s_%d' % ("TOWER", i)) as scope:
""" loss by bucket """
im_batch = images_batch_bucket
label_batch = labels_batch_bucket
width_batch = im_width_batch_bucket
loss = tower_loss_bucket(scope, im_batch, label_batch, width_batch)
.......
问题解决:
参考链接:multi-gpu-towers-training-methods
上述代码块中images_batch_bucket, labels_batch_bucket, im_width_batch_bucket = data_input_bucket()
采用tf的迭代器获取数据,在每个gpu外构建之外已获取,即tf.data方式中应该调用“Iterator.get_next() once per GPU to get multiple different batches”,修改后代码如下:
def train_model_multi_gpu(sess):
inf_dir = os.path.join(cfg.MODEL_DIR, cfg.MODEL_TYPE)
if not os.path.exists(inf_dir):
os.makedirs(inf_dir)
with tf.device('/cpu:0'):
global_step = tf.get_variable('global_step', [],
initializer=tf.constant_initializer(0),
trainable=False)
learning_rate = configure_learning_rate(global_step)
optimizer = configure_optimizer(learning_rate)
tower_grads = []
with tf.variable_scope(tf.get_variable_scope()):
for i in range(cfg.TRAIN.NUM_GPUS):
with tf.device('/gpu:%d' % i):
""" create a dataset and iterator per GPU"""
with tf.name_scope('%s_%d' % ("TOWER", i)) as scope:
""" bucket by sequence length"""
images_batch_bucket, labels_batch_bucket, im_width_batch_bucket = data_input_bucket()
loss = tower_loss_bucket(scope, images_batch_bucket, labels_batch_bucket, im_width_batch_bucket)
问题二:
问题描述:
模型可正常训练但是发现生成的checkpoint文件未保存所有的变量,即tf.train.saver.restore失败,未发现相关变量?
问题解决:
参考:grap & saver: no variables to save
save = tf.train.saver()写的位置不对,应放在所有graph构建完毕之后, 以便“Gets all variables in graph
”.
with graph.as_default():
# [Variable and model creation goes here.]
saver = tf.train.Saver() # Gets all variables in `graph`.
with tf.Session(graph=graph) as sess:
saver.restore(sess)
# Do some work with the model....
最后
以上就是俭朴盼望为你收集整理的Tensorflow踩过的坑——Tf.data加载数据篇简要的全部内容,希望文章能够帮你解决Tensorflow踩过的坑——Tf.data加载数据篇简要所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复