概述
文章目录
- Sampler采样函数基类
- SequentialSampler顺序采样
- RandomSampler随机采样
- SubsetRandomSampler索引随机采样
- WeightedRandomSampler加权随机采样
- BatchSampler批采样
Sampler采样函数基类
torch.utils.data.Sampler(data_source)
所有采样器的基类。
每个采样器子类都必须提供一个__iter__()
方法,提供一种遍历dataset元素索引的方法,以及一个返回迭代器长度的__len__()
方法。
pytorch中提供的采样方法主要有SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler
,关键是__iter__
的实现.
下面用一个简单的例子来分析各个采样函数的源码以及
import torch
from torch.utils.data.sampler import *
import numpy as np
t = np.arange(10)
SequentialSampler顺序采样
torch.utils.data.SequentialSampler(data_source)
其中__iter__
为:
iter(range(len(self.data_source)))
参数
data_source
为数据集
所以SequentialSampler
的功能是顺序逐个采样数据
for i in SequentialSampler(t):
print(i,end=',')
输出:
0,1,2,3,4,5,6,7,8,9,
RandomSampler随机采样
torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)
其中__iter__
为:
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n,
size=(self.num_samples,),
dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
参数
data_source
为数据集- replacement:是否为有放回取样
RandomSampler
当replacement
开关关闭时,返回原始数据集长度下标数组随机打乱后采样值, 而当replacment
开关打开后,则根据num_samples
长度来生成采样序列长度。
具体可见如下代码,在replacement=False
时,RandomSampler对数组t下标随机打乱输出,迭代器长度与源数据长度一致。
当replacement=True
并设定num_samples=20
,这时迭代器长度大于源数据,故会出现重复值。
t = np.arange(10)
for i in RandomSampler(t):
print(i,end=',')
输出:
4,5,6,0,8,1,7,9,2,3,
输入
for i in RandomSampler(t,replacement=True,num_samples=20):
print(i,end=',')
输出:
8,0,4,6,4,0,1,5,3,1,6,8,9,0,4,7,0,8,7,4,
SubsetRandomSampler索引随机采样
torch.utils.data.SubsetRandomSampler(indices)
其中__iter__
为:
(self.indices[i] for i in torch.randperm(len(self.indices)))
其中
torch.randperm
对数组随机排序- indices为给定的下标数组
所以SubsetRandomSampler
的功能是在给定一个数据集下标后,对该下标数组随机排序,然后不放回取样
for i in SubsetRandomSampler(t):
print(i,end=',')
输出:
2,6,1,7,4,3,0,5,8,9,
WeightedRandomSampler加权随机采样
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
其中__iter__
为:
iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
其中
weights
为index权重,权重越大的取到的概率越高- num_samples: 生成的采样长度
- replacement:是否为有放回取样
- multinomial: 伯努利随机数生成函数,也就是根据概率设定生成{0,1,…,n}
weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
torch.multinomial(weights,1,replacement=False)
输出:
tensor([1])
weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
torch.multinomial(weights,2,replacement=False)
输出:
tensor([2, 1])
weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
torch.multinomial(weights,3,replacement=False)
输出:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-41-c641212fcbc8> in <module>
1 weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
----> 2 torch.multinomial(weights,3,replacement=False)
RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False, not enough non-negative category to sample) at /opt/conda/conda-bld/pytorch_1565287148058/work/aten/src/TH/generic/THTensorRandom.cpp:378
weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
torch.multinomial(weights,2,replacement=True)
输出:
tensor([1, 1])
weights = torch.tensor([1, 10, 3, 0], dtype=torch.float)
torch.multinomial(weights,10,replacement=True)
输出:
tensor([1, 1, 0, 0, 1, 0, 2, 1, 2, 1])
通过上面几个例子可以看出,权重值为0的index不会被取到。
当不放回取样时,replacement=False
,若num_samplers
小于输入数组中权重非零值个数,那么非零权重大小基本不起什么作用,反正所有的值都会取到一次
当放回取样时,权重越大的取到的概率越高。
BatchSampler批采样
torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
其中__iter__
为:
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
其中
drop_last
为布尔类型值,当其为真时,如果数据集长度不是batch_size整数倍时,最后一批数据将会丢弃。
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
代码中例子很清晰,数据总长度为10,如果drop_last设置为False,那么最后余下的一个作为新的batch.
最后
以上就是正直枕头为你收集整理的pytorch中数据采样方法Sampler源码解析的全部内容,希望文章能够帮你解决pytorch中数据采样方法Sampler源码解析所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复