概述
模型介绍
Transformer最大的问题在于没有办法建模超过最大长度的序列,Transformer-XL主要提出了两个优化点:段级递归和相对位置编码。
段级递归
为了解决固定长度的限制,Transformer-XL提出了一种递归机制,如下图,第一个segment计算完成后,把计算的结果保存下来,在计算第二个片段的时候,把第一个片段的hidden state和第二个片段的hidden state拼接在一起,再进行后续的计算。
我们看下具体的计算公式,其中h表示的是hidden state,
τ
tau
τ 表示第
τ
tau
τ 个segment,SG函数表示的是不更新梯度,[]表示的是向量的拼接。
第一个公式的意思即:第
τ
+
1
tau+1
τ+1个segment第n-1层的hidden state 等于第
τ
tau
τ 个segment第n - 1层的hidden state拼接上第
τ
+
1
tau +1
τ+1 个segment第n - 1层的hidden state,后续两个公式和vanilla版本类似,但要注意,q是未拼接的hidden state,k、v是拼接过后的,因为q表示的是当前的segment,所以不需要拼接。
可以看到,对于第一个segment来说,hidden state是没有额外需要拼接的值的,从第二个segment开始才需要拼接,在论文中,每次都是和上一个segment进行拼接,理论上来说每次可以拼接多个segment,第n个segment可以和前n-1个segment进行拼接,不过这个就取决于你自己的显存了,并且一个segment通常来说不会像上图中的这么短(一个segment可能长度就512了),文本自身的上下文依赖一般也不会超过一个segment的长度。
实现代码
def init_mems(self, bsz):
if self.mem_len > 0:
mems = []
for i in range(self.n_layer):
empty = tf.zeros([self.mem_len, bsz, self.d_model])
mems.append(empty)
return mems
else:
return None
def _update_mems(self, hids, mems, mlen, qlen):
# does not deal with None
if mems is None:
return None
# mems is not None
assert len(hids) == len(mems), "len(hids) != len(mems)"
# There are `mlen + qlen` steps that can be cached into mems
new_mems = []
end_idx = mlen + tf.math.maximum(0, qlen)
beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len))
for i in range(len(hids)):
mems[i] = tf.cast(mems[i], dtype=hids[i].dtype)
cat = tf.concat([mems[i], hids[i]], axis=0)
tf.stop_gradient(cat)
new_mems.append(cat[beg_idx:end_idx])
return new_mems
相对位置编码
Vanilla的位置编码是和embedding相加后输入到下一层的,Transformer-XL的位置编码没有在输入上做处理,而是对attention score进行了修改。
考虑一下,当query与key进行计算时,实际上并不需要知道key的绝对位置编码,模型实际上需要的是一个“时间线索”即字词的一个先后顺序,因此,知道query与key的相对位置即可。根据以上的思路,Transformer-XL做了三个方面的改进,分别如下:
在新的参数下,每一项都有了一个具体的含义,a表示的是query与key的内容相关性,b表示的是query的内容和key的位置的相关性,c表示的是query的位置与key的内容的相关性,d表示的是quey与key的位置的相关性。
总结一下,对于一个N层1个head的Transformer-XL,其完整步骤如下:
实现代码
class RelativeMultiHeadAttention(layers.Layer):
def __init__(self, num_heads, embed_size):
super(RelativeMultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.embed_size = embed_size
self.hidden_size = embed_size // num_heads
self.qvk_net = layers.Dense(3 * embed_size)
self.r_net = layers.Dense(embed_size)
self.o_net = layers.Dense(embed_size)
self.layer_norm = layers.LayerNormalization()
def _rel_shift(self, x):
x_size = tf.shape(x)
# shape:(seq_len_q, seq_len_k, batch_size, num_heads)=>(seq_len_q, seq_len_k + 1, batch_size, num_heads)
x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])
# shape:(seq_len_q, seq_len_k + 1, batch_size, num_heads)=>(seq_len_q + 1, seq_len_k, batch_size, num_heads)
x = tf.reshape(x, (x_size[0] + 1, x_size[1], x_size[2], x_size[3]))
# shape:(seq_len_q + 1, seq_len_k, batch_size, num_heads)=>(seq_len_q, seq_len_k, batch_size, num_heads)
x = tf.slice(x, [0, 1, 0, 0], [-1, -1, -1, -1])
return x
# w表示token embedding,r表示relative position embedding
# r_w_bias表示uT,r_r_bias表示vT,形状和w的形状一致
def __call__(self, w, r, r_w_bias, r_r_bias, mask=None, mems=None, *args, **kwargs):
# w
# shape:(seq_len, batch_size, embed_size)
# r
# shape:(seq_len, 1, embed_size)
seq_len_q, batch_size, seq_len_r = tf.shape(w)[0], tf.shape(w)[1], tf.shape(r)[0]
if mems is not None:
cat = tf.concat([mems, w], axis=0)
w_heads = self.qvk_net(cat)
# 有mems时:
# w_head_q
# shape:(seq_len_q, batch_size, embed_size)
# w_head_k, w_head_v
# shape:(seq_len_k, batch_size, embed_size),其中seq_len_k = seq_len_q + seq_len_mems
w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)
w_head_q = w_head_q[-seq_len_q:]
r_head_k = self.r_net(r)
else:
w_heads = self.qvk_net(w)
# 没有mems时:(seq_len_q = seq_len)
# w_head_q, w_head_k, w_head_v
# shape:(seq_len_q, batch_size, embed_size)
w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)
r_head_k = self.r_net(r)
seq_len_k = tf.shape(w_head_k)[0]
# w_head_q
# shape:(seq_len_q, batch_size, embed_size)=>(seq_len_q, batch_size, num_heads, hidden_size)
# w_head_k, w_head_v
# shape:(seq_len_k, batch_size, embed_size)=>(seq_len_k, batch_size, num_heads, hidden_size)
# r_head_k
# shape:(seq_len_r, 1, embed_size)=>(seq_len_r, num_heads, hidden_size)
w_head_q = tf.reshape(w_head_q, (seq_len_q, batch_size, self.num_heads, self.hidden_size))
w_head_k = tf.reshape(w_head_k, (seq_len_k, batch_size, self.num_heads, self.hidden_size))
w_head_v = tf.reshape(w_head_v, (seq_len_k, batch_size, self.num_heads, self.hidden_size))
r_head_k = tf.reshape(r_head_k, (seq_len_r, self.num_heads, self.hidden_size))
# 计算A+C两项,(w_head_q + r_w_bias) * w_head_k = (qT + uT) * k
# w_head_q
# shape:(seq_len_q, batch_size, num_heads, hidden_size)
# r_w_bias
# shape:(seq_len_q, batch_size, num_heads, hidden_size)
# w_head_k
# shape:(seq_len_k, batch_size, num_heads, hidden_size)
wr_head_q = w_head_q + r_w_bias
# shape:(seq_len_q, seq_len_k, batch_size, num_heads)
AC = tf.einsum("ibnh,jbnh->ijbn", wr_head_q, w_head_k)
# 计算B+D两项,(w_head_q + r_r_bias) * r_head_k = (qT + vT) * r
wr_head_r = w_head_q + r_r_bias
# shape:(seq_len_q, seq_len_k, batch_size, num_heads)
BD = tf.einsum("ibnh,jnh->ijbn", wr_head_r, r_head_k)
BD = self.rel_shift(BD)
# 计算attention_score,attention_score = softmax((A+B+C+D)/dk[+mask])
attention_score = (AC + BD) / tf.sqrt(self.hidden_size)
# 如果有mask
if mask is not None:
attention_score += (mask * 1e-9)
# shape:(seq_len_q, seq_len_k, batch_size, num_heads)
attention_score = tf.nn.softmax(attention_score, axis=1)
# 计算attention,attention = attention_score * v
# shape:(seq_len_q, batch_size, num_heads, hidden_size)
attention = tf.einsum("ijbn,jbnh->ibnh", attention_score, w_head_v)
# shape:(seq_len_q, batch_size, num_heads, hidden_size)=>(seq_len_q, batch_size, embed_size)
attention = tf.reshape(attention, (seq_len_q, batch_size, self.embed_size))
attention = self.o_net(attention)
# residual connection
output = attention + w
# layer normalization
output = self.layer_norm(output)
return output
模型参考
论文地址:https://arxiv.org/abs/1901.02860
代码地址:https://github.com/kimiyoung/transformer-xl
最后
以上就是自由马里奥为你收集整理的Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context(2019-1-9)模型介绍模型参考的全部内容,希望文章能够帮你解决Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context(2019-1-9)模型介绍模型参考所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复