Slide-SAM/datasets/dataset3d.py
transcendentsky e04459c6fe first commit
2023-12-05 14:58:38 +08:00

240 lines
9.4 KiB
Python

from torchvision import transforms
from monai import transforms as monai_transforms
import numpy as np
import SimpleITK as sitk
import torch
from torch.utils.data import Dataset as dataset
import torch.nn.functional as F
import glob
import os
from einops import rearrange
from tutils.nn.data import read, itk_to_np
from tqdm import tqdm
# from monai.transforms import SpatialPadd, CenterSpatialCropd, Resized, NormalizeIntensityd
# from monai.transforms import RandAdjustContrastd, RandShiftIntensityd, Rotated, RandAffined
# from datasets.common_2d_aug import RandomRotation, RandomResizedCrop, RandomHorizontalFlip, ColorJitter, ToTensor, Normalize
from tutils import tfilename, tdir
import random
import time
import cv2
from scipy.sparse import csr_matrix
def tfunctime(func):
def run(*argv, **kargs):
t1 = time.time()
ret = func(*argv, **kargs)
t2 = time.time()
print(f"[Function {func.__name__}] Running time:{(t2-t1):.6f}s")
return ret
return run
# DEFAULT_PATH='/home1/quanquan/datasets/KiTS/'
DEFAULT_PATH="/home1/quanquan/datasets/BCV-Abdomen/Training/"
LABEL_INDICES={
"t2sag": ["bg","kidney", "label 2", "label 3", "rectum", "tumor", "other"],
}
# CACHE_DISK_DIR="/home1/quanquan/code/projects/medical-guangdong/cache/data2d_3/"
CACHE_DISK_DIR=None
# DEFAULT_CONFIG={
# "pad": (512,512),
# "crop": (384,384),
# "resize": (512,512),
# }
DATASET_CONFIG={
'split': 'train',
'data_root_path':'/quanquan/datasets/',
'dataset_list': ['sam', "their", "ours"],
'data_txt_path':'./datasets/dataset_list/',
}
DATASET_METAINFO={
"WORD": {0:"Background", 1:"Liver", 2:"Spleen", 3:"Left Kidney", 4:"Right Kidney", 5:"Stomach", 6:"Gallbladder", 7:"Esophagus", 8:"Pancreas", 9:"Duodenum", 10:"Colon", 11:"Intestine", 12:"Adrenal", 13:"Rectum", 14:"Bladder", 15:"left head of femur", 16:"right head of femur"}
}
class Dataset3D(dataset):
def __init__(self, config=DATASET_CONFIG, is_train=True, split='train', getting_multi_mask=False, use_cache=False) -> None:
super().__init__()
self.config = config
self.is_train = is_train
self.split = split
self.getting_multi_mask = getting_multi_mask
self.use_cache = use_cache
self.img_names = self.prepare_cached_datalist() if use_cache else self.prepare_datalist()
# self.img_names = self.prepare_datalist()
self.prepare_transforms()
def prepare_cached_datalist(self):
raise NotImplementedError
def prepare_transforms(self):
self.transform = monai_transforms.Compose([
# transforms.Resized(keys=['img', 'label'], spatial_size=(3,512,512)),
# transforms.RandSpatialCropd(keys=["img", 'label'], roi_size=(3,448,448)),
# transforms.RandAffined(keys=['img', 'label'], prob=0.5, shear_range=(0.2,0.2)),
# transforms.RandCropByPosNegLabeld(keys=['img', 'label'], spatial_size=(3,448,448), label_key='label', neg=0),
# transforms.RandSmoothFieldAdjustContrastd(keys=['img', 'label'], )
monai_transforms.RandAdjustContrastd(keys=['img'], ),
# transforms.RandShiftIntensityd(keys=['img'], prob=0.8, offsets=(0, 20)),
monai_transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
])
self.test_transform = transforms.Compose([
monai_transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
])
def _get_image(self, index):
name = self.img_names[index]['img_path']
if not os.path.exists(name):
print("Path not exists!", name)
return self._get_image(index+1%len(self))
img_itk = read(self.img_names[index]['img_path'])
img_ori = itk_to_np(img_itk)
img = np.clip(img_ori, -200, 400).astype(np.float32)
img = (img - img.min()) / img.max() * 255
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
return {"img":img, "name":name.replace(self.config['data_root_path'], ""), "label":label_ori}
def __len__(self):
return len(self.img_names)
# @tfunctime
def prepare_datalist(self):
config = self.config
data_paths = []
for item in config['dataset_list']:
print("Load datalist from ", item)
for line in open(config["data_txt_path"]+ item + f"_{self.split}.txt"):
name = line.strip().split()[1].split('.')[0]
img_path = config['data_root_path'] + line.strip().split()[0]
label_path = config['data_root_path'] + line.strip().split()[1]
data_paths.append({'img_path': img_path, 'label_path': label_path, 'name': name})
print('train len {}'.format(len(data_paths)))
return data_paths
# @tfunctime
def _get_data(self, index, debug=False):
# LABEL_INDICES
name = self.img_names[index]['img_path']
img_itk = read(self.img_names[index]['img_path'])
img_ori = itk_to_np(img_itk)
# spacing = img_itk.GetSpacing()
scan_orientation = np.argmin(img_ori.shape)
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
if min(img_ori.shape) * 2 < max(img_ori.shape):
orientation = scan_orientation
else:
orientation = np.random.randint(3)
slice_idx = np.random.randint(2, img_ori.shape[orientation]-2)
if orientation == 0:
s = img_ori[slice_idx-1:slice_idx+2, :,:]
lb = label_ori[slice_idx-1:slice_idx+2, :,:]
# spacing = (spacing[1], spacing[2])
if orientation == 1:
s = img_ori[:,slice_idx-1:slice_idx+2,:]
s = rearrange(s, "h c w -> c h w")
lb = label_ori[:,slice_idx-1:slice_idx+2,:]
lb = rearrange(lb, "h c w -> c h w")
# spacing = (spacing[0], spacing[2])
if orientation == 2:
s = img_ori[:,:,slice_idx-1:slice_idx+2]
s = rearrange(s, "h w c -> c h w")
lb = label_ori[:,:,slice_idx-1:slice_idx+2]
lb = rearrange(lb, "h w c -> c h w")
# spacing = (spacing[0], spacing[1])
assert s.shape[0] == 3
if np.float32(lb[1,:,:]>0).sum() <= 200:
return self._get_data((index+1)%len(self))
# Choose one label
label_num = int(lb.max())
is_good_mask = []
for label_idx in range(1,label_num+1):
one_lb = np.float32(lb==label_idx)
is_good_mask.append(one_lb.sum()>=50)
label_idx = np.random.choice(range(1,label_num+1), p=np.array(is_good_mask)/np.sum(is_good_mask))
lb = np.float32(lb==label_idx)
return s, lb, name, label_idx
# @tfunctime
def _get_cached_data(self, index):
name = self.img_names[index]
img = cv2.imread(name)
compressed = np.load(name.replace("image/image_", "label/label_").replace(".jpg", ".npz"))
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
label_ori = csr.toarray()
label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024)
label_idx = np.random.randint(0, label_ori.shape[0])
label_ori = label_ori[label_idx]
return rearrange(img, "h w c -> c h w"), label_ori, name, -1
def to_RGB(self, img):
# transform images to RGB style
img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(int)
return img
# @tfunctime
def __getitem__(self, index, debug=False):
img_ori, label_ori, name, label_idx = self._get_data(index)
img_ori = np.clip(img_ori, -200,400)
img_rgb = self.to_RGB(img_ori)
assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}"
bundle_ori = {"img":torch.Tensor(img_rgb).unsqueeze(0), "label":torch.Tensor(label_ori).unsqueeze(0)}
# import ipdb; ipdb.set_trace()
if self.is_train:
# bundle = self.transform(bundle_ori)[0] # use with transforms.RandCropByPosNegLabeld
bundle = self.transform(bundle_ori)
else:
bundle = self.test_transform(bundle_ori)
if not self.use_cache:
bundle['label'] = (bundle['label']>0.5).float()
vector = np.ones(3)
if debug:
ret_dict = {
"name": name,
"img": bundle['img'],
"label": bundle['label'],
"img_ori":img_ori,
"label_ori":label_ori,
"label_idx": label_idx,
"indicators": vector,
# "label_name":
}
return ret_dict
ret_dict = {
"name": name,
"img": bundle['img'][0].float(),
"label": bundle['label'][0].float(),
"indicators": vector,
}
if bundle['label'][0][1,:,:].sum() <= 0:
return self.__getitem__(index+1 % len(self))
return ret_dict
if __name__ == "__main__":
from tutils.new.manager import ConfigManager
config = ConfigManager()
config.add_config("configs/vit_b_103.yaml")
dataset = Dataset3D(config=config['dataset']) # , use_cache=True
data = dataset.__getitem__(0)
import ipdb; ipdb.set_trace()
from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=8)
for batch in loader:
print(batch['img'].shape, batch['label'].shape)
# print(data['label'].max())
# import ipdb; ipdb.set_trace()