概述
最近在学pytorch,今天晚上用pytorch的数据加载部分,一开始一直在纠结怎么划分数据集,后来还是手动分了,开始是用torch.utils.data.random_split
但是后来一直报错,我也不知道哪里有错,解决不了,后来暴力解决了
1.重写dataset类,这是必须要写的
- 主要继承Dataset类,重写
__getitem__
,and__len__
的方法 - 我的问题:针对一个文件夹有n张图片,然后一个csv文件中有每个图片对应的label,具体样式如下
步骤1:将image和label对应加载到一个数据集中
class SkinDataset(Dataset):
def __init__(self,csv_file,root_dir,transform=None):
self.csv=pd.read_csv(csv_file)
self.root_dir=root_dir
self.transform=transform
def __len__(self):
return len(self.csv)
def __getitem__(self,idx):
image_path=os.path.join(self.root_dir+self.csv.ix[idx,0]+'.jpg')
image=io.imread(image_path)
label=self.csv.ix[idx,1:].as_matrix()
label=label.reshape(-1,1)
sample={"image":image,"label":label}
return sample
步骤2:将得到的数据集划分为train和test
- 得到整个数据集
dataset=SkinDataset(csv_file="...",root_dir="...")
- 划分数据集
train_size=0.8*len(dataset)
#因为得到的dataset是一个数组字典,所以只能一个个往数组里添加
train_dataset=[]
teat_train=[]
for i in range(train_size):
train_dataset.append(dataset[i])
for i in range(train_size,len(dataset)):
test_dataset.append(dataset[i])
步骤3:对不同数据进行transform,并将其加载到dataloader中
train_transform=transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_trainsform=transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
class Dataset2(Dataset):
def __init__(self,dataset,transform):
self.dataset=dataset
self.transform=transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img,label=self.dataset
return img,label
#因为要将train和test分别进行trainsform,所以只能重新写一个类进行transform,实在想不到好办法了
train_dataset2=Dataset2(train_dataset,transform=train_transform)
test_dataset2=Dataset2(test_dataset,transform=val_trainsform)
traindata=DataLoader(train_dataset2,batch_size=32,shuffle=True,num_workers=4)
traindata=DataLoader(test_dataset2,batch_size=32,shuffle=True,num_workers=4)
哈哈哈,终于不报错了,是的random_split太坑了,出坑太难
最后
以上就是鲤鱼小刺猬为你收集整理的pytorch用自己的数据集进行Dataloader,并对其划分数据集的全部内容,希望文章能够帮你解决pytorch用自己的数据集进行Dataloader,并对其划分数据集所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复