我是靠谱客的博主 精明小蚂蚁,最近开发中收集的这篇文章主要介绍RNN介绍及梯度详细推导1. RNN简介2. RNN的几种常见结构3. RNN的梯度推导,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

RNN介绍及梯度详细推导

  • 1. RNN简介
  • 2. RNN的几种常见结构
    • 结构一
    • 结构二
    • 结构三
  • 3. RNN的梯度推导

1. RNN简介

RNN(Recurrent Neural Network, 循环神经网络)算得上是极具魅力的一类神经网络。同经典的前馈神经网络相比较(如多层感知器、深度置信网络、卷积神经网络等),RNN允许网络隐层(hidden layer)的输出再以输入的形式作用于该隐层自己。假设一个文本序列“I have a pen”,我们可以用一些方法将其中的每个单词进行向量化,如采用one-hot编码、ti-idf、词嵌入等方法。再假设我们有一个计算单元,它有两个输入。一个输入是上述提到的单词的向量,另一个输入是一个大小为m的向量。该计算单元的输出也是一个大小为m的向量。首先,我们随机初始化一个大小为m的向量,并将该向量和上述文本序列的第一个词“I”的词向量输入给计算单元。计算单元输出一个大小为m的向量。接着我们又将这个刚得到的大小为m的向量和第二个词“have”的向量输入给计算单元,又得到一个大小为m的向量。如此重复,直至对文本序列“I have a pen”进行了一次完整遍历,最终得到一个大小为m的向量。这个最终得到的向量就可以作为对原文本序列的一个表征,用于后续任务。例如,接一个输出层再加softmax用来做文本分类。

上述的计算单元有两个重要性质。首先,不论上述文本序列的长度如何变化,只要完成对该文本进行一次遍历,最终的输出向量大小均为m。因此它非常有利于处理不同长度的序列数据。其次,当前词总是伴随着前一个词的输出一起作为输入参与计算(递归),这就意味这序列本身的顺序能够被计算单元很好的捕捉到。因此,这样的计算单元特别适合于对序列顺序敏感的任务,如时序预测、词性标注等任务。

RNN的基本构成就是这样的计算单元。

从理论上看,RNN的一系列后续变体(如GRU、LSTM等)在设计上非常经验(实际上,必须承认目前大多数的网络在设计上都很经验),但却很容易被人们所接受。这些接受绝不仅仅是因为它们强大的性能表现,也包含其背后设计思路的“合理性”(符合直觉)。而经验式的合理,往往会催生大量后续更加“合理”的设计,这直接导致了一个研究领域的迅速崛起。RNN是这样,CNN是这样,Deep Learning也是这样。从实践上看,RNN在涉及到序列或自然语言的任务上获得了前所未有的成功,圈粉无数。本文的目的是详细介绍RNN的梯度推导过程,便于读者对RNN形成更加理性的认识。在正式进入推导之前,先简单介绍一下RNN的三种常见结构。

2. RNN的几种常见结构

为了方便表述,我们把一个序列中的数据以时刻(time stamp)来进行划分。简单来讲,我们可以认为时刻0对应的就是序列中第一个元素;时刻1对应的是第2个元素,以此类推。如果将之前的递归计算方式用时刻的概念来表述,便能更加容易地对网络进行理解。下面我们采用这种方法来介绍RNN的三种常见结构。

结构一

