我是靠谱客的博主 刻苦画笔,最近开发中收集的这篇文章主要介绍循环神经网络(RNN)与长短期记忆网络(LSTM)循环神经网络(RNN)长短期记忆网络(LSTM),觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

文章目录

  • 循环神经网络(RNN)
    • RNN网络结构
    • RNN的神经元个数
    • RNN前向传播
    • RNN反向传播
    • RNN的梯度消失问题
  • 长短期记忆网络(LSTM)
    • LSTM结构
    • LSTM反向传播
    • LSTM神经元个数

循环神经网络(RNN)

  如果我们的数据是一个时间序列,且序列长短不一,在每一个时间点存在数据,如下 &lt; … , x t − 2 , x t − 1 , x t , x t + 1 , x t + 2 , ⋯ &gt; &lt;dots , x_{t-2}, x_{t-1}, x_{t}, x_{t+1},x_{t+2}, dots&gt; <,xt2,xt1,xt,xt+1,xt+2,>可以说不同时刻的数据相互之间十分可能存在某种联系。

  对于这种数据,如果考虑使用DNN或者CNN,首先,DNN与CNN的batch维度是不起作用的,因为batch中的数据在前向传播是独立进行,反向传播仅仅把误差相加,DNN与CNN显然无法很好的学习batch中的数据在时间上的相关规律。另一方面,如果我们明确知道这种数据的时间连续性在前100个时间点内,我们可以增加DNN的输入神经元或CNN的输入通道数到100,把前100个时间点的数据当做特征字段 x x x输入,来得到该时间点的目标输出 y y y。但是实际应用中,序列长短不一(可能还不够100),也有可能与超过100个时间点之前的数据相关系,这样我们的CNN和DNN就不好解决了,因此引入了RNN。

RNN网络结构

  RNN的一个cell往往占据一个隐层,只有一个隐层的RNN其网络结构如下图左,按时间轴展开后就得到了右边的图像。

其中的 h h h所代表的模块即是RNN的cell,它有一个箭头(权重是 W W W)指向自己,这就是"循环"二字的由来。一般看RNN要看其展开后的图像, x x x是网络的输入, h h h( h i d d e n hidden hidden的缩写)是隐藏层状态, o o o是网络的输出, y y y是期望输出, L L L就是 y y y o o o的损失,右上角的 t t t即是在 t t t时刻的数据,我们看 t t t时刻的隐藏层状态 h ( t ) h^{(t)} h(t),它是由 t t t时刻的输入 x ( t ) x^{(t)} x(t) t − 1 t-1 t1时刻的隐藏层状态 h ( t − 1 ) h^{(t-1)} h(t1)共同决定,可见RNN存在同一隐层内部的传播,这在CNN和DNN这种前馈神经网络结构中是不存在的,正是由于这种隐层内的传播,让RNN能够综合学习之前时刻的数据,来输出当前时刻的数据。

  上图中的权重 U , V , W U,V,W UVW,对于展开后的RNN,是参数共享的,不同时刻的 x 、 h 、 o x、h、o xho在计算时,共享相同的权重值。

  在不同的时刻,RNN有不同的输入 x ( t ) x^{(t)} x(t),同样就也有对应的输出 y ( t ) y^{(t)} y(t),但有时我们可能在输入完一个序列 x ( 1 ) … x ( n ) x^{(1)}dots x^{(n)} x(1)x(n)后,只要最后的输出 y ( n ) y^{(n)} y(n),或者只要部分时刻的输出,或者有其它更复杂的情况,根据输入输出,我们可以对RNN进行分类。

  当然RNN一般都不会只有一个隐层,我们可以通过叠加RNN cell来增加RNN的深度,构成深度RNN结构,一般RNN的深度越深,学习能力越强,但是学习速率会下降。下面给出一个有5个隐层的RNN结构。

红色部分表示RNN的输入部分,绿色部分表示RNN的输出部分。不难看出来,RNN的输入包含两个部分,除了数据输入 x x x外,我们还要初始化隐藏层状态的输入 h h h,同理,RNN的输出也包含两个部分。

