1.main.py
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52import os import torch from torchvision import transforms from my_dataset import MyDataSet from utils import read_split_data, plot_data_loader_image # http://download.tensorflow.org/example_images/flower_photos.tgz root = "C:/Users/dell/Desktop/dataset/flower_photos" # 数据集所在根目录 os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("using {} device.".format(device)) train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root) # 字典 data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), "val": transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} # 实例化dataset train_data_set = MyDataSet(images_path=train_images_path, # 训练集图像列表 images_class=train_images_label, # 训练集所有图像对应的标签 transform=data_transform["train"]) # 预处理方法 batch_size = 8 nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers print('Using {} dataloader workers'.format(nw)) train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=batch_size, shuffle=True, # 打乱数据集 num_workers=nw, # 训练时用nw,调试时候建议使用0 collate_fn=train_data_set.collate_fn) # plot_data_loader_image(train_loader) for step, data in enumerate(train_loader): images, labels = data if __name__ == '__main__': main()
2.utils.py
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116import os import json import pickle import random import matplotlib.pyplot as plt # 遍历文件夹,抽取一定比例训练集和测试集 def read_split_data(root: str, val_rate: float = 0.2): random.seed(0) # 随便那个电脑保证随机结果一样额 assert os.path.exists(root), "dataset root: {} does not exist.".format(root) # 遍历文件夹,一个文件夹对应一个类别 # 遍历root下的每个文件夹,如果这个root下有这个cla文件,那么赋给flower_class,也就是个列表,这个样子就是个列表 flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] # 排序,保证顺序一致 flower_class.sort() # 生成类别名称以及对应的数字索引 # dict是创建字典dict(key:value),这儿是对应文件名和文件名所对应的数字索引 class_indices = dict((k, v) for v, k in enumerate(flower_class)) # 编码为json的对象,为了生成class_indices.json json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str) train_images_path = [] # 存储训练集的所有图片路径 train_images_label = [] # 存储训练集图片对应索引信息 val_images_path = [] # 存储验证集的所有图片路径 val_images_label = [] # 存储验证集图片对应索引信息 every_class_num = [] # 存储每个类别的样本总数 supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型 # 遍历每个文件夹下的文件 # 与上面的写法等价 for cla in flower_class: cla_path = os.path.join(root, cla) # 遍历获取supported支持的所有文件路径 # 进一步到i嘛,for就是给i赋值,有个判断条件 # os.path.splitext(i)返回路径名和扩展名,加个-1就是扩展名,因为要求是上面support中文件名 images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) if os.path.splitext(i)[-1] in supported] # 获取该类别对应的索引,上面字典中的cla image_class = class_indices[cla] # 记录该类别的样本数量,并加载在every_class_num里面 every_class_num.append(len(images)) # 按比例随机采样验证样本 val_path = random.sample(images, k=int(len(images) * val_rate)) # 存入训练集,验证集中 for img_path in images: if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集 val_images_path.append(img_path) val_images_label.append(image_class) else: # 否则存入训练集 train_images_path.append(img_path) train_images_label.append(image_class) print("{} images were found in the dataset.".format(sum(every_class_num))) print("{} images for training.".format(len(train_images_path))) print("{} images for validation.".format(len(val_images_path))) plot_image = False if plot_image: # 绘制每种类别个数柱状图 plt.bar(range(len(flower_class)), every_class_num, align='center') # 将横坐标0,1,2,3,4替换为相应的类别名称 plt.xticks(range(len(flower_class)), flower_class) # 在柱状图上添加数值标签 for i, v in enumerate(every_class_num): plt.text(x=i, y=v + 5, s=str(v), ha='center') # 设置x坐标 plt.xlabel('image class') # 设置y坐标 plt.ylabel('number of images') # 设置柱状图的标题 plt.title('flower class distribution') plt.show() return train_images_path, train_images_label, val_images_path, val_images_label def plot_data_loader_image(data_loader): batch_size = data_loader.batch_size plot_num = min(batch_size, 4) json_path = './class_indices.json' assert os.path.exists(json_path), json_path + " does not exist." json_file = open(json_path, 'r') class_indices = json.load(json_file) for data in data_loader: images, labels = data for i in range(plot_num): # [C, H, W] -> [H, W, C] img = images[i].numpy().transpose(1, 2, 0) # 反Normalize操作 img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255 label = labels[i].item() plt.subplot(1, plot_num, i+1) plt.xlabel(class_indices[str(label)]) plt.xticks([]) # 去掉x轴的刻度 plt.yticks([]) # 去掉y轴的刻度 plt.imshow(img.astype('uint8')) plt.show() def write_pickle(list_info: list, file_name: str): with open(file_name, 'wb') as f: pickle.dump(list_info, f) def read_pickle(file_name: str) -> list: with open(file_name, 'rb') as f: info_list = pickle.load(f) return info_list
3.my_dataset.py
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40from PIL import Image import torch from torch.utils.data import Dataset class MyDataSet(Dataset): """自定义数据集""" def __init__(self, images_path: list, images_class: list, transform=None): self.images_path = images_path self.images_class = images_class self.transform = transform def __len__(self): return len(self.images_path) def __getitem__(self, item): img = Image.open(self.images_path[item]) # RGB为彩色图片,L为灰度图片 if img.mode != 'RGB': raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item])) label = self.images_class[item] if self.transform is not None: img = self.transform(img) return img, label @staticmethod def collate_fn(batch): # 官方实现的default_collate可以参考 # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py images, labels = tuple(zip(*batch)) # 将图片和图片放在一起,标签和标签放在一起 images = torch.stack(images, dim=0) # 拼接成一个矩阵 labels = torch.as_tensor(labels) # 标签转化为tensor return images, labels
参考:https://blog.csdn.net/qq_37541097?type=blog
最后
以上就是大方哈密瓜最近收集整理的关于数据集的划分的全部内容,更多相关数据集内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复