我是靠谱客的博主 懵懂嚓茶,最近开发中收集的这篇文章主要介绍机器学习——网络模型的保存与读取,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

加载网络模型,pretrained=False 网络中模型的参数是没有训练的,初始化参数。

import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)

1.有两种保存方式:torch.save()

①保存网络的 结构+参数

torch.save(vgg16, "vgg16_method1.pth")

②方式2【官方推荐】,网络模型的参数保存为字典,不保存结构。空间小

state_dict()返回包含模块整个状态的字典
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

运行后,出现保存的文件。

2.网络的加载方式:torch.load()

对应方式①打印出的是网络模型的结构

# 加载方式1模型
import torch
# 打印出的是网络模型的结构
model1 = torch.load("vgg16_method1.pth")
print(model1)

对应方式②打印出的参数是字典形式

model2 = torch.load("vgg16_method2.pth")
print(model2)

如何恢复成网络模型?

第一步:新建网络模型

vgg16 = torchvision.models.vgg16(pretrained=False)

第二步:加载网络参数模型

model2 = torch.load("vgg16_method2.pth")

第三步:调用load_state_dict,恢复成网络结构

vgg16.load_state_dict(model2)

最后

以上就是懵懂嚓茶为你收集整理的机器学习——网络模型的保存与读取的全部内容,希望文章能够帮你解决机器学习——网络模型的保存与读取所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(72)

评论列表共有 0 条评论

立即
投稿
返回
顶部