第一种RNN结构
上图所示的RNN结构中,输入序列为   X = ( x 0 , x 1 , . . . , x l − 1 ) X=(x^{0}, x^{1}, ...,x^{l-1} )  X=(x0,x1,...,xl1)。对于任意时刻   t t  t   h t h^t  ht   h t − 1 h^{t-1}  ht1   x t x^t  xt决定。同时,时刻   t t  t对应于目标输出(target)   y t y^t  yt。这样的结构很适合时间序列预测类的任务,因为模型能够通过之前的信息   h t − 1 h^{t-1}  ht1和当前的输入   x t x^t  xt来预测当前时刻的输出   y ^ t hat y^t  y^t。该结构最大的问题是时刻   t t  t的损失   L t L^{t}  Lt与所有之前时刻的状态   h t − 1 , h t − 2 , . . . , h 0 h^{t-1}, h^{t-2},... ,h^{0}  ht1,ht2,...,h0都相关。这意味着要求得   h t h^{t}  ht,首先需要求   h t − 1 h^{t-1}  ht1;而要求   h t − 1 h^{t-1}  ht1,又首先得求出   h t − 2 h^{t-2}  ht2,以此类推。我们知道,在使用BP(Back Propagation,反向传播)求梯度的时候,对每一个输入需要一次前向传播过程来计算网络中的activation。而上述RNN是顺序计算的,难以并行化,因此其训练过程通常比较耗时。

结构二

第二种RNN结构
与前文介绍的第一种RNN结构的区别在于:这种结构取消了相邻时刻间状态   h h  h之间 的连接。取而代之的是前一时刻的输出   o t − 1 o^{t-1}  ot1与当前时刻的状态   h t h^t  ht之间的连接。采用与结构一中类似的分析可知,该RNN结构仍是顺序执行的,似乎除了训练参数可能减少(输出   o { o}  o的维度通常比   h h  h要小),没有其它任何区别。不过实际上,这种网络可以采用名为teacher forcing的方法来训练。简单来讲,既然   h t h^t  ht的输入部分来自   o t − 1 o^{t-1}  ot1,而   o t − 1 o^{t-1}  ot1本身在训练过程中受   y t − 1 y^{t-1}  yt1约束(   o t − 1 o^{t-1}  ot1要尽可能逼近   y t − 1 y^{t-1}  yt1),因此可以直接使用   y t − 1 y^{t-1}  yt1来代替   o t − 1 o^{t-1}  ot1作为   h t h^t  ht的输入。而任意时刻的   y t y^{t}  yt在训练集中已经被提供,所以网络在训练过程中不再需要顺序执行!

不过,这样做的缺点就是   y t y^{t}  yt中通常包含的信息远少于   h t h^{t}  ht,因此模型对历史信息的学习能力远不如结构一。

结构三

在这里插入图片描述
与结构一的唯一区别是:该结构的RNN仅在最后一个时刻   l l  l时才会有输出   o l o^l  ol。显然,这种结构的典型应用就是分本分类。

3. RNN的梯度推导

就求解梯度而言,上述介绍的三种结构的异同如下:

  1. 在结构一中,当前状态   h t h^t  ht不但会影响当前时刻的损失   l t l^t  lt,也会影响所有后续时刻的损失;
  2. 在结构二中,如果采用teacher forcing的训练方法,当前时刻的状态   h t h^t  ht只会影响当前时刻的损失   l t l^t  lt,并不影响后续时刻的损失;
  3. 在结构三中,所有时刻的状态均只影响最后的损失(因为只有这一个损失)。

由上述分析可知,结构一的梯度求解是最复杂的。因此,本文仅给出结构一的梯度推导,其它两种结构读者可以自行推导。

我们首先假定输出   y ^ ( t ) hat y^{(t)}  y^(t)是一个经softmax得到的概率分布。可以把这样的结构想象成一个词性标注问题,即对于每一个输入的词   x ( t ) x^{(t)}  x(t),需要将其标注为“名词”、“动词”、“数量词”等有限的词性集合中的一个。那么,softmax得到的概率分布中的每一个概率对应于一个词性。在此假设下,给出结构一的形式化定义:

(1)   h ( t ) = t a n h ( b + W h ( t − 1 ) + U x ( t ) ) , pmb h^{(t)}=mathtt{tanh}(pmb b + pmb W pmb h^{(t-1)} + pmb U pmb x^{(t)}), tag{1}  hhh(t)=tanh(bbb+WWWhhh(t1)+UUUxxx(t)),(1)

(2)   y ^ ( t ) = s o f t m a x ( c + V h ( t ) ) . hat y^{(t)}=mathtt{softmax}(pmb c + pmb V pmb h^{(t)}).tag{2}  y^(t)=softmax(ccc+VVVhhh(t)).(2)