RNN的神经元个数

  我们在比较不同的神经网络对某一具体问题的效果时,需要把隐层的神经元个数设置成一样的。我一直认为RNN是没有神经元这个概念的,只有RNN cell这个概念,但是一个前辈明确指出RNN是有神经元的概念。针对RNN的神经元个数,网上有不同的说法,blog认为RNN的神经元个数近似可以看成,待更新参数(权值和偏置)的个数,还有人认为RNN某一隐层的神经元个数就是把RNN展开后,RNN结构在时间上的迭代次数,即输入的时间序列的长度。这两种定义都不是很到位,这样定义神经元,无法很好地对接DNN中的神经元。

  在说RNN的神经元之前,先说明一下RNN的输入 x x x的size。

  RNN的输入 x x x的size可以表示成 ( b a t c h   s i z e , s e q u e n c e   l e n g t h , i n p u t   s i z e ) (batch~size, sequence~length, input~size) (batch size,sequence length,input size),如果我们的batch size是10,对于batch中的一条记录,其包含的时间步总长sequence length是20,即我们的时间序列长是20,batch中的一条记录在某一时刻的特征总数是30,那么我们输入的size是 ( 10 , 20 , 30 ) (10, 20, 30) (10,20,30),同样的,每一个隐藏层的隐藏状态 h h h可以定义成 ( b a t c h   s i z e , s e q u e n c e   l e n g t h , h i d d e n   s i z e ) (batch~size, sequence~length, hidden~size) (batch size,sequence length,hidden size)

  我们把只有一个隐层的RNN,其 t − 1 t-1 t1时刻到 t t t时刻展开的更加彻底,见下图:

m就是 x x x的size中的 i n p u t   s i z e input~size input size,n就是 h h h的size中的 h i d d e n   s i z e hidden~size hidden size,这样我们就可以按照看DNN结构的方式,来看RNN的 t t t时刻的结构,可见隐层的神经元个数就是 h i d d e n   s i z e hidden~size hidden size

RNN前向传播

  在 t t t时刻,RNN根据该时刻的输入 x t x_{t} xt,和上一时刻保留的信息 h t − 1 h_{t-1} ht1,按照下面的公式计算该时刻的输出 o t o_{t} ot h t = f ( U x t + W h t − 1 + b ) o t = f ( V h t + c ) h_{t} = f(Ux_{t}+Wh_{t-1}+b) \ o_{t} = f(Vh_{t}+c) \ ht=f(Uxt+Wht1+b)ot=f(Vht+c)

其中 f ( ⋅ ) f(·) f()是激活函数,最终的损失函数是对需要考虑的所有时刻的损失的叠加 L = ∑ t = 1 T l t = ∑ t = 1 T l ( o t , y t ) L= sumlimits_{t=1}^{T}l_{t}=sumlimits_{t=1}^{T}l(o_{t}, y_{t}) L=t=1Tlt=t=1Tl(ot,yt)

其中, y t y_{t} yt t t t时刻的期望输出。

