概述
背景
帮师妹改毕设代码,第一次接触pytorch,没看手册,直接开肛,遇到些坑,在这里记录一下,大佬勿喷。
坑1:进入网络训练的数据必须归一化
如果数据没有归一化,可能得到的loss会成为负数,在此参考了Crazy_Omais的一段归一化代码
def data_in_one(inputdata):
min = np.nanmin(inputdata)
max = np.nanmax(inputdata)
outputdata = (inputdata-min)/(max-min)
return outputdata
坑2:torchvision.transforms.ToTensor()
torchvision.transforms.ToTensor()不能用于处理一维数据,如果要处理的话,可以使用torch.from_numpy()
def __getitem__(self, index):
data = self.datas[:][index]
data = torch.from_numpy(data)
坑3:网络训练的数据需要是Dataloader类型
网络训练的数据需要是Dataloader类型,而输入必须是一个Dataset的子类,因此我们有必要定义一个类,以装载我们自己的数据,本代码数据手动分训练集和测试集,两个成员函数是必要的,getitem 函数是在torch.utils.data.DataLoader()分batch的时候循环调用的:
class MyDataset(torch.utils.data.Dataset):
def __init__(self, train_data_flag=0):
super(MyDataset, self).__init__()
file_path = '/Users/sophia/Downloads/****.mat'
fh = scio.loadmat(file_path)
fh = fh['Y']
fh = data_in_one(fh)
# 由uint16->float64
fh_array = np.array(fh, dtype='float')
fh_array_t = fh_array.T
if train_data_flag == 0:
self.datas = fh_array_t[:][0:90000]
else:
self.datas = fh_array_t[:][90001:94001]
def __getitem__(self, index):
data = self.datas[:][index]
data = torch.from_numpy(data)
return data
def __len__(self):
return len(self.datas)
#主函数调用
train_dataset = MyDataset(train_data_flag=0)
test_dataset = MyDataset(train_data_flag=1)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=shuffle)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=shuffle)
坑4:报错RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #2 ‘mat1’ in call to _th_addmm
在网上看到大佬们写的demo可以了解到,要将tenor类型的数据用tenor.float()转换为浮点型即可,如果是有已有的数据,建议在出错行网上搜索输入的tenor类变量,然后对它进行操作。
for batch_index, train_data in enumerate(train_loader):
if torch.cuda.is_available():
train_data = train_data.cuda()
train_data = train_data.float()
总结
调试真的太方便了,不知道比tensorflow方便多少倍,爱了爱了,但是要运用熟练还得好好学习一下人家的框架hhhh
参考链接
- https://blog.csdn.net/weixin_42214565/article/details/102381380
- https://blog.csdn.net/Teeyohuang/article/details/79587125
最后
以上就是大方篮球为你收集整理的用pytorch踩过的坑的全部内容,希望文章能够帮你解决用pytorch踩过的坑所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复