我是靠谱客的博主 失眠翅膀,最近开发中收集的这篇文章主要介绍TF2-callbacks,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

#CALLBACK(回调函数)
import tensorflow as tf
from tensorflow import keras
import matplotlib as mpl
from matplotlib import pyplot as plt
%matplotlib inline
import sklearn
import numpy as np
import pandas as pd
import os
import sys
import time
print(np.__version__)
1.16.4
#Get data:
fashion_mnist = keras.datasets.fashion_mnist
#download the mnist data
(x_train_all,y_train_all),(x_test,y_test) = fashion_mnist.load_data()
#split data to train and test
x_valid ,x_train = x_train_all[:5000],x_train_all[5000:]
#splot train data to train and valid
y_valid ,y_train = y_train_all[:5000],y_train_all[5000:]
#splot train data to train and valid
print(x_valid.shape,y_valid.shape)
print(x_train.shape,y_train.shape)
print(x_test.shape,y_test.shape)
(5000, 28, 28) (5000,)
(55000, 28, 28) (55000,)
(10000, 28, 28) (10000,)
#Normalization或StandarScaler(归一化)
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
#归一化训练集,归一化输入的是float类型的二维数据,所以先转化为二维数据,得到结果之后在把shape换回来
x_train_scaled = scaler.fit_transform(
x_train.astype(np.float32).reshape(-1,1)
).reshape(-1,28,28)
#归一化验证集
x_valid_scaled = scaler.transform(
x_valid.astype(np.float32).reshape(-1,1)
).reshape(-1,28,28)
#归一化测试集
x_test_scaled = scaler.transform(
x_test.astype(np.float32).reshape(-1,1)
).reshape(-1,28,28)
# Tf.keras.Sequential (bulit model)
model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape = (28,28)))
model.add(keras.layers.Dense(300,activation = "relu"))
model.add(keras.layers.Dense(100,activation = "relu"))
model.add(keras.layers.Dense(10,activation = "softmax"))
#relu : y = max(x,0)
#softmax : change vertor to probability distributions
#why sparse: y-> index
,y ->one hot->[](vector)
model.compile(
loss = "sparse_categorical_crossentropy",
optimizer = "sgd",
metrics = ["accuracy"]
)
#TensorBoard , EarlyStopping , ModelCheckpoint
import os
logdir = './callbacks'
if not os.path.exists(logdir):
os.mkdir(logdir)
output_model_file = os.path.join(logdir,"fashion_mnist_model.h5")
callbacks = [
keras.callbacks.TensorBoard(logdir),
keras.callbacks.ModelCheckpoint(
output_model_file,
save_best_only=True
),
keras.callbacks.EarlyStopping(
patience=5,
min_delta = 1e-3
),
]
#Train
history = model.fit(
x_train_scaled,
y_train,
epochs=10,
validation_data=(x_valid_scaled,y_valid),
callbacks=callbacks
)
Train on 55000 samples, validate on 5000 samples
Epoch 1/10
55000/55000 [==============================] - 4s 70us/sample - loss: 0.9230 - accuracy: 0.6990 - val_loss: 0.6209 - val_accuracy: 0.7898
Epoch 2/10
55000/55000 [==============================] - 4s 65us/sample - loss: 0.5812 - accuracy: 0.7980 - val_loss: 0.5220 - val_accuracy: 0.8184
Epoch 3/10
55000/55000 [==============================] - 4s 65us/sample - loss: 0.5148 - accuracy: 0.8186 - val_loss: 0.4777 - val_accuracy: 0.8352
Epoch 4/10
55000/55000 [==============================] - 3s 63us/sample - loss: 0.4780 - accuracy: 0.8311 - val_loss: 0.4546 - val_accuracy: 0.8420
Epoch 5/10
55000/55000 [==============================] - 4s 64us/sample - loss: 0.4537 - accuracy: 0.8399 - val_loss: 0.4348 - val_accuracy: 0.8502
Epoch 6/10
55000/55000 [==============================] - 4s 64us/sample - loss: 0.4357 - accuracy: 0.8460 - val_loss: 0.4250 - val_accuracy: 0.8538
Epoch 7/10
55000/55000 [==============================] - 4s 64us/sample - loss: 0.4213 - accuracy: 0.8507 - val_loss: 0.4118 - val_accuracy: 0.8578
Epoch 8/10
55000/55000 [==============================] - 4s 64us/sample - loss: 0.4095 - accuracy: 0.8550 - val_loss: 0.4027 - val_accuracy: 0.8618
Epoch 9/10
55000/55000 [==============================] - 4s 67us/sample - loss: 0.3993 - accuracy: 0.8583 - val_loss: 0.4005 - val_accuracy: 0.8582
Epoch 10/10
55000/55000 [==============================] - 4s 64us/sample - loss: 0.3907 - accuracy: 0.8614 - val_loss: 0.3871 - val_accuracy: 0.8638
#在命令行激活python-tf2环境的情况下输入"tensorboard --logdir=callbacks" (callbacks是文件夹名称)
#返回
TensorBoard 1.14.0a20190301 at http://lowry-MS-7A37:6006 (Press CTRL+C to quit)
# 打开了一个本地服务器

最后

以上就是失眠翅膀为你收集整理的TF2-callbacks的全部内容,希望文章能够帮你解决TF2-callbacks所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(50)

评论列表共有 0 条评论

立即
投稿
返回
顶部