我是靠谱客的博主 欣喜雪糕,最近开发中收集的这篇文章主要介绍PyTorch深度学习笔记(十六)优化器,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

课程学习笔记,课程链接

优化器:神经网络的学习的目的就是寻找合适的参数,使得损失函数的值尽可能小。解决这个问题的过程为称为最优化。解决这个问题使用的算法叫做优化器。在 PyTorch 官网中,将优化器放置在 torch.optim 中,并详细介绍了各种优化器的使用方法。

现以 CIFAR10 数据集为例,损失函数选取交叉熵函数,优化器选择 SGD 优化器,搭建神经网络,并计算其损失值,用优化器优化各个参数,使其朝梯度下降的方向调整。设置 epoch,让其执行 20 次,并将每一次完整的训练的损失函数值求和输出。

import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader
​
dataset = torchvision.datasets.CIFAR10("D:CodeProjectlearn_pytorchpytorch_p17-21data", train=False,
                                       download=True, transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=4)
​
class Jiaolong(nn.Module):
    def __init__(self):
        super(Jiaolong, self).__init__()
        self.model1 = Sequential(
            Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
​
    def forward(self, x):
        x = self.model1(x)
        return x
​
loss = nn.CrossEntropyLoss()
jiaolong = Jiaolong()
# 构建 SGD 优化器,其中 jiaolong.parameters() 表示:待优化参数的 iterable 或者是定义了参数组的 dict,lr=0.01 表示学习率
optim = torch.optim.SGD(jiaolong.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = jiaolong(imgs)
        result_loss = loss(outputs, targets)
        # 将上一轮计算的梯度清零,避免上一轮的梯度值会影响下一轮的梯度值计算
        optim.zero_grad()
        # 反向传播过程,在反向传播过程中会计算每个参数的梯度值
        result_loss.backward()
        # 所有的 optimizer 都实现了 step() 方法,该方法会更新所有的参数。
        optim.step()
        running_loss = running_loss + result_loss
    print(running_loss)

最后

以上就是欣喜雪糕为你收集整理的PyTorch深度学习笔记(十六)优化器的全部内容,希望文章能够帮你解决PyTorch深度学习笔记(十六)优化器所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(44)

评论列表共有 0 条评论

立即
投稿
返回
顶部