我是靠谱客的博主 从容樱桃,最近开发中收集的这篇文章主要介绍RNN的梯度消失问题,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

学习吴恩达老师的AI课程笔记,详细视频课程请移步https://mooc.study.163.com/smartSpec/detail/1001319001.htm

RNN模型结构:

在这里插入图片描述
可以看出 如果说有一个非常长的句子(则对应会有一个非常深的网络),这个句子中句尾的某个词严重受句首某些词的影响(英语中的单复数,时态),这时候由于网络非常的深,则在反向传播的过程中极易出现梯度消失或者梯度爆炸的情况(反向传播时)。
* 关于梯度消失的详细解释; --引自知乎 https://zhuanlan.zhihu.com/p/28687529
在这里插入图片描述
假设我们的时间序列只有三段S0为给定值,神经元没有激活函数,则RNN最简单的前向传播过程如下:
S 1 = W x X 1 + W s S 0 + b 1 , O 1 = W 0 S 1 + b 2 S_1 = W_xX_1+W_sS_0+b_1,O_1 =W_0S_1+b_2 S1=WxX1+WsS0+b1O1=W0S1+b2 S 2 = W x X 2 + W s S 1 + b 1 , O 2 = W 0 S 2 + b 2 S_2 = W_xX_2+W_sS_1+b_1,O_2 =W_0S_2+b_2 S2=WxX2+WsS1+b1O2=W0S2+b2 S 3 = W x X 3 + W s S 2 + b 1 , O 3 = W 0 S 3 + b 2 S_3 = W_xX_3+W_sS_2+b_1,O_3 =W_0S_3+b_2 S3=WxX3+WsS2+b1O3=W0S3+b2假设在t=3时刻,损失函数为 L 3 = 1 2 ( Y 3 − O 3 ) 2 L3 = frac{1}{2}(Y_3-O_3)^2 L3=21(Y3O3)2。 则对于一次训练任务的损失函数为 L = ∑ t = 0 T L t L=sum_{t=0}^{T}{L_t} L=t=0TLt ,即每一时刻损失值的累加。使用随机梯度下降法训练RNN其实就是对Wx、Ws、W0以及b1、b2求偏导,并不断调整它们以使L尽可能达到最小的过程。现在假设我们我们的时间序列只有三段,t1,t2,t3。我们只对t3时刻的Wx、Ws、W0求偏导(其他时刻类似): δ L 3 δ W 0 = δ L 3 δ O 3 δ O 3 δ W 0 frac{delta L_3}{delta W_0}=frac{delta L_3}{delta O_3}frac{delta O_3}{delta W_0} δW0δL3=δO3δL3δW0δO3 δ L 3 δ W x = δ L 3 δ O 3 δ O 3 δ S 3 δ S 3 δ W x + δ L 3 δ O 3 δ O 3 δ S 3 δ S 3 δ S 2 δ S 2 δ W x + δ L 3 δ O 3 δ O 3 δ S 3 δ S 3 δ S 2 δ S 2 δ S 1 δ S 1 δ W x frac{delta L_3}{delta W_x}=frac{delta L_3}{delta O_3}frac{delta O_3}{delta S_3}frac{delta S_3}{delta W_x}+frac{delta L_3}{delta O_3}frac{delta O_3}{delta S_3}frac{delta S_3}{delta S_2}frac{delta S_2}{delta W_x}+frac{delta L_3}{delta O_3}frac{delta O_3}{delta S_3}frac{delta S_3}{delta S_2}frac{delta S_2}{delta S_1}frac{delta S_1}{delta W_x} δWxδL3=δO3δL3δS3δO3δWxδS3+δO3δL3δS3δO3δS2δS3δWxδS2+δO3δL3δS3δO3δS2δS3δS1δS2δWxδS1 δ L 3 δ W s = δ L 3 δ O 3 δ O 3 δ S 3 δ S 3 δ W s + δ L 3 δ O 3 δ O 3 δ S 3 δ S 3 δ S 2 δ S 2 δ W s + δ L 3 δ O 3 δ O 3 δ S 3 δ S 3 δ S 2 δ S 2 δ S 1 δ S 1 δ W s frac{delta L_3}{delta W_s}=frac{delta L_3}{delta O_3}frac{delta O_3}{delta S_3}frac{delta S_3}{delta W_s}+frac{delta L_3}{delta O_3}frac{delta O_3}{delta S_3}frac{delta S_3}{delta S_2}frac{delta S_2}{delta W_s}+frac{delta L_3}{delta O_3}frac{delta O_3}{delta S_3}frac{delta S_3}{delta S_2}frac{delta S_2}{delta S_1}frac{delta S_1}{delta W_s} δWsδL3=δO3δL3δS3δO3δWsδS3+δO3δL3δS3δO3δS2δS3δWsδS2+δO3δL3δS3δO3δS2δS3δS1δS2δWsδS1 可以看出对于W0求偏导并没有长期依赖,但是对于Wx、Ws求偏导,会随着时间序列产生长期依赖。因为St随着时间序列向前传播,而St又是Wx、Ws的函数。根据上述求偏导的过程,我们可以得出任意时刻对Wx、Ws求偏导的公式: δ L t δ W x = ∑ k = 0 t δ L t δ O t δ O t δ S t ( ⋂ j = k + 1 t δ S j δ S j − 1 ) δ S k δ W x frac{delta L_t}{delta W_x}=sum_{k=0}^{t}{frac{delta L_t}{delta O_t}frac{delta O_t}{delta S_t}(bigcap_{j=k+1}^{t}{frac{delta S_j}{delta S_j-1}})frac{delta S_k}{delta W_x}} δWxδLt=k=0tδOtδLtδStδOt(j=k+1tδSj1δSj)δWxδSk任意时刻对Ws求偏导的公式同上。如果加上激活函数, S j = t a n h ( W x X j + W s S j − 1 + b 1 ) S_j=tanh(W_xX_j+W_sS_{j-1}+b_1) Sj=tanh(WxXj+WsSj1+b1),则 ⋂ j = k + 1 t δ S j δ S j − 1 = ⋂ j = k + 1 t t a n h ′ W s bigcap_{j=k+1}^{t}{frac{delta S_j}{delta S_j-1}}=bigcap_{j=k+1}^{t}{tanh'W_s} j=k+1tδSj1δSj=j=k+1ttanhWs解决方法:
*** 梯度爆炸:** 容易发现,对梯度设定一个最大的阈值,将梯度设定在阈值范围内
*** 梯度消失:**

    GRU: 

在这里插入图片描述
图中的C为记忆单元,他负责存储可能对后续字节有较大影响的字节,CN(t)则代表每步运行后的C值的候选值(是否需要把C值更新为候选值),Iu是门,用来决定是否需要更新,从上式c(t)可以看出 当lu为1是则代表需要更新,为0则无需更新。实际中的GRU 上述的各个环节都可能是多个元素的 如C(t)里可能有[a,b,c,d,e,g…]等100个元素,则Iu中对应的也应有100个元素[0,0,0,0,1,1,0,1,0…],两者的每个元素相对应,使得候选值被准确的定位是否需要更新。 完整的GRU
在这里插入图片描述
* LSTM:在这里插入图片描述
可以看看出LSTM有三个门控单元 分别是更新门Tu,Tf,To 对比GRU可以看出 更新值C(t)直接由更新门Tu和遗忘门Tf控制是否需要更新。

最后

以上就是从容樱桃为你收集整理的RNN的梯度消失问题的全部内容,希望文章能够帮你解决RNN的梯度消失问题所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(56)

评论列表共有 0 条评论

立即
投稿
返回
顶部