官方文档: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
class torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True,
track_running_stats=True)
输入: (N,C)
或 (N,C,L)
输出: (N,C)
或 (N,C,L)
(与输入相同的形状)
参数:
num_features
: 输入(N,C,L)
中的特征数C
或 输入(N,L)
中的L
eps
: 用于稳定分母的极小值ε
, 默认为1E-5
momentum
: 动量, 用于累计移动平均值 (下面细说) , 默认为0.1
affine
: 仿射, 标识该模块是否具有可学习的仿射参数, 默认为True
track_running_stats
:
为True
时此模块跟踪运行的均值和方差.
为False
时此模块不跟踪此类统计信息, 而是初始化统计信息缓冲区.
为None
时running_mean
和running_var
皆为None
, 使得这个模块无论是训练模式还是评估模式总是使用批处理统计.
默认为True
Pytorch 中 BatchNorm 中参数更新规则为:
mean = momentum * new + (1 - momentum) * mean
值得注意的是, 在 Tensorflow 中, batch_normalization 的更新规则为:
mean = momentum * mean + (1 - momentum) * new
或用 decay (衰减率) 表示亦是如此:
mean = decay * mean + (1 - decay) * new
这意味着 PyTorch 中的 momentum
等于 Tensorflow 中 (1 - momentum)
最后
以上就是清新小笼包最近收集整理的关于pytorch 批标准化模块 torch.nn.BatchNorm1d的全部内容,更多相关pytorch内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复