我是靠谱客的博主 英俊西牛,这篇文章主要介绍pytorch学习之旅(一)——自定义数据读取,现在分享给大家,希望可以做个参考。

最近在研究显著性检测,学着使用pytorch框架,以下纯属个人见解,如有错误请指出

(一)自定义数据读取

首先官方案例:

PyTorch读取图片,主要是通过Dataset类,所以先简单了解一下Dataset类。Dataset类作为所有的datasets的基类存在,所有的datasets都需要继承它,类似于C++中的虚基类。

复制代码
1
2
3
4
5
6
7
8
class Dataset(object): def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])

这里重点看 getitem函数,getitem接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。

然而,如何制作这个list呢,通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。 那么读取自己数据的基本流程就是: 1. 制作存储了图片的路径和标签信息的txt 2. 将这些信息转化为list,该list每一个元素对应一个样本 3. 通过getitem函数,读取数据和标签,并返回数据和标签

在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,其实触发去读取图片这些操作的是DataLoader里的iter(self),后面会详细讲解读取过程。在本小节,主要讲Dataset子类。 因此,要让PyTorch能读取自己的数据集,只需要两步: 1. 制作图片数据的索引 2. 构建Dataset子类

下面是我做显著性检测时自定义的(我纠结label的定义足足两天,总算明白了:label 在官网给出的是分类问题,因此标签是对应的类别要么是文字要么手写体表示的数字,而我需要的是图片,这里就发一下他们之间的对比,就很容易理解到pytorch这个自定义的类是有多么方便)
下面是分类问题常用模板(显著性检测用的比较少,所以我就没有运行过代码,仅作为对比帮助理解)

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from PIL import Image from torch.utils.data import Dataset class MyDataset(Dataset): def __init__(self, txt_path, transform = None, target_transform = None): fh = open(txt_path, 'r') imgs = [] for line in fh: line = line.rstrip() words = line.split() imgs.append((words[0], int(words[1]))) self.imgs = imgs self.transform = transform self.target_transform = target_transform def __getitem__(self, index): fn, label = self.imgs[index] img = Image.open(fn).convert('RGB') if self.transform is not None: img = self.transform(img) return img, label def __len__(self): return len(self.imgs)

下面是我自己的数据读取,最后生成一个dataset的类

主要思路将地址对应的image,label,通过地址列表形式,一个一个的导入,不过也有一个弊端,这个只能一张图片的输入到网络中,正好我们的batch_size = 1,最后一个代码我将用官方给出的例子改写,这样方便后续设置出我们需要的batch_size, 这样还有一个坏处,我的内存会溢出,一次性把全部图片读取出来,内存不够用,后续可以考虑把图片一张一张的读取,然后再一张一张的送进去,这样内存应该会轻松些

(最后的代码由于时间紧张,后续再补,其实很简单的说一下思路:

1.在__init__()中改写代码,最后返回index
2.打开image和label存放的txt,读取里面的地址生成list,两个list具有相同的index,最后return index就好,比较简单
3.在__getitem__()改写代码,把返回的index打开相应的地址,把对应的image和label转换成tensor,同时返回
4__len__()不变都行
可以在我的代码基础上,不相关的模块改写进去就好

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def readtxt_into_list(address): file = open(address) addressMat = [] namelMat = [] for line in file.readlines(): curLine = line.strip().split(" ") addressMat.append(curLine[0]) namelMat.append(curLine[2]) number_of_lines = len(namelMat) # 返回值包括图片地址名,文件名,已经这个list的大小 return addressMat, namelMat,number_of_lines def img_tensor(address): img = Image.open(address).convert('RGB') img_np1 = numpy.transpose(img, (2, 0, 1)) img3_tensor = torch.Tensor(img_np1) four_dims = img3_tensor.unsqueeze(0) return four_dims # 取出lable和img的相关信息 dataset = [] # 用来存放lable 和img 的tensor 四维格式(B x C x H x W) add_img = 'F:dataMSRA10K_Imgs_GTdir.txt' address_img, name_img,lines = readtxt_into_list(add_img) add_lable = 'F:dataMSRA10K_Imgs_GTdir1.txt' address_lable, name_lable,lines = readtxt_into_list(add_lable) for index in range(lines): # 取出地址 img_add = str(address_img[index] + name_img[index]) address1 = img_add lable_add = str(address_lable[index] + name_lable[index]) address2 = lable_add # 读取图片转化成tensor input =img_tensor(address1) lable = img_tensor(address2) dataset.append([input, lable])

有问题,有错误,请指正,大家一起学习一起进步!

最后

以上就是英俊西牛最近收集整理的关于pytorch学习之旅(一)——自定义数据读取的全部内容,更多相关pytorch学习之旅(一)——自定义数据读取内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(79)

评论列表共有 0 条评论

立即
投稿
返回
顶部