概述
Data Process
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
import torch
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
to_tensor = transforms.ToTensor()
class MyDataset(Dataset):
def __init__(self, root_path, resize):
self.img_dir = root_path + "JPEGImages/" # path for image
self.label_dir = "./SegmentationClass/" # path for label
self.resize = resize
# transformer
self.transform_img = transforms.Compose([
transforms.Resize((resize[0], resize[1])),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __getitem__(self, index):
name = # image name list
img_path = self.img_dir + name + '.jpg'
label_path = self.label_dir + name + '.png'
# load image&label and process
img = Image.open(img_path).convert('RGB')
img = self.transform_img(img)
label = Image.open(label_path)
label = label.resize((self.resize[0], self.resize[1]), Image.NEAREST)
label = np.array(label)
return {'image': img, 'label': torch.from_numpy(label).type(torch.uint8)}
def __len__(self):
return len(self.img_names)
Define Network
class MyNet(nn.Module):
def __init__(self, args, num_classes):
super(MyNet, self).__init__()
self.args = args
self.backbone = args.backbone
self.myblock = MyBlock()
self.conv1 = nn.Conv2d(256, 256, 3, padding=1, bias=False)
self.conv2 = nn.Conv2d(256, num_classes, 1)
self.bn = nn.BatchNorm2d(256)
self.relu = nn.ReLU()
def forward(self, x):
x = self.backbone(x)
x = self.myblock(x)
x = self.relu(self.bn(self.conv1(x)))
x = self.conv2(x)
x = F.interpolate(x, input_shape[2:], mode='bilinear', align_corners=True)
return x
class MyBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(MyBlock, self).__init__()
self.mylayer = MyLayer()
... ...
def forward(self, x):
... ...
class MyLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
super(MyLayer, self).__init__()
... ...
def forward(self, x):
... ...
Define Loss Function
class MyLoss(nn.Module):
def __init__(self):
super(MyLoss, self).__init__()
def forward(self, y, C):
loss = torch.zeros(1, requires_grad=True)
loss = ... + ... + ...
return loss # 注意最后只能返回Tensor值,且带梯度,即 loss.requires_grad == True
Trainning
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def train(args, model, criterion, optimizer, train_loader, test_loader, starting=0):
train_loss_list = []
train_acc_list = []
test_acc_list = []
best_test = 1.0
for epoch in range(starting, args.epoch_num):
train_loss = 0
train_acc = 0
adding_time = 0
# train
model.train()
for batch_index, batch in enumerate(train_loader):
batch_img = batch['image'].to(device)
batch_label = batch['label'].to(device)
if len(batch_img) == 1: # 网络用到了BN
break
out = model(batch_img) # out(n,21,H,W)
loss = criterion(out, batch_label.long())
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
pred = torch.max(out, 1)[1]
train_correct = (pred == batch_label).sum()
train_acc += train_correct.item()
adding_time += batch_label.shape[0] * batch_label.shape[1] * batch_label.shape[2]
if batch_index % 100 == 0 and batch_index != 0:
print('epoch:{}, iter:{}, loss:{:.5f}'.format(epoch, batch_index, loss.item()))
epoch_train_loss = train_loss / len(train_loader)
epoch_train_acc = train_acc / adding_time
train_loss_list.append(epoch_train_loss)
train_acc_list.append(epoch_train_acc)
print('Epoch: {} : Train Loss: {:.6f}, Acc: {:.6f}'.format(epoch, epoch_train_loss, epoch_train_acc))
# evaluate the model and save the best
if epoch % 10 == 0:
epoch_test_acc = evaluate(args, model, test_loader)
print('testing accuracy: {:.3f}, %'.format(epoch_test_acc))
test_acc_list.append(epoch_test_acc)
# save the last checkpoint
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, './checkpoints/{}_epoch_checkpoint.pth'.format(args.exp_id))
print('saved at epoch: {}'.format(epoch))
print('-------------------------------------------------')
# save the best checkpoint
if epoch_test_acc > best_test:
best_test = epoch_test_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, './checkpoints/{}_best_checkpoint.pth'.format(args.exp_id))
print('new best model saved at epoch: {}'.format(epoch))
print('-------------------------------------------------')
print('-------------------------------------------------')
print('best testing achieved: {:.3f}'.format(best_test))
print("train_loss: ", train_loss_list)
print("train_acc: ", train_acc_list)
print("test_acc: ", test_acc_list)
def evaluate(args, model, test_loader):
test_acc = 0.0
adding_time = 0
hist = np.zeros((21, 21)) # class_number=21
with torch.no_grad():
for batch_index, batch in enumerate(test_loader):
batch_x = batch['image'].to(device)
batch_y = batch['label'].to(device)
if len(batch_x) == 1:
break
out = model(batch_x)
pred = torch.max(out, 1)[1]
# confusion matrix
pred_np = pred.data.cpu().numpy()
target_np = batch_y.cpu().numpy()
hist = hist + generate_matrix(target_np, pred_np, 21)
test_correct = (pred == batch_y).sum()
test_acc += test_correct.item()
adding_time += batch_y.shape[0] * batch_y.shape[1] * batch_y.shape[2]
# mIoU for one epoch
mIoU = Mean_Intersection_over_Union(hist)
epoch_val_acc = test_acc / adding_time
return epoch_val_acc, mIoU
def resume(args, model, optimizer):
checkpoint_path = './checkpoints/' + args.resume_path
assert os.path.exists(checkpoint_path), ('checkpoint do not exits for %s' % checkpoint_path)
checkpoint_saved = torch.load(checkpoint_path)
epoch = checkpoint_saved['epoch']
model.load_state_dict(checkpoint_saved['model_state_dict'])
optimizer.load_state_dict(checkpoint_saved['optimizer_state_dict'])
print('Resume completed for the modeln')
return model, optimizer, epoch
主函数
import data_loader
from torch.utils.data import DataLoader
import argparse
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--exp_id', type=str, default='exp_test')
parser.add_argument('--resume_path', type=str, default='', help='resume path')
parser.add_argument('--resume', type=int, default=0, help='resume the trained model')
parser.add_argument('--test', type=int, default=0, help='test with trained model')
parser.add_argument('--seed', type=int, default=1, help='random seed')
parser.add_argument('--root_path', type=str, default='./VOC2012/', help='VOC2012 data path')
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--epoch_num', type=int, default=10, help='number of training epochs')
# parameters that we are focus on
parser.add_argument('--resize', type=int, default=320, help='resize')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--backbone', type=str, default='MobileNet', help='[MobileNet, ResNet]')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
data_train = data_loader.MyDataset(root_path=args.root_path, resize=None)
data_test = data_loader.MyDataset(root_path=args.root_path, resize=None)
train_loader = DataLoader(data_train, batch_size=args.batch_size, num_workers=2)
test_loader = DataLoader(data_test, batch_size=args.batch_size, num_workers=2)
# network
model = MyNet(args, 21).to(device)
# loss
criterion = MyLoss()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
# resume the trained model or train again
if args.resume:
model, optimizer, starting = train.resume(args, model, optimizer)
starting += 1
else:
starting = 0
if args.test == 1: # test mode, resume the trained model and test
testing_accuracy = train.evaluate(args, model, test_loader)
print('testing finished, accuracy: {:.3f}'.format(testing_accuracy))
else: # train mode, train the network from scratch
train.train(args, model, criterion, optimizer, train_loader, test_loader, starting)
print('training finished')
最后
以上就是专一长颈鹿为你收集整理的神经网络框架——从加载数据到展示结果的全部内容,希望文章能够帮你解决神经网络框架——从加载数据到展示结果所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复