我是靠谱客的博主 沉默冬天,最近开发中收集的这篇文章主要介绍3、DataSet 与 DataLoader,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

# -*- coding: utf-8 -*-
"""
# @file name  : train_lenet.py
# @author     : tingsongyu
# @date       : 2019-09-07 10:08:00
# @brief      : 人民币分类模型训练
"""
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


set_seed()  # 设置随机种子
rmb_label = {"1": 0, "100": 1}

# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1

# ============================ step 1/5 数据 ============================

split_dir = os.path.join("..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 【1】构建MyDataset实例(必须是用户自己构建的 )--》ctrl+点击---》my_dataset.py
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 【2】有了Dataset就可以构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 设置学习率下降策略

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

# 【3】训练是以epoch为周期,在每个epoch中会有多个Iteration的训练
for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()

    # 【4】数据的获取--》debug查看pytorch是如何获取数据的--》dataloader.py
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward--》获取梯度
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率

    # validate the model,在每个epoch中会进行验证集的测试,通过验证集观察模型是否过拟合
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            loss_val_epoch = loss_val / len(valid_loader)
            valid_curve.append(loss_val_epoch)
            # valid_curve.append(loss.item())    # 20191022改,记录整个epoch样本的loss,注意要取平均
            print("Valid:t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_epoch, correct_val / total_val))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

# ============================ inference ============================

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")

test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)

for i, data in enumerate(valid_loader):
    # forward
    inputs, labels = data
    outputs = net(inputs)
    _, predicted = torch.max(outputs.data, 1)

    rmb = 1 if predicted.numpy()[0] == 0 else 100
    print("模型获得{}元".format(rmb))
1、在dataloader.py中---》是用单进程/多进程,以单进程为例
def __iter__(self):
    if self.num_workers == 0:
        return _SingleProcessDataLoaderIter(self)
    else:
        return _MultiProcessingDataLoaderIter(self)

2、step Into _SingleProcessDataLoaderIter(self)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self.timeout == 0
        assert self.num_workers == 0

        self.dataset_fetcher = _DatasetKind.create_fetcher(
            self.dataset_kind, self.dataset, self.auto_collation, self.collate_fn, self.drop_last)

    def __next__(self):
        index = self._next_index()  # may raise StopIteration
        data = self.dataset_fetcher.fetch(index)  # may raise StopIteration
        if self.pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

    next = __next__  # Python 2 compatibility

3、单进程中最主要的函数是def __next__(self):
这个函数中告诉我们从每个Iteration当中读取哪些数据

4、在index = self._next_index()  # may raise StopIteration。--》 run to cusor
Step Into--》simpler.py: 就是一个采样器,用来告诉我们每个Iteration中,batchsize该读取那些数据

def _next_index(self):
    return next(self.sampler_iter)  # may raise StopIteration----》step Into

def __iter__(self):
    batch = []
    for idx in self.sampler:
        batch.append(idx)
        if len(batch) == self.batch_size:
            yield batch
            batch = []
    if len(batch) > 0 and not self.drop_last:
        yield batch


5、step out返回到dataoader.py---》data = self.dataset_fetcher.fetch(index)
把index输入dataset获取data---》step Into ---》fetch.py

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)


# 在这里正式实现数据读取
    def fetch(self, possibly_batched_index):
        if self.auto_collation:
# 把一系列data拼接成一个list
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

在data = [self.dataset[idx] for idx in possibly_batched_index] ----》step Into---》mydataset.py--->def __getitem__(self, index):

1、在 for i, data in enumerate(train_loader):设置断点---》step Into---》进程判断

2、 def __iter__(self): ---》这里进入单进程--》step Into
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)

3、    def __next__(self):
        index = self._next_index()  # may raise StopIteration---》run to cusor
        data = self.dataset_fetcher.fetch(index)  # may raise StopIteration --》step over 
        # 根据索引获取数据
        
        if self.pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data
这是可以获取Index:告诉我们可以读取哪些数据 --》step over      

4、到了fetch.py

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            
            在这里调用dataset---》step Into---》my_dataset.py
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)    
 
 5、my_dataset.py
     def __getitem__(self, index):
        # 根据索引获取图片,和标签
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255
        # 读出来的图片是PIL格式

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等--》run to cusor
                                         --》step Into
        return img, label
        
6、transforms.py
    def __call__(self, img):
        for t in self.transforms:  # 依次有序的从compose中调用预处理方法
            img = t(img)           --》step over
        return img
transform实在getItem函数中调用,在此函数中实现数据预处理,通过此函数返回一个样本--》返回fetch.py

7、    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            # 不断循环index,获取一个batchsize大小的数据
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)
        在这里调用,将数据整理成batchdata的形式

 

最后

以上就是沉默冬天为你收集整理的3、DataSet 与 DataLoader的全部内容,希望文章能够帮你解决3、DataSet 与 DataLoader所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(54)

评论列表共有 0 条评论

立即
投稿
返回
顶部