概述
Pythorch 数据初始化
pytorch在数据从原始数据集里面获取以后(一般处理成numpy数组),需要以下步骤:
1.构造DataClass
torch.utils.data.Dataset
是一个表示数据的抽象类,在构造自己的数据集时候,首先应该构造一个他的子类,并且该子类需要重写下面两个方法:
__len__
,这个方法在使用len(Dataset)时候会被调动,用于返回数据集Dataset数据集条目数
__getitem__
,该方法用于支持Dataset[i]这样的索引操作一般来讲,在__init__
函数中载入比较小的数据,像比较大的数据,如图像等可以在__getitem__
里面再载入,这样会节省内存开销。
如下栗子:
class Mydataset(Dataset):
""" 自己的数据集"""
def __init__(self, labelFiles, root_dir, transform=None):
"""
Args:
labels (string): 此处是labels的文件路径.
root_dir (string): 此处是图像文件路径.
transform (callable, optional): 应用与图像之上的各种变换,诸如数据增强等。.
"""
self.labels = read(labelFiles) #这里加载了标签文件
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
#TODO:获取一个指定的样本,组合成字典
sample={"image":image,"labels":labels}
#如果有变换则加变换
if self.transform:
sample = self.transform(sample)
return sample
2.数据变换以及预处理:transforms
通常的预处理有:1.改变图像的尺寸,2.数据增强(随机切片,镜像…),3.数据格式的变换(维度变化)这里定义的transforms对象是一个可调用的类,当该类的对象被调用的时候,方法:__call__
将被调用,所以这个类一般都要实现这个方法。
下面给出一个栗子:
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size): #这个在初始化时候调用,一般做传参用
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample): #当对象被当做函数使用时候调用,输入一个样本对,返回处理后的样本对(字典)
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
# h and w are swapped for landmarks because for images,
# x and y axes are axis 1 and 0 respectively
landmarks = landmarks * [new_w / w, new_h / h]
return {'image': img, 'label': landmarks}
3.多种transforms组合使用
可以使用torchversion.transforms.Compose
类来完成组合,这个类是一个可调用的类,使用方法同样是实例化以后当做函数一样使用,具体方法如下:
scale = Rescale() #这里有两个变换
crop = RandomCrop()
composed = transforms.Compose([scale,crop])
dataset=Mydataset("labels路径","图片路径",transform=composed)
#这里的dataset是已经处理过的数据。可以通过索引来访问。
Python数据载入
训练模型时候,需要的数据通常都是以miniBatch方式载入,而且Pytorch也同样只能以batch方式载入数据,为此,需要我们把以上处理好的数据做以下处理:
1.组合为Batch
2.打乱数据
3.使用并行处理工具:multprocessing
载入数据
torch.utils.data.DataLoader
函数可以完成以上工作,其完整的参数如下:
class torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=<function default_collate>,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None)
该函数返回一个迭代器,举个使用的栗子:
dataloader=DataLoader(Mydataset,batch_size=1,shuffle=True,num_workers=2)
for i_batch,sample_batched in enumrate(dataloader):
images=sample_batched['image']
labels=sample_batched['label']
#TODO:训练模型
References:Pytorch Documentations
最后
以上就是孝顺泥猴桃为你收集整理的pytorch数据导入以及预处理的全部内容,希望文章能够帮你解决pytorch数据导入以及预处理所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复