我是靠谱客的博主 美丽毛巾,这篇文章主要介绍PyTorch学习笔记(三)总结篇 --------自建数据集的载入,现在分享给大家,希望可以做个参考。

前言

经过这几天学习,我算是把数据集这一块给摸清楚了,前面分布分支的学习总是有点模棱两可,不清楚这步到底要干啥,在网上找资料学习时,总是拿的pytorch官网给的数据集,没有针对性和专一性。这里教大家如何使用咱们自己的数据集,当然,在做实验时数据集是通过爬虫来获取的,关于爬虫的相关知识可以留言私信,或者看我第一篇博客哦

一、MyData类的定义

在自建数据集时需要自己去定义一个dataset类来继承torch.utils.data.Dataset

来看代码

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class MyData(Dataset): def __init__(self, root_dir, label_dir, transform=None): # 初始化类,为class提供全局变量 self.transform = transform self.root_dir = root_dir # 根文件位置 self.label_dir = label_dir # 子文件名 self.path = os.path.join(self.root_dir, self.label_dir) # 合并,即具体位置 self.img_path = os.listdir(self.path) # 转换成列表的形式 def __getitem__(self, idx): # 获取列表中每一个图片 img_name = self.img_path[idx] # idx表示下标,即对应位置 img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 每一个图片的位置 img = Image.open(img_item_path) # 调用方法,拿到该图像 img = img.convert("RGB") img = self.transform(img) label = self.label_dir # 标签 return img, label # 返回img 图片 label 标签 def __len__(self): # 返回长度 return len(self.img_path)

这里要注意一下,跟第一篇学习笔记的不同在于第一篇没有定义transform导致返回的就是PIL类型,这边加入了几行代码,目的就是为了返回tensor类型,这里返回的是img,label两个对象

二、数据集的实际运用

上面的类是通法,在任何研究中都可以套用,下面来看在本次实验中的实际运用

复制代码
1
2
3
4
5
tensor_trans = transforms.Compose([transforms.ToTensor(), transforms.Resize([512, 512])]) root_dir = 'D://情绪图片' # 根目录 happy_label_dir = '开心' # 子目录 happy_dataset = MyData(root_dir, happy_label_dir, transform=tensor_trans) # 开心数据集创建完成

这边是通过totensor转换为tensor类型,同时将图片尺寸变为512*512

也算so easy吧

三、DataLoader类

dataloader是用来load数据集,其中batch_size=4是为了每次抓取4张;shuffle是按需求来是否需要打乱,即在等于True的时候是打乱的,False的时候是不打乱的;drop_last是表示在最后一次抓取时不满4个是否需要保留,比如一共10张图片,每次抓取4个,最后一次不满4个可以选择保留或者舍弃

复制代码
1
2
3
4
5
6
7
8
9
10
11
test_loader = DataLoader(dataset=happy_dataset, batch_size=4, shuffle=True, num_workers=0, drop_last=False) # img, label = happy_dataset[0] step = 0 writer = SummaryWriter("dataloader") for epoch in range(2): for data in test_loader: imgs, label = data writer.add_images('{}'.format(epoch), imgs, step) step = step + 1 writer.close()

四、源码

这边是最终的源码,大家可以按自己的要求选取哦

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# -*- coding = utf-8 -*- import torchvision from torch.utils.data import Dataset import cv2 from PIL import Image # 图像处理的库 import os from torch.utils.tensorboard import SummaryWriter from torchvision import transforms from torch.utils.data import DataLoader class MyData(Dataset): def __init__(self, root_dir, label_dir, transform=None): # 初始化类,为class提供全局变量 self.transform = transform self.root_dir = root_dir # 根文件位置 self.label_dir = label_dir # 子文件名 self.path = os.path.join(self.root_dir, self.label_dir) # 合并,即具体位置 self.img_path = os.listdir(self.path) # 转换成列表的形式 def __getitem__(self, idx): # 获取列表中每一个图片 img_name = self.img_path[idx] # idx表示下标,即对应位置 img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 每一个图片的位置 img = Image.open(img_item_path) # 调用方法,拿到该图像 img = img.convert("RGB") img = self.transform(img) label = self.label_dir # 标签 return img, label # 返回img 图片 label 标签 def __len__(self): # 返回长度 return len(self.img_path) tensor_trans = transforms.Compose([transforms.ToTensor(), transforms.Resize([512, 512])]) root_dir = 'D://情绪图片' # 根目录 happy_label_dir = '开心' # 子目录 happy_dataset = MyData(root_dir, happy_label_dir, transform=tensor_trans) # 开心数据集创建完成 # img, label = happy_dataset[2] # 由上面可知返回的是两个值 # print(label) # 分别调用 # img.show() # batch_size每次取dataset的四个数据集并打包, shuffle是是否打乱,drop_last为False即最后一步不满4个时不舍,反之舍 # print(happy_dataset[0]) # tensor_trans = transforms.Compose([transforms.ToTensor(), transforms.Resize([224, 224])]) # test_data = torchvision.datasets(datasets=happy_dataset, transforms=tensor_trans) test_loader = DataLoader(dataset=happy_dataset, batch_size=4, shuffle=True, num_workers=0, drop_last=False) # img, label = happy_dataset[0] step = 0 writer = SummaryWriter("dataloader") for epoch in range(2): for data in test_loader: imgs, label = data writer.add_images('{}'.format(epoch), imgs, step) step = step + 1 writer.close()

下面来看结果

 

五、总结

这几天对数据集的学习对我这个初次接触pytorch的人来说也是挺头疼的,各种报错,包括loader的基本使用当时也是不太熟悉。现在大部分教学用的数据集都是通过torchvision.dataset.来获取pytorch自带的数据集,这篇也算是给大家提供另一种方法吧

最后有什么不明白或者报错不知道怎么解决的可以留言私信哦~说不定我曾经也经历过

最后

以上就是美丽毛巾最近收集整理的关于PyTorch学习笔记(三)总结篇 --------自建数据集的载入的全部内容,更多相关PyTorch学习笔记(三)总结篇内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(42)

评论列表共有 0 条评论

立即
投稿
返回
顶部