我是靠谱客的博主 善良蜗牛,这篇文章主要介绍Python深度学习入门之mnist-VGG(Tensorflow2.0实现),现在分享给大家,希望可以做个参考。

mnist手写数字数据集深度学习最常用的数据集,本文以mnist数据集为例,利用Tensorflow2.0框架搭建VGG网络,实现mnist数据集识别任务,并画出各个曲线。

Demo完整代码如下:

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import tensorflow as tf from tensorflow.keras import layers import numpy as np #加载mnist数据集 (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() #预处理 x_train, x_test = x_train.astype(np.float32)/255., x_test.astype(np.float32)/255. x_train, x_test = np.expand_dims(x_train, axis=3), np.expand_dims(x_test, axis=3) # 创建训练集50000、验证集10000以及测试集10000 x_val = x_train[-10000:] y_val = y_train[-10000:] x_train = x_train[:-10000] y_train = y_train[:-10000] #标签转为one-hot格式 y_train = tf.one_hot(y_train, depth=10).numpy() y_val = tf.one_hot(y_val, depth=10).numpy() y_test = tf.one_hot(y_test, depth=10).numpy() # tf.data.Dataset 批处理 train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(100).repeat() val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(100).repeat() test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(100).repeat() from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics,regularizers from tensorflow import keras input_shape = (28, 28, 1) weight_decay = 0.001 num_classes = 10 model = tf.keras.Sequential() model.add(layers.Conv2D(64, (3, 3), padding='same', input_shape=input_shape, kernel_regularizer=regularizers.l2(weight_decay))) model.add(layers.Activation('relu')) model.add(layers.BatchNormalization()) model.add(layers.Dropout(0.3)) model.add(layers.Conv2D(64, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) model.add(layers.Activation('relu')) model.add(layers.BatchNormalization()) model.add(layers.MaxPooling2D(pool_size=(2, 2))) model.add(layers.Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) model.add(layers.Activation('relu')) model.add(layers.BatchNormalization()) model.add(layers.Dropout(0.4)) model.add(layers.Conv2D(128, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) model.add(layers.Activation('relu')) model.add(layers.BatchNormalization()) model.add(layers.MaxPooling2D(pool_size=(2, 2))) model.add(layers.Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) model.add(layers.Activation('relu')) model.add(layers.BatchNormalization()) model.add(layers.Dropout(0.4)) model.add(layers.Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) model.add(layers.Activation('relu')) model.add(layers.BatchNormalization()) model.add(layers.Dropout(0.4)) model.add(layers.Conv2D(256, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) model.add(layers.Activation('relu')) model.add(layers.BatchNormalization()) model.add(layers.MaxPooling2D(pool_size=(2, 2))) # model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) # model.add(layers.Activation('relu')) # model.add(layers.BatchNormalization()) # model.add(layers.Dropout(0.4)) # model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) # model.add(layers.Activation('relu')) # model.add(layers.BatchNormalization()) # model.add(layers.Dropout(0.4)) # model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) # model.add(layers.Activation('relu')) # model.add(layers.BatchNormalization()) # model.add(layers.MaxPooling2D(pool_size=(2, 2))) # model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) # model.add(layers.Activation('relu')) # model.add(layers.BatchNormalization()) # model.add(layers.Dropout(0.4)) # model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) # model.add(layers.Activation('relu')) # model.add(layers.BatchNormalization()) # model.add(layers.Dropout(0.4)) # model.add(layers.Conv2D(512, (3, 3), padding='same',kernel_regularizer=regularizers.l2(weight_decay))) # model.add(layers.Activation('relu')) # model.add(layers.BatchNormalization()) # model.add(layers.MaxPooling2D(pool_size=(2, 2))) # model.add(layers.Dropout(0.5)) model.add(layers.Flatten()) model.add(layers.Dense(512,kernel_regularizer=regularizers.l2(weight_decay))) model.add(layers.Activation('relu')) model.add(layers.BatchNormalization()) model.add(layers.Dropout(0.5)) model.add(layers.Dense(num_classes, activation='softmax')) #设置网络优化方法、损失函数、评价指标 model.compile(optimizer=tf.keras.optimizers.Adam(0.001), loss=tf.keras.losses.CategoricalCrossentropy(), metrics = ['acc'] ) #开始训练 history_VGG16 = model.fit(train_dataset, epochs=100, steps_per_epoch=20, validation_data=val_dataset, validation_steps=3) #在测试集上评估并保存权重文件 model.evaluate(test_dataset, steps=100) model.save_weights('save_model/VGG_minst/VGG_mnist_weights.ckpt')

网络参数

在这里插入图片描述

测试集评估结果

在这里插入图片描述

绘制曲线代码

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import matplotlib.pyplot as plt #输入两个曲线的信息 plt.figure( figsize=(12,8), dpi=80 ) plt.plot(history_VGG16.epoch, history_VGG16.history.get('loss'), color='r', label = 'loss') plt.plot(history_VGG16.epoch, history_VGG16.history.get('acc'), color='g', linestyle='-.', label = 'acc') plt.plot(history_VGG16.epoch, history_VGG16.history.get('val_acc'), color='b', linestyle='--', label = 'val_acc') #显示图例 plt.legend() #默认loc=Best #添加网格信息 plt.grid(True, linestyle='--', alpha=0.5) #默认是True,风格设置为虚线,alpha为透明度 #添加标题 plt.xlabel('epochs') plt.ylabel('loss/acc') plt.title('VGG16_Curve of loss/acc Change with epochs in Mnist') plt.savefig('./save_png/VGG16_mnist.png') plt.show()

曲线图

在这里插入图片描述

最后

以上就是善良蜗牛最近收集整理的关于Python深度学习入门之mnist-VGG(Tensorflow2.0实现)的全部内容,更多相关Python深度学习入门之mnist-VGG(Tensorflow2内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(66)

评论列表共有 0 条评论

立即
投稿
返回
顶部