我是靠谱客的博主 大方银耳汤,这篇文章主要介绍PyTorch预训练模型保存与加载,现在分享给大家,希望可以做个参考。

模型的保存

复制代码
1
2
3
4
5
6
7
8
9
10
''' torch.save()函数保存序列化的对象。 ''' # 保存整个模型 torch.save(model, './path') # 仅保存参数 torch.save(model.state_dict(), './path')

模型的加载

复制代码
1
2
3
4
5
6
# 整个模型的加载 model = torch.load('./path') # 获取权重值 checkpoint = model.state_dict()

当使用pytorch自己训练了一个模型并保存,下次想要直接加载使用时,必须清楚这个模型结构的所有内容来自PyTorch自带函数,还是有自定义的部分。若有自定义的部分则必须在使用它之前import或者写好自定义的部分,意即给出自定义的层、model类等。比如:

复制代码
1
2
from TheModelByYourself import Layer1, Layer2, Function1, Function2

另外,model.state_dict()里面仅有定义为可训练的参数,可以自己打印出来看一下。
想要保存额外的参数,可以在保存时自定义保存内容,比如:

复制代码
1
2
3
4
5
6
7
torch.save( {'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss}, './path' )

最后

以上就是大方银耳汤最近收集整理的关于PyTorch预训练模型保存与加载的全部内容,更多相关PyTorch预训练模型保存与加载内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(64)

评论列表共有 0 条评论

立即
投稿
返回
顶部