我是靠谱客的博主 纯真纸鹤,最近开发中收集的这篇文章主要介绍pytorch BatchNorm1d 输入二维和三维数据的区别,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

在阅读KPConv-PyTorch源码时,发现其对torch.nn.BatchNorm1d进行了封装。

class BatchNormBlock(nn.Module):

    def __init__(self, in_dim, use_bn, bn_momentum):
        """
        Initialize a batch normalization block. If network does not use batch normalization, replace with biases.
        :param in_dim: dimension input features
        :param use_bn: boolean indicating if we use Batch Norm
        :param bn_momentum: Batch norm momentum
        """
        super(BatchNormBlock, self).__init__()
        self.bn_momentum = bn_momentum
        self.use_bn = use_bn
        self.in_dim = in_dim
        if self.use_bn:
            self.batch_norm = nn.BatchNorm1d(in_dim, momentum=bn_momentum)
            #self.batch_norm = nn.InstanceNorm1d(in_dim, momentum=bn_momentum)
        else:
            self.bias = Parameter(torch.zeros(in_dim, dtype=torch.float32), requires_grad=True)
        return

    def reset_parameters(self):
        nn.init.zeros_(self.bias)

    def forward(self, x):	
        if self.use_bn:
			# x: [num_of_point, dim]
            x = x.unsqueeze(2)
            x = x.transpose(0, 2)
            # x: [1, dim, num_of_point]
            x = self.batch_norm(x)
            x = x.transpose(0, 2)
            return x.squeeze()	# x: [num_of_point, dim]
        else:
            return x + self.bias

    def __repr__(self):
        return 'BatchNormBlock(in_feat: {:d}, momentum: {:.3f}, only_bias: {:s})'.format(self.in_dim,
 

x输入时维度为[num_of_point, dim]
经过变换为[1, dim, num_of_point]再输入到batch_norm,return时还原维度

通过阅读torch.nn.BatchNorm1d官方文档发现:

  • 当输入为(N, C, L)时,计算维度 (N, L) 上的统计数据进行归一化。按照维度C恢复全局方差偏置。
  • 当输入为(N, L)时,计算维度N切片上的统计数据进行归一化。按照维度L恢复全局方差偏置。

那么为什么不直接

# x: [num_of_point, dim]
x = self.batch_norm(x)

实验

data与batch_norm采用真实样本与预训练模型参数

self.batch_norm = self.batch_norm.eval()
# data: [num_of_point, dim]

a = self.batch_norm(data)

x = data.unsqueeze(2)
x = x.transpose(0, 2)
x = self.batch_norm(x)
x = x.transpose(0, 2)
b = x.squeeze()

print(a-b)
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0',
       grad_fn=<SubBackward0>)

结果

当我对代码进行更改并运行后发现运行速度大幅下降,一开始为还以为是其他问题,后来测试发现:

import time
t = time.time()
for i in range(1000):
	# x: [num_of_point, dim]
	x = x.unsqueeze(2)
	x = x.transpose(0, 2)
	x = self.batch_norm(x)
	x = x.transpose(0, 2)
	x = x.squeeze()
print(time.time() - t)	# 0.7s

t = time.time()
for i in range(1000):
	# x: [num_of_point, dim]
	x = self.batch_norm(x)
print(time.time() - t)	# 2.1s

补充测试:

t = time.time()
for i in range(1000):
	# x: [num_of_point, dim]
	x = x.transpose(0, 2)
	x = self.batch_norm(x)
	x = x.transpose(0, 2)
print(time.time() - t)	# 0.7s

batch_norm处理2个维度数据要比处理3个维度慢3倍,官方代码中并没有提到输入2维和3维的有什么不同,但是既然这样就回退更改,并记录下这个问题。

补充

上述实验章节中的ab也并非完全相同

print(torch.max(a-b))	# 7.4506e-09

最后

以上就是纯真纸鹤为你收集整理的pytorch BatchNorm1d 输入二维和三维数据的区别的全部内容,希望文章能够帮你解决pytorch BatchNorm1d 输入二维和三维数据的区别所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(34)

评论列表共有 0 条评论

立即
投稿
返回
顶部