概述
目录
- RNN
- LSTM
参考一个很全的总结:
预训练语言模型的前世今生 - 从Word Embedding到BERT
RNN部分参考了这个:
循环神经网络
LSTM部分参考了这两个:
LSTM以及三重门,遗忘门,输入门,输出门
LSTM如何解决梯度消失与梯度爆炸
这儿对预训练模型又有了一点理解,也是之前在做VGG实验时在困惑的点,预训练模型在使用时可以有两种做法:一种是Frozen,将参数锁住,在下游应用时不再改变;另一种就是Fine-Tuning,即将参数初始化为预训练模型的参数,下游应用时这里的参数仍然可以改变。
好了进入正题:
RNN
RNN结构最大的特点就是融入了时序信息,其结构如下图所示:
左侧部分称为RNN的一个timestep,对于每一个时刻
t
t
t ,输入的
x
t
x_t
xt 都可以计算出一个
h
t
h_t
ht ,将该信息传入下一个时刻
t
+
1
t+1
t+1 ,这个过程是一个前馈神经网络;接收完一个序列中所有时刻的数据之后从
x
t
x_t
xt 时刻沿时间反向传播(BPTT)计算loss。
RNN的主体结构是
A
A
A ,
A
A
A 的结构如下图所示,输入为
(
h
t
−
1
,
x
t
)
(h_{t-1},x_t)
(ht−1,xt) ,两个权重矩阵
W
h
W_h
Wh 和
W
x
W_x
Wx 可以分开,也可以合并在一起是一个
W
W
W:
可以看到,RNN解决了时序依赖问题,但这里的时序一般是短距离的,短距离依赖影响较大,长距离依赖影响很小(一般超过10步就无能为力了)。
导致长期依赖的原因,在于RNN训练时容易发生梯度爆炸和梯度消失。
梯度爆炸相对友好,因为这时程序会收到NaN错误,同时处理上也可以设置一个梯度阈值,当梯度超过这个阈值时进行截断。
对于梯度消失,主要采用以下三种方式:
- 合理地初始化权重值,使每个神经元尽可能不要取极大或极小值,以避开梯度消失的区域。
- 用ReLU代替sigmoid和tanh作为激活函数。
- 采用其它结构的RNNs,比如LTSM和GRU,这也是最流行的方法。
梯度消失原因:
前向传播过程包括:
- 隐藏状态: h ( t ) = σ ( z ( t ) ) = σ ( U x ( t ) + W h ( t − 1 ) + b ) h^{(t)}=sigma (z^{(t)})=sigma(Ux^{(t)}+Wh^{(t-1)}+b) h(t)=σ(z(t))=σ(Ux(t)+Wh(t−1)+b) , 此处激活函数一般为 t a n h tanh tanh
- 模型输出: o ( t ) = V h ( t ) + c o^{(t)}=Vh^{(t)}+c o(t)=Vh(t)+c
- 预测输出: y ^ = σ ( o ( t ) ) hat{y}=sigma(o^{(t)}) y^=σ(o(t)) ,此处激活函数一般为 s o f t m a x softmax softmax
- 模型损失: L = ∑ t = 1 T L ( t ) L=sum^T_{t=1}L^{(t)} L=∑t=1TL(t)
RNN所有的timestep共享一套参数
U
,
V
,
W
U,V,W
U,V,W ,在RNN反向传播的过程中,需要计算
U
,
V
,
W
U,V,W
U,V,W 的梯度,以
W
W
W 为例,如下(这是一个链式求导…微积分全不会了好无语…):
∂
L
∂
W
=
∑
t
=
1
T
∂
L
∂
y
(
T
)
∂
y
(
T
)
∂
o
(
T
)
∂
o
(
T
)
∂
h
(
T
)
(
∏
k
=
t
+
1
T
∂
h
(
k
)
∂
h
(
k
−
1
)
)
∂
h
(
t
)
∂
W
=
∑
t
=
1
T
∂
L
∂
y
(
T
)
∂
y
(
T
)
∂
o
(
T
)
∂
o
(
T
)
∂
h
(
T
)
(
∏
k
=
t
+
1
T
tanh
′
(
z
(
k
)
)
W
)
∂
h
(
t
)
∂
W
begin{aligned} frac{partial L}{partial W} &= sum_{t=1}^Tfrac{partial L}{partial y^{(T)}} frac{partial y^{(T)}}{partial o^{(T)}} frac{partial o^{(T)}}{partial h^{(T)}}(prod_{k=t+1}^{T} frac{partial h^{(k)}}{partial h^{(k-1)}}) frac{partial h^{(t)}}{partial W}\ &=sum_{t=1}^Tfrac{partial L}{partial y^{(T)}} frac{partial y^{(T)}}{partial o^{(T)}} frac{partial o^{(T)}}{partial h^{(T)}}(prod_{k=t+1}^{T} tanh' (z^{(k)})W) frac{partial h^{(t)}}{partial W} end{aligned}
∂W∂L=t=1∑T∂y(T)∂L∂o(T)∂y(T)∂h(T)∂o(T)(k=t+1∏T∂h(k−1)∂h(k))∂W∂h(t)=t=1∑T∂y(T)∂L∂o(T)∂y(T)∂h(T)∂o(T)(k=t+1∏Ttanh′(z(k))W)∂W∂h(t)
对于公式中的
(
∏
k
=
t
+
1
T
∂
h
(
k
)
∂
h
(
k
−
1
)
)
=
(
∏
k
=
t
+
1
T
tanh
′
(
z
(
k
)
)
W
)
(prod_{k=t+1}^{T} frac{partial h^{(k)}}{partial h^{(k-1)}})=(prod_{k=t+1}^{T} tanh' (z^{(k)})W)
(∏k=t+1T∂h(k−1)∂h(k))=(∏k=t+1Ttanh′(z(k))W) ,tanh的导数总是小于1的,又因为是
(
T
−
(
t
−
+
1
)
)
(T-(t-+1))
(T−(t−+1)) 个timestep参数的连乘,所以如果
W
W
W 小于1,梯度就会消失;如果
W
W
W 的特征值大于1,梯度就会爆炸。
所以,RNN梯度消失的真正含义是,梯度被近距离(当
(
t
+
1
)
(t+1)
(t+1) 趋向于
T
T
T)的梯度主导,远距离会发生爆炸或消失,导致模型难以学到远距离的信息。
值得强调的是,RNN的这一缺陷并非理论上的,而是技术实践上的。换言之,RNN在理论上是一个优秀的模型,前提是我们能够找到一组合适的参数,然而实践上这组参数并不好找。
LSTM
先来大致看看LSTM相比RNN的结构改变是什么,多了一个传输状态:
这个图是LSTM的timestep:
根据这个图,LSTM的前向传播过程包括:
- 遗忘门:接收 t − 1 t-1 t−1 时刻的状态 h t − 1 h_{t-1} ht−1 以及当前的输入 x t x_t xt,经过sigmoid函数之后输出一个0到1之间的值,输出为: f t = σ ( W f h t − 1 + U f x t + b f ) f_t=sigma(W_fh_{t-1}+U_fx_t+b_f) ft=σ(Wfht−1+Ufxt+bf)
- 输入门:这里进行了两个操作,输出分别为: i t = σ ( W i h t − 1 + U i x t + b i ) i_t=sigma(W_ih_{t-1}+U_ix_t+b_i) it=σ(Wiht−1+Uixt+bi), C ~ t = tanh ( W a h t − 1 + U a x t + b a ) tilde C_t=tanh(W_ah_{t-1}+U_ax_t+b_a) C~t=tanh(Waht−1+Uaxt+ba)
- 当前状态:输出为: C t = C t − 1 ⊙ f t + i t ⊙ C ~ t C_t=C_{t-1} odot f_t+i_t odot tilde C_t Ct=Ct−1⊙ft+it⊙C~t
- 输出门:输出为: o t = σ ( W o h t − 1 + U o x t + b o ) o_t=sigma(W_oh_{t-1}+U_ox_t+b_o) ot=σ(Woht−1+Uoxt+bo), h t = o t ⊙ tanh C t h_t=o_t odot tanh C_t ht=ot⊙tanhCt
- 预测输出: y ^ = σ ( V h t + c ) hat y=sigma (Vh_t+c) y^=σ(Vht+c)
对于三个门的作用如下图所示:
关于LSTM如何RNN中解决梯度消失或爆炸:
如上文中所述,RNN中引起梯度消失或爆炸的点在于:
∏
k
=
t
+
1
T
∂
h
(
k
)
∂
h
(
k
−
1
)
=
∏
k
=
t
+
1
T
tanh
′
(
z
(
k
)
)
W
prod_{k=t+1}^{T} frac{partial h^{(k)}}{partial h^{(k-1)}}=prod_{k=t+1}^{T} tanh' (z^{(k)})W
k=t+1∏T∂h(k−1)∂h(k)=k=t+1∏Ttanh′(z(k))W
在LSTM中这个公式是这样的:
∏
k
=
t
+
1
T
∂
h
(
k
)
∂
h
(
k
−
1
)
=
∏
k
=
t
+
1
T
tanh
′
σ
(
W
f
X
t
+
b
f
)
prod_{k=t+1}^{T} frac{partial h^{(k)}}{partial h^{(k-1)}}=prod_{k=t+1}^{T} tanh' sigma(W_fX_t+b_f)
k=t+1∏T∂h(k−1)∂h(k)=k=t+1∏Ttanh′σ(WfXt+bf)
如果设
Z
=
tanh
(
x
)
σ
(
y
)
Z=tanh (x)sigma(y)
Z=tanh(x)σ(y),其函数图像如下所示:
可以看到这个函数的值基本可以近似为0或1,这样就可以解决多个小于1或多个大于1的数相乘导致的梯度消失或梯度爆炸问题。
通过LSTM这种方式,除了在结构上天然地克服了梯度消失的问题,更重要的是能够具有更多的参数来控制模型;其参数量是RNN的四倍,能够更加精细地预测时间序列变量。
最后
以上就是明理毛衣为你收集整理的预训练语言模型(三):RNN和LSTMRNNLSTM的全部内容,希望文章能够帮你解决预训练语言模型(三):RNN和LSTMRNNLSTM所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复