我是靠谱客的博主 大方银耳汤,最近开发中收集的这篇文章主要介绍PyTorch预训练模型保存与加载,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

模型的保存

'''
torch.save()函数保存序列化的对象。
'''

# 保存整个模型
torch.save(model, './path')

# 仅保存参数
torch.save(model.state_dict(), './path')

模型的加载

# 整个模型的加载
model = torch.load('./path')

# 获取权重值
checkpoint = model.state_dict()

当使用pytorch自己训练了一个模型并保存,下次想要直接加载使用时,必须清楚这个模型结构的所有内容来自PyTorch自带函数,还是有自定义的部分。若有自定义的部分则必须在使用它之前import或者写好自定义的部分,意即给出自定义的层、model类等。比如:

from TheModelByYourself import Layer1, Layer2, Function1, Function2

另外,model.state_dict()里面仅有定义为可训练的参数,可以自己打印出来看一下。
想要保存额外的参数,可以在保存时自定义保存内容,比如:

torch.save(	
			{'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss}, 
            './path'
          )

最后

以上就是大方银耳汤为你收集整理的PyTorch预训练模型保存与加载的全部内容,希望文章能够帮你解决PyTorch预训练模型保存与加载所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(56)

评论列表共有 0 条评论

立即
投稿
返回
顶部