我是靠谱客的博主 健壮康乃馨,最近开发中收集的这篇文章主要介绍Pytorch数据加载模块:Dataset,Sampler和DataLoader总结1 Dataset2 Sampler3 DataLoader,觉得挺不错的,现在分享给大家,希望可以做个参考。
概述
官网教程示例:
https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
Pytorch加载数据三步走:
-
Dataset:解析单个样本,把数据映射成(x,y)的形式;
- map-style:实现__getitem__和__len__接口,随机取数据代价小(大多数情况用map-stype);
- iterable-style:实现__iter__接口,随机取数据代价大,适合处理流数据(比如文本流数据);
-
Sampler:提供一种遍历数据集所有元素索引的方式,有默认值;
-
DataLoader:将当个样本变成训练时需要的batch形式;
1 Dataset
1.1 源码
# 接口
from torch.utils.data import Dataset
# 源码位置
# ../torch/utils/data/dataset.py
# 查看torch安装位置
import torch
print(torch.__file__)
源码
# Dataset抽象类 对外暴露一些接口
# map-style
class Dataset(Generic[T_co]):
def __getitem__(self, index) -> T_co:
# 基类中没有实现 需要自己实现
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
# iter-style
class IterableDataset(Dataset[T_co]):
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
def __add__(self, other: Dataset[T_co]):
return ChainDataset([self, other])
1.2 创建自己的Dataset
定义自己的Dataset,继承Dataset类后,需要(必须)实现三个方法:
- _init_
- _len_
- _getitem_
示例:
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
# 保存图像的根路径
self.img_dir = img_dir
# 对数据的处理 数据增强之类的
self.transform = transform
# 对标签的处理
self.target_transform = target_transform
def __len__(self):
# 返回一共有多少个数据
return len(self.img_labels)
def __getitem__(self, idx):
# 拼凑图像的完整路径
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
# 读取图像
image = read_image(img_path)
# 从csv中读取的信息分割出标签
label = self.img_labels.iloc[idx, 1]
# 对数据及标签进行处理
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
# csv保存图片名 大概长这样
# tshirt1.jpg, 0
# tshirt2.jpg, 0
# ......
# ankleboot999.jpg, 9
1.3 加载数据集
ann_csv = '../ann.csv'
img_root = '/'
# 实例化
myDataset = CustomImageDataset(ann_csv,img_root)
# 获取该类的属性
print(myDataset.img_dir)
# 获取数据的数量 可以用 但是一般不这么用
print(myDataset.__len__())
# 获取第1个数据的img和label(下标0)
# 可以用 但是一般不这么用
img,lab = myDataset.__getitem__(0)
print(img.shape, lab)
# 一般这么用...
print(len(myDataset))
img,lab = myDataset[0]
print(img.shape, lab)
# 一般不会单独用Dataset
# 扔到DataLoader里 构成batch数据
1.4 Dataset的子类
1.4.1 TensorDataset
如果数据本身已经是tensor形式了
# 数据转为tensor格式
x_train, y_train = torch.tensor(x_train), torch.tensor(y_train)
# 直接用TensorDataset封装即可
train_dataset = TensorDataset(x_train, y_train)
1.4.2 IterableDataset
根据两个数start和end生成数据集;
# 继承IterableDataset
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
iter_start = self.start
iter_end = self.end
else: # in a worker process
# split workload
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
# !核心就是根据range生成的数
return iter(range(iter_start, iter_end))
# 实例化
# 结果:[3, 4, 5, 6].
ds = MyIterableDataset(start=3, end=7)
# 用DataLoader 单线程进行加载
# [tensor([3]), tensor([4]), tensor([5]), tensor([6])]
print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
# 用DataLoader 多线程进行加载
print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
1.4.3 ConcatDataset
将多个数据集拼接成一个;
用法如下:
# 第一个数据集 len 60000
mnist_data = MNIST('./data', train=True, download=True)
# 第二个数据集 len 50000
cifar10_data = CIFAR100('./data', train=True, download=True)
# 两个数据集拼接 len 110000
concat_data = ConcatDataset([mnist_data, cifar10_data])
1.4.4 ChainDataset
将IterableDataset类的多个数据集拼接成一个数据集;
1.4.5 Subset
将一个数据集划分为子数据集,比如划分训练集和验证集;
# 训练集和验证集的索引
train_indices, val_indices = indices[split:], indices[:split]
# 根据索引随机划分
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
2 Sampler
就是遍历数据集的方式,默认方式有两种:
- shuffle = True:sampler = RandomSampler(dataset, generator=generator),随机打乱;
- shuffle = False:sampler = SequentialSampler(dataset),不打乱;
- 也可以自定义Sampler传入,但是Sampler与shuffle互斥;
2.1 RandomSampler
class RandomSampler(Sampler[int]):
data_source: Sized
replacement: bool
def __init__(self, data_source: Sized, replacement: bool = False,
num_samples: Optional[int] = None, generator=None) -> None:
# slef. = ...
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
for _ in range(self.num_samples // n):
# 核心就是torch.randperm函数
# 生成0~n-1的随机数列(索引)
yield from torch.randperm(n, generator=generator).tolist()
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
def __len__(self) -> int:
return self.num_samples
2.2 SequentialSampler
SequentialSampler其实什么也没做,不破坏数据集原有的顺序;
class SequentialSampler(Sampler[int]):
def __init__(self, data_source: Sized) -> None:
self.data_source = data_source
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
2.3 自定义Sampler
import random
from torch.utils.data.sampler import Sampler
# 自定义必须先继承Sample类
# 必须实现__init__,__iter__,__len__方法
class MySampler(Sampler):
def __init__(self, dataset):
# 将数据集均分为两部分
halfway_point = int(len(dataset)/2)
self.first_half_indices = list(range(halfway_point))
self.second_half_indices = list(range(halfway_point, len(dataset)))
def __iter__(self):
# 每次从前一半和后一半各返回一个
# 假设前一半为 1 2 3 4 5
# 后一半为 6 7 8 9 10
# 则依次返回(1,6)(2,7)(3,8)...
random.shuffle(self.first_half_indices)
random.shuffle(self.second_half_indices)
return iter(self.first_half_indices + self.second_half_indices)
def __len__(self):
return len(self.first_half_indices) + len(self.second_half_indices)
3 DataLoader
3.1 使用DataLoader
from torch.utils.data import DataLoader
training_data = myDataset(...)
test_data = myDataset(...)
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
# 一般测试集不打乱 没有意义
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)
train_features, train_labels = next(iter(train_dataloader))
# Feature batch shape: torch.Size([64, 1, 28, 28])
print(f"Feature batch shape: {train_features.size()}")
# Labels batch shape: torch.Size([64])
print(f"Labels batch shape: {train_labels.size()}")
# ... 一些处理
3.2 源码及参数
# DataLoader源码位置
# /torch/utils/data/dataloader.py
# 参数们
# dataset: Dataset实例对象
# batch_size:批量大小 默认为1
# shuffle:每周期后是否对数据进行打乱
# sampler:遍历数据集的方式 有默认值 和shuffle互斥
# batch_sampler:同上 和shuffle sampler drop_last batch_size互斥
# num_workers:默认为0 加载数据(batch)的进程数目
# num_workers的经验设置值是自己电脑/服务器的CPU核心数
# 0意味着所有的数据都会被load进主进程
# collate_fn: 对batch数据再处理
# pin_memory: 锁页内存 数据放到GPU上
# drop_last: 非整数batch时 最后一个batch丢掉
# timeout: 如果是正数,表明等待从worker进程中收集一个batch等待的时间
# 若超出设定的时间还没有收集到,那就不收集这个内容了
class DataLoader(Generic[T_co]):
dataset: Dataset[T_co]
batch_size: Optional[int]
num_workers: int
pin_memory: bool
drop_last: bool
timeout: float
sampler: Union[Sampler, Iterable]
prefetch_factor: int
_iterator : Optional['_BaseDataLoaderIter']
__initialized = False
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False, sampler: Union[Sampler, Iterable, None] = None,
batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False):
# 一堆成员变量设置
# self. = ...
# sampler的设置
if sampler is None:
if self._dataset_kind == _DatasetKind.Iterable:
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
# 原理:通过torch.randperm实现 打乱
sampler = RandomSampler(dataset, generator=generator)
else:
# 原理:iter(range()) 有序
sampler = SequentialSampler(dataset)
# 在__iter__调用
# 复写基类方法 实现iter函数
# 可以调用为iter(train_dataloader)
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
# 获取下一个索引 根据索引获得并返回数据
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
# 多进程进行处理
return _MultiProcessingDataLoaderIter(self)
# 变成迭代器
def __iter__(self) -> '_BaseDataLoaderIter':
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
# 在_BaseDataLoaderIter类调用
# 其实就是复写基类方法,实现next函数
# next(iter(train_dataloader))
@property
def _index_sampler(self):
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
# 返回有多少batch
def __len__(self) -> int:
# ...
# 对num_workers设定合理性进行检查
def check_worker_number_rationality(self):
# ...
最后
以上就是健壮康乃馨为你收集整理的Pytorch数据加载模块:Dataset,Sampler和DataLoader总结1 Dataset2 Sampler3 DataLoader的全部内容,希望文章能够帮你解决Pytorch数据加载模块:Dataset,Sampler和DataLoader总结1 Dataset2 Sampler3 DataLoader所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复