概述
evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None)
作用
使用验证集 input_fn 对 model 进行验证。
对于每一步,执行 input_fn(返回数据集的一个 batch)。
- 已经进行了
steps
个 batch,或者 input_fn
抛出了出界异常(OutOfRangeError
或StopIteration
)
参数
input_fn
:此函数构造出验证所需的输入数据,需要返回以下结构之一:
- 一个
tf.data.Dataset
对象:Dataset
对象的输出必须是一个元组 (features, labels),和下面的规格相同。 - 一个元组 (features, labels):
features
是一个Tensor
或者字典(a dictionary of string feature name toTensor
)。labels
是一个Tensor
或者字典(a dictionary of string label name toTensor
)。features
和labels
都被model_fn
所使用(model_fn
是tf.estimator.Estimator
的构造函数的参数之一)。他们应该满足model_fn
输入端的需求。
steps
:验证模型的步数。如果是 None
,则一直验证下去,直至input_fn
抛出了出界异常。
hooks
:SessionRunHook
子类实例的 list。作为验证的回调函数。
checkpoint_path
:特定检查点的路径。如果是 None
,则默认为 model_dir
中最近的检查点(model_dir
是 tf.estimator.Estimator
的构造函数的参数之一)
name
:验证的名字。使用者可以针对不同的数据集运行多个验证操作,比如训练集 vs 测试集。不同验证的结果被保存在不同的文件夹中,且分别出现在 tensorboard 中。
返回值
返回一个字典,包括 model_fn
中指定的评价指标、global_step
。
异常抛出
ValueError
:如果 step
小于等于0
ValueError
:如果 model_dir
指定的模型没有被训练,或者指定的 checkpoint_path
为空。
示例
先定义Estimator
:
cnn_model = tf.estimator.Estimator(
model_fn=model_function, model_dir=save_model_path
)
然后进行训练:
cnn_model.train(
input_fn=lambda: get_train_batch(train_file_path), steps=steps_per_eval)
最后进行验证:
evaluate_results = cnn_model.evaluate(
input_fn=lambda: get_val_batch(val_file_path),
steps=eval_steps_per_train_cycle)
其中,数据是从 tfrecords 中读取的:
def get_train_batch(data_dir, batch_size=conf.batch_size, set_name='train', use_distortion=True):
dataset = DataSet(data_dir, set_name, use_distortion)
return dataset.get_batch(data_dir, batch_size)
def get_val_batch(data_dir, batch_size=conf.batch_size, set_name='val', use_distortion=False):
dataset = DataSet(data_dir, set_name, use_distortion)
return dataset.get_batch(data_dir, batch_size)
class DataSet(object):
....
def get_batch(self, file_path, batch_size):
"""
:param batch_size: train, val, test batch_size is different
:param file_path:
:return:
"""
files = tf.data.Dataset.list_files(file_path)
dataset = files.apply(
tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, cycle_length=conf.num_parallel_readers,
sloppy=True))
if self.set_name == 'train':
dataset = dataset.repeat(conf.train_epochs)
dataset = dataset.shuffle(conf.shuffle_buffer_size)
dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=self.parser_single_img, batch_size=batch_size,
num_parallel_batches=conf.num_parallel_batches))
dataset = dataset.prefetch(conf.batch_size)
iterator = dataset.make_one_shot_iterator()
img_batch, label_batch = iterator.get_next()
return img_batch, label_batch
最后
以上就是激情往事为你收集整理的Tensorflow API 讲解——tf.estimator.Estimator.evaluate的全部内容,希望文章能够帮你解决Tensorflow API 讲解——tf.estimator.Estimator.evaluate所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复