概述
文章目录
- 循环神经网络(Recurrent Neural Network)基础
- RNN的前向传播
- RNN的反向传播
- RNN的梯度弥散与爆炸
- 长短期记忆网络(Long Short Term Memory)
- LSTM的前向传播
循环神经网络(Recurrent Neural Network)基础
在深度学习领域,神经网络已经被用于处理各类数据,如CNN在图像领域的应用,全连接神经网络在分类问题的应用等。随着神经网络在各个领域的渗透,传统以统计机器学习为主的NLP问题,也逐渐开始采用深度学习的方法来解决。如由Google Brain提出的Word2Vec模型,便将传统BoW等统计方法的词向量方法,带入到了以深度学习为基础的Distribution Representation的方法中来,真正地将NLP问题带入了深度学习的练兵场。当然,RNN的模型并非局限于NLP领域,而是为了解决一系列序列化数据的建模问题,如视频、语音等,而文本也只是序列化数据的一种典型案例。
RNN的特征在于,对于每个RNN神经元,其参数始终共享,即对于文本序列,任何一个输入都经过相同的处理,得到一个输出。在传统的全连接神经网络的结构中,神经元之间互不影响,并没有直接联系,神经元与神经元之间相互独立。而在RNN结构中,隐藏层的神经元开始通过一个隐藏状态所相连,通常会被表示为 h t h_t ht。在理解RNN与全连接神经网络时,需要对两者的结构加以区分,通常,FCN会采用水平方式进行可视化理解,即每一层的神经元垂直排列,而不同层之间以水平方式排布。但在RNN的模型图中,隐藏层的不同神经元之间通常水平排列,而隐藏层的不同层之间以垂直方式排列,如图所示,在FCN网络中,各层水平布局,隐藏层各神经元相互独立,在RNN中,各层以垂直布局,而水平方向上布局着各神经元。注意:RNN结构图只是为了使得结构直观易理解,而在水平方向上其实每个A都相同,对于每个时间步其都是采用同一个神经元进行前向传播。
RNN的前向传播
在RNN中,序列数据按照其时间顺序,依次输入到网络中,而时间顺序则表示时间步的概念。在RNN中,隐藏状态极为重要,隐藏状态是连接各隐藏层各神经元的中介值。如上图,在第一层中,在时间步 t t t,RNN隐藏层神经元得到隐藏状态 h t ( 1 ) h_{t}^{(1)} ht(1),在时间步 t + 1 t+1 t+1,则接受来自上一个时间步的隐藏层输出 h t ( 1 ) h_{t}^{(1)} ht(1),得到新的隐藏状态 h t + 1 ( 1 ) h_{t+1}^{(1)} ht+1(1)。而从垂直方向上看,各层之间,也通过隐藏状态所连接,对于 L 1 L_1 L1到 L 2 L_2 L2, L 2 L_2 L2在水平的时间轴上,各神经元通过隐藏状态 h t ( 2 ) h_{t}^{(2)} ht(2)连接,而层间还将接受前一层的 h t ( 1 ) h_{t}^{(1)} ht(1)的值来作为 x t x_t xt的值,从而获得到该层新的隐藏状态。因此,RNN是一个在水平方向和垂直方向上,均可扩展的结构(水平方向上只是人为添加的易于理解的状态,在工程实践中不存在水平方向的设置)。
根据RNN的定义,可以简单地给出RNN的前向传播过程:
h t = g ( W x t + V h t − 1 + b ) h_t=gleft(Wx_{t}+Vh_{t-1}+bright) ht=g(Wxt+Vht−1+b)
如上式,对于某一层, W 、 V 、 b W、V、b W、V、b均为模型需要学习的参数,通过上图RNN结构图的对应,则应为 L 1 L_1 L1层水平方向所有神经元的参数,**同一层的RNN单元参数相同,即参数共享。**若考虑多层RNN,则可将上式改为:
h t [ i ] = g ( W [ i ] h t [ i − 1 ] + V [ i − 1 ] h t − 1 [ i ] + b [ i ] ) h_{t}^{[i]}=gleft(W^{[i]}h_{t}^{[i-1]}+V^{[i-1]}h_{t-1}^{[i]}+b^{[i]}right) ht[i]=g(W[i]ht[i−1]+V[i−1]ht−1[i]+b[i])
为了简化研究,下文统一对单层RNN进行讨论。
值得注意的是,单层RNN前向传播可做如下变换:
W x t + V h t = [ W V ] × [ x t h t − 1 ] Wx_t+Vh_t=left[begin{array}{cc}W&Vend{array}right]timesleft[begin{array}{c}x_t\h_t-1end{array}right] Wxt+Vht=[WV]×[xtht−1]
为此,我们不妨将参数进行统一表示: W = [ W ; V ] W=left[W;Vright] W=[W;V],其中 [ ⋅ ; ⋅ ] [cdot;cdot] [⋅;⋅]表示拼接操作,则前向传播变为 h t = g ( W [ h t − 1 ; x t ] ⊤ + b ) h_t=gleft(W[h_{t-1};x_t]^{top}+bright) ht=g(W[ht−1;xt]⊤+b)。
再获得隐藏状态后,若需要获得每一个时间步的输出,则需要进一步进行线性变换:
o t = V h t + b o , y t = g ( o t ) o_t=Vh_t+b_o, ;;y_t=g(o_t) ot=Vht+bo,yt=g(ot),其中 V 、 b V、b V、b为参数, g ( ⋅ ) g(cdot) g(⋅)为激活函数,如softmax。
针对单层RNN,可采用上述结构进行描述。
RNN的反向传播
为简化分析,选用RNN的最后时间步的隐藏状态(无输出层)直接作为输出层,即 o u t p u t = h t = g ( W [ h t − 1 ; x t ] ⊤ + b ) output=h_t=gleft(Wleft[h_{t-1};x_{t}right]^{top}+bright) output=ht=g(W[ht−1;xt]⊤+b),若为分类问题,则 g ( ⋅ ) g(cdot) g(⋅)通常为Softmax。定义问题的损失函数为 J ( θ ) = L o s s ( o u t p u t , y ∣ θ ) J(theta)=Lossleft(output,y|thetaright) J(θ)=Loss(output,y∣θ),则在进行反向传播时,需要计算 W 、 b W、b W、b的梯度,可进行如下推导:
Δ W = ∂ J ( θ ) ∂ W = ∂ ∂ W L o s s ( o u t p u t , y ) = L o s s ( o u t p u t , y ) ′ ∂ g ( W [ h t − 1 ; x ] ⊤ + b ) ∂ W = L o s s ( o u t p u t , y ) ′ g ( ⋅ ) ′ [ h t − 1 ; x t ] Delta W=frac{partial J(theta)}{partial W}=frac{partial}{partial W} Loss(output,y)=Loss(output,y)'frac{partial gleft(Wleft[h_{t-1};xright]^{top}+bright)}{partial W}=Loss(output,y)'g(cdot)'[h_{t-1};x_{t}] ΔW=∂W∂J(θ)=∂W∂Loss(output,y)=Loss(output,y)′∂W∂g(W[ht−1;x]⊤+b)=Loss(output,y)′g(⋅)′[ht−1;xt]
然而,在RNN的反向传播中,不仅需要根据垂直方向进行梯度推导,同时需要根据水平方向,按照时间步进行梯度推导,即RNN中的BPTT(Back Propagation Through Time)反向传播。从公式中也可以看出,在前向传播过程中 h t h_t ht是关于 W W W和 h t − 1 h_{t-1} ht−1的函数值,即 h t = f ( W [ h t − 1 ; x t ] ⊤ + b ) h_{t}=fleft(Wleft[h_{t-1};x_{t}right]^top + bright) ht=f(W[ht−1;xt]⊤+b),则 h t h_t ht可以进一步进行微分,于是将 h t h_t ht关于 W W W求偏导,以循着时间轴更新 t − 1 t-1 t−1时刻的 W W W:
∂ h t ∂ [ h t − 1 ; x t ] ∂ [ h t − 1 ; x t ] ∂ W = f ( ⋅ ) ′ W [ h t − 2 ; x t − 1 ] ⇒ Δ W t − 1 = L o s s ( o u t p u t , y ) ′ g ( ⋅ ) ′ f ( ⋅ ) ′ W [ h t − 2 ; x t − 1 ] frac{partial h_t}{partial [h_{t-1};x_t]}frac{partial [h_{t-1};x_t]}{partial W}=f(cdot)'W[h_{t-2};x_{t-1}]Rightarrow Delta W_{t-1}=Loss(output,y)'g(cdot)'f(cdot)'W[h_{t-2};x_{t-1}] ∂[ht−1;xt]∂ht∂W∂[ht−1;xt]=f(⋅)′W[ht−2;xt−1]⇒ΔWt−1=Loss(output,y)′g(⋅)′f(⋅)′W[ht−2;xt−1]
根据反向传播的规则,每个在当前时间步 t t t应向前追溯直到 t 0 t_0 t0,计算梯度并更新参数,而在RNN中时间步中的 W W W参数被所有步共享,因此梯度是对同一个参数计算,为此可以将梯度作求和,一次性更新至 W W W,如图每个箭头表示一次梯度计算,则在 t 3 t_3 t3时刻计算梯度时,不仅需要直接计算当前时刻的梯度,还仍需根据时间轴,分别计算 t 2 , t 1 t_2,t_1 t2,t1时刻的梯度。
注:本推导在假设RNN仅使用一个输出,即最后一个时间步的输出为最终输出,而RNN在每个时间步均有输出,若考虑多个输出,则损失函数不同,即损失为各时间步损失的总和,而在计算梯度时,需要对每个时间步输出均计算一个输出,即 L o s s = ∑ t J ( θ ) Loss=sumlimits^t J(theta) Loss=∑tJ(θ).
则在 t → 1 trightarrow1 t→1的过程中, W W W更新的梯度为 Δ W = ( L o s s ( o u t p u t , y ) ′ g ( ⋅ ) ′ [ h t − 1 ; x t ] ) + ∑ k = 1 T − 1 ( ( ∏ t = k T L o s s ( o u t p u t , y ) ′ g ( ⋅ ) ′ f ( ⋅ ) ′ W ) [ h k − 1 ; x k ] ) Delta W=left(Loss(output,y)'g(cdot)'[h_{t-1};x_t]right)+sumlimits_{k=1}^{T-1}left(left(prod limits_{t=k}^{T}Loss(output,y)'g(cdot)'fleft(cdotright)'Wright)[h_{k-1};x_{k}]right) ΔW=(Loss(output,y)′g(⋅)′[ht−1;xt])+k=1∑T−1((t=k∏TLoss(output,y)′g(⋅)′f(⋅)′W)[hk−1;xk])
对于偏置 b b b采用相同方式推导,此处不再重复推导。
注意:此处和后文若无特殊说明,均只讨论单层RNN,多层RNN则将RNN单元视为FCN中层即可。
RNN的梯度弥散与爆炸
根据上节的推导,可知,在进行BPTT时,RNN单元的反向传播梯度如下:
Δ W t = ( L o s s ( o u t p u t , y ) ′ g ( ⋅ ) ′ ∏ t T − t f ( W [ h t − 1 ; x t ] ⊤ + b ) ′ ) [ h t − 1 ; x t ] Delta W_{t}=left(Loss(output,y)'g(cdot)'prod limits_{t}^{T-t}fleft(Wleft[h_{t-1};x_tright]^top + bright)'right)left[h_{t-1};x_{t}right] ΔWt=(Loss(output,y)′g(⋅)′t∏T−tf(W[ht−1;xt]⊤+b)′)[ht−1;xt]
若激活函数 f ( ⋅ ) f(cdot) f(⋅)采用 t a n h tanh tanh或 s i g m o i d sigmoid sigmoid,图像如图:
对激活函数求导,当 f ( x ) = 1 1 + e − x f(x)=frac{1}{1+e^{-x}} f(x)=1+e−x1时, f ( x ) ′ = − ( 1 + e − x ) − 2 e − x = − e − x ( 1 + e − x ) 2 = − e − x + 1 ( 1 + e − x ) 2 + 1 ( 1 + e − x ) 2 = − f ( x ) + f ( x ) 2 = f ( x ) ( 1 − f ( x ) ) f(x)'=-(1+e^{-x})^{-2}e^{-x}=-frac{e^{-x}}{(1+e^{-x})^2}=-frac{e^{-x}+1}{(1+e^{-x})^2}+frac{1}{(1+e^{-x})^2}=-f(x)+f(x)^2=f(x)left(1-fleft(xright)right) f(x)′=−(1+e−x)−2e−x=−(1+e−x)2e−x=−(1+e−x)2e−x+1+(1+e−x)21=−f(x)+f(x)2=f(x)(1−f(x))
当 f ( x ) = e x − e − x e x + e − x f(x)=frac{e^x-e^{-x}}{e^x+e^{-x}} f(x)=ex+e−xex−e−x时,
f ( x ) ′ = e x + e − x e x + e − x + ( e x − e − x ) ⋅ ( 1 e x + e − x ) 2 ⋅ ( e x − e − x ) ( − 1 ) = 1 − f ( x ) 2 begin{aligned}f(x)'&=frac{e^x+e^{-x}}{e^x+e^{-x}}+(e^x-e^{-x})cdotleft(frac{1}{e^x+e^{-x}}right)^2cdotleft(e^x-e^{-x}right)(-1)\&=1-f(x)^2end{aligned} f(x)′=ex+e−xex+e−x+(ex−e−x)⋅(ex+e−x1)2⋅(ex−e−x)(−1)=1−f(x)2
则 s i g m o i d sigmoid sigmoid与 t a n h tanh tanh导数图像如下图所示:
从图像可以看出,在激活函数的两端,导数均介接近于0,根据上述RNN梯度的推导,假设当前处于最后一个时间步 t t t,则在向前BPTT时,会得出 ∏ f ( ⋅ ) ′ prod f(cdot)' ∏f(⋅)′的计算,当 f ( x ) f(x) f(x)值接近于两端时,则其梯度异常接近于0,并且 s i g m o i d sigmoid sigmoid导数最大值才为 1 4 frac{1}{4} 41,多个接近于0的数相乘,将导致梯度呈指数下降趋势,接近于0,导致梯度弥散。随着序列的变长, ∏ f ( ⋅ ) ′ prod f(cdot)' ∏f(⋅)′的值越小,这便说明,RNN不具备长期记忆,而只具备短期记忆。
由于梯度弥散,导致在序列长度很长时,无法在较后的时间步中,按照梯度更新较前时间步的 W W W,导致无法根据后续序列来修改前向序列的参数,使得前向序列无法很好地做特征提取,使得在长时间步过后,模型将无法再获取有效的前向序列记忆信息。
梯度弥散,在RNN属于重要问题,为此便提出了以LSTM、GRU等结构的变种,来解决RNN短期记忆的瓶颈。同样的,根据上述梯度的推导,梯度中 ∏ W prod W ∏W将会导致参数累乘,若初始参数较大时,则较大数相乘,将导致梯度爆炸,然而梯度爆炸相对于梯度弥散较容易解决,通常加入梯度裁剪即可一定程度缓解。
长短期记忆网络(Long Short Term Memory)
前面说到,RNN单元在面对长序列数据时,很容易便遭遇梯度弥散,使得RNN只具备短期记忆,即RNN面对长序列数据,仅可获取较近的序列的信息,而对较早期的序列不具备记忆功能,从而丢失信息。为此,为解决该类问题,便提出了LSTM结构,其核心关键在于:
- 提出了门机制:遗忘门、输入门、输出门;
- 细胞状态:在RNN中只有隐藏状态的传播,而在LSTM中,引入了细胞状态。
LSTM的前向传播
如下图,为三个LSTM单元的连结,其中相较于传统RNN单元,其多了上下两条轴,分别用于承载细胞状态 C C C及隐藏状态 h h h的信息流动,而其中 σ sigma σ则被称为门,通过乘运算于和运算实现数据的合并于过滤。
为更好地比较LSTM与RNN的区别,再此将RNN前向传播记录如下:
h t = f ( U h t − 1 + W x t + b h ) o t = V h t + b o y t = g ( o t ) begin{aligned}h_t&=f(Uh_{t-1}+Wx_t+b_h)\o_t&=Vh_t+b_o\y_t&=g(o_t)end{aligned} htotyt=f(Uht−1+Wxt+bh)=Vht+bo=g(ot)
紧接着,对LSTM的门进行定义,其均为:
g a t e f , i , o ( h t − 1 , x t ) = σ ( U h t − 1 + W x t + b ) gate_{f,i,o}(h_{t-1}, x_t)=sigma(Uh_{t-1}+Wx_t+b) gatef,i,o(ht−1,xt)=σ(Uht−1+Wxt+b)
其中, f , i , o f,i,o f,i,o分别表示遗忘门、输入门、输出们,对应地, U , W , b U,W,b U,W,b在不同门中,也应为不同的参数。为此,可卸除LSTM详细的前向传播过程。
如图中各 σ sigma σ,则表示各门,其与 × , + times,+ ×,+运算做到了信息过滤和叠加。
在遗忘门: f t = σ ( U f h t − 1 + W f x t + b f ) f_t=sigma(U_fh_{t-1}+W_fx_t+b_f) ft=σ(Ufht−1+Wfxt+bf),由之前所介绍的 s i g m o i d sigmoid sigmoid函数可知,其函数值在 ( 0 , 1 ) (0,1) (0,1)范围内。这里可以思考一下计算机中,门电路的思想,在逻辑电路中,分为“与门”,“或门”,“非门”等,对于“与门”,只有当两者均为1时结果为1,同样地对于遗忘门的运算,其输出值为 ( 0 , 1 ) (0,1) (0,1),当进行乘法运算时,是否也能达到信息过滤的效果呢?
结果很显然,当任何一个数乘以0时,其值为0,那么在后续的线性运算过程中其仍然为0,便可表示,其信息被忽略,因为到下一层时,其未产生信息叠加。
同理,对于输入门,我们有:
i t = σ ( U i h t − 1 + W i x t + b i ) i_t=sigma(U_ih_{t-1}+W_ix_t+b_i) it=σ(Uiht−1+Wixt+bi)
而输入门主要控制对输入的信息进行过滤,即在输入时选择性地抛弃某些信息,而抛弃的信息,即为输入门中输出为0的特征维度。同时,在时间步 t t t,原输入应为: h t − 1 , x t h_{t-1},x_t ht−1,xt,按照传统的RNN的前向传播,输入应经过线性变换后进行激活,并且激活函数通常使用 t a n h tanh tanh,即: C t ~ = t a n h ( U x h t − 1 + W x x t + b x ) widetilde{C_t}=tanh(U_xh_{t-1}+W_xx_t+b_x) Ct =tanh(Uxht−1+Wxxt+bx)。
上述输入的变化,可以对应RNN的输入过程。
由于加了门机制,则需要对输入的信息,进行过滤,而输入信息在LSTM中包含:细胞状态、隐藏状态、当前时间步输入。其中隐藏状态、当前时间步,已经作为输入经过传统的RNN变换得到 C t ~ widetilde{C_t} Ct ,还剩下细胞状态,因此需要进一步将细胞状态与 C t ~ widetilde{C_t} Ct 融合,并得到新的细胞状态:
C t = f t ⊙ C t − 1 + i t ⊙ C t ~ C_t=f_todot C_{t-1} + i_t odot widetilde{C_t} Ct=ft⊙Ct−1+it⊙Ct ,其中 ⊙ odot ⊙表示element-wise乘积。
在输出门中,同样采用相同的方式得到门概率分布: o t = σ ( U o h t − 1 + W o x t + b o ) o_t=sigma(U_oh_{t-1}+W_ox_t+b_o) ot=σ(Uoht−1+Woxt+bo)。输出门的作用在于,对于要输出给下一个时间步的信息,进行一定地过滤,有选择性地保留和去除之前时间步的某些数据。因此,有 h t = o t ⊙ t a n h ( C t ) h_t=o_todot tanh(C_t) ht=ot⊙tanh(Ct) 。得到 h t h_t ht后,便可进一步得到 y t y_t yt,其过程与RNN一致。至此,LSTM的前向传播过程即以结束。
LSTM的结构有效地解决了RNN的短期依赖瓶颈。但是从模型结构可以看出,相较于RNN,LSTM含有更多的参数需要学习,从而导致LSTM的学习速度大大降低。
上述公式推导过程中,同样可以采用拼接的方式,使得 W : = [ U W ] W:=left[begin{array}{cc}U&Wend{array}right] W:=[UW],而 X : = [ h t − 1 x t ] X:=left[begin{array}{c}h_{t-1}\x_tend{array}right] X:=[ht−1xt]。
对前向传播的过程进行整理,可得:
C t = σ ( U i h t − 1 + W i x t + b i ) ⊙ t a n h ( U x h t − 1 + W x x t + b x ) + σ ( U f h t − 1 + W f x t + b f ) ⊙ C t − 1 h t = t a n h ( C t ) ⊙ σ ( U o h t − 1 + W o x t + b o ) y t = f ( V y h t + b y ) begin{aligned}C_t&=sigma(U_ih_{t-1}+W_ix_t+b_i)odot tanh(U_xh_{t-1}+W_xx_t+b_x) + sigma(U_fh_{t-1}+W_fx_t+b_f) odot C_{t-1}\ h_t&=tanh(C_t)odot sigma(U_oh_{t-1}+W_ox_t+b_o)\y_t&=f(V_yh_t+b_y)end{aligned} Cthtyt=σ(Uiht−1+Wixt+bi)⊙tanh(Uxht−1+Wxxt+bx)+σ(Ufht−1+Wfxt+bf)⊙Ct−1=tanh(Ct)⊙σ(Uoht−1+Woxt+bo)=f(Vyht+by)
最后
以上就是冷傲糖豆为你收集整理的深入理解RNN与LSTM循环神经网络(Recurrent Neural Network)基础长短期记忆网络(Long Short Term Memory)的全部内容,希望文章能够帮你解决深入理解RNN与LSTM循环神经网络(Recurrent Neural Network)基础长短期记忆网络(Long Short Term Memory)所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复