我是靠谱客的博主 清新小笼包,这篇文章主要介绍pytorch 批标准化模块 torch.nn.BatchNorm1d,现在分享给大家,希望可以做个参考。

官方文档: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html

class torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True,
	  track_running_stats=True)

输入: (N,C)(N,C,L)
输出: (N,C)(N,C,L) (与输入相同的形状)

参数:

  • num_features: 输入 (N,C,L) 中的特征数C 或 输入 (N,L) 中的L
  • eps: 用于稳定分母的极小值ε, 默认为1E-5
  • momentum: 动量, 用于累计移动平均值 (下面细说) , 默认为0.1
  • affine: 仿射, 标识该模块是否具有可学习的仿射参数, 默认为True
  • track_running_stats:
    True时此模块跟踪运行的均值和方差.
    False时此模块不跟踪此类统计信息, 而是初始化统计信息缓冲区.
    Nonerunning_meanrunning_var皆为None, 使得这个模块无论是训练模式还是评估模式总是使用批处理统计.
    默认为 True

Pytorch 中 BatchNorm 中参数更新规则为:

mean = momentum * new + (1 - momentum) * mean

值得注意的是, 在 Tensorflow 中, batch_normalization 的更新规则为:

mean = momentum * mean + (1 - momentum) * new

或用 decay (衰减率) 表示亦是如此:

mean = decay * mean + (1 - decay) * new

这意味着 PyTorch 中的 momentum 等于 Tensorflow 中 (1 - momentum)

最后

以上就是清新小笼包最近收集整理的关于pytorch 批标准化模块 torch.nn.BatchNorm1d的全部内容,更多相关pytorch内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(44)

评论列表共有 0 条评论

立即
投稿
返回
顶部