我是靠谱客的博主 纯真纸鹤,这篇文章主要介绍pytorch BatchNorm1d 输入二维和三维数据的区别,现在分享给大家,希望可以做个参考。

在阅读KPConv-PyTorch源码时,发现其对torch.nn.BatchNorm1d进行了封装。

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class BatchNormBlock(nn.Module): def __init__(self, in_dim, use_bn, bn_momentum): """ Initialize a batch normalization block. If network does not use batch normalization, replace with biases. :param in_dim: dimension input features :param use_bn: boolean indicating if we use Batch Norm :param bn_momentum: Batch norm momentum """ super(BatchNormBlock, self).__init__() self.bn_momentum = bn_momentum self.use_bn = use_bn self.in_dim = in_dim if self.use_bn: self.batch_norm = nn.BatchNorm1d(in_dim, momentum=bn_momentum) #self.batch_norm = nn.InstanceNorm1d(in_dim, momentum=bn_momentum) else: self.bias = Parameter(torch.zeros(in_dim, dtype=torch.float32), requires_grad=True) return def reset_parameters(self): nn.init.zeros_(self.bias) def forward(self, x): if self.use_bn: # x: [num_of_point, dim] x = x.unsqueeze(2) x = x.transpose(0, 2) # x: [1, dim, num_of_point] x = self.batch_norm(x) x = x.transpose(0, 2) return x.squeeze() # x: [num_of_point, dim] else: return x + self.bias def __repr__(self): return 'BatchNormBlock(in_feat: {:d}, momentum: {:.3f}, only_bias: {:s})'.format(self.in_dim,

x输入时维度为[num_of_point, dim]
经过变换为[1, dim, num_of_point]再输入到batch_norm,return时还原维度

通过阅读torch.nn.BatchNorm1d官方文档发现:

  • 当输入为(N, C, L)时,计算维度 (N, L) 上的统计数据进行归一化。按照维度C恢复全局方差偏置。
  • 当输入为(N, L)时,计算维度N切片上的统计数据进行归一化。按照维度L恢复全局方差偏置。

那么为什么不直接

复制代码
1
2
3
# x: [num_of_point, dim] x = self.batch_norm(x)

实验

data与batch_norm采用真实样本与预训练模型参数

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
self.batch_norm = self.batch_norm.eval() # data: [num_of_point, dim] a = self.batch_norm(data) x = data.unsqueeze(2) x = x.transpose(0, 2) x = self.batch_norm(x) x = x.transpose(0, 2) b = x.squeeze() print(a-b)
复制代码
1
2
3
4
5
6
7
8
9
tensor([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0', grad_fn=<SubBackward0>)

结果

当我对代码进行更改并运行后发现运行速度大幅下降,一开始为还以为是其他问题,后来测试发现:

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import time t = time.time() for i in range(1000): # x: [num_of_point, dim] x = x.unsqueeze(2) x = x.transpose(0, 2) x = self.batch_norm(x) x = x.transpose(0, 2) x = x.squeeze() print(time.time() - t) # 0.7s t = time.time() for i in range(1000): # x: [num_of_point, dim] x = self.batch_norm(x) print(time.time() - t) # 2.1s

补充测试:

复制代码
1
2
3
4
5
6
7
8
t = time.time() for i in range(1000): # x: [num_of_point, dim] x = x.transpose(0, 2) x = self.batch_norm(x) x = x.transpose(0, 2) print(time.time() - t) # 0.7s

batch_norm处理2个维度数据要比处理3个维度慢3倍,官方代码中并没有提到输入2维和3维的有什么不同,但是既然这样就回退更改,并记录下这个问题。

补充

上述实验章节中的ab也并非完全相同

复制代码
1
2
print(torch.max(a-b)) # 7.4506e-09

最后

以上就是纯真纸鹤最近收集整理的关于pytorch BatchNorm1d 输入二维和三维数据的区别的全部内容,更多相关pytorch内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(46)

评论列表共有 0 条评论

立即
投稿
返回
顶部