Slide-SAM/datasets/eval_dataloader/loader_abstract.py
2024-04-02 15:48:48 +08:00

147 lines
5.1 KiB
Python

"""
DataLoader only for evaluation
"""
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from tutils.nn.data import read, itk_to_np, np_to_itk
from einops import reduce, repeat, rearrange
# from tutils.nn.data.tsitk.preprocess import resampleImage
from trans_utils.data_utils import Data3dSolver
from tutils.nn.data.tsitk.preprocess import resampleImage
import SimpleITK as sitk
# Example
DATASET_CONFIG={
'split': 'test',
'data_root_path':'/quanquan/datasets/',
'dataset_list': ["ours"],
'data_txt_path':'./datasets/dataset_list/',
'label_idx': 0,
}
class AbstractLoader(Dataset):
def __init__(self, config, split="test") -> None:
super().__init__()
self.config = config
self.split = split
self.img_names = self.prepare_datalist()
def __len__(self):
return len(self.img_names)
def prepare_datalist(self):
config = self.config
data_paths = []
for item in config['dataset_list']:
print("Load datalist from ", item)
for line in open(config["data_txt_path"]+ item + f"_{self.split}.txt"):
name = line.strip().split()[1].split('.')[0]
img_path = config['data_root_path'] + line.strip().split()[0]
label_path = config['data_root_path'] + line.strip().split()[1]
data_paths.append({'img_path': img_path, 'label_path': label_path, 'name': name})
print('train len {}'.format(len(data_paths)))
return data_paths
def _get_data(self, index, debug=False):
label_idx = self.config['label_idx']
name = self.img_names[index]['img_path']
img_itk = read(name)
spacing = img_itk.GetSpacing()
img_ori = itk_to_np(img_itk)
scan_orientation = np.argmin(img_ori.shape)
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
label = label_ori == label_idx
# img_ori, new_spacing = Data3dSolver().read(self.img_names[index]['img_path'])
# label_itk = read(self.img_names[index]['label_path'])
# ori_spacing = label_itk.GetSpacing()
# label = itk_to_np(label_itk) == label_idx
# print("[loader_abstract.DEBUG] size", img_ori.shape, label.shape)
# label = self._get_resized_label(label, new_size=img_ori.shape)
if debug:
Data3dSolver().simple_write(label)
Data3dSolver().simple_write(img_ori, "tmp_img.nii.gz")
s = reduce(label, "c h w -> c", reduction="sum")
coords = np.nonzero(s)
x_min = np.min(coords[0])
x_max = np.max(coords[0])
template_slice_id = s.argmax() - x_min
if img_ori.min() < -10:
img_ori = np.clip(img_ori, -200, 400)
else:
img_ori = np.clip(img_ori, 0, 600)
img_ori = img_ori[x_min:x_max+1,:,:]
label = label[x_min:x_max+1,:,:]
assert label.shape[0] >= 3
if template_slice_id <= 1 or template_slice_id >= label.shape[0]-2:
template_slice_id == label.shape[0] // 2
dataset_name = name.replace(self.config['data_root_path'], "").split("/")[0]
template_slice = label[template_slice_id,:,:]
print("template_slice.area ", template_slice.sum(), template_slice.sum() / (template_slice.shape[0] * template_slice.shape[1]))
d = {
"name": name,
"dataset_name": dataset_name,
"img": np.array(img_ori).astype(np.float32),
"label_idx": label_idx,
"label": np.array(label).astype(np.float32),
"template_slice_id": template_slice_id,
"template_slice": np.array(label[template_slice_id,:,:]).astype(np.float32),
"spacing": np.array(spacing),
}
return d
def __getitem__(self, index):
return self._get_data(index)
if __name__ == "__main__":
from tutils.new.manager import ConfigManager
EX_CONFIG = {
'dataset':{
'prompt': 'box',
'dataset_list': ['guangdong'], # ["sabs"], chaos, word
'label_idx': 2,
}
}
config = ConfigManager()
config.add_config("configs/vit_sub_rectum.yaml")
config.add_config(EX_CONFIG)
dataset = AbstractLoader(config['dataset'], split="test")
for i in range(len(dataset)):
dataset._get_data(i, debug=False)
# label_path = "/home1/quanquan/datasets/01_BCV-Abdomen/Training/label/label0001.nii.gz"
# from monai.transforms import SpatialResample
# resample = SpatialResample()
# label = itk_to_np(read(label_path)) == 1
# print(label.shape)
# # resampled = resample(label, spatial_size=(label.shape[0]*7, label.shape[1], label.shape[2]))
# print(label.shape)
exit(0)
data = itk_to_np(read("tmp_img.nii.gz"))
data = torch.Tensor(data)
maxlen = data.shape[0]
slices = []
for i in range(1, maxlen-1):
slices.append(data[i-1:i+2, :, :])
input_slices = torch.stack(slices, axis=0)
input_slices = torch.clip(input_slices, -200, 600)
input_slices
from torchvision.utils import save_image
save_image(input_slices, "tmp.jpg")