我是靠谱客的博主 高兴小猫咪,这篇文章主要介绍torch.flatten()函数,现在分享给大家,希望可以做个参考。

1)flatten(x,1)是按照x的第1个维度拼接(按照列来拼接,横向拼接);
2)flatten(x,0)是按照x的第0个维度拼接(按照行来拼接,纵向拼接);
3)有时候会遇到flatten里面有两个维度参数,flatten(x, start_dim, end_dimension),此时flatten函数执行的功能是将从start_dim到end_dim之间的所有维度值乘起来,其他的维度保持不变。例如x是一个size为[4,5,6]的tensor, flatten(x, 0, 1)的结果是一个size为[20,6]的tensor。

In [1]: import torch

In [2]: A = torch.tensor([[[1,2,3,4],[5,6,7,8],[9,10,11,12]],[[13,14,15,16],[17,18,19,20],[21,22,23,24]]])
In [3]: A.size
Out[3]: <function Tensor.size>

In [4]: A.shape
Out[4]: torch.Size([2, 3, 4])

In [5]: B = torch.flatten(A,1)

In [6]: B
Out[6]:
tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12],
        [13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]])

In [7]: B.shape
Out[7]: torch.Size([2, 12])

In [8]: C = torch.flatten(A,0,1)

In [9]: C
Out[9]:
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16],
        [17, 18, 19, 20],
        [21, 22, 23, 24]])

In [10]: C.shape
Out[10]: torch.Size([6, 4])
 

最后

以上就是高兴小猫咪最近收集整理的关于torch.flatten()函数的全部内容,更多相关torch内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(60)

评论列表共有 0 条评论

立即
投稿
返回
顶部