概述
xlnet-pytorch版本代码解读
- 安装
- 参数分析
- data_utils.py代码分析
这篇不详细介绍xlnet的原理,有需要了解的请自行前往xlnet原理分析
安装
首先clone XLNet-pytorch的源码
git clone https://github.com/graykode/xlnet-Pytorch && cd xlnet-Pytorch
# To use Sentence Piece Tokenizer(pretrained-BERT Tokenizer)
$ pip install pytorch_pretrained_bert
参数分析
- data 数据存放路径
- tokenizer 分词
- seq_len 序列长度,
- reuse_len cache的长度,
- perm_size 最长的排列长度
- bi_data 是否双向的batch,
- mask_alpha 多少词组成一个group
- mask_beta 每个group里mask几个词
- num_predict 预测多少个词
- mem_len 缓存的长度
- num_epoch 训练轮数
data_utils.py代码分析
data_utils.py是用来生成训练数据的,首先调用的是_create_data函数,这个函数的核心代码一点点来分析
def _create_data(sp, input_paths, seq_len, reuse_len,
bi_data, num_predict, mask_alpha, mask_beta):
features = []
f = open(input_paths, 'r')
lines = f.readlines()
input_data, sent_ids, sent_id = [], [], True
for line in lines:
tokens = sp.tokenize(line)
cur_sent = sp.convert_tokens_to_ids(tokens)
input_data.extend(cur_sent)
sent_ids.extend([sent_id] * len(cur_sent))
sent_id = not sent_id
这里为了方便处理,作者只对单个文件进行了处理,在xlnet源代码中是对多个文件进行了处理,对于每一个文件(我们这里只有一个),最终是为了得到”input_data, sent_ids = [], []”两个list。input_data里是放到这个文件的每一个WordPiece对应的ID,而sent_ids用于判断句子的边界。
比如说对"This is the first sentence.this is the second sentence and also the end of the paragraph.",首先使用sp将其切分为[‘this’, ‘is’, ‘the’, ‘first’, ‘sentence’, ‘.’, ‘this’, ‘is’, ‘the’ ‘second’, ‘sentence’, ‘and’, ‘also’, ‘the’, ‘end’, ‘of’, ‘the’, ‘paragraph’, ‘.’],最后变成ID得到[2023, 2003, 1996, 2034, 6251, 1012, 2023, 2003, 1996, 2117, 6251, 1998, 2036, 1996 , 2203, 1997, 1996, 20423, 1012]。第一个句子"This is the first sentence"对应的sent_ids是[True, True, True, True, True, True],第二个句子对应的sent_ids是[False, … ,False]。于是,最后得到的input_data和sent_ids为:
input_data = [2023, 2003, 1996, 2034, 6251, 1012, 2023, 2003, 1996, 2117, 6251, 1998, 2036, 1996 , 2203, 1997, 1996, 20423, 1012]
sent_ids = [True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False]
因此input_data是每一个WordPiece对应的ID的数组,而sent_ids可以判断哪些ID是属于一个句子的,也就是sent_ids通过交替的True和False来告诉我们句子的边界,比如前面的sent_ids的前6个为True,因此我们可以知道前6个WordPiece属于第一个句子,而后面的12个连续False告诉我们第二个句子有12个WordPiece。那么如果第三个句子有5个WordPiece,则我们可以猜测后面应该出现连续5个True。
# shape of data : [1, 582]
data = np.array([input_data], dtype=np.int64)
sent_ids = np.array([sent_ids], dtype=np.bool)
assert reuse_len < seq_len - 3
data_len = data.shape[1]
sep_array = np.array([SEP_ID], dtype=np.int64)
cls_array = np.array([CLS_ID], dtype=np.int64)
i = 0
while i + seq_len <= data_len:
inp = data[0, i: i + reuse_len]
tgt = data[0, i + 1: i + reuse_len + 1]
results = _split_a_and_b(
data[0], # all line in one Text file.
sent_ids[0],
begin_idx=i + reuse_len,
tot_len=seq_len - reuse_len - 3,
extend_target=True)
函数_split_a_and_b此时的参数为:
data[0] = [2023, 2003, 1996, 2034, 6251, 1012, 2023, 2003, 1996, 2117, 6251, 1998, 2036, 1996 , 2203, 1997, 1996, 20423, 1012]
sent_ids[0] = [True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False]
begin_idx = 0 + 4 = 4
tot_len = 8 - 4 - 3 = 1
# unpack the results
(a_data, b_data, label, _, a_target, b_target) = tuple(results)
# sample ngram spans to predict
reverse = bi_data
if num_predict is None:
num_predict_0 = num_predict_1 = None
else:
num_predict_1 = num_predict // 2
num_predict_0 = num_predict - num_predict_1
mask_0 = _sample_mask(sp, inp, mask_alpha, mask_beta, reverse=reverse,
goal_num_predict=num_predict_0)
mask_1 = _sample_mask(sp, np.concatenate([a_data, sep_array, b_data,
sep_array, cls_array]),
mask_alpha, mask_beta,
reverse=reverse, goal_num_predict=num_predict_1)
# concatenate data
cat_data = np.concatenate([inp, a_data, sep_array, b_data,
sep_array, cls_array])
seg_id = ([0] * (reuse_len + a_data.shape[0]) + [0] +
[1] * b_data.shape[0] + [1] + [2])
assert cat_data.shape[0] == seq_len
assert mask_0.shape[0] == seq_len // 2
assert mask_1.shape[0] == seq_len // 2
# the last two CLS's are not used, just for padding purposes
tgt = np.concatenate([tgt, a_target, b_target, cls_array, cls_array])
assert tgt.shape[0] == seq_len
is_masked = np.concatenate([mask_0, mask_1], 0)
if num_predict is not None:
assert np.sum(is_masked) == num_predict
feature = {
"input": cat_data,
"is_masked": is_masked,
"target": tgt,
"seg_id": seg_id,
"label": [label],
}
features.append(feature)
i += reuse_len
f.close()
return features
最后
以上就是飘逸蛋挞为你收集整理的xlnet pytorch简易版代码解读安装参数分析data_utils.py代码分析的全部内容,希望文章能够帮你解决xlnet pytorch简易版代码解读安装参数分析data_utils.py代码分析所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复