pytorch 模型训练时多卡负载不均衡(GPU的0卡显存过高)解决办法(简单有效)
本文主要解决pytorch在进行模型训练时出现GPU的0卡占用显存比其他卡要多的问题。如下图所示:本机GPU卡为TITAN RTX,显存24220M,batch_size = 9,用了三张卡。第0卡显存占用24207M,这时仅仅是刚开始运行,数据只是少量的移到显卡上,如果数据在多点,0卡的显存肯定撑爆。出现0卡显存更高的原因:网络在反向传播的时候,计算loss的梯度默认都在0卡上计算。因此会比其他显卡多用一些显存,具体多用多少,主要还要看网络的结构。因此,为了防止训练由于 out of memo