我是靠谱客的博主 舒适自行车,这篇文章主要介绍torch学习 (十五):读取和存储模型引入1 读写Tensor2 读写模型,现在分享给大家,希望可以做个参考。

文章目录

  • 引入
  • 1 读写Tensor
  • 2 读写模型
    • 2.1 state_dict
    • 2.2 保存和加载模型
      • 2.2.1 保存和加载state_dict (推荐)
      • 2.2.2 保存整个模型

引入

  本节介绍如何把在内存中训练好的模型参数进行存储,以及后续的读取 [ 1 ] ^{[1]} [1]

1 读写Tensor

  torch的save与load函数与numpy的类似:

复制代码
1
2
3
4
5
6
7
8
9
10
11
import torch from torch import nn if __name__ == '__main__': # Main x = torch.ones(3) torch.save(x, 'x.pt') x = torch.load('x.pt') print(x)

  输出如下:

复制代码
1
2
tensor([1., 1., 1.])

2 读写模型

  torch中,Module的可学习参数 (权重和偏差),以及模块模型包含在参数中,可通过model.parameters()访问。

2.1 state_dict

  state_dict是一个从参数名称映射到参数Tensor的字典:

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch from torch import nn class Test(nn.Module): def __init__(self): super(Test, self).__init__() self.hidden = nn.Linear(3, 2) self.act = nn.ReLU() self.output = nn.Linear(2, 1) def forward(self, x): """ The forward function. """ return self.output(self.act(self.hidden(x))) if __name__ == '__main__': # Main net = Test() print(net.state_dict())

  输出如下:

复制代码
1
2
3
OrderedDict([('hidden.weight', tensor([[-0.1209, -0.1974, -0.2399], [-0.3348, 0.5283, -0.5134]])), ('hidden.bias', tensor([ 0.1019, -0.1037])), ('output.weight', tensor([[-0.4111, -0.1848]])), ('output.bias', tensor([-0.4113]))])

  注:只有具有可学习参数的层才有state_dict条目。优化器也有一个state_dict,其中包含关于优化器状态以及琐事有超参数的信息:

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch from torch import nn class Test(nn.Module): def __init__(self): super(Test, self).__init__() self.hidden = nn.Linear(3, 2) self.act = nn.ReLU() self.output = nn.Linear(2, 1) def forward(self, x): """ The forward function. """ return self.output(self.act(self.hidden(x))) if __name__ == '__main__': # Main net = Test() optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.2) print(optimizer.state_dict())

  输出如下:

复制代码
1
2
{'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0.2, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3]}]}

2.2 保存和加载模型

  torch保存和加载模型有两种常见的方法:
  1)仅保存和加载模型参数,即state_dict;
  2)保存和加载整个模型。

2.2.1 保存和加载state_dict (推荐)

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch from torch import nn class Test(nn.Module): def __init__(self): super(Test, self).__init__() self.hidden = nn.Linear(3, 2) self.act = nn.ReLU() self.output = nn.Linear(2, 1) def forward(self, x): """ The forward function. """ return self.output(self.act(self.hidden(x))) if __name__ == '__main__': # Main net = Test() torch.save(net.state_dict(), 'net.pt') # .pt or .pth model = Test() model.load_state_dict(torch.load('net.pt')) print(model)

  输出如下:

复制代码
1
2
3
4
5
6
Test( (hidden): Linear(in_features=3, out_features=2, bias=True) (act): ReLU() (output): Linear(in_features=2, out_features=1, bias=True) )

2.2.2 保存整个模型

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch from torch import nn class Test(nn.Module): def __init__(self): super(Test, self).__init__() self.hidden = nn.Linear(3, 2) self.act = nn.ReLU() self.output = nn.Linear(2, 1) def forward(self, x): """ The forward function. """ return self.output(self.act(self.hidden(x))) if __name__ == '__main__': # Main net = Test() torch.save(net, 'net.pt') # .pt or .pth net = torch.load('net.pt') print(net)

  输出如下:

复制代码
1
2
3
4
5
6
Test( (hidden): Linear(in_features=3, out_features=2, bias=True) (act): ReLU() (output): Linear(in_features=2, out_features=1, bias=True) )

参考文献
[1] 李沐、Aston Zhang等老师的这本《动手学深度学习》一书。

最后

以上就是舒适自行车最近收集整理的关于torch学习 (十五):读取和存储模型引入1 读写Tensor2 读写模型的全部内容,更多相关torch学习内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(79)

评论列表共有 0 条评论

立即
投稿
返回
顶部