我是靠谱客的博主 飘逸蛋挞,最近开发中收集的这篇文章主要介绍xlnet pytorch简易版代码解读安装参数分析data_utils.py代码分析,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

xlnet-pytorch版本代码解读

  • 安装
  • 参数分析
  • data_utils.py代码分析

这篇不详细介绍xlnet的原理,有需要了解的请自行前往xlnet原理分析

安装

首先clone XLNet-pytorch的源码

git clone https://github.com/graykode/xlnet-Pytorch && cd xlnet-Pytorch

# To use Sentence Piece Tokenizer(pretrained-BERT Tokenizer)
$ pip install pytorch_pretrained_bert

参数分析

  • data 数据存放路径
  • tokenizer 分词
  • seq_len 序列长度,
  • reuse_len cache的长度,
  • perm_size 最长的排列长度
  • bi_data 是否双向的batch,
  • mask_alpha 多少词组成一个group
  • mask_beta 每个group里mask几个词
  • num_predict 预测多少个词
  • mem_len 缓存的长度
  • num_epoch 训练轮数

data_utils.py代码分析

data_utils.py是用来生成训练数据的,首先调用的是_create_data函数,这个函数的核心代码一点点来分析

def _create_data(sp, input_paths, seq_len, reuse_len,
                bi_data, num_predict, mask_alpha, mask_beta):
    features = []

    f = open(input_paths, 'r')
    lines = f.readlines()
    input_data, sent_ids, sent_id = [], [], True

    for line in lines:
        tokens = sp.tokenize(line)
        cur_sent = sp.convert_tokens_to_ids(tokens)
        input_data.extend(cur_sent)
        sent_ids.extend([sent_id] * len(cur_sent))
        sent_id = not sent_id

这里为了方便处理,作者只对单个文件进行了处理,在xlnet源代码中是对多个文件进行了处理,对于每一个文件(我们这里只有一个),最终是为了得到”input_data, sent_ids = [], []”两个list。input_data里是放到这个文件的每一个WordPiece对应的ID,而sent_ids用于判断句子的边界。
比如说对"This is the first sentence.this is the second sentence and also the end of the paragraph.",首先使用sp将其切分为[‘this’, ‘is’, ‘the’, ‘first’, ‘sentence’, ‘.’, ‘this’, ‘is’, ‘the’ ‘second’, ‘sentence’, ‘and’, ‘also’, ‘the’, ‘end’, ‘of’, ‘the’, ‘paragraph’, ‘.’],最后变成ID得到[2023, 2003, 1996, 2034, 6251, 1012, 2023, 2003, 1996, 2117, 6251, 1998, 2036, 1996 , 2203, 1997, 1996, 20423, 1012]。第一个句子"This is the first sentence"对应的sent_ids是[True, True, True, True, True, True],第二个句子对应的sent_ids是[False, … ,False]。于是,最后得到的input_data和sent_ids为:

input_data = [2023, 2003, 1996, 2034, 6251, 1012, 2023, 2003, 1996, 2117, 6251, 1998, 2036, 1996 , 2203, 1997, 1996, 20423, 1012]
sent_ids = [True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False]

因此input_data是每一个WordPiece对应的ID的数组,而sent_ids可以判断哪些ID是属于一个句子的,也就是sent_ids通过交替的True和False来告诉我们句子的边界,比如前面的sent_ids的前6个为True,因此我们可以知道前6个WordPiece属于第一个句子,而后面的12个连续False告诉我们第二个句子有12个WordPiece。那么如果第三个句子有5个WordPiece,则我们可以猜测后面应该出现连续5个True。

    # shape of data : [1, 582]
    data = np.array([input_data], dtype=np.int64)
    sent_ids = np.array([sent_ids], dtype=np.bool)

    assert reuse_len < seq_len - 3

    data_len = data.shape[1]
    sep_array = np.array([SEP_ID], dtype=np.int64)
    cls_array = np.array([CLS_ID], dtype=np.int64)

    i = 0
    while i + seq_len <= data_len:
        inp = data[0, i: i + reuse_len]
        tgt = data[0, i + 1: i + reuse_len + 1]

        results = _split_a_and_b(
            data[0], # all line in one Text file.
            sent_ids[0],
            begin_idx=i + reuse_len,
            tot_len=seq_len - reuse_len - 3,
            extend_target=True)

函数_split_a_and_b此时的参数为:
data[0] = [2023, 2003, 1996, 2034, 6251, 1012, 2023, 2003, 1996, 2117, 6251, 1998, 2036, 1996 , 2203, 1997, 1996, 20423, 1012]
sent_ids[0] = [True, True, True, True, True, True, False, False, False, False, False, False, False, False, False, False, False, False]
begin_idx = 0 + 4 = 4
tot_len = 8 - 4 - 3 = 1

        # unpack the results
        (a_data, b_data, label, _, a_target, b_target) = tuple(results)

        # sample ngram spans to predict
        reverse = bi_data
        if num_predict is None:
            num_predict_0 = num_predict_1 = None
        else:
            num_predict_1 = num_predict // 2
            num_predict_0 = num_predict - num_predict_1

        mask_0 = _sample_mask(sp, inp, mask_alpha, mask_beta, reverse=reverse,
                              goal_num_predict=num_predict_0)
        mask_1 = _sample_mask(sp, np.concatenate([a_data, sep_array, b_data,
                                                  sep_array, cls_array]),
                              mask_alpha, mask_beta,
                              reverse=reverse, goal_num_predict=num_predict_1)

        # concatenate data
        cat_data = np.concatenate([inp, a_data, sep_array, b_data,
                                   sep_array, cls_array])
        seg_id = ([0] * (reuse_len + a_data.shape[0]) + [0] +
                  [1] * b_data.shape[0] + [1] + [2])
        assert cat_data.shape[0] == seq_len
        assert mask_0.shape[0] == seq_len // 2
        assert mask_1.shape[0] == seq_len // 2

        # the last two CLS's are not used, just for padding purposes
        tgt = np.concatenate([tgt, a_target, b_target, cls_array, cls_array])
        assert tgt.shape[0] == seq_len

        is_masked = np.concatenate([mask_0, mask_1], 0)
        if num_predict is not None:
            assert np.sum(is_masked) == num_predict

        feature = {
            "input": cat_data,
            "is_masked": is_masked,
            "target": tgt,
            "seg_id": seg_id,
            "label": [label],
        }
        features.append(feature)
        i += reuse_len
    f.close()
    return features

最后

以上就是飘逸蛋挞为你收集整理的xlnet pytorch简易版代码解读安装参数分析data_utils.py代码分析的全部内容,希望文章能够帮你解决xlnet pytorch简易版代码解读安装参数分析data_utils.py代码分析所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(46)

评论列表共有 0 条评论

立即
投稿
返回
顶部