我是靠谱客的博主 快乐滑板,最近开发中收集的这篇文章主要介绍反向传播理解,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

参考链接:

反向传播算法(过程及公式推导) - ZYVV - 博客园

RuntimeError: Trying to backward through the graph a second time..._Huiyu Blog-CSDN博客 

https://blog.csdn.net/weixin_44058333/article/details/99701876

参考链接的博主讲的特别细,保存一下,以便随时复习。

两个网络的两个loss需要分别执行backward进行回传的时候: loss1.backward(), loss1.backward().

两个网络的情况需要分别为两个网络分别定义optimizer
optimizer1= torch.optim.SGD(net1.parameters(), learning_rate, momentum,weight_decay)
optimizer2= torch.optim.SGD(net2.parameters(), learning_rate, momentum,weight_decay)
.....
#train 部分的loss回传处理
loss1 = loss()
loss2 = loss()

optimizer1.zero_grad() #set the grade to zero
loss1.backward(retain_graph=True) #保留backward后的中间参数。
optimizer1.step()

optimizer2.zero_grad() #set the grade to zero
loss2.backward() 
optimizer2.step()


步骤解释

optimizer.zero_grad()将梯度初始化为零

output = net(inputs)前向传播求出预测的值

loss = Loss(outputs, labels)求loss

loss.backward()反向传播求梯度

optimizer.step()更新所有参数

最后

以上就是快乐滑板为你收集整理的反向传播理解的全部内容,希望文章能够帮你解决反向传播理解所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(48)

评论列表共有 0 条评论

立即
投稿
返回
顶部