我们首先来梳理一下上述各矩阵的大小。观察公式(1),假设输入   x ( t ) x^{(t)}  x(t)是大小为n的向量(向量均指列向量;即可以看作是大小为(n,1)的矩阵);状态   h ( t ) pmb h^{(t)}  hhh(t)   h ( t − 1 ) pmb h^{(t-1)}  hhh(t1)是大小为k的向量。约定 t a n h mathtt {tanh} tanh作用于矩阵时相当于对矩阵中的每一个元素   x i , j x_{i,j}  xi,j分别执行 t a n h ( x i , j ) mathtt {tanh}(x_{i,j}) tanh(xi,j)

因为矩阵加法和 t a n h mathtt {tanh} tanh函数均不会影响矩阵的大小,故可知公式(1)中的   b pmb b  bbb   W h ( t − 1 ) pmb W pmb h^{(t-1)}  WWWhhh(t1)   U x ( t ) pmb U pmb x^{(t)}  UUUxxx(t)均与   h ( t ) {pmb h^{(t)}}  hhh(t)具有相同的大小,即(k, 1)。那么, U pmb U UUU的大小为(k, n),   W pmb W  WWW的大小为(k, k),   b pmb b  bbb的大小为(k, 1)。

类似的,公式(2)中的softmax也不会改变它输入矩阵的大小。那么假设   y ^ ( t ) hat y^{(t)}  y^(t)的大小为(m, 1),可知   c pmb c  ccc的大小为(m, 1),   V { pmb V}  VVV的大小为(m, n)。为了方便后续推导,将公式(1)和公式(2)进行如下改写:

(3)   a ( t ) = b + W h ( t − 1 ) + U x ( t ) , pmb a^{(t)}=pmb b + pmb W pmb h^{(t-1)} + pmb U pmb x^{(t)}, tag{3}  aaa(t)=bbb+WWWhhh(t1)+UUUxxx(t),(3)

(4)   h ( t ) = t a n h ( a ( t ) ) , pmb h^{(t)}=mathtt{tanh}(pmb a^{(t)}), tag{4}  hhh(t)=tanh(aaa(t)),(4)

(5)   o ( t ) = c + V h ( t ) , pmb o^{(t)}=pmb c + pmb V pmb h^{(t)}, tag{5}  ooo(t)=ccc+VVVhhh(t),(5)

(6)   y ^ ( t ) = s o f t m a x ( o ( t ) ) . hat y^{(t)}=mathtt{softmax}(pmb o^{(t)}). tag{6}  y^(t)=softmax(ooo(t)).(6)
(抱歉公式(2)和(6)中的   y ^ ( t ) hat y^{(t)}  y^(t)没有用加粗斜体,因为MathJax的渲染真的太难看!!!)

假设单个时刻的损失函数定义为交叉熵(Cross-entropy):
(7)   L ( t ) = l o s s ( y ^ ( t ) , y ( t ) ) = − ∑ i y i ( t ) l o g ( y ^ i ( t ) ) , L^{(t)}=loss(hat y^{(t)}, y^{(t)})=-sum_{i}y^{(t)}_i log (hat y^{(t)}_i), tag{7}  L(t)=loss(y^(t),y(t))=iyi(t)log(y^i(t)),(7)
其中   y ( t ) y^{(t)}  y(t)表示真实的标签分布;下标   i i  i表示对应向量的第   i i  i个分量。那么总体损失即为:
(8)   L = ∑ t = 0 l − 1 L ( t ) , L= sum_{t=0}^{l-1}L^{(t)}, tag{8}  L=t=0l1L(t)(8)

其中   l l  l为序列长度。

我们需要求解的梯度为   ∇ c L nabla_c L  cL   ∇ b L nabla_b L  bL   ∇ U L nabla_U L  UL   ∇ V L nabla_V L  VL   ∇ W L nabla_W L  WL (因为渲染问题,下标未使用加粗斜体)。