RNN反向传播

  RNN的反向传播基本思路和DNN和CNN的反向传播类似,但是,由于RNN具有时间因素,最后的损失函数又考虑了所有时刻的误差,所以RNN的反向传播要沿着不同时刻的网络结构进行,所以叫做BPTT(Back-Propagation Through Time)。RNN的反向传播要更新的参数包括权重 U , V , W U,V,W U,V,W和偏置 b , c b,c b,c

  以一个隐层的RNN为例,下面的推导针对batch中的一个数据,多个数据只需要求和取平均即可。
  损失对 V V V c c c的偏导数如下: ∂ L ∂ V = ∑ t = 1 T ∂ l t ∂ o t ∂ o t ∂ V frac{partial L}{partial V} = sumlimits_{t=1}^{T}frac{partial l_{t}}{partial o_{t}}frac{partial o_{t}}{partial V} VL=t=1TotltVot ∂ L ∂ c = ∑ t = 1 T ∂ l t ∂ o t ∂ o t ∂ c frac{partial L}{partial c} = sumlimits_{t=1}^{T}frac{partial l_{t}}{partial o_{t}}frac{partial o_{t}}{partial c} cL=t=1Totltcot

  我们定义RNN中的残差结构如下: δ t = ∂ L ∂ h t = ∑ t = 1 T ∂ l t ∂ h t delta_{t}=frac{partial L}{partial h_{t}} =sumlimits_{t=1}^{T}frac{partial l_{t}}{partial h_{t}} δt=htL=t=1Thtlt由于 h t h_{t} ht只传播给 o t o_{t} ot以及 h t + 1 h_{t+1} ht+1(最后的时刻没有 h t + 1 h_{t+1} ht+1),所以 δ t = ∂ L ∂ h t = ∑ t = 1 T ( ∂ l t ∂ o t ∂ o t ∂ h t + ∂ l t ∂ h t + 1 ∂ h t + 1 ∂ h t ) = ∂ l t ∂ o t ∂ o t ∂ h t + ∑ t = 1 T ∂ l t ∂ h t + 1 ∂ h t + 1 ∂ h t = ∂ l t ∂ o t ∂ o t ∂ h t + δ t + 1 ∂ h t + 1 ∂ h t begin{aligned} delta_{t} &amp; =frac{partial L}{partial h_{t}} \ &amp; = sumlimits_{t=1}^{T}(frac{partial l_{t}}{partial o_{t}}frac{partial o_{t}}{partial h_{t}}+frac{partial l_{t}}{partial h_{t+1}}frac{partial h_{t+1}}{partial h_{t}})\ &amp; = frac{partial l_{t}}{partial o_{t}}frac{partial o_{t}}{partial h_{t}}+sumlimits_{t=1}^{T}frac{partial l_{t}}{partial h_{t+1}}frac{partial h_{t+1}}{partial h_{t}}\ &amp; = frac{partial l_{t}}{partial o_{t}}frac{partial o_{t}}{partial h_{t}}+delta_{t+1}frac{partial h_{t+1}}{partial h_{t}}\ end{aligned} δt=htL=t=1T(otlthtot+ht+1lththt+1)=otlthtot+t=1Tht+1lththt+1=otlthtot+δt+1htht+1

  损失对 W W W的偏导数,由于 h t h_{t} ht中含有 h t − 1 h_{t-1} ht1 h t − 1 h_{t-1} ht1中也有W,同理 h t − 1 h_{t-1} ht1中仍然含有 h t − 2 h_{t-2} ht2,所以损失对 W W W的偏导数要更复杂。具体公式如下: ∂ L ∂ W = ∑ t = 1 T ∂ l t ∂ h t ∂ h t ∂ W = ∑ t = 1 T δ t ( ∂ h t ∂ W + ∂ h t ∂ h t − 1 ∂ h t − 1 ∂ W + ∂ h t ∂ h t − 1 ∂ h t − 1 ∂ h t − 2 ∂ h t − 2 ∂ W + … &ThinSpace; ) = ∑ t = 1 T δ t ∑ k = 1 t ( ∏ j = k t − 1 ∂ h j + 1 ∂ h j ) ∂ h k ∂ W begin{aligned} frac{partial L}{partial W} &amp; = sumlimits_{t=1}^{T}frac{partial l_{t}}{partial h_{t}}frac{partial h_{t}}{partial W} \ &amp; = sumlimits_{t=1}^{T}delta_{t}(frac{partial h_{t}}{partial W}+frac{partial h_{t}}{partial h_{t-1}}frac{partial h_{t-1}}{partial W}+frac{partial h_{t}}{partial h_{t-1}}frac{partial h_{t-1}}{partial h_{t-2}}frac{partial h_{t-2}}{partial W}+dots ) \ &amp; = sumlimits_{t=1}^{T}delta_{t}sumlimits_{k=1}^{t}(prodlimits_{j=k}^{t-1}frac{partial h_{j+1}}{partial h_{j}})frac{partial h_{k}}{partial W} end{aligned} WL=t=1ThtltWht=t=1Tδt(Wht+ht1htWht1+ht1htht2ht1Wht2+)=t=1Tδtk=1t(j=kt1hjhj+1)Whk

