概述
鉴于数据太大,肯要用S3来存
S3的读取速度应该还是很快的
from torch.utils.data import Dataset,DataLoader
import torch
import numpy as np
import nibabel as nbl
HCPDataListPath = '/home/ec2-user/SageMaker/Models_HCP/dt.txt'
HCPDataRootPath = '/home/ec2-user/SageMaker/HCP_dataset/'
labelNumber = {'EMOTION': 0, 'GAMBLING': 1, 'LANGUAGE': 2, 'MOTOR': 3,
'RELATIONAL': 4, 'SOCIAL': 5, 'WM': 6}
def getLabelList(HCPDataPathList):
labelList = []
for HCPDataPath in HCPDataPathList:
labelList.append(labelNumber[HCPDataPath.split('_')[1].split('.')[0]])
return np.array(labelList)
def getData(HCPDataRootPath, dataPath):
dtseries = nbl.load(HCPDataRootPath + dataPath)
time_series = dtseries.get_fdata().reshape((-1))
shape = dtseries.header.matrix.get_index_map(1).volume.volume_dimensions
nifti = np.zeros(shape)
for bm in dtseries.header.matrix.get_index_map(1).brain_models:
if bm.model_type == 'CIFTI_MODEL_TYPE_SURFACE':
continue
voxels = bm.voxel_indices_ijk
off, cnt = bm.index_offset, bm.index_count
nifti[tuple(np.transpose(voxels))] = time_series[off:off + cnt]
return np.array(nifti).reshape((1, 91, 109, 91)).astype(np.float32)
class HCPDataSet(DataSet):
def __init__(self,HCPDataRootPath,HCPDataPathList,HCPLabelList):
self.HCPDataRootPath = HCPDataRootPath
self.HCPDataPathList = np.array(HCPDataPathList)
self.HCPLabelList = np.array(HCPLabelList)
self.total = len(HCPDataPathList)
def __getitem__(self,index):
dataPath = self.HCPDataPathList[index]
data = getData(self.HCPDataRootPath,dataPath)
label = np.array(self.HCPLabelList[index])
return torch.from_numpy(data),torch.from_numpy(label)
def __len__(self):
return self.total
class WholeDataSet():
def __init__(self, trainDataSet, evalDataSet, testDataSet):
self.trainDataSet = trainDataSet
self.evalDataSet = evalDataSet
self.testDataSet = testDataSet
def getHCPDataSet(HCPDataRootPath,HCPDataListPath,evalRate=0.2,testRate=0.2,tiny_data=0):
dataPathList = []
with open(HCPDataListPath,'r') as fr:
for HCPDataPath in fr.readlines():
dataPathList.append(HCPDataPath.strip())
if tiny_data != 0:
dataPathList = dataPathList[:int(tiny_data)]
totalNumber = len(dataPathList)
totalTraining = int(totalNumber * (1 - evalRate - testRate))
totalEvaluation = int(totalNumber * evalRate)
print('Training : {},Evaluation: {},Test: {}'.format(totalTraining,totalEvaluation,totalNumber-totalTraining-totalEvaluation))
trainDataPathList = dataPathList[:totalTraining]
trainLabelList = getLabelList(trainDataPathList)
evalDataPathList = dataPathList[totalTraining:totalTraining + totalEvaluation]
evalLabelList = getLabelList(evalDataPathList)
testDataPathList = dataPathList[totalTraining+totalEvaluation:]
testLabelList = getLabelList(testDataPathList)
trainDataSet = HCPDataSet(HCPDataRootPath, trainDataPathList, trainLabelList)
evalDataSet = HCPDataSet(HCPDataRootPath, evalDataPathList, evalLabelList)
testDataSet = HCPDataSet(HCPDataRootPath, testDataPathList, testLabelList)
dataSet = WholeDataSet(trainDataSet,evalDataSet,testDataSet)
return dataSet
dataSet = getHCPDataSet(HCPDataRootPath, HCPDataListPath, evalRate=0.05, testRate=0.15)
def getData(HCPDataRootPath, dataPath):
data = nbl.load(HCPDataRootPath + dataPath).get_fdata().reshape((-1))
data = np.array(data).reshape((1, 91, 109, 91)).astype(np.float32)
data = (data - np.min(data))/np.max(data) - np.min(data)
return data
最后
以上就是爱笑书本为你收集整理的HCP_DataLoader的全部内容,希望文章能够帮你解决HCP_DataLoader所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复