我是靠谱客的博主 默默乐曲,最近开发中收集的这篇文章主要介绍tf.split()在keras中切割张量TensorFlow中的split函数,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

TensorFlow中的split函数

1. tf.split()函数

tf.split(
value,
num_or_size_splits,
axis=0,
num=None,
name='split'
)

value:传入的tensor,就是传入的矩阵或高维矩阵
num_or_size_splits:切成几份,比如输入的是2,那么会在传入的axis上切成2份。举个栗子:2 x 32 x 32的tensor,在axis = 0上切2份,就是两个1 x 32 x 32,具体顺序如何,一会再讨论。
axis:在哪个轴上进行切割,比如2 x 32 x 32,依次对应axis = 0 , 1 和 2

2. tf.split()在keras中切割张量

Lambda层

keras.layers.Lambda(function, output_shape=None, mask=None, arguments=None)

如果split()想要用到keras中,就必须套入Lambda,作为神经网络的一层出现。具体写法如下:

x = Lambda(tf.split, arguments={'axis': 2, 'num_or_size_splits': 4})(input_tensor)

Lambda层的第一个参数是要作为层出现的函数,第二个参数形式为字典,指的是要传入前面函数的参数,其中key是函数API中定义的参数名称,value是要传入的参数。
这里需要注意,tensor这个参数作为Lambda的层的输入,写在最后的(input_tensor)里

Lambda层的详细介绍见keras中文文档

切割张量的排列顺序

以 input_tensor.shape = (2, 32, 32) 为例,输入到网络的 shape 应该是
(?, 2, 32, 32)
用如下 split() 进行切割

x = Lambda(tf.split, arguments={'axis': 2, 'num_or_size_splits': 4})(input_tensor)

x 的形状应该是

# x.shape = (4,?,2,8,32)

即 split() 会将切割的片段放到axis = 0的位置

我的理解是split()先将 ? x 2 x 32 x 32 切成 ? x 2 x 4 x 8 x 32,再转置成4 x ? x 2 x 8 x32
为什么有如上的理解,和下面要说的还原有关

它们具体是如何排列的,只需要 print 一下 tensor 就可以了!

调整顺序

tf.transpose() 转置函数,第一个参数是tensor,第二个参数是axis的顺序

x = Lambda(tf.transpose, arguments={'perm': [1,2,0,3,4]})(x)
# x.shape = (?,2,4,8,32)

还原

还原成最初的样子以及顺序

# 直接reshape
x = Reshape((2, 32, 32,))(x)

这里由于调整顺序部分已经将tensor的顺序调整为2 x 4 x 8 x 32了,即我理解的split()切割的第一步,因此只需要reshape就可以还原回去了。

注:不知道是不是可以写成 2 x 8 x 4 x 32(也许可以…)

最后

以上就是默默乐曲为你收集整理的tf.split()在keras中切割张量TensorFlow中的split函数的全部内容,希望文章能够帮你解决tf.split()在keras中切割张量TensorFlow中的split函数所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(49)

评论列表共有 0 条评论

立即
投稿
返回
顶部