概述
一、加载数据集
加载数据集需要继承torch.utils.data 的 Dataset类,并实现 __len__和__getitem__方法。其中__len__返回数据集总数,
__getitem__返回指定的数的矩阵和标签。
二、数据集批处理
需要torch.utils.data 的 DataLoader类,有batch_size(批处理尺寸),num_workers(多进程),Sampler(多种取样器)等参数
from torch.utils.data import Dataset
#加载数据集
from torch.utils.data import DataLoader #数据管道,批处理数据
import glob
import os
import numpy as np
from PIL import Image
#加载数据集
class Picture(Dataset):
def __init__(self,paths,size = (10,10)):
self.paths = glob.glob(paths)
self.size = size
"""图片总数量"""
def __len__(self):
return len(self.paths)
"""根据item得到相应的图片矩阵"""
def __getitem__(self, item):
img = np.asarray(Image.open(self.paths[item]).resize(self.size))
lable = self.paths[item].split('\')[-1]
return img,lable
if __name__ == '__main__':
"""图片路径"""
root_path = os.path.join(os.path.dirname(os.getcwd()), "cap")
pic_paths = root_path + '\*.jpg'
"""实例化"""
picture = Picture(pic_paths)
"""数据集批量处理"""
dataloader = DataLoader(picture,batch_size=32,num_workers=2,timeout=2)
for a,b in dataloader:
print(b,a)
最后
以上就是疯狂毛衣为你收集整理的pytorch 加载并批处理数据集的全部内容,希望文章能够帮你解决pytorch 加载并批处理数据集所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复