我是靠谱客的博主 善良橘子,最近开发中收集的这篇文章主要介绍Pytorch中数据采样方法Sampler(torch.utils.data)(二) —— WeightedRandomSampler & SubsetRandomSampler,觉得挺不错的,现在分享给大家,希望可以做个参考。
概述
WeightedRandomSampler加权随机采样
平衡不平衡数据的抽取
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
其中__iter__为:
iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
其中
- weights为index权重,权重越大的取到的概率越高
- num_samples: 生成的采样长度
- replacement:是否为有放回取样
- multinomial: 伯努利随机数生成函数,也就是根据概率设定生成{0,1,…,n}
如果label为1,那么对应的该类别被取出来的概率是另外一个类别的2倍
import torchvision from torchvision import transforms from torch.utils.data import sampler from torch.utils.data import DataLoader from torch.utils.data.sampler import * transform = transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,)) ]) trainset = torchvision.datasets.MNIST( root='dataset/', train=True, #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。 download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。 transform=transform ) ## 如果label为1,那么对应的该类别被取出来的概率是另外一个类别的2倍 weights = [2 if label == 1 else 1 for data, label in trainset] sampler = WeightedRandomSampler(weights,num_samples=10, replacement=True) dataloader = DataLoader(trainset, batch_size=16, sampler=sampler)
SubsetRandomSampler索引随机采样
根据index从数据集中抽取这些index对应的图片,然后随机排序
torch.utils.data.SubsetRandomSampler(indices)
其中__iter__为:
(self.indices[i] for i in torch.randperm(len(self.indices)))
其中
- torch.randperm对数组随机排序
- indices为给定的下标数组
所以SubsetRandomSampler的功能是在给定一个数据集下标后,对该下标数组随机排序,然后不放回取样
如果我要划分train_set和test_set, 那么读进整个数据集来再split比较慢
不如我直接生成train_set的index和test_set的index这样就可以很快了,所以就出现了SubsetRandomSampler
import torchvision from torchvision import transforms from torch.utils.data import sampler from torch.utils.data import DataLoader transform = transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,)) ]) trainset = torchvision.datasets.MNIST( root='dataset/', train=True, #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。 download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。 transform=transform ) testset = torchvision.datasets.MNIST( root='dataset/', train=False, download=True, transform=transform ) split_num = int(len(trainset) * 0.8) index_list = list(range(len(trainset))) train_idx, val_idx = index_list[:split_num], index_list[split_num:] train_sampler = sampler.SubsetRandomSampler(train_idx) val_sampler = sampler.SubsetRandomSampler(val_idx) loader_train = DataLoader(trainset, batch_size=100, sampler=train_sampler) loader_val = DataLoader(trainset, batch_size=100, sampler=val_sampler)
最后
以上就是善良橘子为你收集整理的Pytorch中数据采样方法Sampler(torch.utils.data)(二) —— WeightedRandomSampler & SubsetRandomSampler的全部内容,希望文章能够帮你解决Pytorch中数据采样方法Sampler(torch.utils.data)(二) —— WeightedRandomSampler & SubsetRandomSampler所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复