我们首先求解   ∇ V L nabla_V L  VL。由链式法则可知:
(9)   ∇ V L = ∑ t = 0 l − 1 ∂ L ∂ L ( t ) ∂ L ( t ) ∂ y ^ ( t ) ∂ y ^ ( t ) ∂ o ( t ) ∂ o ( t ) ∂ V . nabla_V L= sum_{t=0}^{l-1} frac{partial L}{partial L^{(t)}} frac{partial L^{(t)}}{partial hat y^{(t)}} frac{partial hat y^{(t)}}{partial pmb o^{(t)}} frac{partial pmb o^{(t)}}{partial pmb V}. tag{9}  VL=t=0l1L(t)Ly^(t)L(t)ooo(t)y^(t)VVVooo(t).(9)

由公式(8)可得:
(10)   ∂ L ∂ L ( t ) = 1. frac{partial L}{partial L^{(t)}}=1. tag{10}  L(t)L=1.(10)

接着考察第二项   ∂ L ( t ) ∂ y ^ ( t ) frac{partial L^{(t)}}{partial hat y^{(t)}}  y^(t)L(t)
(11)   ( ∂ L ( t ) ∂ y ^ ( t ) ) i = ( ∇ y ^ ( t ) L ( t ) ) i = − ∂ y i ( t ) l o g y ^ i ( t ) ∂ y ^ i ( t ) = − y i ( t ) y ^ i ( t ) . (frac{partial L^{(t)}}{partial hat y^{(t)}})_i=(nabla_{hat y^{(t)}}L^{(t)})_i=-frac{partial y_i^{(t)}log hat y_i^{(t)}}{partial hat y_i^{(t)}} = -frac{y_i^{(t)}}{hat y_i^{(t)}}.tag{11}  (y^(t)L(t))i=(y^(t)L(t))i=y^i(t)yi(t)logy^i(t)=y^i(t)yi(t).(11)

再来看第三项   ∂ y ^ ( t ) ∂ o ( t ) frac{partial hat y^{(t)}}{partial o^{(t)}}  o(t)y^(t)。由公式(6)知,
(12)   y ^ i ( t ) = e ( o i ( t ) ) ∑ k e ( o k ( t ) ) . hat y^{(t)}_i = frac{e^{(o^{(t)}_i)}}{sum_{k}e^{(o^{(t)}_k)}}.tag{12}  y^i(t)=ke(ok(t))e(oi(t)).(12)

注意到在公式(12)中,   o ( t ) o^{(t)}  o(t)的每一分量都在分母中出现,而只有一个分量在分子中出现,且该分量的下标   i i  i对应于所求   y ^ ( t ) hat y^{(t)}  y^(t)的下标。因此,在求解偏导数   ∂ y ^ i ( t ) ∂ o j ( t ) frac{partial hat y_i^{(t)}}{partial o^{(t)}_j}  oj(t)y^i(t)时,需要考虑   i i  i   j j  j是否相等。
  i i  i   j j  j不相等,
(13)   ∂ y ^ i ( t ) ∂ o j ( t ) = ∂ ∂ o j ( t ) ( e ( o i ( t ) ) ∑ k e ( o k ( t ) ) ) = − e ( o i ( t ) ) e ( o j ( t ) ) ( ∑ k e ( o k ( t ) ) ) 2 = − y ^ i ( t ) y ^ j ( t ) . frac{partial hat y_i^{(t)}}{partial o^{(t)}_j}= frac{partial}{partial o^{(t)}_j} (frac{e^{(o^{(t)}_i)}}{sum_{k}e^{(o^{(t)}_k)}})=-frac{e^{(o^{(t)}_i)}e^{(o^{(t)}_j)}}{(sum_{k}e^{(o^{(t)}_k)})^2}=-hat y^{(t)}_i hat y^{(t)}_j.tag{13}  oj(t)y^i(t)=oj(t)(ke(ok(t))e(oi(t)))=(ke(ok(t)))2e(oi(t))e(oj(t))=y^i(t)y^j(t).(13)
  i i  i   j j  j相等(都记为   i i  i),则有:
