概述
这个可以用来冻结某些参数,也可以用来指定参数的学习
import torch
import torch.optim as optim
w1 = torch.randn(3, 3)
w1.requires_grad = True
w2 = torch.randn(3, 3)
w2.requires_grad = True
o = optim.Adam([w1])
print(o.param_groups)
以上代码给出下面结果
[{'amsgrad': False,
'betas': (0.9, 0.999),
'eps': 1e-08,
'lr': 0.001,
'params': [tensor([[ 2.9064, -0.2141, -0.4037],
[-0.5718, 1.0375, -0.6862],
[-0.8372, 0.4380, -0.1572]])],
'weight_decay': 0}]
Per the docs, the add_param_group method accepts a param_group parameter that is a dict. Example of use:
import torch
import torch.optim as optim
w1 = torch.randn(3, 3)
w1.requires_grad = True
w2 = torch.randn(3, 3)
w2.requires_grad = True
o = optim.Adam([w1])
print(o.param_groups)
gives
[{'amsgrad': False,
'betas': (0.9, 0.999),
'eps': 1e-08,
'lr': 0.001,
'params': [tensor([[ 2.9064, -0.2141, -0.4037],
[-0.5718, 1.0375, -0.6862],
[-0.8372, 0.4380, -0.1572]])],
'weight_decay': 0}]
now
o.add_param_group({'params': w2})
print(o.param_groups)
再继续给出结果:
[{'amsgrad': False,
'betas': (0.9, 0.999),
'eps': 1e-08,
'lr': 0.001,
'params': [tensor([[ 2.9064, -0.2141, -0.4037],
[-0.5718, 1.0375, -0.6862],
[-0.8372, 0.4380, -0.1572]])],
'weight_decay': 0},
{'amsgrad': False,
'betas': (0.9, 0.999),
'eps': 1e-08,
'lr': 0.001,
'params': [tensor([[-0.0560, 0.4585, -0.7589],
[-0.1994, 0.4557, 0.5648],
[-0.1280, -0.0333, -1.1886]])],
'weight_decay': 0}]
翻译
最后
以上就是孤独小蝴蝶为你收集整理的Pytorch的add_param_group使用说明的全部内容,希望文章能够帮你解决Pytorch的add_param_group使用说明所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复