我是靠谱客的博主 大方篮球,最近开发中收集的这篇文章主要介绍用pytorch踩过的坑,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

背景

帮师妹改毕设代码,第一次接触pytorch,没看手册,直接开肛,遇到些坑,在这里记录一下,大佬勿喷。

坑1:进入网络训练的数据必须归一化

如果数据没有归一化,可能得到的loss会成为负数,在此参考了Crazy_Omais的一段归一化代码

def data_in_one(inputdata):
min = np.nanmin(inputdata)
max = np.nanmax(inputdata)
outputdata = (inputdata-min)/(max-min)
return outputdata

坑2:torchvision.transforms.ToTensor()

torchvision.transforms.ToTensor()不能用于处理一维数据,如果要处理的话,可以使用torch.from_numpy()

def __getitem__(self, index):
data = self.datas[:][index]
data = torch.from_numpy(data)

坑3:网络训练的数据需要是Dataloader类型

网络训练的数据需要是Dataloader类型,而输入必须是一个Dataset的子类,因此我们有必要定义一个类,以装载我们自己的数据,本代码数据手动分训练集和测试集,两个成员函数是必要的,getitem 函数是在torch.utils.data.DataLoader()分batch的时候循环调用的:

class MyDataset(torch.utils.data.Dataset):
def __init__(self, train_data_flag=0):
super(MyDataset, self).__init__()
file_path = '/Users/sophia/Downloads/****.mat'
fh = scio.loadmat(file_path)
fh = fh['Y']
fh = data_in_one(fh)
# 由uint16->float64
fh_array = np.array(fh, dtype='float')
fh_array_t = fh_array.T
if train_data_flag == 0:
self.datas = fh_array_t[:][0:90000]
else:
self.datas = fh_array_t[:][90001:94001]
def __getitem__(self, index):
data = self.datas[:][index]
data = torch.from_numpy(data)
return data
def __len__(self):
return len(self.datas)
#主函数调用
train_dataset = MyDataset(train_data_flag=0)
test_dataset = MyDataset(train_data_flag=1)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=shuffle)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=shuffle)

坑4:报错RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #2 ‘mat1’ in call to _th_addmm

在网上看到大佬们写的demo可以了解到,要将tenor类型的数据用tenor.float()转换为浮点型即可,如果是有已有的数据,建议在出错行网上搜索输入的tenor类变量,然后对它进行操作。

for batch_index, train_data in enumerate(train_loader):
if torch.cuda.is_available():
train_data = train_data.cuda()
train_data = train_data.float()

总结

调试真的太方便了,不知道比tensorflow方便多少倍,爱了爱了,但是要运用熟练还得好好学习一下人家的框架hhhh

参考链接

  1. https://blog.csdn.net/weixin_42214565/article/details/102381380
  2. https://blog.csdn.net/Teeyohuang/article/details/79587125

最后

以上就是大方篮球为你收集整理的用pytorch踩过的坑的全部内容,希望文章能够帮你解决用pytorch踩过的坑所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(53)

评论列表共有 0 条评论

立即
投稿
返回
顶部