概述
构造一个简单神经网络(线性model)的步骤:
- 构建数据集
x_data = torch.Tensor([[1.0] ,[2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
- 设计模型
把模型定义成一个类,代码模板(以后写模型都可以参考这个模板):
import torch
class LinearModel(torch.nn.Module):
# 构造函数:初始化对象默认调用的函数
def __init__(self):
# 必写
super(LinearModel, self).__init__()
# 构造对象
self.linear = torch.nn.Linear(1, 1)
# 前馈任务所要进行的计算
def forward(self, x):
y_pred = self.linear(x)
return y_pred
# 实例化对象
model = LinearModel()
- 构造损失函数和优化器,代码如下:
#损失函数对象
criterion=torch.nn.MSELoss(size_average=False)
#优化器对象
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
- 写训练周期(前馈——反馈——更新)
训练过程代码如下:
# 进行训练
for epoch in range(100):
# 计算y hat
y_pred = model(x_data)
# 计算损失
loss = criterion(y_pred, y_data)
print(epoch, loss.item())
# 所有权重每次都梯度清零
optimizer.zero_grad()
# 反向传播求梯度
loss.backward()
# 更新,step()更新函数
optimizer.step()
完整代码如下:
import torch
x_data = torch.Tensor([[1.0] ,[2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
class LinearModel(torch.nn.Module):
# 构造函数:初始化对象默认调用的函数
def __init__(self):
# 必写
super(LinearModel, self).__init__()
# 构造对象
self.linear = torch.nn.Linear(1, 1)
# 前馈任务所要进行的计算
def forward(self, x):
y_pred = self.linear(x)
return y_pred
# 实例化
model = LinearModel()
# 损失函数对象
criterion = torch.nn.MSELoss(size_average=False)
# 优化器对象
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 进行训练
for epoch in range(100):
# 计算y hat
y_pred = model(x_data)
# 计算损失
loss = criterion(y_pred, y_data)
print(epoch, loss.item())
# 所有权重每次都梯度清零
optimizer.zero_grad()
# 反向传播求梯度
loss.backward()
# 更新,step()更新函数
optimizer.step()
# 输出权重和偏置
print('w=', model.linear.weight.item())
print('b=', model.linear.bias.item())
# 测试
x_test = torch.Tensor([[4.0]])
y_test = torch = model(x_test)
print('y_pred=', y_test.data)
运行结果如图:
因为是随机梯度下降算法,所以每次训练结果都是不同的,运行结果不用拘泥于数字。
补充:
- torch.nn.Linear类
- torch.nn.MSELoss类 :
size_average:是否求均值
reduce:最终是否需要降维(一般不考虑)
- torch.optim.SGD类:
params:模型中所有需要训练的参数
lr:自定义学习率
momentum:是否需要冲量
最后
以上就是紧张冬日为你收集整理的深度学习——用PyTorch实现线性回归(B站刘二大人P5学习笔记)的全部内容,希望文章能够帮你解决深度学习——用PyTorch实现线性回归(B站刘二大人P5学习笔记)所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复