(14)   ∂ y ^ i ( t ) ∂ o i ( t ) = ∂ ∂ o i ( t ) ( e ( o i ( t ) ) ∑ k e ( o k ( t ) ) ) = − e ( o i ( t ) ) ∑ k e ( o k ( t ) ) − e ( o i ( t ) ) e ( o i ( t ) ) ( ∑ k e ( o k ( t ) ) ) 2 = e ( o i ( t ) ) ∑ k e ( o k ( t ) ) − ( e ( o i ( t ) ) ∑ k e ( o k ( t ) ) ) 2 = y ^ i ( t ) ( 1 − y ^ i ( t ) ) . frac{partial hat y_i^{(t)}}{partial o^{(t)}_i}= frac{partial}{partial o^{(t)}_i} (frac{e^{(o^{(t)}_i)}}{sum_{k}e^{(o^{(t)}_k)}})=-frac{e^{(o^{(t)}_i)}sum_{k}e^{(o^{(t)}_k)}-e^{(o_i^{(t)})}e^{(o_i^{(t)})}}{(sum_{k}e^{(o^{(t)}_k)})^2}=frac{e^{(o_i^{(t)})}}{sum_{k}e^{(o^{(t)}_k)}} - (frac{e^{(o_i^{(t)})}}{sum_{k}e^{(o^{(t)}_k)}})^2\ =hat y_i^{(t)}(1-hat y_i^{(t)}).tag{14}  oi(t)y^i(t)=oi(t)(ke(ok(t))e(oi(t)))=(ke(ok(t)))2e(oi(t))ke(ok(t))e(oi(t))e(oi(t))=ke(ok(t))e(oi(t))(ke(ok(t))e(oi(t)))2=y^i(t)(1y^i(t)).(14)

根据上述推导,我们来梳理一下   ∂ L ( t ) ∂ y ^ ( t ) ∂ y ^ ( t ) ∂ o ( t ) frac{partial L^{(t)}}{partial hat y^{(t)}} frac{partial hat y^{(t)}}{partial o^{(t)}}  y^(t)L(t)o(t)y^(t)的结果究竟是什么。因为   L ( t ) L^{(t)}  L(t) 是标量,   y ^ ( t ) hat y^{(t)}  y^(t)是大小为m的向量,因此   ∂ L ( t ) ∂ y ^ ( t ) frac{partial L^{(t)}}{partial hat y^{(t)}}  y^(t)L(t)同样是一个大小为m的向量(记该向量为 M pmb M MMM),该向量第   i i  i个元素的值是   − y i ( t ) y ^ i ( t ) -frac{y_i^{(t)}}{hat y_i^{(t)}}  y^i(t)yi(t)(参见公式(11))。因为   y ^ ( t ) hat y^{(t)}  y^(t)是大小为m的向量,   o ( t ) o^{(t)}  o(t)也是大小为m的向量,所以   ∂ y ^ ( t ) ∂ o ( t ) frac{partial hat y^{(t)}}{partial o^{(t)}}  o(t)y^(t)的结果是大小为(m, m)的矩阵(记该矩阵为 N pmb N NNN)。矩阵的第   i i  i列为   ∂ y ^ ( t ) ∂ o i ( t ) frac{partial hat y^{(t)}}{partial o^{(t) }_i}  oi(t)y^(t)。那么   ( M T × N ) T (M^T times N)^T  (MT×N)T为大小为m的向量,向量的每一个值为   ∂ L ( t ) ∂ o i ( t ) frac{partial L^{(t)}}{partial o^{(t)}_i}  oi(t)L(t),且这个值的大小为:
(15)   ∑ j , i ≠ j ( y j ( t ) y ^ i ( t ) ) − y i ( t ) ( 1 − y ^ i ( t ) ) = − y i ( t ) + y ^ i ( t ) ∑ j y j ( t ) . sum_{j,i neq j}(y_j^{(t)}hat y_i^{(t)}) - y_i^{(t)}(1-hat y_i^{(t)})=-y_i^{(t)}+hat y_i^{(t)}sum_j y_j^{(t)}.tag{15}  j,i̸=j(yj(t)y^i(t))yi(t)(1y^i(t))=yi(t)+y^i(t)jyj(t).(15)

