概述
# -*- coding: utf-8 -*-
"""
# @file name : train_lenet.py
# @author : tingsongyu
# @date : 2019-09-07 10:08:00
# @brief : 人民币分类模型训练
"""
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
set_seed() # 设置随机种子
rmb_label = {"1": 0, "100": 1}
# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1
# ============================ step 1/5 数据 ============================
split_dir = os.path.join("..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
valid_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
# 【1】构建MyDataset实例(必须是用户自己构建的 )--》ctrl+点击---》my_dataset.py
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
# 【2】有了Dataset就可以构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
# ============================ step 2/5 模型 ============================
net = LeNet(classes=2)
net.initialize_weights()
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss() # 选择损失函数
# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # 设置学习率下降策略
# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()
# 【3】训练是以epoch为周期,在每个epoch中会有多个Iteration的训练
for epoch in range(MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
net.train()
# 【4】数据的获取--》debug查看pytorch是如何获取数据的--》dataloader.py
for i, data in enumerate(train_loader):
# forward
inputs, labels = data
outputs = net(inputs)
# backward--》获取梯度
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
# 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy()
# 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0.
scheduler.step() # 更新学习率
# validate the model,在每个epoch中会进行验证集的测试,通过验证集观察模型是否过拟合
if (epoch+1) % val_interval == 0:
correct_val = 0.
total_val = 0.
loss_val = 0.
net.eval()
with torch.no_grad():
for j, data in enumerate(valid_loader):
inputs, labels = data
outputs = net(inputs)
loss = criterion(outputs, labels)
_, predicted = torch.max(outputs.data, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).squeeze().sum().numpy()
loss_val += loss.item()
loss_val_epoch = loss_val / len(valid_loader)
valid_curve.append(loss_val_epoch)
# valid_curve.append(loss.item()) # 20191022改,记录整个epoch样本的loss,注意要取平均
print("Valid:t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_epoch, correct_val / total_val))
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()
# ============================ inference ============================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")
test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)
for i, data in enumerate(valid_loader):
# forward
inputs, labels = data
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
rmb = 1 if predicted.numpy()[0] == 0 else 100
print("模型获得{}元".format(rmb))
1、在dataloader.py中---》是用单进程/多进程,以单进程为例
def __iter__(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)
2、step Into _SingleProcessDataLoaderIter(self)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self.timeout == 0
assert self.num_workers == 0
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collation, self.collate_fn, self.drop_last)
def __next__(self):
index = self._next_index() # may raise StopIteration
data = self.dataset_fetcher.fetch(index) # may raise StopIteration
if self.pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
next = __next__ # Python 2 compatibility
3、单进程中最主要的函数是def __next__(self):
这个函数中告诉我们从每个Iteration当中读取哪些数据
4、在index = self._next_index() # may raise StopIteration。--》 run to cusor
Step Into--》simpler.py: 就是一个采样器,用来告诉我们每个Iteration中,batchsize该读取那些数据
def _next_index(self):
return next(self.sampler_iter) # may raise StopIteration----》step Into
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
5、step out返回到dataoader.py---》data = self.dataset_fetcher.fetch(index)
把index输入dataset获取data---》step Into ---》fetch.py
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
# 在这里正式实现数据读取
def fetch(self, possibly_batched_index):
if self.auto_collation:
# 把一系列data拼接成一个list
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
在data = [self.dataset[idx] for idx in possibly_batched_index] ----》step Into---》mydataset.py--->def __getitem__(self, index):
1、在 for i, data in enumerate(train_loader):设置断点---》step Into---》进程判断
2、 def __iter__(self): ---》这里进入单进程--》step Into
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)
3、 def __next__(self):
index = self._next_index() # may raise StopIteration---》run to cusor
data = self.dataset_fetcher.fetch(index) # may raise StopIteration --》step over
# 根据索引获取数据
if self.pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
这是可以获取Index:告诉我们可以读取哪些数据 --》step over
4、到了fetch.py
def fetch(self, possibly_batched_index):
if self.auto_collation:
在这里调用dataset---》step Into---》my_dataset.py
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
5、my_dataset.py
def __getitem__(self, index):
# 根据索引获取图片,和标签
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255
# 读出来的图片是PIL格式
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等--》run to cusor
--》step Into
return img, label
6、transforms.py
def __call__(self, img):
for t in self.transforms: # 依次有序的从compose中调用预处理方法
img = t(img) --》step over
return img
transform实在getItem函数中调用,在此函数中实现数据预处理,通过此函数返回一个样本--》返回fetch.py
7、 def fetch(self, possibly_batched_index):
if self.auto_collation:
# 不断循环index,获取一个batchsize大小的数据
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
在这里调用,将数据整理成batchdata的形式
最后
以上就是沉默冬天为你收集整理的3、DataSet 与 DataLoader的全部内容,希望文章能够帮你解决3、DataSet 与 DataLoader所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复