概述
源码
# !/usr/bin/python
# -*- coding: UTF-8 -*-
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
import numpy as np
(train_x,train_y),(test_x,test_y) = datasets.mnist.load_data()
train_x = tf.convert_to_tensor(train_x, dtype=tf.float32) / 255.
train_y = train_y.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices((train_x,train_y))
# batch_size = 100,数据集最多重复10次
train_dataset = train_dataset.batch(100).repeat(10)
# 用keras.Sequential构建一个模型,并从keras.optimizers实例化一个随机梯度下降优化器。
model = tf.keras.Sequential([
layers.Reshape(target_shape = (28 * 28,), input_shape=(28, 28)),
layers.Dense(256, activation = tf.nn.relu),
layers.Dense(256, activation = tf.nn.relu),
layers.Dense(256, activation = tf.nn.relu),
layers.Dense(10)
])
model.summary()
optimizer = optimizers.Adam(lr=1e-3)
acc = metrics.Accuracy()
# @tf.function
# def compute_loss(logits,label):
#
return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels = label))
for step,(x,y) in enumerate(train_dataset):
"""
使用tf.GradientTape相对于网络的可训练变量手动计算损耗的梯度。GradientTape只是TensorFlow 2.0中执行梯度步骤的多种方法之一
Tf.GradientTape:通过在上下文管理器中记录操作,针对给定变量手动计算损耗梯度。这是执行优化程序步骤的最灵活的方法,因为我们可以直接使用渐变,而无需预先定义的Keras模型或损失函数。
Model.train():Keras的内置函数,用于遍历数据集并在其上拟合Keras.Model。这通常是训练Keras模型的最佳选择,并带有进度条显示,验证拆分,多处理和生成器支持的选项。
Optimizer.minimize():通过给定的损失函数进行计算和微分,并执行一个步骤以通过梯度下降将其最小化。此方法易于实现,并且可以方便地应用于任何现有的计算图上,以进行有效的优化步骤。
"""
with tf.GradientTape() as tape:
# loss = compute_loss(logits=output, label=y)
output = model(x)
# [batch_size,28,28] => [batch_size,10]
y_onehot = tf.one_hot(y, depth=10) # [batch_size,1] => [batch_size,10]
loss = tf.square(output - y_onehot)
loss = tf.reduce_mean(loss)# [batch_size,10] => [batch_size,1]
# 更新准确率
acc.update_state(tf.argmax(output,axis = 1),y)
# 求梯度
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
# 没200 步打印一次,并且重新统计准确率
if step % 200 == 0:
print(step, 'loss:', float(loss), 'acc:', acc.result().numpy())
acc.reset_states()
最后
以上就是尊敬心锁为你收集整理的tf2学习 mnist的全部内容,希望文章能够帮你解决tf2学习 mnist所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复