概述
bn的计算
训练时统计训练数据的均值和方差(前向传播的时候更新)。
测试时使用训练的均值和方差。
bn的作用
- 使模型不用去学习复杂的数据分布导致过拟合
- 为模型提供有效地正则化防止过拟合(和上面一条一个意思更专业)
- 加速网络收敛
- 防止梯度消失
多卡为什么同步BN
多卡的原理:
- 模型并行,将模型放在不同的卡上,可以不用同步bn
- 数据并行,将数据拆分放在不同的卡上需要同步bn
对于数据并行的情况,因为每张卡处理的数据不同因此每张卡在前向传播统计得到的bn参数(均值、方差)自然也不同,因为越大的batchsize统计得到的均值方差越符合整体数据集的均值方差,所以如果如果进行多卡实验,一共两张卡,整个batchsize为512,那么每张卡的batchsize是256,这其实也是一个比较大的batchsize了,其实可以不同步bn参数(不同步训练会稍微快一点,并且现在单卡的batchsize已经很大了,同步的意义不大)。但是如果整体的batchsize是4,那么单卡的batchsize是2,就很有必要同步bn。
怎么实现多卡同步BN
代码实现(如果模型包含bn层调用下面函数就会把所有bn层修改成syncbn):
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
实现原理:
两次同步:
- 每张卡单独计算均值,然后做一次同步得到全局均值。
- 用全局均值去算每张卡对应的方差,然后做一次同步,得到全局方差
一次同步
根据计算方差公式的变形,其实只需要计算一次即可:
σ
2
=
1
m
∑
i
=
1
m
(
x
i
−
μ
)
2
=
1
m
∑
i
=
1
m
(
x
i
2
+
μ
2
−
2
x
i
μ
)
=
1
m
∑
i
=
1
m
[
x
i
2
+
(
μ
−
2
x
i
)
μ
]
sigma^2 = frac{1}{m}sum_{i=1}^m(x_i-mu)^2=frac{1}{m}sum_{i=1}^m(x_i^2+mu^2-2x_imu)=frac{1}{m}sum_{i=1}^m[x_i^2+(mu-2x_i)mu]
σ2=m1i=1∑m(xi−μ)2=m1i=1∑m(xi2+μ2−2xiμ)=m1i=1∑m[xi2+(μ−2xi)μ]
=
1
m
∑
i
=
1
m
x
i
2
+
(
μ
−
2
⋅
1
m
∑
i
=
1
m
x
i
)
μ
=
1
m
∑
i
=
1
m
x
i
2
+
(
μ
−
2
⋅
1
m
⋅
m
⋅
μ
)
μ
=frac{1}{m}sum_{i=1}^mx_i^2+(mu-2cdotfrac{1}{m}sum^m_{i=1}x_i)mu=frac{1}{m}sum_{i=1}^mx_i^2+(mu-2cdotfrac{1}{m}cdot m cdot mu)mu
=m1i=1∑mxi2+(μ−2⋅m1i=1∑mxi)μ=m1i=1∑mxi2+(μ−2⋅m1⋅m⋅μ)μ
= 1 m ∑ i = 1 m x i 2 − μ 2 = 1 m ∑ i = 1 m x i 2 − ( 1 m ∑ i = 1 m x i ) 2 =frac{1}{m}sum_{i=1}^mx_i^2-mu^2=frac{1}{m}sum_{i=1}^mx_i^2-(frac{1}{m}sum^m_{i=1}x_i)^2 =m1i=1∑mxi2−μ2=m1i=1∑mxi2−(m1i=1∑mxi)2
所以有:
σ
2
=
1
m
∑
i
=
1
m
x
i
2
−
(
1
m
∑
i
=
1
m
x
i
)
2
sigma^2 =frac{1}{m}sum_{i=1}^mx_i^2-(frac{1}{m}sum^m_{i=1}x_i)^2
σ2=m1i=1∑mxi2−(m1i=1∑mxi)2
所以只需要每张卡求出
1
m
∑
i
=
1
m
x
i
2
frac{1}{m}sum_{i=1}^mx_i^2
m1∑i=1mxi2和
1
m
∑
i
=
1
m
x
i
frac{1}{m}sum_{i=1}^mx_i
m1∑i=1mxi然后多卡进行一次同步即可。
最后
以上就是受伤路人为你收集整理的多卡同步bn的原理与推导bn的计算bn的作用多卡为什么同步BN怎么实现多卡同步BN的全部内容,希望文章能够帮你解决多卡同步bn的原理与推导bn的计算bn的作用多卡为什么同步BN怎么实现多卡同步BN所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复