在fine tune Keras Applications中给出的分类CNN Model的时候,如果在Model的top层之上加入Flatten层就会出现错误。可能的报错信息类似下面的内容:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20$ python3 ./train.py Using TensorFlow backend. Found 60000 images belonging to 200 classes. Found 20000 images belonging to 200 classes. # 略过一些信息... Creating TensorFlow device (/device:GPU:0) -> (device: 0, name: GeForce GTX 1080, pci bus id: 0000:02:00.0, compute capability: 6.1) # ↓↓↓ 错误出现 ↓↓↓ Traceback (most recent call last): File "./train.py", line 51, in <module> x = Flatten()(x) File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 636, in __call__ output_shape = self.compute_output_shape(input_shape) File "/usr/local/lib/python3.5/dist-packages/keras/layers/core.py", line 490, in compute_output_shape '(got ' + str(input_shape[1:]) + '. ' ValueError: The shape of the input to "Flatten" is not fully defined (got (None, None, 1536). Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model. # ↑↑↑ 错误结束 ↑↑↑
出错的代码行是x = Flatten()(x)
,错误提示为ValueError: The shape of the input to "Flatten" is not fully defined (got (None, None, 1536). Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model.
Flatten()(x)
希望参数拥有确定的shape
属性,实际得到的参数x
的shape
属性是(None, None, 1536)
,很明显不符合要求。同时,错误提示信息中也给出了修正错误的方法Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model
。即,在Model的第一层给出确定的input_shape
或batch_input_shape
。那么,如何在Keras中解决该问题呢?
以Keras Applications中的VGG16为例,我们只需要在其初始化的时候,给出具体的input_shape
就可以了。例如,Keras给出的VGG16模型输入层图像尺寸是(224, 224)的,所以如果使用TensorFlow的channels_last
数据格式,则初始化代码为:
1
2
3
4vgg16 = keras.applications.vgg16.VGG16(include_top=False, weights='imagenet', input_shape=(224, 224, 3)) x = vgg16.output x = Flatten()(x) ...
注意,因为要fine tune模型,对模型分类的种类和类别数进行重新定义,所以include_top=False
,这样返回的模型不包括VGG16的全连接层和输出层。
所以该类似问题,只需指定input_shape参数
对于tensorflow后端:
vgg19_base = VGG19(weights ='imagenet',include_top = False,input_shape =(224,224,3))
对于theano后端:
vgg19_base = VGG19(weights ='imagenet',include_top = False,input_shape =(3,224,224))
参考:
- Unable to fine tune Keras vgg16 model - input shape issue
- Keras Applications
- 简书
主要参考作者:Aspirinrin
链接:https://www.jianshu.com/p/ec188fa1cca1
最后
以上就是酷酷寒风最近收集整理的关于Keras Flatten的input_shape问题的全部内容,更多相关Keras内容请搜索靠谱客的其他文章。
发表评论 取消回复