Slide-SAM/datasets/cache_dataset3d3.py
transcendentsky ce31f13be0 .
2024-03-20 22:07:35 +08:00

98 lines
3.4 KiB
Python

"""
re-index by masks! not images
"""
import numpy as np
import os
from einops import rearrange, reduce, repeat
from tutils.nn.data import read, itk_to_np, np_to_itk, write
from tutils import tfilename
from .cache_dataset3d import Dataset3D as basic_3d_dataset
from monai import transforms
import torch
import cv2
from scipy.sparse import csr_matrix
import torch.nn.functional as F
from torchvision import transforms
from einops import rearrange
import glob
from torchvision import transforms
from monai import transforms as monai_transforms
class Dataset3D(basic_3d_dataset):
def __init__(self, config=..., use_cache=True, *args, **kwargs) -> None:
super().__init__(config, use_cache=use_cache, *args, **kwargs)
self.basic_dir = config['data_root_path']
self.cache_dir = config['cache_data_path']
def prepare_cached_datalist(self):
config = self.config
data_paths = []
for dirpath in glob.glob(config['cache_data_path'] + "/*"):
if not os.path.isdir(dirpath):
continue
prefix = dirpath.split("/")[-1]
if prefix.split("_")[0] in config['cache_prefix']:
data_paths += glob.glob(dirpath + "/label_jpg/*.jpg")
print("Load ", dirpath)
print('Masks len {}'.format(len(data_paths)))
print('Examples: ', data_paths[:2])
return data_paths
def _get_cached_data(self, index):
mask_path = self.img_names[index]
# print(name)
mask = np.int32(cv2.imread(mask_path) > 0)
prefix = mask_path[:-9]
img_path = prefix.replace("/label_jpg/label_", "/image/image_") + ".jpg"
img = cv2.imread(img_path)
number = int(mask_path[-8:-4])
meta = np.load(prefix.replace("/label_jpg/label_", "/meta/meta_")+".npy", allow_pickle=True).tolist()
# label_idx = np.random.randint(0, label_ori.shape[0])
return rearrange(img, "h w c -> c h w"), rearrange(mask, "h w c -> c h w"), mask_path, meta['labels'][number], meta['label_idx'][number]
# @tfunctime
def __getitem__(self, index, debug=False):
index = index % len(self)
img_rgb, label_ori, name, label_idx, local_idx = self._get_cached_data(index)
# assert label_ori.sum() > 0
if label_ori.sum() <= 0:
print("[Label Error] ", name)
return self.__getitem__(index+1)
# assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}"
# img_rgb = self.transform((img_rgb[None,:,:,:]))
img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy()
vector = np.ones(3)
ret_dict = {
"name": name,
"img": img_rgb,
"label": label_ori,
"indicators": vector,
"class": label_idx,
"local_idx": local_idx,
"is_problem": label_ori.sum() <= 30,
}
return ret_dict
if __name__ == "__main__":
# def go_cache():
from tutils.new.manager import ConfigManager
from tqdm import tqdm
config = ConfigManager()
config.add_config("configs/vit_b.yaml")
config.print()
dataset = Dataset3D(config=config['dataset'], use_cache=True)
# dataset.caching_data()
# dataset.convert_masks_types()
for i in tqdm(range(len(dataset))):
data = dataset.__getitem__(i)
# assert