我是靠谱客的博主 闪闪日记本,最近开发中收集的这篇文章主要介绍pytorch多节点分布式训练,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

本文为代码结构梳理。不提供理论知识。
顺便说一点,nccl好像只支持linux。

1.参数输入(选)

parser.add_argument('--distributed', default=True, help="Whether to turn on the distribution")
parser.add_argument('--rank', type=int, default=0, help='node rank for distributed training')
parser.add_argument('--world_size', type=int, default=2, help='number of nodes for distributed traning')
parser.add_argument('--init_method', type=str, default='tcp://192.168.80.156:65530', help='url used to set up distributed training')
parser.add_argument('--backend', type=str, default='nccl', help='distributed backend')

输入的格式

主节点:python xxx.py --rank 0 --world_size 2 --init_method tcp://192.168.80.156:65530 -- backend nccl
分节点1:python xxx.py --rank 1 --world_size 2 --init_method tcp://192.168.80.156:65530 -- backend nccl
分节点2:python xxx.py --rank 2 --world_size 2 --init_method tcp://192.168.80.156:65530 -- backend nccl

2.参数初始化

放在以下所有操作之前

dist.init_process_group(backend=opt.backend,  # distributed backend
                                init_method=opt.init_method,  # init method
                                world_size=opt.world_size,  # number of nodes
                                rank=opt.rank)    # node rank

3.分发数据

data_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             num_workers=nw,
                                             pin_memory=True,
                                             shuffle=(data_sampler is None),
                                             sampler=data_sampler)

4.分发模型

model = torch.nn.parallel.DistributedDataParallel(model)

5.BatchSize

多机单卡不用改batchsize也可。

args.batch_size = int(args.batch_size / ngpus_per_node)

6.同步数据

一般用在加载模型之前和optim.step()

dist.barrier()

7.混洗数据

放在for epoch in range(start_epoch, epochs):下面第一行

data_sampler.set_epoch(epoch)

8.保存模型

可以选在单在主机保存,也可以分别报错。
只在主机保存时用if rank==0:

9.整理资源

放在最外层最后

dist.destroy_process_group()

最后

以上就是闪闪日记本为你收集整理的pytorch多节点分布式训练的全部内容,希望文章能够帮你解决pytorch多节点分布式训练所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(35)

评论列表共有 0 条评论

立即
投稿
返回
顶部