我是靠谱客的博主 英俊西牛,最近开发中收集的这篇文章主要介绍pytorch学习之旅(一)——自定义数据读取,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

最近在研究显著性检测,学着使用pytorch框架,以下纯属个人见解,如有错误请指出

(一)自定义数据读取

首先官方案例:

PyTorch读取图片,主要是通过Dataset类,所以先简单了解一下Dataset类。Dataset类作为所有的datasets的基类存在,所有的datasets都需要继承它,类似于C++中的虚基类。

class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])

这里重点看 getitem函数,getitem接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。

然而,如何制作这个list呢,通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。 那么读取自己数据的基本流程就是: 1. 制作存储了图片的路径和标签信息的txt 2. 将这些信息转化为list,该list每一个元素对应一个样本 3. 通过getitem函数,读取数据和标签,并返回数据和标签

在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,其实触发去读取图片这些操作的是DataLoader里的iter(self),后面会详细讲解读取过程。在本小节,主要讲Dataset子类。 因此,要让PyTorch能读取自己的数据集,只需要两步: 1. 制作图片数据的索引 2. 构建Dataset子类

下面是我做显著性检测时自定义的(我纠结label的定义足足两天,总算明白了:label 在官网给出的是分类问题,因此标签是对应的类别要么是文字要么手写体表示的数字,而我需要的是图片,这里就发一下他们之间的对比,就很容易理解到pytorch这个自定义的类是有多么方便)
下面是分类问题常用模板(显著性检测用的比较少,所以我就没有运行过代码,仅作为对比帮助理解)

from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(fn).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs)

下面是我自己的数据读取,最后生成一个dataset的类

主要思路将地址对应的image,label,通过地址列表形式,一个一个的导入,不过也有一个弊端,这个只能一张图片的输入到网络中,正好我们的batch_size = 1,最后一个代码我将用官方给出的例子改写,这样方便后续设置出我们需要的batch_size, 这样还有一个坏处,我的内存会溢出,一次性把全部图片读取出来,内存不够用,后续可以考虑把图片一张一张的读取,然后再一张一张的送进去,这样内存应该会轻松些

(最后的代码由于时间紧张,后续再补,其实很简单的说一下思路:

1.在__init__()中改写代码,最后返回index
2.打开image和label存放的txt,读取里面的地址生成list,两个list具有相同的index,最后return index就好,比较简单
3.在__getitem__()改写代码,把返回的index打开相应的地址,把对应的image和label转换成tensor,同时返回
4__len__()不变都行
可以在我的代码基础上,不相关的模块改写进去就好

def readtxt_into_list(address):
file = open(address)
addressMat = []
namelMat = []
for line in file.readlines():
curLine = line.strip().split(" ")
addressMat.append(curLine[0])
namelMat.append(curLine[2])
number_of_lines = len(namelMat)
# 返回值包括图片地址名,文件名,已经这个list的大小
return
addressMat, namelMat,number_of_lines
def img_tensor(address):
img = Image.open(address).convert('RGB')
img_np1 = numpy.transpose(img, (2, 0, 1))
img3_tensor = torch.Tensor(img_np1)
four_dims = img3_tensor.unsqueeze(0)
return four_dims
# 取出lable和img的相关信息
dataset = []
# 用来存放lable 和img 的tensor 四维格式(B x C x H x W)
add_img = 'F:dataMSRA10K_Imgs_GTdir.txt'
address_img, name_img,lines = readtxt_into_list(add_img)
add_lable = 'F:dataMSRA10K_Imgs_GTdir1.txt'
address_lable, name_lable,lines = readtxt_into_list(add_lable)
for index in range(lines):
# 取出地址
img_add = str(address_img[index] + name_img[index])
address1 = img_add
lable_add = str(address_lable[index] + name_lable[index])
address2 = lable_add
# 读取图片转化成tensor
input =img_tensor(address1)
lable = img_tensor(address2)
dataset.append([input, lable])

有问题,有错误,请指正,大家一起学习一起进步!

最后

以上就是英俊西牛为你收集整理的pytorch学习之旅(一)——自定义数据读取的全部内容,希望文章能够帮你解决pytorch学习之旅(一)——自定义数据读取所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(62)

评论列表共有 0 条评论

立即
投稿
返回
顶部