我是靠谱客的博主 还单身冬天,最近开发中收集的这篇文章主要介绍【PyTorch】数据加载官方数据集自定义数据加载numpy类型数据加载,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

  • pytorch使用torch.utils.data对常用的数据加载进行封装,可以实现多线程预读取批量加载
  • 主要包括两个方面:1)把数据包装成Dataset类;2)用DataLoader加载。
  • TensorDataset可以直接接受Tensor类型的输入,并用DataLoader进行加载;省去自定义的过程。

官方数据集

  • torchvision中实现了一些常用的数据集,可以通过torchvision.datasets直接调用。如:MNIST,COCO,Captions,Detection,LSUN,ImageFolder,Imagenet-12,CIFAR,STL10,SVHN,PhotoTour。
  • torchvision.transforms提供了许多图像操作,可以很方便的进行数据增强。

一个典型的CIFAR10数据加载过程如下:

import torchvision.transforms as transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 可以加入更多数据增强处理
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

自定义数据加载

如果不使用官方数据集,想要加载自己的数据集,需要自定义一个dataset类。它需要继承torch.utils.data.Dataset类,并实现__getitem__()__len__()两个成员方法。

下面是一个自定义的视频数据集的例子:

from torch.utils.data import Dataset
class FrameDataset(Dataset):
# 初始化时候要有一个数据加载,可以是数据路径的列表,或者直接把数据全加载进来。
# 后者实际上没有用到批量加载的功能,需要注意内存占用。
def __init__(self, data_dir, transform):
with open(data_dir, 'r') as fr:
reader = csv.reader(fr)
self.video_files = [video for video, label in reader]
self.transform = transform
print("dataset size: ", len(self.video_files))
def __getitem__(self, index):
video_file = self.video_files[index]
# 读取视频的帧并返回
return imgs, num_img
def __len__(self):
return len(self.video_files)

注意:

  • 如果数据类型是图像、视频,我们可以把原始数据保存在一个文件夹中,再用一个列表保存图像或视频的路径。这样数据集初始化时候加载的其实只是路径列表,在训练和测试时才会分批把原始数据读入。
  • 如果原始数据是直接以数字形式存储在一个文件中,无法通过索引单个读取,可以在初始化时候把整个矩阵读入,然后每次getitem时返回其中一行。

自定义的一大优点是处理更灵活,例如对于视频或文本数据,getitem函数中返回的帧序列或句子序列往往是长度不固定的,默认情况下DataLoaderstack时会出错,这时可以用collate_fn指定batch数据的连接方式:

def collate_fn(batch):
imgs, num_img = zip(*batch)
return torch.cat(imgs), num_img

然后就可以正常加载数据了:

dataset = FrameDataset(csv_file, transform=tfms)
videoloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

numpy类型数据加载

如果数据需要一次性全部读入,而且不需要额外的复杂处理的话可以不用自定义数据集Dataset类。

比如通常情况下,我们的输入可以很容易处理成一个numpy类型。这时可以不用定义Dataset类,直接使用TensorDataset,只要把读入的数据转化成一个tensor传入即可。

random_split是一个可以自动划分数据集的函数,实现随机不重复划分的功能。

from torch.utils.data import TensorDataset,DataLoader,random_split
dataset = TensorDataset(torch.from_numpy(data))
n_train = int(len(dataset) * 0.9)
n_test = len(dataset) - n_train
trainset, testset = random_split(dataset, [n_train, n_test])
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

最后

以上就是还单身冬天为你收集整理的【PyTorch】数据加载官方数据集自定义数据加载numpy类型数据加载的全部内容,希望文章能够帮你解决【PyTorch】数据加载官方数据集自定义数据加载numpy类型数据加载所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(52)

评论列表共有 0 条评论

立即
投稿
返回
顶部