1)flatten(x,1)是按照x的第1个维度拼接(按照列来拼接,横向拼接);
2)flatten(x,0)是按照x的第0个维度拼接(按照行来拼接,纵向拼接);
3)有时候会遇到flatten里面有两个维度参数,flatten(x, start_dim, end_dimension),此时flatten函数执行的功能是将从start_dim到end_dim之间的所有维度值乘起来,其他的维度保持不变。例如x是一个size为[4,5,6]的tensor, flatten(x, 0, 1)的结果是一个size为[20,6]的tensor。
In [1]: import torch
In [2]: A = torch.tensor([[[1,2,3,4],[5,6,7,8],[9,10,11,12]],[[13,14,15,16],[17,18,19,20],[21,22,23,24]]])
In [3]: A.size
Out[3]: <function Tensor.size>
In [4]: A.shape
Out[4]: torch.Size([2, 3, 4])
In [5]: B = torch.flatten(A,1)
In [6]: B
Out[6]:
tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]])
In [7]: B.shape
Out[7]: torch.Size([2, 12])
In [8]: C = torch.flatten(A,0,1)
In [9]: C
Out[9]:
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]])
In [10]: C.shape
Out[10]: torch.Size([6, 4])
最后
以上就是高兴小猫咪最近收集整理的关于torch.flatten()函数的全部内容,更多相关torch内容请搜索靠谱客的其他文章。
发表评论 取消回复