我是靠谱客的博主 忧伤皮带,最近开发中收集的这篇文章主要介绍PyTorch - 多任务网络之年龄与性别预测数据集类的构建多任务网络的搭建与训练,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

目录

  • 数据集类的构建
    • torch.utils.data.Dataset
    • implementation
      • 导入包
      • __init__()
      • __len__()
      • __getitem__()
    • 检验数据集类
      • 手动提取数据
      • DataLoader
  • 多任务网络的搭建与训练
    • 导入包
    • 检查是否可使用GPU
    • 多任务网络的搭建
      • 设置网络(__init__())
      • forward()
    • 训练过程

数据集类的构建

本实例使用的是UTKFace数据集,包含了两万多张不同种族的不同年龄的人脸图片

torch.utils.data.Dataset

是一个抽象类, 自定义的Dataset需要继承它并且实现两个成员方法:

  1. getitem()
  2. len()

第一个最为重要,即每次怎么读数据;
第二个比较简单, 就是返回整个数据集的长度。

implementation

导入包

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import os
import cv2 as cv

_init_()

torchvision.transforms.Normalize(mean, std) 用法
torchvision.transforms.ToTensor() 用法
os.listdir() 用法
.split() 用法
os.path() 模块

class AgeGenderDataset(Dataset):
def __init__(self, root_dir):
# Normalize: image => [-1, 1]
(利于更好的训练)
# ToTensor() => Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]),
transforms.Resize((64, 64))
])
img_files = os.listdir(root_dir) #存放的是所有图片的文件名
# age: 0 ~116, 0 :male, 1 :female
self.ages = []
self.genders = []
# 注意:self.image存放的是图片的路径
self.images = []
for file_name in img_files:
age_gender_group = file_name.split("_")
age_ = age_gender_group[0]
gender_ = age_gender_group[1]
self.genders.append(np.float32(gender_))
# 将age缩小到[0, 1]的范围内
self.ages.append(np.float32(age_)/max_age)
# os.path.join() 将路径和文件名合成为一个路径
self.images.append(os.path.join(root_dir, file_name))

_len_()


def __len__(self):
return len(self.images)

_getitem_()

imread() 用法
torch.from_numpy(ndarray) 用法


def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image_path = self.images[idx]
else:
image_path = self.images[idx]
img = cv.imread(image_path)
# BGR order
sample = {'image': self.transform(img), 'age': self.ages[idx], 'gender': self.genders[idx]}
# 返回一个字典形式
return sample

检验数据集类

手动提取数据

if __name__ == "__main__":
ds = AgeGenderDataset("W:/data_PyTorch/UTKFace")
for i in range(len(ds)):
sample = ds[i]
print(i, sample['image'].size(), sample['age'])
# 提取一个batch的数据
if i == 3:
break

DataLoader


dataloader = DataLoader(ds, batch_size=4, shuffle=True, num_workers=4)
# enumerate将可迭代对象组合为索引序列
例如:[(0, 'Tom'), (1, 'Jerry')]
for i_batch, sample_batched in enumerate(dataloader):
print(i_batch, sample_batched['image'].size(), sample_batched['gender'])
break

多任务网络的搭建与训练

导入包

import torch
from age_gender_dataset import AgeGenderDataset
from torch.utils.data import DataLoader

检查是否可使用GPU

# 检查是否可以利用GPU
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
print('CUDA is not available.')
else:
print('CUDA is available!')

多任务网络的搭建

设置网络(_init_())

BatchNorm2d(num_channel) 用法
torch.nn.AdaptiveMaxPool2d(output_size) 用法


