概述
在fine tune Keras Applications中给出的分类CNN Model的时候,如果在Model的top层之上加入Flatten层就会出现错误。可能的报错信息类似下面的内容:
$ python3 ./train.py
Using TensorFlow backend.
Found 60000 images belonging to 200 classes.
Found 20000 images belonging to 200 classes.
# 略过一些信息...
Creating TensorFlow device (/device:GPU:0) ->
(device: 0, name: GeForce GTX 1080, pci bus id: 0000:02:00.0, compute capability: 6.1)
# ↓↓↓ 错误出现 ↓↓↓
Traceback (most recent call last):
File "./train.py", line 51, in <module>
x = Flatten()(x)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 636, in __call__
output_shape = self.compute_output_shape(input_shape)
File "/usr/local/lib/python3.5/dist-packages/keras/layers/core.py", line 490, in
compute_output_shape
'(got ' + str(input_shape[1:]) + '. '
ValueError: The shape of the input to "Flatten" is not fully defined (got (None, None, 1536).
Make sure to pass a complete "input_shape" or "batch_input_shape" argument
to the first layer in your model.
# ↑↑↑ 错误结束 ↑↑↑
出错的代码行是x = Flatten()(x)
,错误提示为ValueError: The shape of the input to "Flatten" is not fully defined (got (None, None, 1536). Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model.
Flatten()(x)
希望参数拥有确定的shape
属性,实际得到的参数x
的shape
属性是(None, None, 1536)
,很明显不符合要求。同时,错误提示信息中也给出了修正错误的方法Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model
。即,在Model的第一层给出确定的input_shape
或batch_input_shape
。那么,如何在Keras中解决该问题呢?
以Keras Applications中的VGG16为例,我们只需要在其初始化的时候,给出具体的input_shape
就可以了。例如,Keras给出的VGG16模型输入层图像尺寸是(224, 224)的,所以如果使用TensorFlow的channels_last
数据格式,则初始化代码为:
vgg16 = keras.applications.vgg16.VGG16(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
x = vgg16.output
x = Flatten()(x)
...
注意,因为要fine tune模型,对模型分类的种类和类别数进行重新定义,所以include_top=False
,这样返回的模型不包括VGG16的全连接层和输出层。
所以该类似问题,只需指定input_shape参数
对于tensorflow后端:
vgg19_base = VGG19(weights ='imagenet',include_top = False,input_shape =(224,224,3))
对于theano后端:
vgg19_base = VGG19(weights ='imagenet',include_top = False,input_shape =(3,224,224))
参考:
- Unable to fine tune Keras vgg16 model - input shape issue
- Keras Applications
- 简书
主要参考作者:Aspirinrin
链接:https://www.jianshu.com/p/ec188fa1cca1
最后
以上就是酷酷寒风为你收集整理的Keras Flatten的input_shape问题的全部内容,希望文章能够帮你解决Keras Flatten的input_shape问题所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复