我是靠谱客的博主 花痴鸡翅,最近开发中收集的这篇文章主要介绍DataLoader详解,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

   在深度学习加载模型的时候,会对数据进行处理,今天主要介绍pytorch中Dateset和DataLoader的使用方法。

目录

一、基础概念

二、Dataset使用方法

1.torch.utils.data里面的dataset使用方法

2.torchvision.datasets的使用方法

三、DateLoader详解


一、基础概念

  1.torch.utils.data.datasets-抽象类可以创建数据集,但是抽象类不能实例化,所以需要构建这个抽象类的子类来创建数据集,并且我们还可以定义自己的继承和重写方法。其中最重要的是len和getitem这两个函数,len能够给出数据集的大小,getitem用于查找数据和标签。

 2.torch.utils.data.DataLoader是一个迭代器,主要是用于多线程的读取数据,并且可以实现batch和shuffle的读取。

二、Dataset使用方法

1.torch.utils.data里面的dataset使用方法

当我们继承了一个Dataset类之后,我们需要重写里面的len方法,该方法提供了dataset的大小,getitem(),该方法支持从0-len(self)的索引。

 from torch.utils.data import Dataset, DataLoader
 import torch
        
        class MyDataset(Dataset):
            def __init__(self):
                self.x = torch.linspace(11, 20, 10)
                self.y = torch.linspace(1, 10, 10)
                self.len = len(self.x)

            def __getitem__(self, index):
                return self.x[index], self.y[index]

            def __len__(self):
                return self.len
        mydataset = MyDataset()  
        train_loader2 = DataLoader(dataset=mydataset,batch_size=5,shuffle=False)
                    
                           

2.torchvision.datasets的使用方法

torchvisiondatasets中所有封装的数据集都是torch.utils.data.Dataset的子类,它们都实现了__getitem__和__len__方法。因此,它们都可以用torch.utils.data.DataLoader进行数据加载。

import torchvision
import torch

# 导入FashionMNIST数据集
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

train_data = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_data = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

三、DateLoader详解

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,

各个参数的介绍: 

1.dataset(Dataset): 传入的数据集
2.batch_size(int, optional): 每个batch有多少个样本
3.shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序
4.sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么       shuffle必须为False

5.batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last    就不能再制定了(互斥——Mutually exclusive)

6.num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所      有的数据都会被load进主进程。(默认为0)
7.collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数

8.pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,       将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中

9.drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如     你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被     扔掉了…如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。10.timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的     时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是           大于等于0。默认为0
11.worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called       on each

# 处理数据集,把数据转换成张量,使数据可以输入下面我们搭建的网络
def load_data_fashion_mnist(mnist_train, mnist_test, batch_size):
    if sys.platform.startswith('win'):
        num_workers = 0
    else:
        num_workers = 4
    train_data = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_data = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_data, test_data


 

最后

以上就是花痴鸡翅为你收集整理的DataLoader详解的全部内容,希望文章能够帮你解决DataLoader详解所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(38)

评论列表共有 0 条评论

立即
投稿
返回
顶部