我是靠谱客的博主 老迟到云朵,最近开发中收集的这篇文章主要介绍RNN中的梯度消失/爆炸原因,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

RNN中的梯度消失/爆炸原因

梯度消失/梯度爆炸是深度学习中老生常谈的话题,这篇博客主要是对RNN中的梯度消失/梯度爆炸原因进行公式层面上的直观理解。

在这里插入图片描述

首先,上图是RNN的网络结构图, ( x 1 , x 2 , x 3 , … , ) (x_1, x_2, x_3, …, ) (x1,x2,x3,,)是输入的序列, X t X_t Xt表示时间步为 t t t时的输入向量。假设我们总共有 k k k个时间步,用第 k k k个时间步的输出 H k H_k Hk作为输出(实际上每个时间步都有输出,这里仅考虑 H k H_k Hk),用 E k E_k Ek表示损失。

其中, C t = tanh ⁡ ( W c C t − 1 + W x X t ) C_{t}=tanh left(W_{c} C_{t-1}+W_{x} X_{t}right) Ct=tanh(WcCt1+WxXt)

从上式可以看出 W x W_x Wx W c W_c Wc其实是差不多的,记 W = [ W c , W x ] W=[W_c, W_x] W=[Wc,Wx],那么求偏导可以得到:

∂ E k ∂ W = ∂ E k ∂ H k ∂ H k ∂ C k ∂ C k ∂ C k − 1 … ∂ C 2 ∂ C 1 ∂ C 1 ∂ W = ∂ E k ∂ H k ∂ H k ∂ C k ( ∏ t = 2 k ∂ C t ∂ C t − 1 ) ∂ C 1 ∂ W begin{aligned} frac{partial E_{k}}{partial W}=& frac{partial E_{k}}{partial H_{k}} frac{partial H_{k}}{partial C_{k}} frac{partial C_{k}}{partial C_{k-1}} ldots frac{partial C_{2}}{partial C_{1}} frac{partial C_{1}}{partial W}=\ & frac{partial E_{k}}{partial H_{k}} frac{partial H_{k}}{partial C_{k}}left(prod_{t=2}^{k} frac{partial C_{t}}{partial C_{t-1}}right) frac{partial C_{1}}{partial W} end{aligned} WEk=HkEkCkHkCk1CkC1C2WC1=HkEkCkHk(t=2kCt1Ct)WC1

其中的累乘部分为:

∂ C t ∂ c t − 1 = tanh ⁡ ′ ( W c C t − 1 + W x X t ) ⋅ d d C t − 1 [ W c C t − 1 + W x X t ] = tanh ⁡ ′ ( W c C t − 1 + W x X t ) ⋅ W c begin{aligned} frac{partial C_{t}}{partial c_{t-1}}=& tanh ^{prime}left(W_{c} C_{t-1}+W_{x} X_{t}right) cdot frac{d}{d C_{t-1}}left[W_{c} C_{t-1}+W_{x} X_{t}right]=\ & tanh ^{prime}left(W_{c} C_{t-1}+W_{x} X_{t}right) cdot W_{c} end{aligned} ct1Ct=tanh(WcCt1+WxXt)dCt1d[WcCt1+WxXt]=tanh(WcCt1+WxXt)Wc

将该式代入上式有:

∂ E k ∂ W = ∂ E k ∂ H k ∂ H k ∂ C k ( ∏ t = 2 k tanh ⁡ ′ ( W c C t − 1 + W x X t ) ⋅ W c ) ∂ c 1 ∂ W frac{partial E_{k}}{partial W}=frac{partial E_{k}}{partial H_{k}} frac{partial H_{k}}{partial C_{k}}left(prod_{t=2}^{k} tanh ^{prime}left(W_{c} C_{t-1}+W_{x} X_{t}right) cdot W_{c}right) frac{partial c_{1}}{partial W} WEk=HkEkCkHk(t=2ktanh(WcCt1+WxXt)Wc)Wc1

观察这个式子,和上篇文章中一样,因为链式法则,出现了累乘项,因为tanh的导数 <= 1,所以,当k很大的时候,上式的值是趋向于0的。(<1的数多次相乘),也就是:

Π t = 2 k tanh ⁡ ′ ( W c C t − 1 + w x X t ) ⋅ W c → 0 , Pi_{t=2}^{k} tanh ^{prime}left(W_{c} C_{t-1}+w_{x} X_{t}right) cdot W_{c} rightarrow 0, Πt=2ktanh(WcCt1+wxXt)Wc0, so ∂ E k ∂ W → 0 frac{partial E_{k}}{partial W} rightarrow 0 WEk0

此时,权重更新公式:

W ← W − α ∂ E k ∂ W ≈ W W leftarrow W-alpha frac{partial E_{k}}{partial W} approx W WWαWEkW

也就是说,RNN很容易出现梯度消失现象,使得参数更新缓慢,甚至是停止更新。

最后

以上就是老迟到云朵为你收集整理的RNN中的梯度消失/爆炸原因的全部内容,希望文章能够帮你解决RNN中的梯度消失/爆炸原因所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(49)

评论列表共有 0 条评论

立即
投稿
返回
顶部