接下来几篇博文开始,介绍pytorch五大模块中的数据模块,所有概念都会以第四代人民币1元和100元纸币的二分类问题为例来具体介绍,在实例中明白相关知识。
数据模块的结构体系
数据模块分为数据的收集、划分、读取、预处理四部分,其中收集和划分是人工可以设定,而读取部分和预处理部分,pytorch有相应的函数和运行机制来实现。读取部分中pytorch靠dataloader这个数据读取机制来读取数据。
Dataloader
Dataloader涉及两个部分,一是sampler部分,用于生成数据的索引(即序号),二是dataset,根据索引来读取相应的数据和标签。
torch.utils.data.Dataloader
功能:构建可迭代的数据装载器
主要属性:
dataset:Dataset类,决定数据从哪里读取以及如何读取
batchsize:批大小
num_works:是否以多进程读取数据
shuffle:每个epoch是否乱序
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
epoch:所有训练样本都已输入到模型中,称为一个epoch
iteration:一批样本输入到模型中,称之为一个iteration
batchsize:批大小,决定一个epoch有多少个iteration
举例:样本总数:80,batchsize:8,则 1 epoch = 10 iteration
样本总数:85,batchsize:8,则 1 epoch = {设定drop_last:10 iteration;不设定:11 iteration}
torch.utils.data.Dataset
功能:抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()
getitem:接收一个索引,返回一个样本
实例体现
下面介绍一下代码构建的流程,主要涉及数据模块
1.数据收集(img,label)
由于是二分类,所以可以构建两个文件夹进行简单区分
并划分训练、验证和数据集
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48def makedir(new_dir): if not os.path.exists(new_dir): os.makedirs(new_dir) if __name__ == '__main__': random.seed(1) dataset_dir = os.path.join("路径", "data", "RMB_data") split_dir = os.path.join("路径", "data", "rmb_split") train_dir = os.path.join(split_dir, "train") valid_dir = os.path.join(split_dir, "valid") test_dir = os.path.join(split_dir, "test") train_pct = 0.8 valid_pct = 0.1 test_pct = 0.1 for root, dirs, files in os.walk(dataset_dir): for sub_dir in dirs: imgs = os.listdir(os.path.join(root, sub_dir)) imgs = list(filter(lambda x: x.endswith('.jpg'), imgs)) random.shuffle(imgs) img_count = len(imgs) train_point = int(img_count * train_pct) valid_point = int(img_count * (train_pct + valid_pct)) for i in range(img_count): if i < train_point: out_dir = os.path.join(train_dir, sub_dir) elif i < valid_point: out_dir = os.path.join(valid_dir, sub_dir) else: out_dir = os.path.join(test_dir, sub_dir) makedir(out_dir) target_path = os.path.join(out_dir, imgs[i]) src_path = os.path.join(dataset_dir, sub_dir, imgs[i]) shutil.copy(src_path, target_path) print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point, img_count-valid_point))
以8:1:1的比例划分train valid test三个数据集,接下来设置好各数据路径
以及数据各通道的均值和标准差(这个需要自己计算得出)
下面就是数据模块中预处理中transform方法的建立,这个会在下一篇博文中展开
接下来为构建自定义Dataset实例
以及构建Dataloader
其中Dataset必须是用户自己写的
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43rmb_label = {"1": 0, "100": 1} class RMBDataset(Dataset): def __init__(self, data_dir, transform=None): """ rmb面额分类任务的Dataset :param data_dir: str, 数据集所在路径 :param transform: torch.transform,数据预处理 """ self.label_name = {"1": 0, "100": 1} self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本 self.transform = transform def __getitem__(self, index): path_img, label = self.data_info[index] img = Image.open(path_img).convert('RGB') # 0~255 if self.transform is not None: img = self.transform(img) # 在这里做transform,转为tensor等等 return img, label def __len__(self): return len(self.data_info) @staticmethod def get_img_info(data_dir): data_info = list() for root, dirs, _ in os.walk(data_dir): # 遍历类别 for sub_dir in dirs: img_names = os.listdir(os.path.join(root, sub_dir)) img_names = list(filter(lambda x: x.endswith('.jpg'), img_names)) # 遍历图片 for i in range(len(img_names)): img_name = img_names[i] path_img = os.path.join(root, sub_dir, img_name) label = rmb_label[sub_dir] data_info.append((path_img, int(label))) return data_info
接下来便是模型模块、损失函数模块、优化器模块、迭代训练模块
在迭代训练中,数据的获取为 for i, data in enumerate(train_loader)
主要探究enumerate(train_loader)其中的机制
阅读Dataloader源码可知:
- 迭代dataloader首先会进入是否多线程运行的判断(比如单进程singleprocess)
- 然后进入_SingleProcessDataloaderIter.__next__中获取index和通过index获取data
- index列表由sampler生成,长度为一个batch_size
- 再由self.dataset_fetcher.fetch(index)去获取data的路径和标签,fetch会一步步跳转到自定义dataset中的__getitem__(self, index)
- 采用Image.open读取路径中的数据,如果有transform方法,则进行transform后再返回img及label
- 当fetch进行return时,会采用collate_fn(data)方法将所有单个数据整理成一个batch(字典样式:label - img.tensor)的形式并返回
可以归纳得到:
另附CNY二分类模型代码示例:(为节省篇幅,以下只展示与数据模块有关的步骤)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142import os import random import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import torchvision.transforms as transforms import torch.optim as optim from matplotlib import pyplot as plt from model.lenet import LeNet from tools.my_dataset import RMBDataset def set_seed(seed=1): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) set_seed() # 设置随机种子 rmb_label = {"1": 0, "100": 1} # 参数设置 MAX_EPOCH = 10 BATCH_SIZE = 16 LR = 0.01 log_interval = 10 val_interval = 1 # ============================ step 1/5 数据 ============================ split_dir = os.path.join("..", "..", "data", "rmb_split") train_dir = os.path.join(split_dir, "train") valid_dir = os.path.join(split_dir, "valid") norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] train_transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) valid_transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) # 构建MyDataset实例 train_data = RMBDataset(data_dir=train_dir, transform=train_transform) valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform) # 构建DataLoder train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE) # ================== step 2、3、4 暂且略过 后续涉及 ====================== # ============================ step 5/5 训练 ============================ train_curve = list() valid_curve = list() for epoch in range(MAX_EPOCH): loss_mean = 0. correct = 0. total = 0. net.train() for i, data in enumerate(train_loader): # forward inputs, labels = data outputs = net(inputs) # backward optimizer.zero_grad() loss = criterion(outputs, labels) loss.backward() # update weights optimizer.step() # 统计分类情况 _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).squeeze().sum().numpy() # 打印训练信息 loss_mean += loss.item() train_curve.append(loss.item()) if (i+1) % log_interval == 0: loss_mean = loss_mean / log_interval print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total)) loss_mean = 0. scheduler.step() # 更新学习率 # validate the model if (epoch+1) % val_interval == 0: correct_val = 0. total_val = 0. loss_val = 0. net.eval() with torch.no_grad(): for j, data in enumerate(valid_loader): inputs, labels = data outputs = net(inputs) loss = criterion(outputs, labels) _, predicted = torch.max(outputs.data, 1) total_val += labels.size(0) correct_val += (predicted == labels).squeeze().sum().numpy() loss_val += loss.item() loss_val_epoch = loss_val / len(valid_loader) valid_curve.append(loss_val_epoch) print("Valid:t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_epoch, correct_val / total_val)) train_x = range(len(train_curve)) train_y = train_curve train_iters = len(train_loader) valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations valid_y = valid_curve plt.plot(train_x, train_y, label='Train') plt.plot(valid_x, valid_y, label='Valid') plt.legend(loc='upper right') plt.ylabel('loss value') plt.xlabel('Iteration') plt.show()
最后
以上就是斯文手套最近收集整理的关于Pytorch:Dataloader和Dataset以及搭建数据部分的步骤Dataloader的全部内容,更多相关Pytorch:Dataloader和Dataset以及搭建数据部分内容请搜索靠谱客的其他文章。
发表评论 取消回复