我是靠谱客的博主 结实火,最近开发中收集的这篇文章主要介绍自定义dataloader里的batch_sampler,你就不需要再做padding了!前言正文完整代码后记,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

前言

最近再看一份代码,居然已经水了三篇博客…and这是第四篇。
在说batch_sampler之前,这里不要脸的放一下之前关于dataset和dataloader的博客

正文

在自然语言处理中,由于一个batch中的句子有长有短,我们对短的做padding。但今天我遇到的fixedlengthsampler给了我另外一种思路,我们可以统计所有句子的长度,然后保证每一个batch在采样的时候,只采样句子长度相同的句子,这样就不需要padding了。这个fixedlengthsampler具体怎么实现呢?

第一步:建立length_map

在这里插入图片描述
这里的self.data_source是已经定义好的,集成了torch.utils.data.Dataset的dataset。假设我们这里的self.data_source.dataset是形似:[[单词1,单词2,…], [单词1,单词2,…], …]这样的句子列表。
我们遍历这个句子列表,记录长度,并以句子长度为length_map的键,而对应的值是一个列表,里面是句子在句子列表中对应的下标。也就是说:length_map形如下(随便举个例子):
{9:[10,9,11,1,2,3,4,5,6,7,8], 17:[18,20]}

第二步:建立state

在这里插入图片描述
遍历上述的length_map字典,每当取出一个length(句子长度)和arr(对应句子长度的句子下标列表)的时候,我们通过self.get_batch_size得到当前这个length对应的,我们设置的batch_size,这一步的意义我目前理解为: 对于一些句子长度很长的,那么batch_size可以调小一点,这样就不容易让显存爆掉,不管这么多,我们这里假设所有长度对应的batch_size都是10!
然后计算nbatches, 并用surplus标识是否有剩下的,不是完整的一个batch的句子,后续我们可以把它加上。
这个postion在后续会体现,可以看到它一开始是-1,那么对于某一个句子长度,我第一次对你取一个batch的时候,我先position+=1,就可以通过 句子下标列表[position*batch_size: (position+1)*batch_size]来取了。每一次取,我都给它+1。
好了,state的形式如下:
{9:{“nbatches”:1, “surplus”:True, “position”:-1}, 17:{“nbatches”:0, “surplus”:True, “position”:-1 }}
因为可以看第一步中的例子,长度为9的句子有11个,那么第一个batch可以取10个(nbatches=10),还剩下了1个(surplus=True); 而长度为17的句子只有2个,nbatches=len(arr)//10=0(nbatch=0), 还剩下了2个(surplus=True)…

第三步:建立order,确定访问次序

在这里插入图片描述
order是一个次序列表,里面放的是句子长度。
具体创建方式是:遍历state,假设当前length=9,那么根据上面,它的nbatches就是1,order里面就放一个9。可以看到遍历完state的order中是没有17的。我们可以接下来进行补充(是否补充最后一个不足batch_size的句子,这个选择也可以设置成一个参数(就比如下面的self.include_partial))
我们的order现在长这样:
[9]

第四步:补全不足batch_size的样本

在这里插入图片描述
这个不多讲。
如果include_partial=True,那我们的order现在长这样:
[9,9,17]
建立完这些东西,我们就可以根据order列表给定的次序,进行迭代生成了。

第五步:迭代生成

在这里插入图片描述
在这里插入图片描述
我不知道为什么这里的len(self)居然跟上面的order列表的长度一样!太神奇了。既然一样,就可以遍历所有的order里面的length。
上面已经说过关于position什么的细节了,所以get_next_batch的过程应该能看懂。

我们把这个sampler类定义为FixedLengthBatchSampler,那它的完整代码如下:

完整代码

from torch.utils.data import Sampler
class FixedLengthBatchSampler(Sampler):

    def __init__(self, data_source, batch_size, include_partial=False, rng=None, maxlen=None,
                 length_to_size=None):
        self.data_source = data_source
        self.active = False
        if rng is None:
            rng = np.random.RandomState(seed=11)
        self.rng = rng
        self.batch_size = batch_size
        self.maxlen = maxlen
        self.include_partial = include_partial
        self.length_to_size = length_to_size
        self._batch_size_cache = { 0: self.batch_size }
        self.logger = get_logger()

    def get_batch_size(self, length):
        if self.length_to_size is None:
            return self.batch_size
        if length in self._batch_size_cache:
            return self._batch_size_cache[length]
        start = max(self._batch_size_cache.keys())
        batch_size = self._batch_size_cache[start]
        for n in range(start+1, length+1):
            if n in self.length_to_size:
                batch_size = self.length_to_size[n]
            self._batch_size_cache[n] = batch_size
        return batch_size

    def reset(self):
        """
        Create a map of {length: List[example_id]} and maintain how much of
        each list has been seen.

        If include_partial is False, then do not provide batches that are below
        the batch_size.

        If length_to_size is set, then batch size is determined by length.

        """

        # Record the lengths of each example.
        length_map = {}
        for i in range(len(self.data_source)):
            x = self.data_source.dataset[i]
            length = len(x)

            if self.maxlen is not None and self.maxlen > 0 and length > self.maxlen:
                continue

            length_map.setdefault(length, []).append(i)

        # Shuffle the order.
        for length in length_map.keys():
            self.rng.shuffle(length_map[length])

        # Initialize state.
        state = {}
        for length, arr in length_map.items():
            batch_size = self.get_batch_size(length)
            nbatches = len(arr) // batch_size
            surplus = nbatches * batch_size < len(arr)
            state[length] = dict(nbatches=nbatches, surplus=surplus, position=-1)

        # Batch order, in terms of length.
        order = []
        for length, v in state.items():
            order += [length] * v['nbatches']

        ## Optionally, add partial batches.
        if self.include_partial:
            for length, v in state.items():
                if v['surplus']:
                    order += [length]

        self.rng.shuffle(order)

        self.length_map = length_map
        self.state = state
        self.order = order
        self.index = -1

    def get_next_batch(self):
        index = self.index + 1
        length = self.order[index]
        batch_size = self.get_batch_size(length)
        position = self.state[length]['position'] + 1
        start = position * batch_size
        batch_index = self.length_map[length][start:start+batch_size]

        self.state[length]['position'] = position
        self.index = index
        return batch_index

    def __iter__(self):
        self.reset()
        for _ in range(len(self)):
            yield self.get_next_batch()

    def __len__(self):
        return len(self.order)

把这个类传给torch.utils.data.Dataloader的batch_sampler参数就好了。

sampler = FixedLengthBatchSampler(dataset, batch_size=batch_size, rng=rng,
   maxlen=filter_length, include_partial=include_partial, length_to_size=length_to_size)
loader = torch.utils.data.DataLoader(dataset, shuffle=(sampler is None), num_workers=workers, pin_memory=pin_memory,batch_sampler=sampler, collate_fn=collate_fn)

当然,如果使用默认的sampler也是可以的,无非就是shuffle=False和shuffle=True的区别。

后记

如果之前的我再hard一点,现在的我是不是会更好过。

最后

以上就是结实火为你收集整理的自定义dataloader里的batch_sampler,你就不需要再做padding了!前言正文完整代码后记的全部内容,希望文章能够帮你解决自定义dataloader里的batch_sampler,你就不需要再做padding了!前言正文完整代码后记所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(47)

评论列表共有 0 条评论

立即
投稿
返回
顶部