概述
nnUNet最大的优点,就是可以根据数据集的情况,自适应的选择模型情况,而这一切的基础建立在nnUNet具有一套自成体系的数据分析框架,下面就开始一步一步解析nnUNet数据分析代码。相关代码存在于 nnUNet_plan_and_preprocess.py文件中。
参数解析
该文件首先就是参数配置,包括 task_id, planner3d, planner2d, no_pp, tl, tf, verify_dataset_integrity, overwrite_plans, overwrite_plans_identifier 。其中no_pp 设置为true, 则不会进行预处理,tl、tf 设置预处理时所用的进程,verify_dataset_integrity 用于设置是是否检查数据完整性,如果设置为True,则将对数据完整性进行检查。
完整性检查
在降verify_dataset_integrity 设置为True,首先进行数据完整性检查,判断是否输入正确,有以下检查内容
- 指定文件夹中的数据与json文件能否一一对应,没有吗存在重复
- 多模态数据是否存在缺失
- label 文件中是否存在nan
- img与label的几何属性是否一致,包括原点、方向、大小、spacing
- 在对所有的label重新进行检查,包括label是否连续,所有的label是否有异常值出现
- 如果存在测试集,对测试集的数据在进行check# 7、重新检查一下所有数据的坐标系统是否一致
该段内容位于 nnUNet/preprocessing/sanity_checks.py文件中,主要用到子函数有
# label与数据应该具有相同的远点、分辨率、方向和大小
def verify_same_geometry(img_1: sitk.Image, img_2: sitk.Image):
ori1, spacing1, direction1, size1 = img_1.GetOrigin(), img_1.GetSpacing(), img_1.GetDirection(), img_1.GetSize()
ori2, spacing2, direction2, size2 = img_2.GetOrigin(), img_2.GetSpacing(), img_2.GetDirection(), img_2.GetSize()
same_ori = np.all(np.isclose(ori1, ori2)) # np.isclose判断两个数组是否相近
if not same_ori:
print("the origin does not match between the images:")
print(ori1)
print(ori2)
same_spac = np.all(np.isclose(spacing1, spacing2))
if not same_spac:
print("the spacing does not match between the images")
print(spacing1)
print(spacing2)
same_dir = np.all(np.isclose(direction1, direction2))
if not same_dir:
print("the direction does not match between the images")
print(direction1)
print(direction2)
same_size = np.all(np.isclose(size1, size2))
if not same_size:
print("the size does not match between the images")
print(size1)
print(size2)
if same_ori and same_spac and same_dir and same_size:
return True
else:
return False
# 对label进行验证,只能含有规定的label值,可以少不能多,返回标志位,以及unique label
def verify_contains_only_expected_labels(itk_img: str, valid_labels: (tuple, list)):
img_npy = sitk.GetArrayFromImage(sitk.ReadImage(itk_img))
uniques = np.unique(img_npy)
invalid_uniques = [i for i in uniques if i not in valid_labels]
if len(invalid_uniques) == 0:
r = True
else:
r = False
return r, invalid_uniques
整个函数部分如下,英文注释非常完整,我再此基础上有所补充,其中用到的isfile、join、subfiles 就是对os.path 模块中的函数进行了相应封装。
def verify_dataset_integrity(folder):
"""
folder needs the imagesTr, imagesTs and labelsTr subfolders. There also needs to be a dataset.json checks if all training cases and labels are present checks if all test cases (if any) are present for each case, checks whether all modalities apre present for each case, checks whether the pixel grids are aligned checks whether the labels really only contain values they should
:param folder:
:return:
"""
assert isfile(join(folder, "dataset.json")), "There needs to be a dataset.json file in folder, folder=%s" % folder
assert isdir(join(folder, "imagesTr")), "There needs to be a imagesTr subfolder in folder, folder=%s" % folder
assert isdir(join(folder, "labelsTr")), "There needs to be a labelsTr subfolder in folder, folder=%s" % folder
dataset = load_json(join(folder, "dataset.json"))
training_cases = dataset['training']
num_modalities = len(dataset['modality'].keys())
test_cases = dataset['test']
expected_train_identifiers = [i['image'].split("/")[-1][:-7] for i in training_cases]
expected_test_identifiers = [i.split("/")[-1][:-7] for i in test_cases]
# check training set 得到所有的训练数据和label
nii_files_in_imagesTr = subfiles((join(folder, "imagesTr")), suffix=".nii.gz", join=False)
nii_files_in_labelsTr = subfiles((join(folder, "labelsTr")), suffix=".nii.gz", join=False)
label_files = []
geometries_OK = True
has_nan = False
# check all cases 是否有重复数据
if len(expected_train_identifiers) != len(np.unique(expected_train_identifiers)): raise RuntimeError("found duplicate training cases in dataset.json")
print("Verifying training set")
for c in expected_train_identifiers:
print("checking case", c)
# check if all files are present
expected_label_file = join(folder, "labelsTr", c + ".nii.gz")
label_files.append(expected_label_file)
expected_image_files = [join(folder, "imagesTr", c + "_%04.0d.nii.gz" % i) for i in range(num_modalities)] # 多模态数据
assert isfile(expected_label_file), "could not find label file for case %s. Expected file: n%s" % (c, expected_label_file)
assert all([isfile(i) for i in expected_image_files]), "some image files are missing for case %s. Expected files:n %s" % (c, expected_image_files)
# label数据和不同模态的数据应该具有相同的数据大小
# verify that all modalities and the label have the same shape and geometry.
label_itk = sitk.ReadImage(expected_label_file)
# 判断label中是否存在nan值
nans_in_seg = np.any(np.isnan(sitk.GetArrayFromImage(label_itk)))
has_nan = has_nan | nans_in_seg
if nans_in_seg:
print("There are NAN values in segmentation %s" % expected_label_file)
images_itk = [sitk.ReadImage(i) for i in expected_image_files] # 批量读取不同模态的数据
for i, img in enumerate(images_itk):
nans_in_image = np.any(np.isnan(sitk.GetArrayFromImage(img)))
has_nan = has_nan | nans_in_image
# 这里判断数据是否具有一样的几何性质,如果不相同,则跳过该数据
same_geometry = verify_same_geometry(img, label_itk)
if not same_geometry:
geometries_OK = False
print("The geometry of the image %s does not match the geometry of the label file. The pixel arrays will not be aligned and nnU-Net cannot use this data. Please make sure your image modalities are coregistered and have the same geometry as the label" % expected_image_files[0][:-12])
if nans_in_image:
print("There are NAN values in image %s" % expected_image_files[i])
# now remove checked files from the lists nii_files_in_imagesTr and nii_files_in_labelsTr
# 这里的这个操作是什么意思
for i in expected_image_files:
nii_files_in_imagesTr.remove(os.path.basename(i)) # os.path.basename 返回path最后的文件名
nii_files_in_labelsTr.remove(os.path.basename(expected_label_file))
# check for stragglers
assert len(nii_files_in_imagesTr) == 0, "there are training cases in imagesTr that are not listed in dataset.json: %s" %
nii_files_in_imagesTr
assert len(nii_files_in_labelsTr) == 0, "there are training cases in labelsTr that are not listed in dataset.json: %s" %
nii_files_in_labelsTr
# verify that only properly declared values are present in the labels
# 判断label中的值是否都在json中
print("Verifying label values")
expected_labels = list(int(i) for i in dataset['labels'].keys())
# check if labels are in consecutive order
assert expected_labels[0] == 0, 'The first label must be 0 and maps to the background' # 背景对应的数值为0
labels_valid_consecutive = np.ediff1d(expected_labels) == 1 # 计算数组之间的差分,判断是否连续
assert all(labels_valid_consecutive), f'Labels must be in consecutive order (0, 1, 2, ...). The labels {np.array(expected_labels)[1:][~labels_valid_consecutive]} do not satisfy this restriction'
p = Pool(default_num_threads) # 使用多线程
results = p.starmap(verify_contains_only_expected_labels, zip(label_files, [expected_labels] * len(label_files))) # 看看这个返回结果是什么
p.close()
p.join()
fail = False
print("Expected label values are", expected_labels)
for i, r in enumerate(results):
if not r[0]:
print("Unexpected labels found in file %s. Found these unexpected values (they should not be there) %s" % (label_files[i], r[1]))
fail = True
if fail:
raise AssertionError("Found unexpected labels in the training dataset. Please correct that or adjust your dataset.json accordingly")
else:
print("Labels OK")
# check test set, but only if there actually is a test set if len(expected_test_identifiers) > 0:
print("Verifying test set")
nii_files_in_imagesTs = subfiles((join(folder, "imagesTs")), suffix=".nii.gz", join=False)
for c in expected_test_identifiers:
# check if all files are present, 对于测试集,只有数据,没有label
expected_image_files = [join(folder, "imagesTs", c + "_%04.0d.nii.gz" % i) for i in range(num_modalities)] assert all([isfile(i) for i in expected_image_files]), "some image files are missing for case %s. Expected files:n %s" %(c, expected_image_files)
# verify that all modalities and the label have the same geometry. We use the affine for this
if num_modalities > 1:
images_itk = [sitk.ReadImage(i) for i in expected_image_files]
reference_img = images_itk[0]
for i, img in enumerate(images_itk[1:]):
assert verify_same_geometry(img, reference_img), "The modalities of the image %s do not seem to be registered. Please coregister your modalities." % (expected_image_files[i])
# now remove checked files from the lists nii_files_in_imagesTr and nii_files_in_labelsTr
for i in expected_image_files:
nii_files_in_imagesTs.remove(os.path.basename(i)) assert len(nii_files_in_imagesTs) == 0, "there are training cases in imagesTs that are not listed in dataset.json: %s" % nii_files_in_imagesTr
# 在单个数据对比后,所有的数据再check
all_same, unique_orientations = verify_all_same_orientation(join(folder, "imagesTr"))
if not all_same:
print("WARNING: Not all images in the dataset have the same axis ordering. We very strongly recommend you correct that by reorienting the data. fslreorient2std should do the trick")
# save unique orientations to dataset.json
if not geometries_OK:
raise Warning("GEOMETRY MISMATCH FOUND! CHECK THE TEXT OUTPUT! This does not cause an error at this point but you should definitely check whether your geometries are alright!")
else:
print("Dataset OK")
if has_nan:
raise RuntimeError("Some images have nan values in them. This will break the training. See text output above to see which ones")
Crop去除黑边
如果不进行数据完整性检查,其实第一步就是进入crop部分,用来对数据进行去黑边操作
crop(task_name, False, tf) # 对数据去黑边,并将data、seg保存为npz文件,属性保存为pkl文件
crop的具体实现在nnUnetexperiment_planningutils.py文件中,不过其重点运用了ImageCropper类,该类实现了以下几种功能
- 读取多模态数据,以及其Mask,使用np.vstack对数据进行拼接
- 获取多模态数据的非零区域,生成nonzero_mask
- 根据nonzero_mask对数据以及seg进行crop,去黑边
- 使用np将数据、seg一起保存进入npz、属性保存为pkl文件,加快后期数据加载速度。
下面重点该类进行解析
# 以case为单位读取data和seg
def load_case_from_list_of_files(data_files, seg_file=None):
assert isinstance(data_files, list) or isinstance(data_files, tuple), "case must be either a list or a tuple"
properties = OrderedDict() # 使用OrderedDict进行属性存储
data_itk = [sitk.ReadImage(f) for f in data_files] # 因为sitk读取数据存在差异,这里都交换位置[2, 1, 0]
properties["original_size_of_raw_data"] = np.array(data_itk[0].GetSize())[[2, 1, 0]]
properties["original_spacing"] = np.array(data_itk[0].GetSpacing())[[2, 1, 0]]
properties["list_of_data_files"] = data_files
properties["seg_file"] = seg_file
# 虽然简单,但是最好分开存放
properties["itk_origin"] = data_itk[0].GetOrigin()
properties["itk_spacing"] = data_itk[0].GetSpacing()
properties["itk_direction"] = data_itk[0].GetDirection()
data_npy = np.vstack([sitk.GetArrayFromImage(d)[None] for d in data_itk]) # 多模态数据拼接编成4维矩阵
if seg_file is not None:
seg_itk = sitk.ReadImage(seg_file)
seg_npy = sitk.GetArrayFromImage(seg_itk)[None].astype(np.float32)
else:
seg_npy = None
return data_npy.astype(np.float32), seg_npy, properties
# 找到每个模态的数据都要找到非零区域并且并集
def create_nonzero_mask(data):
from scipy.ndimage import binary_fill_holes
assert len(data.shape) == 4 or len(data.shape) == 3, "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
nonzero_mask = np.zeros(data.shape[1:], dtype=bool)
for c in range(data.shape[0]):
this_mask = data[c] != 0
nonzero_mask = nonzero_mask | this_mask
nonzero_mask = binary_fill_holes(nonzero_mask)
return nonzero_mask
# 得到每个区域的bbox,然后进行裁剪
# 这里输入Mask的形状应该是( X, Y, Z)或者(X,Y)
def get_bbox_from_mask(mask, outside_value=0):
mask_voxel_coords = np.where(mask != outside_value) # 这里不会有问题吗,如果是2D多模态数据怎么办?
minzidx = int(np.min(mask_voxel_coords[0]))
maxzidx = int(np.max(mask_voxel_coords[0])) + 1
minxidx = int(np.min(mask_voxel_coords[1]))
maxxidx = int(np.max(mask_voxel_coords[1])) + 1
minyidx = int(np.min(mask_voxel_coords[2]))
maxyidx = int(np.max(mask_voxel_coords[2])) + 1
return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]]
def crop_to_bbox(image, bbox):
assert len(image.shape) == 3, "only supports 3d images"
resizer = (slice(bbox[0][0], bbox[0][1]), slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1]))
return image[resizer]
def crop_to_nonzero(data, seg=None, nonzero_label=-1):
"""
:param data:
:param seg:
:param nonzero_label: this will be written into the segmentation map
:return:
"""
nonzero_mask = create_nonzero_mask(data) # 根据多模态数据,生成对应非零区域的Mask,并填补孔洞
bbox = get_bbox_from_mask(nonzero_mask, 0) # ([z_min, z_max], [x..], [y..])
cropped_data = []
for c in range(data.shape[0]):
# 对每个模态的数据根据生成的bbox都进行crop
cropped = crop_to_bbox(data[c], bbox)
cropped_data.append(cropped[None])
data = np.vstack(cropped_data)
if seg is not None:
cropped_seg = []
# 对Mask进行crop,这里可能有点问题,不是只有一个seg吗?
for c in range(seg.shape[0]):
cropped = crop_to_bbox(seg[c], bbox)
cropped_seg.append(cropped[None])
seg = np.vstack(cropped_seg)
# 对Mask也进行crop
nonzero_mask = crop_to_bbox(nonzero_mask, bbox)[None]
if seg is not None:
seg[(seg == 0) & (nonzero_mask == 0)] = nonzero_label
else:
nonzero_mask = nonzero_mask.astype(int)
nonzero_mask[nonzero_mask == 0] = nonzero_label
nonzero_mask[nonzero_mask > 0] = 0
seg = nonzero_mask # 这里应该是在进行推理的时候,方便使用,否则有点讲不通,为什么会需要一个nonzero_mask
return data, seg, bbox
class ImageCropper(object):
def __init__(self, num_threads, output_folder=None):
"""
This one finds a mask of nonzero elements (must be nonzero in all modalities) and crops the image to that mask.
In the case of BRaTS and ISLES data this results in a significant reduction in image size
:param num_threads:
:param output_folder: whete to store the cropped data
:param list_of_files:
"""
self.output_folder = output_folder
self.num_threads = num_threads
if self.output_folder is not None:
maybe_mkdir_p(self.output_folder)
# 核心代码,对数据进行crop
@staticmethod
def crop(data, properties, seg=None):
shape_before = data.shape # (C, D, H, W)
data, seg, bbox = crop_to_nonzero(data, seg, nonzero_label=-1)
shape_after = data.shape
print("before crop:", shape_before, "after crop:", shape_after, "spacing:", np.array(properties["original_spacing"]), "n") properties["crop_bbox"] = bbox
properties['classes'] = np.unique(seg) # 当前数据有多少类
seg[seg < -1] = 0
properties["size_after_cropping"] = data[0].shape
# 分割后的形状
return data, seg, properties
@staticmethod
def crop_from_list_of_files(data_files, seg_file=None):
data, seg, properties = load_case_from_list_of_files(data_files, seg_file)
return ImageCropper.crop(data, properties, seg)
def load_crop_save(self, case, case_identifier, overwrite_existing=False):
try:
print(case_identifier)
if overwrite_existing or (not os.path.isfile(os.path.join(self.output_folder, "%s.npz" % case_identifier)) or not os.path.isfile(os.path.join(self.output_folder, "%s.pkl" % case_identifier))):
data, seg, properties = self.crop_from_list_of_files(case[:-1], case[-1])
# 去除黑边后的数据以及Mask,properties中存放数据属性
# 直接将数据以及分割结果一起存在,加快后期数据读取速度。
all_data = np.vstack((data, seg))
np.savez_compressed(os.path.join(self.output_folder, "%s.npz" % case_identifier), data=all_data)
with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'wb') as f:
pickle.dump(properties, f)
except Exception as e:
print("Exception in", case_identifier, ":")
print(e)
raise e
# 在解析完上面的函数之后,再看run_cropping 就很简单了
def run_cropping(self, list_of_files, overwrite_existing=False, output_folder=None):
"""
also copied ground truth nifti segmentation into the preprocessed folder so that we can use them for evaluation on the cluster :param list_of_files: list of list of files [[PATIENTID_TIMESTEP_0000.nii.gz,[PATIENTID_TIMESTEP_0000.nii.gz]]
:param overwrite_existing:
:param output_folder:
:return:
"""
if output_folder is not None:
self.output_folder = output_folder
output_folder_gt = os.path.join(self.output_folder, "gt_segmentations")
maybe_mkdir_p(output_folder_gt)
for j, case in enumerate(list_of_files):
if case[-1] is not None:
shutil.copy(case[-1], output_folder_gt)
list_of_args = []
for j, case in enumerate(list_of_files):
case_identifier = get_case_identifier(case)
list_of_args.append((case, case_identifier, overwrite_existing)) # case:每个id中的多模态数据, case_identifier id
p = Pool(self.num_threads)
p.starmap(self.load_crop_save, list_of_args)
p.close()
p.join()
最后
以上就是鳗鱼小兔子为你收集整理的nnUNet代码学习——数据预处理部分(一)的全部内容,希望文章能够帮你解决nnUNet代码学习——数据预处理部分(一)所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复