我是靠谱客的博主 灵巧戒指,最近开发中收集的这篇文章主要介绍pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘问题,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

问题

最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。
    KeyError: 'layer1.0.bn1.num_batches_tracked’

其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,
这个参数的作用如下:
  训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1
  如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).
原文链接:https://blog.csdn.net/qq_40821799/article/details/103079350

其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.
所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked’.代码例子,如下.

有问题的代码:


def load_specific_param(self, state_dict, param_name, model_path):
param_dict = torch.load(model_path)
for i in state_dict:
key = param_name + '.' + i
state_dict[i].copy_(param_dict[key])
del param_dict

对’num_batches_tracked进行过滤:


def load_specific_param(self, state_dict, param_name, model_path):
param_dict = torch.load(model_path)
param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
for i in state_dict:
key = param_name + '.' + i
if 'num_batches_tracked' in key:
continue
state_dict[i].copy_(param_dict[key])
del param_dict

最后

以上就是灵巧戒指为你收集整理的pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘问题的全部内容,希望文章能够帮你解决pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘问题所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(58)

评论列表共有 0 条评论

立即
投稿
返回
顶部