1.累加API
复制代码
1
2
3
4
5from keras.models import Sequential from keras.layers import Dense model = Sequential() model.add(Dense(2, input_dim=1)) model.add(Dense(1))
但是他有很多限制
For example, it is not straightforward to define models that may have multiple different input sources, produce multiple output destinations or models that re-use layers.
2.函数式API
举个最简单的例子
复制代码
1
2
3
4
5
6
7from keras.models import Model from keras.layers import Input from keras.layers import Dense visible = Input(shape=(2,)) hidden = Dense(2)(visible) model = Model(inputs=visible, outputs=hidden)
稍微复杂点的例子
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21# Convolutional Neural Network from keras.utils import plot_model from keras.models import Model from keras.layers import Input from keras.layers import Dense from keras.layers import Flatten from keras.layers.convolutional import Conv2D from keras.layers.pooling import MaxPooling2D visible = Input(shape=(64,64,1)) conv1 = Conv2D(32, kernel_size=4, activation='relu')(visible) pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) conv2 = Conv2D(16, kernel_size=4, activation='relu')(pool1) pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) flat = Flatten()(pool2) hidden1 = Dense(10, activation='relu')(flat) output = Dense(1, activation='sigmoid')(hidden1) model = Model(inputs=visible, outputs=output) # summarize layers print(model.summary()) # plot graph plot_model(model, to_file='convolutional_neural_network.png')
再复杂一下,体现这个API的优越性
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30# Shared Input Layer from keras.utils import plot_model from keras.models import Model from keras.layers import Input from keras.layers import Dense from keras.layers import Flatten from keras.layers.convolutional import Conv2D from keras.layers.pooling import MaxPooling2D from keras.layers.merge import concatenate # input layer visible = Input(shape=(64,64,1)) # first feature extractor conv1 = Conv2D(32, kernel_size=4, activation='relu')(visible) pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) flat1 = Flatten()(pool1) # second feature extractor conv2 = Conv2D(16, kernel_size=8, activation='relu')(visible) pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) flat2 = Flatten()(pool2) # merge feature extractors merge = concatenate([flat1, flat2]) # interpretation layer hidden1 = Dense(10, activation='relu')(merge) # prediction output output = Dense(1, activation='sigmoid')(hidden1) model = Model(inputs=visible, outputs=output) # summarize layers print(model.summary()) # plot graph plot_model(model, to_file='shared_input_layer.png')
这里体现了分支和merge
当然还有下面这种结构
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26# Shared Feature Extraction Layer from keras.utils import plot_model from keras.models import Model from keras.layers import Input from keras.layers import Dense from keras.layers.recurrent import LSTM from keras.layers.merge import concatenate # define input visible = Input(shape=(100,1)) # feature extraction extract1 = LSTM(10)(visible) # first interpretation model interp1 = Dense(10, activation='relu')(extract1) # second interpretation model interp11 = Dense(10, activation='relu')(extract1) interp12 = Dense(20, activation='relu')(interp11) interp13 = Dense(10, activation='relu')(interp12) # merge interpretation merge = concatenate([interp1, interp13]) # output output = Dense(1, activation='sigmoid')(merge) model = Model(inputs=visible, outputs=output) # summarize layers print(model.summary()) # plot graph plot_model(model, to_file='shared_feature_extractor.png')
多个input一个ouput也可以
多个output也可以
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23# Multiple Outputs from keras.utils import plot_model from keras.models import Model from keras.layers import Input from keras.layers import Dense from keras.layers.recurrent import LSTM from keras.layers.wrappers import TimeDistributed # input layer visible = Input(shape=(100,1)) # feature extraction extract = LSTM(10, return_sequences=True)(visible) # classification output class11 = LSTM(10)(extract) class12 = Dense(10, activation='relu')(class11) output1 = Dense(1, activation='sigmoid')(class12) # sequence output output2 = TimeDistributed(Dense(1, activation='linear'))(extract) # output model = Model(inputs=visible, outputs=[output1, output2]) # summarize layers print(model.summary()) # plot graph plot_model(model, to_file='multiple_outputs.png')
参考:https://machinelearningmastery.com/keras-functional-api-deep-learning/
最后
以上就是复杂花生最近收集整理的关于keras两个API的全部内容,更多相关keras两个API内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复