Slide-SAM/datasets/dataset3d_2dmask.py
transcendentsky e04459c6fe first commit
2023-12-05 14:58:38 +08:00

179 lines
7.4 KiB
Python

# from torchvision import transforms
from monai import transforms
import numpy as np
import SimpleITK as sitk
import torch
from torch.utils.data import Dataset as dataset
import glob
import os
from einops import rearrange, repeat
from tutils.nn.data.tsitk import read
from tqdm import tqdm
# from monai.transforms import SpatialPadd, CenterSpatialCropd, Resized, NormalizeIntensityd
# from monai.transforms import RandAdjustContrastd, RandShiftIntensityd, Rotated, RandAffined
# from datasets.common_2d_aug import RandomRotation, RandomResizedCrop, RandomHorizontalFlip, ColorJitter, ToTensor, Normalize
from tutils import tfilename, tdir
import random
from tutils.nn.data import itk_to_np
from scipy.sparse import csr_matrix
import cv2
DEFAULT_PATH="/quanquan/datasets/08_AbdomenCT-1K/"
class Dataset2D(dataset):
def __init__(self, dirpath=None, is_train=True, getting_multi_mask=False) -> None:
super().__init__()
self.dirpath = dirpath
self.is_train = is_train
self.getting_multi_mask = getting_multi_mask
self.img_names = self.prepare_datalist()
self.prepare_transforms()
self.weights_dict = {"gt":2, "sam_auto_seg":2, "prompt_point_from_superpixel":1, "prompt_box_from_superpixel":1, "superpixel":0}
def prepare_transforms(self):
self.transform = transforms.Compose([
transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
# transforms.RandSpatialCropd(keys=["img"], roi_size=(448,448,1)),
transforms.RandAffined(keys=['img', 'label'], prob=0.5, shear_range=(0.2,0.2)),
transforms.RandCropByPosNegLabeld(keys=['img', 'label'], spatial_size=(3,960,960), label_key='label', neg=0),
# transforms.RandSmoothFieldAdjustContrastd(keys=['img', 'label'], )
transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
transforms.RandAdjustContrastd(keys=['img'], ),
transforms.RandShiftIntensityd(keys=['img'], prob=0.8, offsets=(-5, 5)),
])
self.test_transform = transforms.Compose([
transforms.Resized(keys=['img'], spatial_size=(3,1024,1024)),
])
def __len__(self):
return len(self.img_names)
def to_RGB(self, img):
# transform images to RGB style
img = ((img - img.min()) / img.max() * 255).astype(int)
return img
def prepare_datalist(self):
dirpath_img = os.path.join(self.dirpath, 'preprocessed', 'cache_2d_various_pseudo_masks')
names = glob.glob(os.path.join(dirpath_img, "*_mask.npz"))
names = [os.path.split(name)[-1].replace("_mask.npz", "") for name in names]
names.sort()
# names = names[:15000]
print(f"[Dataset2d] Load {len(names)} paths.")
assert len(names) > 0, f"{__file__} Gotdirpath: {self.dirpath}"
return names
def _get_data(self, index, debug=False, iternum=0):
img_names = self.img_names
img_info = os.path.split(img_names[index])[-1].split('_s')
filename, slice_idx = img_info[0], int(img_info[-1][:4])
mask_loc = np.random.randint(0,3)
if mask_loc == 0:
slices_indices = [slice_idx, slice_idx+1, slice_idx+2]
elif mask_loc == 1:
slices_indices = [slice_idx-1, slice_idx, slice_idx+1]
elif mask_loc == 2:
slices_indices = [slice_idx-2, slice_idx-1, slice_idx]
# Load .npy data
filenames = [os.path.join(self.dirpath, "preprocessed", "cache_jpg", f"{filename}_s{i:04}_img.jpg") for i in slices_indices]
for name in filenames:
if not os.path.exists(name):
return self._get_data(index+1 % len(self))
imgs = [cv2.imread(name, cv2.IMREAD_GRAYSCALE) for name in filenames]
img_rgb = np.stack(imgs, axis=0)
# Load RGB data
compressed = np.load(os.path.join(self.dirpath, "preprocessed", "cache_2d_various_pseudo_masks", img_names[index]+"_mask.npz"))
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
label_ori = csr.toarray()
label_ori = rearrange(label_ori, "c (h w) -> c h w", h=1024, w=1024)
metadata = np.load(os.path.join(self.dirpath, "preprocessed", "cache_2d_various_pseudo_masks", img_names[index]+"_metadata.npy"), allow_pickle=True)
label_prob = np.array([self.weights_dict[item['source']] for item in metadata]).astype(float)
label_prob = label_prob / label_prob.sum()
label_idx = np.random.choice(a=np.arange(len(metadata)), p=label_prob)
label_ori = label_ori[label_idx]
metadata = metadata[label_idx]
assert metadata['source'] != 'superpixel'
assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}"
bundle_ori = {"img":torch.Tensor(rearrange(img_rgb, "c h w -> 1 c h w")), "label":torch.Tensor(repeat(label_ori, "h w -> 1 3 h w"))}
# import ipdb; ipdb.set_trace()
if self.is_train:
bundle = self.transform(bundle_ori)[0]
else:
bundle = self.test_transform(bundle_ori)
bundle['label'] = (bundle['label']>0.5).float()
if bundle['label'][0].sum() < 100:
return self._get_data((index+1)%len(self), iternum=iternum+1)
vector = np.zeros(3)
vector[mask_loc] = 1
if debug:
ret_dict = {
"name": img_names[index],
"img": bundle['img'][0],
"label": bundle['label'][0],
"img_ori":img_rgb,
"label_ori":label_ori,
"weight": self.weights_dict[metadata['source']],
"iternum": iternum,
"mask_loc": mask_loc,
"indicators": vector,
}
return ret_dict
ret_dict = {
"name": img_names[index],
"img": bundle['img'][0],
"label": bundle['label'][0],
"mask_loc": mask_loc,
"indicators": vector,
}
return ret_dict
def __getitem__(self, index):
return self._get_data(index)
class Testset2d(Dataset2D):
def __init__(self, dirpath=None, is_train=False, getting_multi_mask=False) -> None:
super().__init__(dirpath, is_train, getting_multi_mask)
self.test_names = self.prepare_datalist()
def prepare_datalist(self):
dirpath_img = os.path.join(self.dirpath, 'cache_2d_various_pseudo_masks')
names = glob.glob(os.path.join(dirpath_img, "*_mask.npz"))
names = [os.path.split(name)[-1].replace("_mask.npz", "") for name in names]
names.sort()
names = names[15000:]
print(f"[Dataset2d] Load {len(names)} paths.")
assert len(names) > 0, f"{__file__} Gotdirpath: {self.dirpath}"
return names
if __name__ == "__main__":
from torch.utils.data import DataLoader
dataset = Dataset2D(dirpath=DEFAULT_PATH)
loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)
iternums = 0
for i, data in enumerate(loader):
# iternums += data['iternum'].item()
print(i, iternums / (i+1), data['img'].shape, data['label'].shape)
assert data['label'].sum() >= 100, f"{__file__} Got{data['label'].sum()}"
assert torch.Tensor(data['label']==1).sum() >= 100, f"{__file__} Got {torch.Tensor(data['label']==1).sum().sum()}"
import ipdb; ipdb.set_trace()