我是靠谱客的博主 孝顺泥猴桃,最近开发中收集的这篇文章主要介绍pytorch数据导入以及预处理,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

Pythorch 数据初始化

pytorch在数据从原始数据集里面获取以后(一般处理成numpy数组),需要以下步骤:

1.构造DataClass

torch.utils.data.Dataset是一个表示数据的抽象类,在构造自己的数据集时候,首先应该构造一个他的子类,并且该子类需要重写下面两个方法:

__len__,这个方法在使用len(Dataset)时候会被调动,用于返回数据集Dataset数据集条目数
__getitem__,该方法用于支持Dataset[i]这样的索引操作一般来讲,在__init__函数中载入比较小的数据,像比较大的数据,如图像等可以在__getitem__里面再载入,这样会节省内存开销。

如下栗子:
class Mydataset(Dataset):
    """ 自己的数据集"""

    def __init__(self, labelFiles, root_dir, transform=None):
        """
        Args:
            labels (string): 此处是labels的文件路径.
            root_dir (string): 此处是图像文件路径.
            transform (callable, optional): 应用与图像之上的各种变换,诸如数据增强等。.
        """
        self.labels = read(labelFiles)  #这里加载了标签文件
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        #TODO:获取一个指定的样本,组合成字典
        sample={"image":image,"labels":labels}
        
        #如果有变换则加变换
        if self.transform:
            sample = self.transform(sample)

        return sample
2.数据变换以及预处理:transforms

通常的预处理有:1.改变图像的尺寸,2.数据增强(随机切片,镜像…),3.数据格式的变换(维度变化)这里定义的transforms对象是一个可调用的类,当该类的对象被调用的时候,方法:__call__将被调用,所以这个类一般都要实现这个方法。
下面给出一个栗子:

class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):    #这个在初始化时候调用,一般做传参用
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):         #当对象被当做函数使用时候调用,输入一个样本对,返回处理后的样本对(字典)
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'label': landmarks}

3.多种transforms组合使用

可以使用torchversion.transforms.Compose类来完成组合,这个类是一个可调用的类,使用方法同样是实例化以后当做函数一样使用,具体方法如下:

scale = Rescale()       #这里有两个变换
crop = RandomCrop()
composed = transforms.Compose([scale,crop])
dataset=Mydataset("labels路径","图片路径",transform=composed)
#这里的dataset是已经处理过的数据。可以通过索引来访问。

Python数据载入

训练模型时候,需要的数据通常都是以miniBatch方式载入,而且Pytorch也同样只能以batch方式载入数据,为此,需要我们把以上处理好的数据做以下处理:
1.组合为Batch
2.打乱数据
3.使用并行处理工具:multprocessing载入数据
torch.utils.data.DataLoader函数可以完成以上工作,其完整的参数如下:

class torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=<function default_collate>,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None)

该函数返回一个迭代器,举个使用的栗子:


dataloader=DataLoader(Mydataset,batch_size=1,shuffle=True,num_workers=2)
for i_batch,sample_batched in enumrate(dataloader):
    images=sample_batched['image']
    labels=sample_batched['label']
    #TODO:训练模型
References:Pytorch Documentations

最后

以上就是孝顺泥猴桃为你收集整理的pytorch数据导入以及预处理的全部内容,希望文章能够帮你解决pytorch数据导入以及预处理所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(43)

评论列表共有 0 条评论

立即
投稿
返回
顶部