我是靠谱客的博主 甜美麦片,最近开发中收集的这篇文章主要介绍记录一个在使用 Pytorch 0.4.0 过程中遇到的 RuntimeError,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

最近在使用新版 Pytorch 0.4.0(听说1.0版本马上就要出了…)训练 GAN 的时候遇到了这样一个BUG,RuntimeError: Expected object of type torch.DoubleTensor but found type torch.cuda.FloatTensor for argument #2 ‘weight’ 。相信很多初学者都会在刚开始的时候遇到这样的问题,我就 debug 了一天的时间。首先我们看看问题本身,报错信息提示我们某个变量的类型错了,应该是 torch.cuda.FloatTensor 但是我们给的是 torch.DoubleTensor ,这个地方很容易理解反 这里说的是第二个参数 weight 的数据类型是 torch.cuda.FloatTensor,但应该是 torch.DoubleTensor,大家在前期学习深度学习基础理论的时候应该知道我们传入的数据是要和每层网络的 filter 进行卷积运算的,卷积核上的参数就是 weight,它的数据类型是你实例化你建立的模型的时候决定的(见下面的代码)。那么,这里的正确类型也就是 torch.DoubleTensor 是哪里来的呢?其实这个就是我们输入数据的数据类型,所以这里实际上是我们输入数据的类型错了而不是模型需要的类型错了,这也是这个报错令人疑惑的地方,它告诉我们一个错误但其实它是另一个错误引起的,确实有点反直觉。下面我们结合代码看看为什么会触发这个BUG。


以下是有关这个错误的部分代码:

main.py

    cudnn.benchmark = True
    device = torch.device('cuda:3')
    G = Generator().to(device)
    D = Discriminator().to(device)

main.py

    for epoch in range(num_epochs):
        for t, x in enumerate(loader):
            optimizerD.zero_grad()
            optimizerG.zero_grad()
            x.requires_grad_().to(device)
            noise_size = x.shape[0]
            noise = torch.randn(noise_size, 1, 28, 28).requires_grad_().to(device)

这里我们可以看到 Pytorch 0.4.0 新增了一个 torch.device 属性,这样一来我们就可以更方便的指定模型和变量运行的设备, 并且新版本弃用了 torch.Variable 类,现在我们的数据只要转成 torch.Tensor 就可以跑了。回到正题,上面的代码有几个地方需要注意。

  • 现在 torch.Tensor 现在可以直接用.requires_grad_() 来进行原地(Inplace)设置来让张量可以被反向求导,就像之前的Variable 一样,换言之,.requires_grad_() 是一个原地(Inplace)操作
  • 现在 nn.Module 对象和 torch.Tensor 对象都有一个 .to() 方法它可以接受一个 torch.Device 类型的变量用以指定模型或者张量运行在哪个设备上,但 .to() 不是一个原地(Inplace)操作
  • nn.Module 调用.to(device)后它内部参数的数据类型会从 torch.FloatTensor 变成 torch.cuda.FloatTensor, torch.Tensor 调用 .to(device)方法后数据类型会从torch.*Tensor 变成 torch.cuda.*Tensor

所以解决这个 bug 只需要修改

x.requires_grad_().to(device)

x.requires_grad_()
x = x.float().to(device)

就可以了。


这个问题虽然解决起来很简单,但是在查找 bug 原因的过程中可是一点不容易,这也告诉我们数据类型真的很重要,我们在复现论文或者自己做实验的时候要时刻注意自己手中数据在网络中的类型变化。希望这篇文章可以帮到有类似疑问的同学。另外,有兴趣的同学可以看看我在 Pytorch论坛上提问 的过程,Pytorch论坛 确实是一个学习Pytorch的好地方,但需要吐槽的是 Pytorch 官网可以正常访问,但是 Pytorch 论坛却被墙了,有需要的同学请自备梯子。

最后

以上就是甜美麦片为你收集整理的记录一个在使用 Pytorch 0.4.0 过程中遇到的 RuntimeError的全部内容,希望文章能够帮你解决记录一个在使用 Pytorch 0.4.0 过程中遇到的 RuntimeError所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(45)

评论列表共有 0 条评论

立即
投稿
返回
顶部