class MyMulitpleTaskNet(torch.nn.Module):
def __init__(self):
super(MyMulitpleTaskNet, self).__init__()
self.cnn_layers = torch.nn.Sequential(
# x => [3, 64, 64]
torch.nn.Conv2d(3, 32, 3, padding=1),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(32),
torch.nn.MaxPool2d(2, 2),
# x => [32, 32, 32]
torch.nn.Conv2d(32, 64, 3, padding=1),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(64),
torch.nn.MaxPool2d(2, 2),
# x => [64, 16, 16]
torch.nn.Conv2d(64, 96, 3, padding=1),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(96),
torch.nn.MaxPool2d(2, 2),
# x => [96, 8, 8]
torch.nn.Conv2d(96, 128, 3, padding=1),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(128),
torch.nn.MaxPool2d(2, 2),
# x => [128, 4, 4]
torch.nn.Conv2d(128, 196, 3, padding=1),
torch.nn.ReLU(),
torch.nn.BatchNorm2d(196),
torch.nn.MaxPool2d(2, 2)
# x => [196, 2, 2]
)
# 全局最大池化
self.global_max_pooling = torch.nn.AdaptiveMaxPool2d((1, 1))
# x => [196, 1, 1]
# 预测age(回归)
self.age_fc_layers = torch.nn.Sequential(
torch.nn.Linear(196, 25),
torch.nn.ReLU(),
torch.nn.Linear(25, 1),
torch.nn.Sigmoid()
)
# 预测gender(分类)
self.gender_fc_layers = torch.nn.Sequential(
torch.nn.Linear(196, 25),
torch.nn.ReLU(),
torch.nn.Linear(25, 2)
)

forward()


def forward(self, x):
# x => [3, 64, 64]
x = self.cnn_layers(x)
# x => [196, 2, 2]
B, C, H, W = x.size()
out = self.global_max_pooling(x).view(B, -1)
# -1的值由其他层推断出来
# 全连接层
out_age = self.age_fc_layers(out)
out_gender = self.gender_fc_layers(out)
return out_age, out_gender

训练过程

if __name__ == "__main__":
model = MyMulitpleTaskNet()
print(model)
# 使用GPU
if train_on_gpu:
model.cuda()
ds = AgeGenderDataset("W:/data_PyTorch/UTKFace")
num_train_samples = ds.__len__()
bs = 16
dataloader = DataLoader(ds, batch_size=bs, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
# sets the module in training mode.
model.train()
# 损失函数
mse_loss = torch.nn.MSELoss()
cross_loss = torch.nn.CrossEntropyLoss()
index = 0
num_epochs = 25
for epoch in
range(num_epochs):
train_loss = 0.0
# 依次取出每一个图片与label
for i_batch, sample_batched in enumerate(dataloader):
images_batch, age_batch, gender_batch = 
sample_batched['image'], sample_batched['age'], sample_batched['gender']
if train_on_gpu:
images_batch, age_batch, gender_batch = images_batch.cuda(), age_batch.cuda(), gender_batch.cuda()
optimizer.zero_grad()
# forward pass
m_age_out_, m_gender_out_ = model(images_batch)
age_batch = age_batch.view(-1, 1).float()
gender_batch = gender_batch.long()
# calculate the batch loss
loss = mse_loss(m_age_out_, age_batch) + cross_loss(m_gender_out_, gender_batch)
# backward pass
loss.backward()
# perform a single optimization step (parameter update)
optimizer.step()
# update training loss
train_loss += loss.item()
if index % 100 == 0:
print('step: {} tTraining Loss: {:.6f} '.format(index, loss.item()))
index += 1
# 计算平均损失
train_loss = train_loss / num_train_samples
# 显示训练集与验证集的损失函数
print('Epoch: {} tTraining Loss: {:.6f} '.format(epoch, train_loss))
# save model
# sets the module in evaluation mode.
model.eval()
torch.save(model, 'age_gender_model.pt')

最后

以上就是忧伤皮带为你收集整理的PyTorch - 多任务网络之年龄与性别预测数据集类的构建多任务网络的搭建与训练的全部内容,希望文章能够帮你解决PyTorch - 多任务网络之年龄与性别预测数据集类的构建多任务网络的搭建与训练所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(37)

评论列表共有 0 条评论

立即
投稿
返回
顶部