概述
TensorFlow运作方式入门中文文档, 这里是直译过来的,所以很多逻辑顺序不是很合理,刚开始看的时候一脸懵,需要整体开下来,多看几遍,然后才能理解一点,不过由于很多代码是基于python2实现的,换成python3实现起来对于刚入门的我们不是那么容易,下面是源码,里面有带解释,希望能帮助大家更好的理解:
# -*- coding: utf-8 -*-
import os
import sys
import time
import argparse
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.examples.tutorials.mnist import mnist
FLAGS = None
# placeholder_inputs()函数将生成两个tf.placeholder操作,定义传入图表中的shape参数,
# shape参数中包括batch_size值,后续还会将实际的训练用例传入图表。
# 在训练循环(training loop)的后续步骤中,传入的整个图像和标签数据集会被切片,
# 以符合每一个操作所设置的batch_size值,占位符操作将会填补以符合这个batch_size值。
# 然后使用feed_dict参数,将数据传入sess.run()函数。
def placeholder_inputs(batch_size):
images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, mnist.IMAGE_PIXELS))
labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
return images_placeholder, labels_placeholder
# fill_feed_dict函数会查询给定的DataSet,索要下一批次batch_size的图像和标签,
# 与占位符相匹配的Tensor则会包含下一批次的图像和标签。
def fill_feed_dict(data_set, images_placeholder, labels_placeholder):
images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size)
# 然后,以占位符为哈希键,创建一个Python字典对象,键值则是其代表的反馈Tensor。
feed_dict = {
images_placeholder: images_feed,
labels_placeholder: labels_feed
}
return feed_dict
# 对模型进行评估(eval即evaluation),Eval Output
def do_eval(sess, eval_correct, images_placeholder, labels_placeholder,
data_set):
true_count = 0
steps_per_epoch = data_set.num_examples
num_examples = steps_per_epoch * FLAGS.batch_size
for step in range(num_examples):
feed_dict = fill_feed_dict(data_set,
images_placeholder,
labels_placeholder)
# 累加所有in_top_k操作判定为正确的预测之和
true_count += sess.run(eval_correct, feed_dict=feed_dict)
# 准确率 = 正确测试的总数 除以例子总数
precision = float(true_count) / num_examples
print('Num examples: %d Num correct: %d Precision @ 1: %0.04f' %
(num_examples, true_count, precision))
def run_training():
# 在run_training()方法的一开始,input_data.read_data_sets()函数会确保你的本地训练文件夹中,
# 已经下载了正确的数据,然后将这些数据解压并返回一个含有DataSet实例的字典。
data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)
# 在默认图表中创建模型
with tf.Graph().as_default():
# 生成两个tf.placeholder操作
images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)
# 构建从推理模型计算预测的图
logits = mnist.inference(images_placeholder, FLAGS.hidden1, FLAGS.hidden2)
# 将计算损失的操作添加到图表中
loss = mnist.loss(logits, labels_placeholder)
# 向图中添加计算和应用梯度的操作
train_op = mnist.training(loss, FLAGS.learning_rate)
# 添加Op以在评估期间将逻辑与标签进行比较
eval_correct = mnist.evaluation(logits, labels_placeholder)
# 所有的即时数据(在这里只有一个)都要在图表构建阶段合并至一个操作(op)中。
summary_op = tf.summary.merge_all()
# 添加变量初始化器Op
init = tf.global_variables_initializer()
# 保存检查点(checkpoint)
# 为了得到可以用来后续恢复模型以进一步训练或评估的检查点文件(checkpoint file),
# 我们实例化一个tf.train.Saver
saver = tf.train.Saver()
sess = tf.Session()
# 用于写入包含了图表本身和即时数据具体值的事件文件
summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
sess.run(init)
for step in range(FLAGS.max_steps):
start_time = time.time()
feed_dict = fill_feed_dict(data_sets.train, images_placeholder, labels_placeholder)
_, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
duration = time.time() - start_time
if step % 100 == 0 :
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
# 每次运行summary_op时,都会往事件文件中写入最新的即时数据
summary_str = sess.run(summary_op, feed_dict=feed_dict)
# 函数的输出会传入事件文件读写器(writer)的add_summary()函数
summary_writer.add_summary(summary_str, step)
summary_writer.flush()
# 保存检查点并定期评估模型
if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
# 向训练文件夹中写入包含了当前所有可训练变量值得检查点文件
saver.save(sess, checkpoint_file, global_step=step)
print('使用训练数据集对模型进行评估')
do_eval(
sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.train
)
print('使用验证数据集对模型进行评估')
do_eval(
sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.validation
)
print('使用测试数据集对模型进行评估')
do_eval(
sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.test
)
def main(_):
if tf.gfile.Exists(FLAGS.log_dir):
tf.gfile.DeleteRecursively(FLAGS.log_dir)
tf.gfile.MakeDirs(FLAGS.log_dir)
run_training()
if __name__ == '__main__':
#
# 创建一个命令行解析模块的解析对象
parser = argparse.ArgumentParser()
# add_argumen()第一个是选项,第二个是数据类型,第三个默认值,第四个是help命令时的说明
parser.add_argument(
'--learning_rate',
type=float,
default=0.01,
help='Initial leaning rate'
)
parser.add_argument(
'--max_steps',
type=int,
default=2000,
help='Number of steps to run trainer.'
)
parser.add_argument(
'--hidden1',
type=int,
default=128,
help='Number of units in hedden layer 1.'
)
parser.add_argument(
'--hidden2',
type=int,
default=32,
help='Number of units in hedden layer 2.'
)
parser.add_argument(
'--batch_size',
type=int,
default=10,
help='Batch size. divide evenly into the dataset sizes.'
)
parser.add_argument(
'--input_data_dir',
type=str,
default=os.path.join(os.getenv('TEST_TMPDIR', 'tmp'), 'tensorflow/mnist/input_data'),
help='Directory to put the input data.'
)
parser.add_argument(
'--log_dir',
type=str,
default=os.path.join(os.getenv('TEST_TMPDIR', 'tmp'), 'tensorflow/mnist/logs/fully_connected_feed'),
help='Directory to put the log data.'
)
parser.add_argument(
'--fake_data',
default=False,
help='If true, uses fake data for unit testing.',
action='store_true'
)
# 有时间一个脚本只需要解析所有命令行参数中的一小部分,
# 剩下的命令行参数给两一个脚本或者程序
# 在接受到多余的命令行参数时不报错
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
如果有碰到问题,请在下方进行留言,将第一时间为你解决!
【The End!】
最后
以上就是凶狠枕头为你收集整理的Python3实现TensorFlow运作方式入门的全部内容,希望文章能够帮你解决Python3实现TensorFlow运作方式入门所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复