我是靠谱客的博主 灵巧戒指,最近开发中收集的这篇文章主要介绍pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘问题,觉得挺不错的,现在分享给大家,希望可以做个参考。
概述
问题
最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。
KeyError: 'layer1.0.bn1.num_batches_tracked’
其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,
这个参数的作用如下:
训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1
如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).
原文链接:https://blog.csdn.net/qq_40821799/article/details/103079350
其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.
所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked’.代码例子,如下.
有问题的代码:
def load_specific_param(self, state_dict, param_name, model_path):
param_dict = torch.load(model_path)
for i in state_dict:
key = param_name + '.' + i
state_dict[i].copy_(param_dict[key])
del param_dict
对’num_batches_tracked进行过滤:
def load_specific_param(self, state_dict, param_name, model_path):
param_dict = torch.load(model_path)
param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
for i in state_dict:
key = param_name + '.' + i
if 'num_batches_tracked' in key:
continue
state_dict[i].copy_(param_dict[key])
del param_dict
最后
以上就是灵巧戒指为你收集整理的pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘问题的全部内容,希望文章能够帮你解决pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘问题所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复