注意第二个等号的右式,括号中的 ∂ h t ∂ W frac{partial h_{t}}{partial W} Wht,在对 W W W求偏导时, h t − 1 h_{t-1} ht1看成常数了。损失对 U U U b b b的偏导数如下:
∂ L ∂ U = ∑ t = 1 T ∂ l t ∂ h t ∂ h t ∂ U = ∑ t = 1 T δ t ( ∂ h t ∂ U + ∂ h t ∂ h t − 1 ∂ h t − 1 ∂ U + ∂ h t ∂ h t − 1 ∂ h t − 1 ∂ h t − 2 ∂ h t − 2 ∂ U + … &ThinSpace; ) = ∑ t = 1 T δ t ∑ k = 1 t ( ∏ j = k t − 1 ∂ h j + 1 ∂ h j ) ∂ h k ∂ U begin{aligned} frac{partial L}{partial U} &amp; = sumlimits_{t=1}^{T}frac{partial l_{t}}{partial h_{t}}frac{partial h_{t}}{partial U} \ &amp; = sumlimits_{t=1}^{T}delta_{t}(frac{partial h_{t}}{partial U}+frac{partial h_{t}}{partial h_{t-1}}frac{partial h_{t-1}}{partial U}+frac{partial h_{t}}{partial h_{t-1}}frac{partial h_{t-1}}{partial h_{t-2}}frac{partial h_{t-2}}{partial U}+dots ) \ &amp; = sumlimits_{t=1}^{T}delta_{t}sumlimits_{k=1}^{t}(prodlimits_{j=k}^{t-1}frac{partial h_{j+1}}{partial h_{j}})frac{partial h_{k}}{partial U} end{aligned} UL=t=1ThtltUht=t=1Tδt(Uht+ht1htUht1+ht1htht2ht1Uht2+)=t=1Tδtk=1t(j=kt1hjhj+1)Uhk

∂ L ∂ b = ∑ t = 1 T ∂ l t ∂ h t ∂ h t ∂ b = ∑ t = 1 T δ t ( ∂ h t ∂ b + ∂ h t ∂ h t − 1 ∂ h t − 1 ∂ b + ∂ h t ∂ h t − 1 ∂ h t − 1 ∂ h t − 2 ∂ h t − 2 ∂ b + … &ThinSpace; ) = ∑ t = 1 T δ t ∑ k = 1 t ( ∏ j = k t − 1 ∂ h j + 1 ∂ h j ) ∂ h k ∂ b begin{aligned} frac{partial L}{partial b} &amp; = sumlimits_{t=1}^{T}frac{partial l_{t}}{partial h_{t}}frac{partial h_{t}}{partial b} \ &amp; = sumlimits_{t=1}^{T}delta_{t}(frac{partial h_{t}}{partial b}+frac{partial h_{t}}{partial h_{t-1}}frac{partial h_{t-1}}{partial b}+frac{partial h_{t}}{partial h_{t-1}}frac{partial h_{t-1}}{partial h_{t-2}}frac{partial h_{t-2}}{partial b}+dots ) \ &amp; = sumlimits_{t=1}^{T}delta_{t}sumlimits_{k=1}^{t}(prodlimits_{j=k}^{t-1}frac{partial h_{j+1}}{partial h_{j}})frac{partial h_{k}}{partial b} end{aligned} bL=t=1Thtltbht=t=1Tδt(bht+ht1htbht1+ht1htht2ht1bht2+)=t=1Tδtk=1t(j=kt1hjhj+1)bhk

RNN的梯度消失问题

  同DNN一样,RNN的反向传播公式中含有对激活函数的求导,sigmoid激活函数导数小于等于0.25大于0,tanh的激活函数导数小于等于1大于0,这样就会导致,随着反向传播层数的增加以及反向传播经历时间的变长,我们的误差偏导数会不断的乘以激活函数的导数,进而不断变小,造成RNN不容易对靠近输入 x x x和输入 h h h的参数的更新。

  针对这一问题,学者们提出了很多RNN改进形式,最有名的就是LSTM。

长短期记忆网络(LSTM)

  长短期记忆模型(long short time memory,简称LSTM)由Schmidhuber等人在1997年提出,与高速公路网络(highway networks)有异曲同工之妙。LSTM主体结构沿用RNN的样子,只是把RNN cell改成了LSTM cell,引入了门控结构和细胞状态的概念,输入门作用于输入信息,遗忘门作用于之前的记忆信息,二者加权和,得到汇总信息,最后通过输出门决定输出信息。

LSTM结构

  RNN cell可以表示成下图,在该图中激活函数选用的 t a n h tanh tanh

改进后的LSTM cell有很多种不同的结构,最常见最基本的结构可以表示成下图

其中,

可见,LSTM的cell结构要更加的复杂,我们可以看到,LSTM相对于RNN中单一的隐藏状态 h h h,还维系了一个更加全局的细胞状态 c c c,具体看下图

