概述
Pytorch猫狗大战系列:
猫狗大战1-训练和测试自己的数据集
猫狗大战2-AlexNet
猫狗大战3-MobileNet_V1&V2
猫狗大战3-MobileNet_V3
TensorFlow 2.0猫狗大战系列
猫狗大战1、制作与读取record数据
猫狗大战2、训练与保存模型
文章目录
- 一、网络结构
- 二、使用nn.DataParallel (不推荐)训练
- 二、使用DistributedDataParallel训练
一、网络结构
net = models.resnet18(pretrained=False)
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 2) # 更新resnet18模型的fc模型,
二、使用nn.DataParallel (不推荐)训练
数据加载部分和单机单卡是一样的, 这里只需要修改
net = nn.DataParallel(net)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
即可,至于为啥不推荐,详见知乎
二、使用DistributedDataParallel训练
- 首先需要初始化后端
# 初始化使用nccl后端
torch.distributed.init_process_group(backend="nccl")
- 加载网络
# 加载resnet18 模型,
net = models.resnet18(pretrained=False)
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 2) # 更新resnet18模型的fc模型,
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
net.cuda(local_rank)
print("Let's use", torch.cuda.device_count(), "GPUs!")
# 5) 封装
net = torch.nn.parallel.DistributedDataParallel(net,
device_ids=[local_rank],
output_device=local_rank)
- 加载数据,注意,和单机单卡加载数据有所区别,不过也就是几行代码的是
# 数据的批处理,尺寸大小为batch_size,
# 在训练集中,shuffle 必须设置为True, 表示次序是随机的
train_dataset = datasets.ImageFolder(root='data/train/', transform=data_transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=train_sampler
)
test_dataset = datasets.ImageFolder(root='data/validation/', transform=data_transform)
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=test_sampler
)
训练部分不变
启动脚本的时候,别忘记加上参数
python -m torch.distributed.launch train_host1_gpu2_ddp.py
否则会报Error initializing torch.distributed using env:// rendezvous: environment variable RANK expected, but not set
的错误。
如果在docker里遇到DataLoader worker (pid 179) is killed by signal: Bus error. It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.
这应该是docker里的sharememory不够引起的,修改 sharememory
此时可以重新创建一个容器
或者修改容器的share_memory
可参考
如何修改容docker容器的shmsize共享内存大小
最后
以上就是迷路大侠为你收集整理的pytorch系列(八):猫狗大战3-单机多卡无脑训练的全部内容,希望文章能够帮你解决pytorch系列(八):猫狗大战3-单机多卡无脑训练所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复