我是靠谱客的博主 心灵美大象,最近开发中收集的这篇文章主要介绍torch.nn.Module.zero_grad()的使用,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

参考链接: torch.nn.Module.zero_grad()

在这里插入图片描述
函数功能:

将模型的所有参数的梯度清零.

代码展示:


import torch 
torch.manual_seed(seed=20200910)

class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.conv1 = torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1)
    def forward(self,x):  # torch.Size([64, 1, 28, 28])
        x = self.conv1(x)  
        x = torch.sum(x)
        return x

print('cuda(GPU)是否可用:',torch.cuda.is_available())
print('torch的版本:',torch.__version__)

model = Model() #.cuda()

print("训练之前".center(100,"-"))
print('遍历所有参数'.center(100,"-"))
for k,v in model.state_dict().items():
    print(k,v.size())

print('打印特定参数'.center(100,"-"))
print(model.conv1.weight.grad)  # 此时还没有梯度,因此输出的是None
print(model.conv1.bias.grad)  # 此时还没有梯度,因此输出的是None

data_input = torch.randn(64,1,28,28)
data_output = model(data_input)

print("反向传播之前".center(100,"-"))
print(model.conv1.weight.grad)  # 此时还没有梯度,因此输出的是None
print(model.conv1.bias.grad)  # 此时还没有梯度,因此输出的是None
data_output.backward()
print("反向传播之后".center(100,"-"))
print(model.conv1.weight.grad.shape)  # torch.Size([64, 1, 3, 3])
print(model.conv1.weight.grad[0])  # 打印其中一小部分
print(model.conv1.bias.grad.shape)  # torch.Size([64])
print(model.conv1.bias.grad)  # 打印全部的64个数

model.zero_grad()  # 将模型内的参数的梯度归零
print("模型的zero_grad()方法调用之后".center(100,"-"))
print(model.conv1.weight.grad.shape)  # torch.Size([64, 1, 3, 3])
print(model.conv1.weight.grad[0])  # 打印其中一小部分
print(model.conv1.bias.grad.shape)  # torch.Size([64])
print(model.conv1.bias.grad)  # 打印全部的64个数

控制台输出结果:

Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。

尝试新的跨平台 PowerShell https://aka.ms/pscore6

加载个人及系统配置文件用了 1125 毫秒。
(base) PS C:UserschenxuqiDesktopNews4cxqtest4cxq> conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) PS C:UserschenxuqiDesktopNews4cxqtest4cxq>  & 'D:Anaconda3envsssd4pytorch1_2_0python.exe' 'c:Userschenxuqi.vscodeextensionsms-python.python-2020.12.424452561pythonFileslibpythondebugpylauncher' 
'51173' '--' 'c:UserschenxuqiDesktopNews4cxqtest4cxqtest27.py'
cuda(GPU)是否可用: True
torch的版本: 1.2.0+cu92
------------------------------------------------训练之前------------------------------------------------
-----------------------------------------------遍历所有参数-----------------------------------------------
conv1.weight torch.Size([64, 1, 3, 3])
conv1.bias torch.Size([64])
-----------------------------------------------打印特定参数-----------------------------------------------
None
None
-----------------------------------------------反向传播之前-----------------------------------------------
None
None
-----------------------------------------------反向传播之后-----------------------------------------------
torch.Size([64, 1, 3, 3])
tensor([[[ -2.8736,  -9.4509,  25.2211],
         [ -4.0062, -19.2819,  17.3172],
         [-16.1339, -34.2789,  22.5320]]])
torch.Size([64])
tensor([50176., 50176., 50176., 50176., 50176., 50176., 50176., 50176., 50176.,
        50176., 50176., 50176., 50176., 50176., 50176., 50176., 50176., 50176.,
        50176., 50176., 50176., 50176., 50176., 50176., 50176., 50176., 50176.,
        50176., 50176., 50176., 50176., 50176., 50176., 50176., 50176., 50176.,
        50176., 50176., 50176., 50176., 50176., 50176., 50176., 50176., 50176.,
        50176., 50176., 50176., 50176., 50176., 50176., 50176., 50176., 50176.,
        50176., 50176., 50176., 50176., 50176., 50176., 50176., 50176., 50176.,
        50176.])
----------------------------------------模型的zero_grad()方法调用之后----------------------------------------       
torch.Size([64, 1, 3, 3])
tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])
torch.Size([64])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
(ssd4pytorch1_2_0) PS C:UserschenxuqiDesktopNews4cxqtest4cxq>

最后

以上就是心灵美大象为你收集整理的torch.nn.Module.zero_grad()的使用的全部内容,希望文章能够帮你解决torch.nn.Module.zero_grad()的使用所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(50)

评论列表共有 0 条评论

立即
投稿
返回
顶部