概述
在阅读KPConv-PyTorch源码时,发现其对torch.nn.BatchNorm1d进行了封装。
class BatchNormBlock(nn.Module):
def __init__(self, in_dim, use_bn, bn_momentum):
"""
Initialize a batch normalization block. If network does not use batch normalization, replace with biases.
:param in_dim: dimension input features
:param use_bn: boolean indicating if we use Batch Norm
:param bn_momentum: Batch norm momentum
"""
super(BatchNormBlock, self).__init__()
self.bn_momentum = bn_momentum
self.use_bn = use_bn
self.in_dim = in_dim
if self.use_bn:
self.batch_norm = nn.BatchNorm1d(in_dim, momentum=bn_momentum)
#self.batch_norm = nn.InstanceNorm1d(in_dim, momentum=bn_momentum)
else:
self.bias = Parameter(torch.zeros(in_dim, dtype=torch.float32), requires_grad=True)
return
def reset_parameters(self):
nn.init.zeros_(self.bias)
def forward(self, x):
if self.use_bn:
# x: [num_of_point, dim]
x = x.unsqueeze(2)
x = x.transpose(0, 2)
# x: [1, dim, num_of_point]
x = self.batch_norm(x)
x = x.transpose(0, 2)
return x.squeeze() # x: [num_of_point, dim]
else:
return x + self.bias
def __repr__(self):
return 'BatchNormBlock(in_feat: {:d}, momentum: {:.3f}, only_bias: {:s})'.format(self.in_dim,
x输入时维度为[num_of_point, dim]
经过变换为[1, dim, num_of_point]再输入到batch_norm,return时还原维度
通过阅读torch.nn.BatchNorm1d官方文档发现:
- 当输入为(N, C, L)时,计算维度 (N, L) 上的统计数据进行归一化。按照维度C恢复全局方差偏置。
- 当输入为(N, L)时,计算维度N切片上的统计数据进行归一化。按照维度L恢复全局方差偏置。
那么为什么不直接
# x: [num_of_point, dim]
x = self.batch_norm(x)
实验
data与batch_norm采用真实样本与预训练模型参数
self.batch_norm = self.batch_norm.eval()
# data: [num_of_point, dim]
a = self.batch_norm(data)
x = data.unsqueeze(2)
x = x.transpose(0, 2)
x = self.batch_norm(x)
x = x.transpose(0, 2)
b = x.squeeze()
print(a-b)
tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], device='cuda:0',
grad_fn=<SubBackward0>)
结果
当我对代码进行更改并运行后发现运行速度大幅下降,一开始为还以为是其他问题,后来测试发现:
import time
t = time.time()
for i in range(1000):
# x: [num_of_point, dim]
x = x.unsqueeze(2)
x = x.transpose(0, 2)
x = self.batch_norm(x)
x = x.transpose(0, 2)
x = x.squeeze()
print(time.time() - t) # 0.7s
t = time.time()
for i in range(1000):
# x: [num_of_point, dim]
x = self.batch_norm(x)
print(time.time() - t) # 2.1s
补充测试:
t = time.time()
for i in range(1000):
# x: [num_of_point, dim]
x = x.transpose(0, 2)
x = self.batch_norm(x)
x = x.transpose(0, 2)
print(time.time() - t) # 0.7s
batch_norm处理2个维度数据要比处理3个维度慢3倍,官方代码中并没有提到输入2维和3维的有什么不同,但是既然这样就回退更改,并记录下这个问题。
补充
上述实验章节中的a
与b
也并非完全相同
print(torch.max(a-b)) # 7.4506e-09
最后
以上就是纯真纸鹤为你收集整理的pytorch BatchNorm1d 输入二维和三维数据的区别的全部内容,希望文章能够帮你解决pytorch BatchNorm1d 输入二维和三维数据的区别所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复