我是靠谱客的博主 疯狂毛衣,最近开发中收集的这篇文章主要介绍pytorch 加载并批处理数据集,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

一、加载数据集

加载数据集需要继承torch.utils.data 的 Dataset类,并实现 __len__和__getitem__方法。其中__len__返回数据集总数,

__getitem__返回指定的数的矩阵和标签。

二、数据集批处理

需要torch.utils.data 的 DataLoader类,有batch_size(批处理尺寸),num_workers(多进程),Sampler(多种取样器)等参数

from torch.utils.data import Dataset
#加载数据集
from torch.utils.data import DataLoader #数据管道,批处理数据
import glob
import os
import numpy as np
from PIL import Image
#加载数据集
class Picture(Dataset):
def __init__(self,paths,size = (10,10)):
self.paths = glob.glob(paths)
self.size = size
"""图片总数量"""
def __len__(self):
return len(self.paths)
"""根据item得到相应的图片矩阵"""
def __getitem__(self, item):
img = np.asarray(Image.open(self.paths[item]).resize(self.size))
lable = self.paths[item].split('\')[-1]
return img,lable
if __name__ == '__main__':
"""图片路径"""
root_path = os.path.join(os.path.dirname(os.getcwd()), "cap")
pic_paths = root_path + '\*.jpg'
"""实例化"""
picture = Picture(pic_paths)
"""数据集批量处理"""
dataloader = DataLoader(picture,batch_size=32,num_workers=2,timeout=2)
for a,b in dataloader:
print(b,a)

 

最后

以上就是疯狂毛衣为你收集整理的pytorch 加载并批处理数据集的全部内容,希望文章能够帮你解决pytorch 加载并批处理数据集所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(39)

评论列表共有 0 条评论

立即
投稿
返回
顶部