这里需要注意两点。首先,   ∂ y ^ ( t ) ∂ o ( t ) frac{partial hat y^{(t)}}{partial o^{(t)}}  o(t)y^(t)是列向量对列向量求导,其计算结果虽然也为矩阵,但是行列的关系没有明确定义,即谁作为行和谁作为列不明确。因此通常需要结合实际情况来确定计算规则。上述的   ( M T × N ) T (M^T times N)^T  (MT×N)T就是为了保证计算的实际意义来设立的计算规则。它将   ∂ L ( t ) ∂ y ^ ( t ) frac{partial L^{(t)}}{partial hat y^{(t)}}  y^(t)L(t)   ∂ y ^ ( t ) ∂ o ( t ) frac{partial hat y^{(t)}}{partial o^{(t)}}  o(t)y^(t)的计算结构进行合并成   ∂ L ( t ) ∂ o ( t ) frac{partial L^{(t)}}{partial o^{(t)}}  o(t)L(t)的结果。其次,在做推导时,请特别注意   i i  i   j j  j在不同公式中对应的对象的差别。例如,在公式(11)中,   i i  i对应的是 y ^ ( t ) hat y^{(t)} y^(t),而在公式(15)中,   i i  i对应的是   o ( t ) o^{(t)}  o(t)

通常,   y ( t ) y^{(t)}  y(t)中仅有一个分量为1,表示真实类别,而其它分量都为0。因此,
(16)   ∑ j y j ( t ) = 1. sum_j y_j^{(t)}=1.tag{16}  jyj(t)=1.(16)
公式(15)就简化为:
(17)   − y i ( t ) + y ^ i ( t ) . -y_i^{(t)}+hat y_i^{(t)}.tag{17}  yi(t)+y^i(t).(17)

至此,我们把公式(9)中求和符号内的前三项的结果记为:
(18)   ∂ L ∂ L ( t ) ∂ L ( t ) ∂ y ^ ( t ) ∂ y ^ ( t ) ∂ o ( t ) = ∇ o ( t ) L , frac{partial L}{partial L^{(t)}} frac{partial L^{(t)}}{partial hat y^{(t)}} frac{partial hat y^{(t)}}{partial pmb o^{(t)}}=nabla_{o^{(t)}}L,tag {18}  L(t)Ly^(t)L(t)ooo(t)y^(t)=o(t)L,(18)
该列向量的第 i i i个分量的值为   − y i ( t ) + y ^ i ( t ) -y_i^{(t)}+hat y_i^{(t)}  yi(t)+y^i(t)

最后考察公式(9)中求和符号内的最后一项   ∂ o ( t ) ∂ V frac{partial o^{(t)}}{partial V}  Vo(t)。由公式(5)和矩阵求导公式可知:
(19)   ∂ o ( t ) ∂ V = ( h ( t ) ) T . frac{partial pmb o^{(t)}}{partial pmb V}=(h^{(t)})^{T}.tag{19}  VVVooo(t)=(h(t))T.(19)
因此,我们要求的第一个梯度   ∇ V L nabla_V L  VL的结果为:
(20)   ∇ V L = ∑ t = 0 l − 1 ∇ o ( t ) L ( h ( t ) ) T . nabla_V L=sum_{t=0}^{l-1} nabla_{o^{(t)}}L(h^{(t)})^T.tag{20}  VL=t=0l1o(t)L(h(t))T.(20)

(先到这,空了继续码)

最后

以上就是精明小蚂蚁为你收集整理的RNN介绍及梯度详细推导1. RNN简介2. RNN的几种常见结构3. RNN的梯度推导的全部内容,希望文章能够帮你解决RNN介绍及梯度详细推导1. RNN简介2. RNN的几种常见结构3. RNN的梯度推导所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(63)

评论列表共有 0 条评论

立即
投稿
返回
顶部