概述
本文将涉及以下几个方面:
- 自定义数据集基础方法
- 使用 Torchvision Transforms
- 换一种方法使用 Torchvision Transforms
- 结合 Pandas 读取 csv 文件
- 结合 Pandas 使用__getitem__()
- 使用 Dataloader 读取自定义数据集
自定义数据集基础方法
首先要创建一个 Dataset 类:
from torch.utils.data.dataset import Dataset
class MyCustomDataset(Dataset):
def __init__(self, ...):
# stuff
def __getitem__(self, index):
# stuff
return (img, label)
def __len__(self):
return count
这个代码中:
- __init__() 一些初始化过程写在这里
- __len__() 返回所有数据的数量
- __getitem__() 返回数据和标签,可以这样显示调用:
img, label = MyCustomDataset.__getitem__(99)
使用 Torchvision Transforms
Transform 最常见的使用方法是:
from torch.utils.data.dataset import Dataset
from torchvision import transforms
class MyCustomDataset(Dataset):
def __init__(self, ..., transforms=None):
# stuff
...
self.transforms = transforms
def __getitem__(self, index):
# stuff
...
data = # 一些读取的数据
if self.transforms is not None:
data = self.transforms(data)
# 如果 transform 不为 None,则进行 transform 操作
return (img, label)
def __len__(self):
return count
if __name__ == '__main__':
# 定义我们的 transforms (1)
transformations = transforms.Compose([transforms.CenterCrop(100), transforms.ToTensor()])
# 创建 dataset
custom_dataset = MyCustomDataset(..., transformations)
换一种方法使用 Torchvision Transforms
有些人不喜欢把 transform 操作写在 Dataset 外面(上面代码里的注释 1),所以还有一种写法:
from torch.utils.data.dataset import Dataset
from torchvision import transforms
class MyCustomDataset(Dataset):
def __init__(self, ...):
# stuff
...
# (2) 一种方法是单独定义 transform
self.center_crop = transforms.CenterCrop(100)
self.to_tensor = transforms.ToTensor()
# (3) 或者写成下面这样
self.transformations =
transforms.Compose([transforms.CenterCrop(100),
transforms.ToTensor()])
def __getitem__(self, index):
# stuff
...
data = #一些读取的数据
# 当第二次调用 transform 时,调用的是 __call__()
data = self.center_crop(data) # (2)
data = self.to_tensor(data) # (2)
# 或者写成下面这样
data = self.trasnformations(data) # (3)
# 注意 (2) 和 (3) 中只需要实现一种
return (img, label)
def __len__(self):
return count
if __name__ == '__main__':
custom_dataset = MyCustomDataset(...)
结合 Pandas 读取 csv 文件
假如说我们想从一个 csv 文件中用 Pandas 读取数据。一个 csv 示例如下:
File Name Label Extra Operation
tr_0.png 5 TRUE
tr_1.png 0 FALSE
tr_1.png 4 FALSE
如果我们需要在自定义数据集里从这个 csv 文件读取文件名,可以这样做:
class CustomDatasetFromImages(Dataset):
def __init__(self, csv_path):
"""
Args:
csv_path (string): csv 文件路径
img_path (string): 图像文件所在路径
transform: transform 操作
"""
# Transforms
self.to_tensor = transforms.ToTensor()
# 读取 csv 文件
self.data_info = pd.read_csv(csv_path, header=None)
# 文件第一列包含图像文件的名称
self.image_arr = np.asarray(self.data_info.iloc[:, 0])
# 第二列是图像的 label
self.label_arr = np.asarray(self.data_info.iloc[:, 1])
# 第三列是决定是否进行额外操作
self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
# 计算 length
self.data_len = len(self.data_info.index)
def __getitem__(self, index):
# 从 pandas df 中得到文件名
single_image_name = self.image_arr[index]
# 读取图像文件
img_as_img = Image.open(single_image_name)
# 检查需不需要额外操作
some_operation = self.operation_arr[index]
# 如果需要额外操作
if some_operation:
# ...
# ...
pass
# 把图像转换成 tensor
img_as_tensor = self.to_tensor(img_as_img)
# 得到图像的 label
single_image_label = self.label_arr[index]
return (img_as_tensor, single_image_label)
def __len__(self):
return self.data_len
if __name__ == "__main__":
custom_mnist_from_images =
CustomDatasetFromImages('../data/mnist_labels.csv')
结合 Pandas 使用 __getitem__()
另一种情况是 csv 文件中保存了我们需要的图像文件的像素值(比如有些 MNIST 教程就是这样的)。我们需要改动一下 __getitem__() 函数。
Label pixel_1 pixel_2 …
1 50 99 …
0 21 223 …
9 44 112 …
代码如下:
class CustomDatasetFromCSV(Dataset):
def __init__(self, csv_path, height, width, transforms=None):
"""
Args:
csv_path (string): csv 文件路径
height (int): 图像高度
width (int): 图像宽度
transform: transform 操作
"""
self.data = pd.read_csv(csv_path)
self.labels = np.asarray(self.data.iloc[:, 0])
self.height = height
self.width = width
self.transforms = transform
def __getitem__(self, index):
single_image_label = self.labels[index]
# 读取所有像素值,并将 1D array ([784]) reshape 成为 2D array ([28,28])
img_as_np = np.asarray(self.data.iloc[index][1:]).reshape(28,28).astype('uint8')
# 把 numpy array 格式的图像转换成灰度 PIL image
img_as_img = Image.fromarray(img_as_np)
img_as_img = img_as_img.convert('L')
# 将图像转换成 tensor
if self.transforms is not None:
img_as_tensor = self.transforms(img_as_img)
# 返回图像及其 label
return (img_as_tensor, single_image_label)
def __len__(self):
return len(self.data.index)
if __name__ == "__main__":
transformations = transforms.Compose([transforms.ToTensor()])
custom_mnist_from_csv =
CustomDatasetFromCSV('../data/mnist_in_csv.csv', 28, 28, transformations)
使用 Dataloader 读取自定义数据集
PyTorch 中的 Dataloader 只是调用 __getitem__() 方法并组合成 batch,我们可以这样调用:
if __name__ == "__main__":
# 定义 transforms
transformations = transforms.Compose([transforms.ToTensor()])
# 自定义数据集
custom_mnist_from_csv =
CustomDatasetFromCSV('../data/mnist_in_csv.csv',
28, 28,
transformations)
# 定义 data loader
mn_dataset_loader = torch.utils.data.DataLoader(dataset=custom_mnist_from_csv,
batch_size=10,
shuffle=False)
for images, labels in mn_dataset_loader:
# 将数据传给网络模型
需要注意的是使用多卡训练时,PyTorch dataloader 会将每个 batch 平均分配到各个 GPU。所以如果 batch size 过小,可能发挥不了多卡的效果。
最后
以上就是内向电话为你收集整理的Pytorch 读取自定义数据集的全部内容,希望文章能够帮你解决Pytorch 读取自定义数据集所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复