一、概述
初始化DataLoader类时必须注入一个参数dataset,而dataset为自己定义。DataSet类可以继承,但是必须重载__len__()和__getitem__
使用Pytoch封装的DataLoader有以下好处:
①可以自动实现多进程加载
②自动惰性加载,不会占用过多内存
③封装有数据预处理和数据增强等操作,避免重复造轮子
二、自定义DataSet
以Faster R-CNN为例,一般建议至少传入以下参数,方便后续使用:
1
2
3
4
5
6class FRCNNDataset(Dataset): def __init__(self, annotation_lines, input_shape = [600, 600], train = True): self.annotation_lines = annotation_lines #数据集列表 self.length = len(annotation_lines) #数据集大小 self.input_shape = input_shape #输出尺寸 self.train = train #是否训练
然后重载__len__()和__getitem__
1
2def __len__(self): return self.length #直接返回长度
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16def __getitem__(self, index): index = index % self.length #训练时候对数据进行随机增强,但验证时不进行 image, y = self.get_random_data(self.annotation_lines[index], self.input_shape[0:2], random = self.train) #将图片转换成矩阵 image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1)) #编码先验框 box_data = np.zeros((len(y), 5)) if len(y) > 0: box_data[:len(y)] = y box = box_data[:, :4] label = box_data[:, -1] return image, box, label
关于数据增强函数get_random_data(),其中还包含了对图片的无变形缩放功能
1
2
3
4
5
6
7
8
9
10
11def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True): # 数据经过处理后格式为:地址——(空格)——预测框,使用split函数即可切割出地址和先验框 line = annotation_line.split() # 读取图像并转换为RGB格式 image = Image.open(line[0]) image = cvtColor(image) # 获得图像的高宽与目标高宽 iw, ih = image.size h, w = input_shape # 读取先验框 box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
仅缩放的无变形缩放功(非训练模式)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26# 在不进行随机数据增强的情况下(非训练模式),直接变形后输出 if not random: #获取变形比例 scale = min(w/iw, h/ih) nw = int(iw*scale) nh = int(ih*scale) dx = (w-nw)//2 dy = (h-nh)//2 # 将图像多余的部分加上灰条 image = image.resize((nw,nh), Image.BICUBIC) new_image = Image.new('RGB', (w,h), (128,128,128)) new_image.paste(image, (dx, dy)) image_data = np.array(new_image, np.float32) # 对真实框进行调整 if len(box)>0: np.random.shuffle(box) box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy box[:, 0:2][box[:, 0:2]<0] = 0 box[:, 2][box[:, 2]>w] = w box[:, 3][box[:, 3]>h] = h box_w = box[:, 2] - box[:, 0] box_h = box[:, 3] - box[:, 1] box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box #返回图片和先验框 return image_data, box
带数据增强的无变形缩放(训练模式)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55# 对图像进行缩放并且进行长和宽的扭曲 new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter) scale = self.rand(.25, 2) if new_ar < 1: nh = int(scale*h) nw = int(nh*new_ar) else: nw = int(scale*w) nh = int(nw/new_ar) image = image.resize((nw,nh), Image.BICUBIC) # 将图像多余的部分加上灰条 dx = int(self.rand(0, w-nw)) dy = int(self.rand(0, h-nh)) new_image = Image.new('RGB', (w,h), (128,128,128)) new_image.paste(image, (dx, dy)) image = new_image # 翻转图像 flip = self.rand()<.5 if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT) image_data = np.array(image, np.uint8) # 对图像进行色域变换 # 计算色域变换的参数 r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1 # 将图像转到HSV上 hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV)) dtype = image_data.dtype # 应用变换 x = np.arange(0, 256, dtype=r.dtype) lut_hue = ((x * r[0]) % 180).astype(dtype) lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) lut_val = np.clip(x * r[2], 0, 255).astype(dtype) image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB) # 对真实框进行调整 if len(box)>0: np.random.shuffle(box) box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy if flip: box[:, [0,2]] = w - box[:, [2,0]] box[:, 0:2][box[:, 0:2]<0] = 0 box[:, 2][box[:, 2]>w] = w box[:, 3][box[:, 3]>h] = h box_w = box[:, 2] - box[:, 0] box_h = box[:, 3] - box[:, 1] box = box[np.logical_and(box_w>1, box_h>1)] return image_data, box
关于collate_fn参数
__getitem__一般返回(image,label)样本对,而DataLoader需要一个batch_size用于处理batch样本,以便于批量训练。
默认的default_collate(batch)函数仅能对尺寸一致且batch_size相同的image进行整理,如将(img0,lbl0),(img1,lbl1),(img2,lbl2)整合为([img0,img1,img2],[lbl0,lbl1,lbl2]),如图像中含有box等参数则需要自定义处理
1
2
3
4
5
6
7
8
9
10def frcnn_dataset_collate(batch): images = [] bboxes = [] labels = [] for img, box, label in batch: images.append(img) bboxes.append(box) labels.append(label) images = torch.from_numpy(np.array(images)) return images, bboxes, labels
三、语义分割与目标检测DataSet的区别
①在__getitem__中不需要获取box值,转而获取标志图png。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21def __getitem__(self, index): annotation_line = self.annotation_lines[index] name = annotation_line.split()[0] # 从文件中读取图像 jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), name + ".jpg")) png = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/SegmentationClass"), name + ".png")) # 数据增强 jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train) jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1]) png = np.array(png) png[png >= self.num_classes] = self.num_classes # 转化成one_hot的形式 # 在这里需要+1是因为voc数据集有些标签具有白边部分 seg_labels = np.eye(self.num_classes + 1)[png.reshape([-1])] seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1)) return jpg, png, seg_labels
②get_random_data变形时需要对两张图做同样的变换
1
2
3
4
5
6
7
8
9
10
11
12
13
14if not random: iw, ih = image.size scale = min(w/iw, h/ih) nw = int(iw*scale) nh = int(ih*scale) image = image.resize((nw,nh), Image.BICUBIC) new_image = Image.new('RGB', [w, h], (128,128,128)) new_image.paste(image, ((w-nw)//2, (h-nh)//2)) label = label.resize((nw,nh), Image.NEAREST) new_label = Image.new('L', [w, h], (0)) new_label.paste(label, ((w-nw)//2, (h-nh)//2)) return new_image, new_label
③collate_fn需要进行修改
1
2
3
4
5
6
7
8
9
10
11
12def deeplab_dataset_collate(batch): images = [] pngs = [] seg_labels = [] for img, png, labels in batch: images.append(img) pngs.append(png) seg_labels.append(labels) images = torch.from_numpy(np.array(images)).type(torch.FloatTensor) pngs = torch.from_numpy(np.array(pngs)).long() seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor) return images, pngs, seg_labels
四、在训练过程中的调用
①读取文件集(经处理的txt文件)
1
2
3
4
5
6
7with open(train_annotation_path, encoding='utf-8') as f: train_lines = f.readlines() with open(val_annotation_path, encoding='utf-8') as f: val_lines = f.readlines() #获取数据集长度 num_train = len(train_lines) num_val = len(val_lines)
②检查数据集是否符合要求
这里一般检查数据集是否足够大,也可不检查
③将数据集装入DataSet中
1
2train_dataset = MyDataset(train_lines, input_shape, anchors, batch_size, num_classes, train = True) val_dataset = MyDataset(val_lines, input_shape, anchors, batch_size, num_classes, train = False)
④将DataSet放入DataLoader中
关于dataloader:一般有以下5个参数:
1.dataset:数据集对象,dataset型
2.batch_size:批大小,int型
3.shuffe:每一轮epoch是否重新洗牌,bool型
4.num_workers:多进程读取
5.drop_last:当样本不能被batch_size取整时,是否丢弃最后一批数据,bool型
1
2
3
4gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, drop_last=True, collate_fn=ssd_dataset_collate, sampler=train_sampler) gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, drop_last=True, collate_fn=ssd_dataset_collate, sampler=val_sampler)
最后
以上就是健忘老师最近收集整理的关于[Pytorch]将自己的数据集载入dataloader一、概述二、自定义DataSet三、语义分割与目标检测DataSet的区别四、在训练过程中的调用的全部内容,更多相关[Pytorch]将自己内容请搜索靠谱客的其他文章。
发表评论 取消回复