我是靠谱客的博主 健康自行车,最近开发中收集的这篇文章主要介绍DataLoader的使用,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

官方文档说明

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

其中最重要的当属 dataset 一项,pytorch 支持两种类型的 dataset

  • map-style datasets
  • iterable-style datasets

对于 map-style.dataset 类型,它需要 _getitem__() and __len__() 这两个函数

下面我们以原始数据 ‘abcdefg’ 为例进行说明。注意两个函数的编写,以及 torch.utils.data.DataLoader() 的参数变化

将’abcdefg’顺序遍历

import torch
import torch.utils.data 
class ExampleDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = "abcdefg"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    return self.data[idx]
  
  def __len__(self): # What is the length of the dataset
    return len(self.data)

dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)
for datapoint in dataloader:
  print(datapoint)
  
  
------output----------
['a']
['b']
['c']
['d']
['e']
['f']
['g']

shuffle=True,进行打乱,随机取出

import torch
import torch.utils.data 
class ExampleDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = "abcdefg"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    return self.data[idx]
  
  def __len__(self): # What is the length of the dataset
    return len(self.data)

dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 1)
for datapoint in dataloader:
  print(datapoint)
  
------output----------------
['f']
['a']
['d']
['e']
['c']
['g']
['b']

改变batch_size

import torch
import torch.utils.data 
class ExampleDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = "abcdefg"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    return self.data[idx]
  
  def __len__(self): # What is the length of the dataset
    return len(self.data)

dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = True,batch_size = 2)
for datapoint in dataloader:
  print(datapoint)


-----------output-------------
['d', 'c']
['f', 'b']
['e', 'a']
['g']

改写_getitem__()and __len__()以达到自己想要的结果

import torch
import torch.utils.data 
class ExampleDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = "abcdefg"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    return self.data[idx], self.data[idx].upper()
  
  def __len__(self): # What is the length of the dataset
    return len(self.data)

dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = False,batch_size = 2)
for datapoint in dataloader:
  print(datapoint)


-----------output-----------
[('a', 'b'), ('A', 'B')]
[('c', 'd'), ('C', 'D')]
[('e', 'f'), ('E', 'F')]
[('g',), ('G',)]
import torch.utils.data 
class ExampleDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = "abcdefg"
  
  def __getitem__(self,idx): # if the index is idx, what will be the data?
    if idx >= len(self.data): # if the index >= 26, return upper case letter
      return self.data[idx%7].upper()
    else: # if the index < 26, return lower case, return lower case letter
      return self.data[idx]
  
  def __len__(self): # What is the length of the dataset
    return 2 * len(self.data) # The length is now twice as large

dataset1 = ExampleDataset() # create the dataset
dataloader = torch.utils.data.DataLoader(dataset = dataset1,shuffle = False,batch_size = 2)
for datapoint in dataloader:
  print(datapoint)

-----------output------------
['a', 'b']
['c', 'd']
['e', 'f']
['g', 'A']
['B', 'C']
['D', 'E']
['F', 'G']

带有 transform 的读取本地图片的 Dataset

DATA_DIR = 'data/CIFAR-10'
DATABASE_FILE = 'database_img.txt'
DATABASE_LABEL = 'database_label.txt'

# 一般不同网络架构可能对输入的图片数据有格式要求,可以在此处做处理
# 当然是用 transforms 的操作除了满足网络输入的需求,同样还可以用作数据加强
 transformations = transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),  
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
 dset_database = ExampleDataset(
      DATA_DIR, DATABASE_FILE, DATABASE_LABEL, transformations)
class ExampleDataset(Dataset):
    def __init__(self, data_path, img_filename, label_filename, transform=None):
        self.img_path = data_path
        self.transform = transform
        # reading img file from file
        img_filepath = os.path.join(data_path, img_filename)
        fp = open(img_filepath, 'r')
        self.img_filename = [x.strip() for x in fp]
        fp.close()
        label_filepath = os.path.join(data_path, label_filename)
        fp_label = open(label_filepath, 'r')
        labels = [int(x.strip()) for x in fp_label]
        fp_label.close()
        self.label = labels

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_path, self.img_filename[index]))
        img = img.convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        label = torch.LongTensor([self.label[index]])
        return img, label, index
    def __len__(self):
        return len(self.img_filename)

参考拓展:https://www.daimajiaoliu.com/daima/4ede05ecd1003fc

最后

以上就是健康自行车为你收集整理的DataLoader的使用的全部内容,希望文章能够帮你解决DataLoader的使用所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(63)

评论列表共有 0 条评论

立即
投稿
返回
顶部