细胞状态也是在隐层内随着时间进行传递,我们可以认为细胞状态就是"主线剧情",或者把LSTM理解成人的大脑,在 t t t时刻,传入的 c t − 1 c_{t-1} ct1细胞状态就代表大脑在经过前 t − 1 t-1 t1个cell后记住的东西,在 t t t时刻的cell中,我们会动态忘记(删去)细胞状态中的一些东西,并加上一些新的东西到我们的细胞状态中,更新后的细胞状态在同 h t − 1 h_{t-1} ht1 x t x_{t} xt组合后产生 t t t时刻该cell的隐藏层状态 h t h_{t} ht h t h_{t} ht会传给更深层的cell以及 t + 1 t+1 t+1时刻的cell。这分别就是LSTM中遗忘门,输入门,输出门的作用。注意在cell中传递的都是张量。

  下面给出了LSTM的遗忘门示意图,它根据 h t − 1 h_{t-1} ht1 x t x_{t} xt决定忘记细胞状态的哪些值。

下面给出了LSTM的输入门示意图,它根据 h t − 1 h_{t-1} ht1 x t x_{t} xt决定,把哪些东西加到细胞状态中。

下面给出了LSTM的输入门和遗忘门,是如何作用到细胞状态的示意图,可以看到,遗忘门经过激活函数,是乘在细胞状态上的,如果选用sigmoid激活函数,sigmoid激活函数输出是0,1之间,乘在细胞状态上,就相当于对细胞状态进行了一种"遗忘"操作,其中0表示"完全遗忘"之前 t − 1 t-1 t1个cell积累的状态值,1表示"完全保留"之前 t − 1 t-1 t1个cell积累的状态值;而输出门是加在细胞状态上的,就相当于对细胞状态进行了一种"写入"操作。

下面给出了LSTM输出门示意图,经过了遗忘门和输入门后的细胞状态,会进一步进行封装,即同 h t − 1 h_{t-1} ht1 x t x_{t} xt组合后产生 t t t时刻该cell的隐藏层状态 h t h_{t} ht h t h_{t} ht会传给更深层的cell以及 t + 1 t+1 t+1时刻的cell。

  第一次看到LSTM的结构时,就被这样的门控结构惊呆了,果然是大神设计出的神经网络结构,个人理解,由于细胞状态是贯穿整个隐层的所有cell的,所以一定程度上缓解了梯度消失问题。但还不是很能理解,在输入门和输出门中为何要设计成双激活函数相乘的结构,不是很清楚这样设计的目的何在。

  用三种门结构解释LSTM是目前很多blog用的方法,其实LSTM还可以用五种门结构来理解,新加入了输入调制门(input modulation gate)和输出调制门(output modulation gate),结构其实还是一样的,只不过用五门结构解释会更加的具体,见下图:

不难看出,输入调制和输出调制都是针对最中间的细胞状态说的。

LSTM反向传播

  这一部分可以看刘建平老师关于LSTM反向传播的推导,过程还是很复杂的。

LSTM神经元个数

  LSTM的神经元的看待方式,同上面RNN的神经元看待方式是一样。


参考资料:
RNN:https://www.cnblogs.com/pinard/p/6509630.html
LSTM:https://www.cnblogs.com/pinard/p/6519110.html
RNN+LSTM:https://blog.csdn.net/zhaojc1995/article/details/80572098
根据输入输出对RNN进行分类:https://blog.csdn.net/qq_38742161/article/details/87560232
循环神经网络:https://mp.weixin.qq.com/s?__biz=MzU4MjQ3MDkwNA==&mid=2247484310&idx=1&sn=0fc55a2784a894100a1ae64d7dbfa23d&chksm=fdb69e01cac1171758cb021fc8779952e55de41032a66ee5417bd3e826bf703247e243654bd0&scene=21#wechat_redirect

最后

以上就是刻苦画笔为你收集整理的循环神经网络(RNN)与长短期记忆网络(LSTM)循环神经网络(RNN)长短期记忆网络(LSTM)的全部内容,希望文章能够帮你解决循环神经网络(RNN)与长短期记忆网络(LSTM)循环神经网络(RNN)长短期记忆网络(LSTM)所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(48)

评论列表共有 0 条评论

立即
投稿
返回
顶部