我是靠谱客的博主 内向电话,最近开发中收集的这篇文章主要介绍Pytorch 读取自定义数据集,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

本文将涉及以下几个方面:

  • 自定义数据集基础方法
  • 使用 Torchvision Transforms
  • 换一种方法使用 Torchvision Transforms
  • 结合 Pandas 读取 csv 文件
  • 结合 Pandas 使用__getitem__()
  • 使用 Dataloader 读取自定义数据集

自定义数据集基础方法

首先要创建一个 Dataset 类:

from torch.utils.data.dataset import Dataset
 
class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        
    def __getitem__(self, index):
        # stuff
        return (img, label)
 
    def __len__(self):
        return count

这个代码中:

  • __init__() 一些初始化过程写在这里
  • __len__() 返回所有数据的数量
  • __getitem__() 返回数据和标签,可以这样显示调用:
img, label = MyCustomDataset.__getitem__(99)

使用 Torchvision Transforms

Transform 最常见的使用方法是:

from torch.utils.data.dataset import Dataset
from torchvision import transforms
 
class MyCustomDataset(Dataset):
    def __init__(self, ..., transforms=None):
        # stuff
        ...
        self.transforms = transforms
        
    def __getitem__(self, index):
        # stuff
        ...
        data = # 一些读取的数据
        if self.transforms is not None:
            data = self.transforms(data)
        # 如果 transform 不为 None,则进行 transform 操作
        return (img, label)
 
    def __len__(self):
        return count 
        
if __name__ == '__main__':
    # 定义我们的 transforms (1)
    transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()])
    # 创建 dataset
    custom_dataset = MyCustomDataset(..., transformations)

换一种方法使用 Torchvision Transforms

有些人不喜欢把 transform 操作写在 Dataset 外面(上面代码里的注释 1),所以还有一种写法:

from torch.utils.data.dataset import Dataset
from torchvision import transforms
 
class MyCustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        ...
        # (2) 一种方法是单独定义 transform
        self.center_crop = transforms.CenterCrop(100)
        self.to_tensor = transforms.ToTensor()
        
        # (3) 或者写成下面这样 
        self.transformations = 
            transforms.Compose([transforms.CenterCrop(100),
                                transforms.ToTensor()])
        
    def __getitem__(self, index):
        # stuff
        ...
        data = #一些读取的数据
        
        # 当第二次调用 transform 时,调用的是 __call__()
        data = self.center_crop(data)  # (2)
        data = self.to_tensor(data)  # (2)
        
        # 或者写成下面这样
        data = self.trasnformations(data)  # (3)
        
        # 注意 (2) 和 (3) 中只需要实现一种
        return (img, label)
 
    def __len__(self):
        return count
        
if __name__ == '__main__':
    custom_dataset = MyCustomDataset(...)

结合 Pandas 读取 csv 文件

假如说我们想从一个 csv 文件中用 Pandas 读取数据。一个 csv 示例如下:

File Name			Label		Extra Operation
tr_0.png			  5				TRUE
tr_1.png			  0				FALSE
tr_1.png			  4				FALSE

如果我们需要在自定义数据集里从这个 csv 文件读取文件名,可以这样做:

class CustomDatasetFromImages(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): csv 文件路径
            img_path (string): 图像文件所在路径
            transform: transform 操作
        """
        # Transforms
        self.to_tensor = transforms.ToTensor()
        # 读取 csv 文件
        self.data_info = pd.read_csv(csv_path, header=None)
        # 文件第一列包含图像文件的名称
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # 第二列是图像的 label
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # 第三列是决定是否进行额外操作
        self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
        # 计算 length
        self.data_len = len(self.data_info.index)
 
    def __getitem__(self, index):
        # 从 pandas df 中得到文件名
        single_image_name = self.image_arr[index]
        # 读取图像文件
        img_as_img = Image.open(single_image_name)
 
        # 检查需不需要额外操作
        some_operation = self.operation_arr[index]
        # 如果需要额外操作
        if some_operation:
            # ...
            # ...
            pass
        # 把图像转换成 tensor
        img_as_tensor = self.to_tensor(img_as_img)
 
        # 得到图像的 label
        single_image_label = self.label_arr[index]
 
        return (img_as_tensor, single_image_label)
 
    def __len__(self):
        return self.data_len
 
if __name__ == "__main__":
    custom_mnist_from_images =  
        CustomDatasetFromImages('../data/mnist_labels.csv')

结合 Pandas 使用 __getitem__()

另一种情况是 csv 文件中保存了我们需要的图像文件的像素值(比如有些 MNIST 教程就是这样的)。我们需要改动一下 __getitem__() 函数。

Label	pixel_1		pixel_2		…
 1		 50			  990		 21			  2239		 44		 	  112

代码如下:

class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path, height, width, transforms=None):
        """
        Args:
            csv_path (string): csv 文件路径
            height (int): 图像高度
            width (int): 图像宽度
            transform: transform 操作
        """
        self.data = pd.read_csv(csv_path)
        self.labels = np.asarray(self.data.iloc[:, 0])
        self.height = height
        self.width = width
        self.transforms = transform
 
    def __getitem__(self, index):
        single_image_label = self.labels[index]
        # 读取所有像素值,并将 1D array ([784]) reshape 成为 2D array ([28,28]) 
        img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype('uint8')
 # 把 numpy array 格式的图像转换成灰度 PIL image
        img_as_img = Image.fromarray(img_as_np)
        img_as_img = img_as_img.convert('L')
        # 将图像转换成 tensor
        if self.transforms is not None:
            img_as_tensor = self.transforms(img_as_img)
        # 返回图像及其 label
        return (img_as_tensor, single_image_label)
 
    def __len__(self):
        return len(self.data.index)
        
 
if __name__ == "__main__":
    transformations = transforms.Compose([transforms.ToTensor()])
    custom_mnist_from_csv = 
        CustomDatasetFromCSV('../data/mnist_in_csv.csv', 28, 28, transformations)

使用 Dataloader 读取自定义数据集

PyTorch 中的 Dataloader 只是调用 __getitem__() 方法并组合成 batch,我们可以这样调用:

if __name__ == "__main__":
    # 定义 transforms
    transformations = transforms.Compose([transforms.ToTensor()])
    # 自定义数据集
    custom_mnist_from_csv = 
        CustomDatasetFromCSV('../data/mnist_in_csv.csv',
                             28, 28,
                             transformations)
    # 定义 data loader
    mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv,
                                                    batch_size=10,
                                                    shuffle=False)
    
    for images, labels in mn_dataset_loader:
        # 将数据传给网络模型

需要注意的是使用多卡训练时,PyTorch dataloader 会将每个 batch 平均分配到各个 GPU。所以如果 batch size 过小,可能发挥不了多卡的效果。

最后

以上就是内向电话为你收集整理的Pytorch 读取自定义数据集的全部内容,希望文章能够帮你解决Pytorch 读取自定义数据集所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(81)

评论列表共有 0 条评论

立即
投稿
返回
顶部