概述
模型的保存
'''
torch.save()函数保存序列化的对象。
'''
# 保存整个模型
torch.save(model, './path')
# 仅保存参数
torch.save(model.state_dict(), './path')
模型的加载
# 整个模型的加载
model = torch.load('./path')
# 获取权重值
checkpoint = model.state_dict()
当使用pytorch自己训练了一个模型并保存,下次想要直接加载使用时,必须清楚这个模型结构的所有内容来自PyTorch自带函数,还是有自定义的部分。若有自定义的部分则必须在使用它之前import或者写好自定义的部分,意即给出自定义的层、model类等。比如:
from TheModelByYourself import Layer1, Layer2, Function1, Function2
另外,model.state_dict()里面仅有定义为可训练的参数,可以自己打印出来看一下。
想要保存额外的参数,可以在保存时自定义保存内容,比如:
torch.save(
{'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss},
'./path'
)
最后
以上就是大方银耳汤为你收集整理的PyTorch预训练模型保存与加载的全部内容,希望文章能够帮你解决PyTorch预训练模型保存与加载所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复