我是靠谱客的博主 潇洒小蚂蚁,最近开发中收集的这篇文章主要介绍LSTM如何解决梯度消失与梯度爆炸,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

在这里插入图片描述
  这是一张经典的LSTM示意图,LSTM依靠  f t f_t ft i t i_t it o t o_t ot来控制输入输出, f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_{t}=sigmaleft(W_{f} cdotleft[h_{t-1}, x_{t}right]+b_{f}right) ft=σ(Wf[ht1,xt]+bf) i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_{t}=sigmaleft(W_{i} cdotleft[h_{t-1}, x_{t}right]+b_{i}right) it=σ(Wi[ht1,xt]+bi) o t = σ ( W o [ h t − 1 , x t ] + b o ) o_{t}=sigmaleft(W_{o}left[h_{t-1}, x_{t}right]+b_{o}right) ot=σ(Wo[ht1,xt]+bo)
  我们将其简化为: f t = σ ( W f X t + b f ) f_{t}=sigmaleft(W_{f} X_{t}+b_{f}right) ft=σ(WfXt+bf) i t = σ ( W i X t + b i ) i_{t}=sigmaleft(W_{i} X_{t}+b_{i}right) it=σ(WiXt+bi) o i = σ ( W o X t + b o ) o_{i}=sigmaleft(W_{o} X_{t}+b_{o}right) oi=σ(WoXt+bo)
  当前的状态  S t = f t S t − 1 + i t X t S_{t}=f_{t} S_{t-1}+i_{t} X_{t} St=ftSt1+itXt 类似与传统RNN  S t = W s S t − 1 + W x X t + b 1 S_{t}=W_{s} S_{t-1}+W_{x} X_{t}+b_{1} St=WsSt1+WxXt+b1 。将LSTM的状态表达式展开后得: S t = σ ( W f X t + b f ) S t − 1 + σ ( W i X t + b i ) X t S_{t}=sigmaleft(W_{f} X_{t}+b_{f}right) S_{t-1}+sigmaleft(W_{i} X_{t}+b_{i}right) X_{t} St=σ(WfXt+bf)St1+σ(WiXt+bi)Xt  如果加上激活函数 S t = tanh ⁡ [ σ ( W f X t + b f ) S t − 1 + σ ( W i X t + b i ) X t ] S_{t}=tanh left[sigmaleft(W_{f} X_{t}+b_{f}right) S_{t-1}+sigmaleft(W_{i} X_{t}+b_{i}right) X_{t}right] St=tanh[σ(WfXt+bf)St1+σ(WiXt+bi)Xt]  RNN梯度消失和爆炸的原因这篇文章中传统RNN求偏导的过程包含: ∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t tanh ⁡ ′ W s prod_{j=k+1}^{t} frac{partial S_{j}}{partial S_{j-1}}=prod_{j=k+1}^{t} tanh ^{prime} W_{s} j=k+1tSj1Sj=j=k+1ttanhWs  对于LSTM同样也包含这样的一项,但是在LSTM中: ∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t tanh ⁡ ′ σ ( W f X t + b f ) prod_{j=k+1}^{t} frac{partial S_{j}}{partial S_{j-1}}=prod_{j=k+1}^{t} tanh ^{prime} sigmaleft(W_{f} X_{t}+b_{f}right) j=k+1tSj1Sj=j=k+1ttanhσ(WfXt+bf) 假设   Z = tanh ⁡ ′ ( x ) σ ( y ) Z=tanh ^{prime}(x) sigma(y) Z=tanh(x)σ(y),则 Z Z Z的函数图像如下图所示:

在这里插入图片描述
  可以看到该函数值基本上不是0就是1。
  传统RNN的求偏导过程: ∂ L 3 ∂ W s = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ W s + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ W s + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ S 1 ∂ S 1 ∂ W s frac{partial L_{3}}{partial W_{s}}=frac{partial L_{3}}{partial O_{3}} frac{partial O_{3}}{partial S_{3}} frac{partial S_{3}}{partial W_{s}}+frac{partial L_{3}}{partial O_{3}} frac{partial O_{3}}{partial S_{3}} frac{partial S_{3}}{partial S_{2}} frac{partial S_{2}}{partial W_{s}}+frac{partial L_{3}}{partial O_{3}} frac{partial O_{3}}{partial S_{3}} frac{partial S_{3}}{partial S_{2}} frac{partial S_{2}}{partial S_{1}} frac{partial S_{1}}{partial W_{s}} WsL3=O3L3S3O3WsS3+O3L3S3O3S2S3WsS2+O3L3S3O3S2S3S1S2WsS1
  在LSTM中为: ∂ L 3 ∂ W s = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ W s + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 2 ∂ W s + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 1 ∂ W s frac{partial L_{3}}{partial W_{s}}=frac{partial L_{3}}{partial O_{3}} frac{partial O_{3}}{partial S_{3}} frac{partial S_{3}}{partial W_{s}}+frac{partial L_{3}}{partial O_{3}} frac{partial O_{3}}{partial S_{3}} frac{partial S_{2}}{partial W_{s}}+frac{partial L_{3}}{partial O_{3}} frac{partial O_{3}}{partial S_{3}} frac{partial S_{1}}{partial W_{s}} WsL3=O3L3S3O3WsS3+O3L3S3O3WsS2+O3L3S3O3WsS1
  因为 ∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t tanh ⁡ ′ σ ( W f X t + b f ) ≈ 0 ∣ 1 prod_{j=k+1}^{t} frac{partial S_{j}}{partial S_{j-1}}=prod_{j=k+1}^{t} tanh ^{prime} sigmaleft(W_{f} X_{t}+b_{f}right) approx 0 | 1 j=k+1tSj1Sj=j=k+1ttanhσ(WfXt+bf)01
  这样就解决了传统RNN中梯度消失的问题。

最后

以上就是潇洒小蚂蚁为你收集整理的LSTM如何解决梯度消失与梯度爆炸的全部内容,希望文章能够帮你解决LSTM如何解决梯度消失与梯度爆炸所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(55)

评论列表共有 0 条评论

立即
投稿
返回
顶部