概述
PyTorch 中自定义数据集的读取方法小结
作者: PyTorch 中文网 发布: 2018年9月20日 5,922阅读 0评论
虽然说网上关于 PyTorch 数据集读取的文章和教程多的很,但总觉得哪里不对,尤其是对新手来说,可能需要很长一段时间来钻研和尝试。所以这里我们 PyTorch 中文网为大家总结常用的几种自定义数据集(Custom Dataset)的读取方式(采用 Dataloader)。
本文将涉及以下几个方面:
- 自定义数据集基础方法
- 使用 Torchvision Transforms
- 换一种方法使用 Torchvision Transforms
- 结合 Pandas 读取 csv 文件
- 结合 Pandas 使用
__getitem__()
- 使用 Dataloader 读取自定义数据集
文章目录 [隐藏]
- 1 自定义数据集基础方法
- 2 使用 Torchvision Transforms
- 3 换一种方法使用 Torchvision Transforms
- 4 结合 Pandas 读取 csv 文件
- 5 结合 Pandas 使用 __getitem__()
- 6 使用 Dataloader 读取自定义数据集
自定义数据集基础方法
首先要创建一个 Dataset 类:
from torch.utils.data.dataset import Dataset class MyCustomDataset(Dataset): def __init__(self, ...): # stuff def __getitem__(self, index): # stuff return (img, label) def __len__(self): return count
1 2 3 4 5 6 7 8 9 10 11 12 | from torch.utils.data.dataset import Dataset
class MyCustomDataset(Dataset): def __init__(self, ...): # stuff
def __getitem__(self, index): # stuff return (img, label)
def __len__(self): return count |
这个代码中:
__init__()
一些初始化过程写在这里__len__()
返回所有数据的数量__getitem__()
返回数据和标签,可以这样显示调用:
img, label = MyCustomDataset.__getitem__(99)
1 | img, label = MyCustomDataset.__getitem__(99) |
使用 Torchvision Transforms
Transform 最常见的使用方法是:
from torch.utils.data.dataset import Dataset from torchvision import transforms class MyCustomDataset(Dataset): def __init__(self, ..., transforms=None): # stuff ... self.transforms = transforms def __getitem__(self, index): # stuff ... data = # 一些读取的数据 if self.transforms is not None: data = self.transforms(data) # 如果 transform 不为 None,则进行 transform 操作 return (img, label) def __len__(self): return count if __name__ == '__main__': # 定义我们的 transforms (1) transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()]) # 创建 dataset custom_dataset = MyCustomDataset(..., transformations)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | from torch.utils.data.dataset import Dataset from torchvision import transforms
class MyCustomDataset(Dataset): def __init__(self, ..., transforms=None): # stuff ... self.transforms = transforms
def __getitem__(self, index): # stuff ... data = # 一些读取的数据 if self.transforms is not None: data = self.transforms(data) # 如果 transform 不为 None,则进行 transform 操作 return (img, label)
def __len__(self): return count
if __name__ == '__main__': # 定义我们的 transforms (1) transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()]) # 创建 dataset custom_dataset = MyCustomDataset(..., transformations) |
换一种方法使用 Torchvision Transforms
有些人不喜欢把 transform 操作写在 Dataset 外面(上面代码里的注释 1),所以还有一种写法:
from torch.utils.data.dataset import Dataset from torchvision import transforms class MyCustomDataset(Dataset): def __init__(self, ...): # stuff ... # (2) 一种方法是单独定义 transform self.center_crop = transforms.CenterCrop(100) self.to_tensor = transforms.ToTensor() # (3) 或者写成下面这样 self.transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()]) def __getitem__(self, index): # stuff ... data = #一些读取的数据 # 当第二次调用 transform 时,调用的是 __call__() data = self.center_crop(data) # (2) data = self.to_tensor(data) # (2) # 或者写成下面这样 data = self.trasnformations(data) # (3) # 注意 (2) 和 (3) 中只需要实现一种 return (img, label) def __len__(self): return count if __name__ == '__main__': custom_dataset = MyCustomDataset(...)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | from torch.utils.data.dataset import Dataset from torchvision import transforms
class MyCustomDataset(Dataset): def __init__(self, ...): # stuff ... # (2) 一种方法是单独定义 transform self.center_crop = transforms.CenterCrop(100) self.to_tensor = transforms.ToTensor()
# (3) 或者写成下面这样 self.transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()])
def __getitem__(self, index): # stuff ... data = #一些读取的数据
# 当第二次调用 transform 时,调用的是 __call__() data = self.center_crop(data) # (2) data = self.to_tensor(data) # (2)
# 或者写成下面这样 data = self.trasnformations(data) # (3)
# 注意 (2) 和 (3) 中只需要实现一种 return (img, label)
def __len__(self): return count
if __name__ == '__main__': custom_dataset = MyCustomDataset(...) |
结合 Pandas 读取 csv 文件
假如说我们想从一个 csv 文件中用 Pandas 读取数据。一个 csv 示例如下:
File Name | Label | Extra Operation |
---|---|---|
tr_0.png | 5 | TRUE |
tr_1.png | 0 | FALSE |
tr_1.png | 4 | FALSE |
如果我们需要在自定义数据集里从这个 csv 文件读取文件名,可以这样做:
class CustomDatasetFromImages(Dataset): def __init__(self, csv_path): """ Args: csv_path (string): csv 文件路径 img_path (string): 图像文件所在路径 transform: transform 操作 """ # Transforms self.to_tensor = transforms.ToTensor() # 读取 csv 文件 self.data_info = pd.read_csv(csv_path, header=None) # 文件第一列包含图像文件的名称 self.image_arr = np.asarray(self.data_info.iloc[:, 0]) # 第二列是图像的 label self.label_arr = np.asarray(self.data_info.iloc[:, 1]) # 第三列是决定是否进行额外操作 self.operation_arr = np.asarray(self.data_info.iloc[:, 2]) # 计算 length self.data_len = len(self.data_info.index) def __getitem__(self, index): # 从 pandas df 中得到文件名 single_image_name = self.image_arr[index] # 读取图像文件 img_as_img = Image.open(single_image_name) # 检查需不需要额外操作 some_operation = self.operation_arr[index] # 如果需要额外操作 if some_operation: # ... # ... pass # 把图像转换成 tensor img_as_tensor = self.to_tensor(img_as_img) # 得到图像的 label single_image_label = self.label_arr[index] return (img_as_tensor, single_image_label) def __len__(self): return self.data_len if __name__ == "__main__": custom_mnist_from_images = CustomDatasetFromImages('../data/mnist_labels.csv')
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | class CustomDatasetFromImages(Dataset): def __init__(self, csv_path): """ Args: csv_path (string): csv 文件路径 img_path (string): 图像文件所在路径 transform: transform 操作 """ # Transforms self.to_tensor = transforms.ToTensor() # 读取 csv 文件 self.data_info = pd.read_csv(csv_path, header=None) # 文件第一列包含图像文件的名称 self.image_arr = np.asarray(self.data_info.iloc[:, 0]) # 第二列是图像的 label self.label_arr = np.asarray(self.data_info.iloc[:, 1]) # 第三列是决定是否进行额外操作 self.operation_arr = np.asarray(self.data_info.iloc[:, 2]) # 计算 length self.data_len = len(self.data_info.index)
def __getitem__(self, index): # 从 pandas df 中得到文件名 single_image_name = self.image_arr[index] # 读取图像文件 img_as_img = Image.open(single_image_name)
# 检查需不需要额外操作 some_operation = self.operation_arr[index] # 如果需要额外操作 if some_operation: # ... # ... pass # 把图像转换成 tensor img_as_tensor = self.to_tensor(img_as_img)
# 得到图像的 label single_image_label = self.label_arr[index]
return (img_as_tensor, single_image_label)
def __len__(self): return self.data_len
if __name__ == "__main__": custom_mnist_from_images = CustomDatasetFromImages('../data/mnist_labels.csv') |
结合 Pandas 使用 __getitem__()
另一种情况是 csv 文件中保存了我们需要的图像文件的像素值(比如有些 MNIST 教程就是这样的)。我们需要改动一下 __getitem__()
函数。
Label | pixel_1 | pixel_2 | … |
---|---|---|---|
1 | 50 | 99 | … |
0 | 21 | 223 | … |
9 | 44 | 112 | … |
代码如下:
class CustomDatasetFromCSV(Dataset): def __init__(self, csv_path, height, width, transforms=None): """ Args: csv_path (string): csv 文件路径 height (int): 图像高度 width (int): 图像宽度 transform: transform 操作 """ self.data = pd.read_csv(csv_path) self.labels = np.asarray(self.data.iloc[:, 0]) self.height = height self.width = width self.transforms = transform def __getitem__(self, index): single_image_label = self.labels[index] # 读取所有像素值,并将 1D array ([784]) reshape 成为 2D array ([28,28]) img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype('uint8') # 把 numpy array 格式的图像转换成灰度 PIL image img_as_img = Image.fromarray(img_as_np) img_as_img = img_as_img.convert('L') # 将图像转换成 tensor if self.transforms is not None: img_as_tensor = self.transforms(img_as_img) # 返回图像及其 label return (img_as_tensor, single_image_label) def __len__(self): return len(self.data.index) if __name__ == "__main__": transformations = transforms.Compose([transforms.ToTensor()]) custom_mnist_from_csv = CustomDatasetFromCSV('../data/mnist_in_csv.csv', 28, 28, transformations)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | class CustomDatasetFromCSV(Dataset): def __init__(self, csv_path, height, width, transforms=None): """ Args: csv_path (string): csv 文件路径 height (int): 图像高度 width (int): 图像宽度 transform: transform 操作 """ self.data = pd.read_csv(csv_path) self.labels = np.asarray(self.data.iloc[:, 0]) self.height = height self.width = width self.transforms = transform
def __getitem__(self, index): single_image_label = self.labels[index] # 读取所有像素值,并将 1D array ([784]) reshape 成为 2D array ([28,28]) img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype('uint8') # 把 numpy array 格式的图像转换成灰度 PIL image img_as_img = Image.fromarray(img_as_np) img_as_img = img_as_img.convert('L') # 将图像转换成 tensor if self.transforms is not None: img_as_tensor = self.transforms(img_as_img) # 返回图像及其 label return (img_as_tensor, single_image_label)
def __len__(self): return len(self.data.index)
if __name__ == "__main__": transformations = transforms.Compose([transforms.ToTensor()]) custom_mnist_from_csv = CustomDatasetFromCSV('../data/mnist_in_csv.csv', 28, 28, transformations) |
使用 Dataloader 读取自定义数据集
PyTorch 中的 Dataloader 只是调用 __getitem__()
方法并组合成 batch,我们可以这样调用:
... if __name__ == "__main__": # 定义 transforms transformations = transforms.Compose([transforms.ToTensor()]) # 自定义数据集 custom_mnist_from_csv = CustomDatasetFromCSV('../data/mnist_in_csv.csv', 28, 28, transformations) # 定义 data loader mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv, batch_size=10, shuffle=False) for images, labels in mn_dataset_loader: # 将数据传给网络模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | ... if __name__ == "__main__": # 定义 transforms transformations = transforms.Compose([transforms.ToTensor()]) # 自定义数据集 custom_mnist_from_csv = CustomDatasetFromCSV('../data/mnist_in_csv.csv', 28, 28, transformations) # 定义 data loader mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv, batch_size=10, shuffle=False)
for images, labels in mn_dataset_loader: # 将数据传给网络模型 |
需要注意的是使用多卡训练时,PyTorch dataloader 会将每个 batch 平均分配到各个 GPU。所以如果 batch size 过小,可能发挥不了多卡的效果。
最后
以上就是失眠草莓为你收集整理的PyTorch 中自定义数据集的读取方法小结PyTorch 中自定义数据集的读取方法小结的全部内容,希望文章能够帮你解决PyTorch 中自定义数据集的读取方法小结PyTorch 中自定义数据集的读取方法小结所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复