概述
文章目录
- 引入
- 1 读写Tensor
- 2 读写模型
- 2.1 state_dict
- 2.2 保存和加载模型
- 2.2.1 保存和加载state_dict (推荐)
- 2.2.2 保存整个模型
引入
本节介绍如何把在内存中训练好的模型参数进行存储,以及后续的读取 [ 1 ] ^{[1]} [1]。
1 读写Tensor
torch的save与load函数与numpy的类似:
import torch
from torch import nn
if __name__ == '__main__':
# Main
x = torch.ones(3)
torch.save(x, 'x.pt')
x = torch.load('x.pt')
print(x)
输出如下:
tensor([1., 1., 1.])
2 读写模型
torch中,Module的可学习参数 (权重和偏差),以及模块模型包含在参数中,可通过model.parameters()访问。
2.1 state_dict
state_dict是一个从参数名称映射到参数Tensor的字典:
import torch
from torch import nn
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
"""
The forward function.
"""
return self.output(self.act(self.hidden(x)))
if __name__ == '__main__':
# Main
net = Test()
print(net.state_dict())
输出如下:
OrderedDict([('hidden.weight', tensor([[-0.1209, -0.1974, -0.2399],
[-0.3348, 0.5283, -0.5134]])), ('hidden.bias', tensor([ 0.1019, -0.1037])), ('output.weight', tensor([[-0.4111, -0.1848]])), ('output.bias', tensor([-0.4113]))])
注:只有具有可学习参数的层才有state_dict条目。优化器也有一个state_dict,其中包含关于优化器状态以及琐事有超参数的信息:
import torch
from torch import nn
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
"""
The forward function.
"""
return self.output(self.act(self.hidden(x)))
if __name__ == '__main__':
# Main
net = Test()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.2)
print(optimizer.state_dict())
输出如下:
{'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0.2, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3]}]}
2.2 保存和加载模型
torch保存和加载模型有两种常见的方法:
1)仅保存和加载模型参数,即state_dict;
2)保存和加载整个模型。
2.2.1 保存和加载state_dict (推荐)
import torch
from torch import nn
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
"""
The forward function.
"""
return self.output(self.act(self.hidden(x)))
if __name__ == '__main__':
# Main
net = Test()
torch.save(net.state_dict(), 'net.pt') # .pt or .pth
model = Test()
model.load_state_dict(torch.load('net.pt'))
print(model)
输出如下:
Test(
(hidden): Linear(in_features=3, out_features=2, bias=True)
(act): ReLU()
(output): Linear(in_features=2, out_features=1, bias=True)
)
2.2.2 保存整个模型
import torch
from torch import nn
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
"""
The forward function.
"""
return self.output(self.act(self.hidden(x)))
if __name__ == '__main__':
# Main
net = Test()
torch.save(net, 'net.pt') # .pt or .pth
net = torch.load('net.pt')
print(net)
输出如下:
Test(
(hidden): Linear(in_features=3, out_features=2, bias=True)
(act): ReLU()
(output): Linear(in_features=2, out_features=1, bias=True)
)
参考文献
[1] 李沐、Aston Zhang等老师的这本《动手学深度学习》一书。
最后
以上就是舒适自行车为你收集整理的torch学习 (十五):读取和存储模型引入1 读写Tensor2 读写模型的全部内容,希望文章能够帮你解决torch学习 (十五):读取和存储模型引入1 读写Tensor2 读写模型所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复