概述
TF2.0模型创建
-
- 概述
- 1、通过tensorflow.keras.Sequential构造器创建模型
- 2、使用函数式API创建模型
- 3、通过继承tensorflow.keras.Model类定义自己的模型
- 下一篇
-
- TF2.0模型训练
概述
这是TF2.0入门笔记【TF2.0模型创建、TF2.0模型训练、TF2.0模型保存】中第一篇【TF2.0模型创建】,本篇将介绍模型的创建。
- tensorflow2.0模型创建方法我划分为三种方式:
- 1、通过tensorflow.keras.Sequential构造器创建模型
- 2、使用函数式API创建模型
- 3、通过继承tensorflow.keras.Model类定义自己的模型
接下来将用代码分别演示去构建一个简单的模型
1、通过tensorflow.keras.Sequential构造器创建模型
第一种:通过tensorflow.keras.Sequential构造器创建模型
该方法就是不断堆叠你需要的层,如该模型带参数的一共有三层,一个卷积层,两个全连接层(卷积之后通过 Flatten 层将其展平,从而接全连接层 )。
需要注意的是要在第一层指定输入形状 input_shape 。
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D, Input
from tensorflow.keras import Model, Sequential
model1 = Sequential()
model1.add(Conv2D(32, 3, activation='relu', padding='same', input_shape=(28, 28, 1)))
model1.add(Flatten())
model1.add(Dense(128, activation='relu'))
model1.add(Dense(10, activation='softmax'))
model1.summary()
运行输出:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 28, 28, 32) 320
_________________________________________________________________
flatten (Flatten) (None, 25088) 0
_________________________________________________________________
dense (Dense) (None, 128) 3211392
_________________________________________________________________
dense_1 (Dense) (None, 10) 1290
=================================================================
Total params: 3,213,002
Trainable params: 3,213,002
Non-trainable params: 0
_________________________________________________________________
2、使用函数式API创建模型
第二种:使用函数式API创建模型
这种方法比较灵活、自由,你可以轻易的创建多输入、多输出的模型。
用的 Input 层指定每个样本的形状,不管批次大小。最后通过 Model 类根据输入和输出来创建模型
input = Input((28,28,1))
x=Conv2D(32, 3, activation='relu', padding='same')(input)
x=Flatten()(x)
x=Dense(128, activation='relu')(x)
output=Dense(10, activation='softmax')(x)
model2=Model(input,output)
model2.summary()
运行输出:
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 28, 28, 1)] 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 28, 28, 32) 320
_________________________________________________________________
flatten_1 (Flatten) (None, 25088) 0
_________________________________________________________________
dense_2 (Dense) (None, 128) 3211392
_________________________________________________________________
dense_3 (Dense) (None, 10) 1290
=================================================================
Total params: 3,213,002
Trainable params: 3,213,002
Non-trainable params: 0
_________________________________________________________________
3、通过继承tensorflow.keras.Model类定义自己的模型
第三种:通过继承tensorflow.keras.Model类定义自己的模型
在继承类中, 我们需要重写 __ init__()(构造函数)以及实现模型的前向传递 call(input)(模型调用)两个方法, 你也可以根据需要添加自定义的方法
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu', padding='same')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
x = self.d2(x)
return x
model3 = MyModel()
input = Input((28,28,1))
_ = model3(input)
model3.summary()
运行输出:
Model: "my_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_2 (Conv2D) (None, 28, 28, 32) 320
_________________________________________________________________
flatten_2 (Flatten) (None, 25088) 0
_________________________________________________________________
dense_4 (Dense) (None, 128) 3211392
_________________________________________________________________
dense_5 (Dense) (None, 10) 1290
=================================================================
Total params: 3,213,002
Trainable params: 3,213,002
Non-trainable params: 0
_________________________________________________________________
下一篇
TF2.0模型训练
最后
以上就是贤惠月饼为你收集整理的TF2.0模型创建--卷积网络的全部内容,希望文章能够帮你解决TF2.0模型创建--卷积网络所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复