我是靠谱客的博主 酷酷寒风,最近开发中收集的这篇文章主要介绍Keras Flatten的input_shape问题,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

在fine tune Keras Applications中给出的分类CNN Model的时候,如果在Model的top层之上加入Flatten层就会出现错误。可能的报错信息类似下面的内容:

$ python3 ./train.py
Using TensorFlow backend.
Found 60000 images belonging to 200 classes.
Found 20000 images belonging to 200 classes.
# 略过一些信息...
Creating TensorFlow device (/device:GPU:0) ->
(device: 0, name: GeForce GTX 1080, pci bus id: 0000:02:00.0, compute capability: 6.1)
# ↓↓↓ 错误出现 ↓↓↓
Traceback (most recent call last):
File "./train.py", line 51, in <module>
x = Flatten()(x)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 636, in __call__
output_shape = self.compute_output_shape(input_shape)
File "/usr/local/lib/python3.5/dist-packages/keras/layers/core.py", line 490, in
compute_output_shape
'(got ' + str(input_shape[1:]) + '. '
ValueError: The shape of the input to "Flatten" is not fully defined (got (None, None, 1536).
Make sure to pass a complete "input_shape" or "batch_input_shape" argument
to the first layer in your model.
# ↑↑↑ 错误结束 ↑↑↑

出错的代码行是x = Flatten()(x),错误提示为ValueError: The shape of the input to "Flatten" is not fully defined (got (None, None, 1536). Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model.

Flatten()(x)希望参数拥有确定shape属性,实际得到的参数xshape属性是(None, None, 1536),很明显不符合要求。同时,错误提示信息中也给出了修正错误的方法Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model。即,在Model的第一层给出确定的input_shapebatch_input_shape。那么,如何在Keras中解决该问题呢?

以Keras Applications中的VGG16为例,我们只需要在其初始化的时候,给出具体的input_shape就可以了。例如,Keras给出的VGG16模型输入层图像尺寸是(224, 224)的,所以如果使用TensorFlow的channels_last数据格式,则初始化代码为:

vgg16 = keras.applications.vgg16.VGG16(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
x = vgg16.output
x = Flatten()(x)
...

注意,因为要fine tune模型,对模型分类的种类和类别数进行重新定义,所以include_top=False,这样返回的模型不包括VGG16的全连接层和输出层。


所以该类似问题,只需指定input_shape参数

对于tensorflow后端:
vgg19_base = VGG19(weights ='imagenet',include_top = False,input_shape =(224,224,3))

对于theano后端:
vgg19_base = VGG19(weights ='imagenet',include_top = False,input_shape =(3,224,224))


参考:

  1. Unable to fine tune Keras vgg16 model - input shape issue
  2. Keras Applications
  3. 简书

主要参考作者:Aspirinrin
链接:https://www.jianshu.com/p/ec188fa1cca1
 

最后

以上就是酷酷寒风为你收集整理的Keras Flatten的input_shape问题的全部内容,希望文章能够帮你解决Keras Flatten的input_shape问题所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(105)

评论列表共有 0 条评论

立即
投稿
返回
顶部