概述
加载网络模型,pretrained=False 网络中模型的参数是没有训练的,初始化参数。
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
1.有两种保存方式:torch.save()
①保存网络的 结构+参数
torch.save(vgg16, "vgg16_method1.pth")
②方式2【官方推荐】,网络模型的参数保存为字典,不保存结构。空间小
state_dict()返回包含模块整个状态的字典
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
运行后,出现保存的文件。
2.网络的加载方式:torch.load()
对应方式①打印出的是网络模型的结构
# 加载方式1模型
import torch
# 打印出的是网络模型的结构
model1 = torch.load("vgg16_method1.pth")
print(model1)
对应方式②打印出的参数是字典形式
model2 = torch.load("vgg16_method2.pth")
print(model2)
如何恢复成网络模型?
第一步:新建网络模型
vgg16 = torchvision.models.vgg16(pretrained=False)
第二步:加载网络参数模型
model2 = torch.load("vgg16_method2.pth")
第三步:调用load_state_dict,恢复成网络结构
vgg16.load_state_dict(model2)
最后
以上就是懵懂嚓茶为你收集整理的机器学习——网络模型的保存与读取的全部内容,希望文章能够帮你解决机器学习——网络模型的保存与读取所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复