概述
TensorFlow中的split函数
1. tf.split()函数
tf.split(
value,
num_or_size_splits,
axis=0,
num=None,
name='split'
)
value:传入的tensor,就是传入的矩阵或高维矩阵
num_or_size_splits:切成几份,比如输入的是2,那么会在传入的axis上切成2份。举个栗子:2 x 32 x 32的tensor,在axis = 0上切2份,就是两个1 x 32 x 32,具体顺序如何,一会再讨论。
axis:在哪个轴上进行切割,比如2 x 32 x 32,依次对应axis = 0 , 1 和 2
2. tf.split()在keras中切割张量
Lambda层
keras.layers.Lambda(function, output_shape=None, mask=None, arguments=None)
如果split()想要用到keras中,就必须套入Lambda,作为神经网络的一层出现。具体写法如下:
x = Lambda(tf.split, arguments={'axis': 2, 'num_or_size_splits': 4})(input_tensor)
Lambda层的第一个参数是要作为层出现的函数,第二个参数形式为字典,指的是要传入前面函数的参数,其中key是函数API中定义的参数名称,value是要传入的参数。
这里需要注意,tensor这个参数作为Lambda的层的输入,写在最后的(input_tensor)里
Lambda层的详细介绍见keras中文文档
切割张量的排列顺序
以 input_tensor.shape = (2, 32, 32) 为例,输入到网络的 shape 应该是
(?, 2, 32, 32)
用如下 split() 进行切割
x = Lambda(tf.split, arguments={'axis': 2, 'num_or_size_splits': 4})(input_tensor)
x 的形状应该是
# x.shape = (4,?,2,8,32)
即 split() 会将切割的片段放到axis = 0的位置
我的理解是split()先将 ? x 2 x 32 x 32 切成 ? x 2 x 4 x 8 x 32,再转置成4 x ? x 2 x 8 x32
为什么有如上的理解,和下面要说的还原有关
它们具体是如何排列的,只需要 print 一下 tensor 就可以了!
调整顺序
tf.transpose() 转置函数,第一个参数是tensor,第二个参数是axis的顺序
x = Lambda(tf.transpose, arguments={'perm': [1,2,0,3,4]})(x)
# x.shape = (?,2,4,8,32)
还原
还原成最初的样子以及顺序
# 直接reshape
x = Reshape((2, 32, 32,))(x)
这里由于调整顺序部分已经将tensor的顺序调整为2 x 4 x 8 x 32了,即我理解的split()切割的第一步,因此只需要reshape就可以还原回去了。
注:不知道是不是可以写成 2 x 8 x 4 x 32(也许可以…)
最后
以上就是默默乐曲为你收集整理的tf.split()在keras中切割张量TensorFlow中的split函数的全部内容,希望文章能够帮你解决tf.split()在keras中切割张量TensorFlow中的split函数所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复