我是靠谱客的博主 无语店员,这篇文章主要介绍Pytorch阅读文档之flatten函数pytorch中flatten函数,现在分享给大家,希望可以做个参考。

pytorch中flatten函数

torch.flatten()

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#展平一个连续范围的维度,输出类型为Tensor torch.flatten(input, start_dim=0, end_dim=-1) → Tensor # Parameters:input (Tensor) – 输入为Tensor #start_dim (int) – 展平的开始维度 #end_dim (int) – 展平的最后维度 #example #一个3x2x2的三维张量 >>> t = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]) #当开始维度为0,最后维度为-1,展开为一维 >>> torch.flatten(t) tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) #当开始维度为0,最后维度为-1,展开为3x4,也就是说第一维度不变,后面的压缩 >>> torch.flatten(t, start_dim=1) tensor([[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12]]) >>> torch.flatten(t, start_dim=1).size() torch.Size([3, 4]) #下面的和上面进行对比应该就能看出是,当锁定最后的维度的时候 #前面的就会合并 >>> torch.flatten(t, start_dim=0, end_dim=1) tensor([[ 1, 2], [ 3, 4], [ 5, 6], [ 7, 8], [ 9, 10], [11, 12]]) >>> torch.flatten(t, start_dim=0, end_dim=1).size() torch.Size([6, 2])

torch.nn.Flatten()

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
Class torch.nn.Flatten(start_dim=1, end_dim=-1) #Flattens a contiguous range of dims into a tensor. #For use with Sequential. : #param start_dim: first dim to flatten (default = 1). #param end_dim: last dim to flatten (default = -1). #能力有限,个人认为是用于卷积中的 #Shape: #Input: (N, *dims)(N,∗dims) #Output: (N, prod *dims)(N,∏∗dims) (for the default case). #官方example >>> m = nn.Sequential( >>> nn.Conv2d(1, 32, 5, 1, 1), >>> nn.Flatten() >>> ) #源代码为 TORCH.NN.MODULES.FLATTEN from .module import Module [docs]class Flatten(Module): r""" Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`. Args: start_dim: first dim to flatten (default = 1). end_dim: last dim to flatten (default = -1). Shape: - Input: :math:`(N, *dims)` - Output: :math:`(N, prod *dims)` (for the default case). Examples:: >>> m = nn.Sequential( >>> nn.Conv2d(1, 32, 5, 1, 1), >>> nn.Flatten() >>> ) """ __constants__ = ['start_dim', 'end_dim'] def __init__(self, start_dim=1, end_dim=-1): super(Flatten, self).__init__() self.start_dim = start_dim self.end_dim = end_dim def forward(self, input): return input.flatten(self.start_dim, self.end_dim)

torch.Tensor.flatten()

和torch.flatten()一样

最后

以上就是无语店员最近收集整理的关于Pytorch阅读文档之flatten函数pytorch中flatten函数的全部内容,更多相关Pytorch阅读文档之flatten函数pytorch中flatten函数内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(58)

评论列表共有 0 条评论

立即
投稿
返回
顶部