我是靠谱客的博主 大方哈密瓜,最近开发中收集的这篇文章主要介绍数据集的划分,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

1.main.py

import os

import torch
from torchvision import transforms

from my_dataset import MyDataSet
from utils import read_split_data, plot_data_loader_image

# http://download.tensorflow.org/example_images/flower_photos.tgz
root = "C:/Users/dell/Desktop/dataset/flower_photos"  # 数据集所在根目录

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root)
    # 字典
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    # 实例化dataset
    train_data_set = MyDataSet(images_path=train_images_path,       # 训练集图像列表
                               images_class=train_images_label,     # 训练集所有图像对应的标签
                               transform=data_transform["train"])   # 预处理方法

    batch_size = 8
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_data_set,
                                               batch_size=batch_size,
                                               shuffle=True,   # 打乱数据集
                                               num_workers=nw,  # 训练时用nw,调试时候建议使用0
                                               collate_fn=train_data_set.collate_fn)

    # plot_data_loader_image(train_loader)

    for step, data in enumerate(train_loader):
        images, labels = data


if __name__ == '__main__':
    main()

2.utils.py

import os
import json
import pickle
import random

import matplotlib.pyplot as plt

# 遍历文件夹,抽取一定比例训练集和测试集
def read_split_data(root: str, val_rate: float = 0.2):
    random.seed(0)  # 随便那个电脑保证随机结果一样额
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    # 遍历文件夹,一个文件夹对应一个类别
    # 遍历root下的每个文件夹,如果这个root下有这个cla文件,那么赋给flower_class,也就是个列表,这个样子就是个列表
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保证顺序一致
    flower_class.sort()
    # 生成类别名称以及对应的数字索引
    # dict是创建字典dict(key:value),这儿是对应文件名和文件名所对应的数字索引
    class_indices = dict((k, v)
                         for v, k in enumerate(flower_class))
    # 编码为json的对象,为了生成class_indices.json
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images_path = []  # 存储训练集的所有图片路径
    train_images_label = []  # 存储训练集图片对应索引信息
    val_images_path = []  # 存储验证集的所有图片路径
    val_images_label = []  # 存储验证集图片对应索引信息
    every_class_num = []  # 存储每个类别的样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
    # 遍历每个文件夹下的文件
    # 与上面的写法等价
    for cla in flower_class:
        cla_path = os.path.join(root, cla)
        # 遍历获取supported支持的所有文件路径
        # 进一步到i嘛,for就是给i赋值,有个判断条件
        # os.path.splitext(i)返回路径名和扩展名,加个-1就是扩展名,因为要求是上面support中文件名
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        # 获取该类别对应的索引,上面字典中的cla
        image_class = class_indices[cla]
        # 记录该类别的样本数量,并加载在every_class_num里面
        every_class_num.append(len(images))
        # 按比例随机采样验证样本
        val_path = random.sample(images, k=int(len(images) * val_rate))
        # 存入训练集,验证集中
        for img_path in images:
            if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:  # 否则存入训练集
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))

    plot_image = False
    if plot_image:
        # 绘制每种类别个数柱状图
        plt.bar(range(len(flower_class)), every_class_num, align='center')
        # 将横坐标0,1,2,3,4替换为相应的类别名称
        plt.xticks(range(len(flower_class)), flower_class)
        # 在柱状图上添加数值标签
        for i, v in enumerate(every_class_num):
            plt.text(x=i, y=v + 5, s=str(v), ha='center')
        # 设置x坐标
        plt.xlabel('image class')
        # 设置y坐标
        plt.ylabel('number of images')
        # 设置柱状图的标题
        plt.title('flower class distribution')
        plt.show()

    return train_images_path, train_images_label, val_images_path, val_images_label


def plot_data_loader_image(data_loader):
    batch_size = data_loader.batch_size
    plot_num = min(batch_size, 4)

    json_path = './class_indices.json'
    assert os.path.exists(json_path), json_path + " does not exist."
    json_file = open(json_path, 'r')
    class_indices = json.load(json_file)

    for data in data_loader:
        images, labels = data
        for i in range(plot_num):
            # [C, H, W] -> [H, W, C]
            img = images[i].numpy().transpose(1, 2, 0)
            # 反Normalize操作
            img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
            label = labels[i].item()
            plt.subplot(1, plot_num, i+1)
            plt.xlabel(class_indices[str(label)])
            plt.xticks([])  # 去掉x轴的刻度
            plt.yticks([])  # 去掉y轴的刻度
            plt.imshow(img.astype('uint8'))
        plt.show()


def write_pickle(list_info: list, file_name: str):
    with open(file_name, 'wb') as f:
        pickle.dump(list_info, f)


def read_pickle(file_name: str) -> list:
    with open(file_name, 'rb') as f:
        info_list = pickle.load(f)
        return info_list

3.my_dataset.py

from PIL import Image
import torch
from torch.utils.data import Dataset


class MyDataSet(Dataset):
    """自定义数据集"""

    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        # RGB为彩色图片,L为灰度图片
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))  # 将图片和图片放在一起,标签和标签放在一起

        images = torch.stack(images, dim=0)  # 拼接成一个矩阵
        labels = torch.as_tensor(labels)  # 标签转化为tensor
        return images, labels


参考:https://blog.csdn.net/qq_37541097?type=blog

最后

以上就是大方哈密瓜为你收集整理的数据集的划分的全部内容,希望文章能够帮你解决数据集的划分所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(38)

评论列表共有 0 条评论

立即
投稿
返回
顶部