# from utils.predict_automasks import predict from segment_anything import sam_model_registry, SamPredictor from core.automask import SamAutomaticMaskGenerator from core.trainer import SamLearner # from datasets.dataset2d import Dataset2D from einops import rearrange, repeat import torch import numpy as np from skimage.segmentation import slic from skimage.util import img_as_float from skimage import io import os from tutils import timer, tdir from scipy.sparse import csr_matrix # from datasets.dataset2d import Dataset2D import cv2 import torch.nn.functional as F import time from tutils import tfilename import matplotlib.pylab as plt from .utils import load_compressed_data, show_anns, img_to_show import cv2 def tfunctime(func): def run(*argv, **kargs): t1 = time.time() ret = func(*argv, **kargs) t2 = time.time() print(f"[Function <{func.__name__}>] Running time:{(t2-t1):.6f}s") return ret return run def get_predictors(): sam_checkpoint = "/quanquan/code/segment-anything/segment_anything/sam_vit_h_4b8939.pth" # for A100 # sam_checkpoint = "/home1/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_h_4b8939.pth" # for server 103 device = "cuda" model_type = "default" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) mask_generator = SamAutomaticMaskGenerator(sam) predictor = SamLearner(sam_model=sam) return mask_generator, predictor def center_clip(img, eta=3): count_values = img[torch.where(img>-199)] mean = count_values.mean() std = count_values.std() img3 = torch.clip(img, mean-eta*std, mean+eta*std) return img3 def find_not_exist_masks(masks_repo, masks_new, iou_threshold=0.2): if len(masks_repo) == 0: return masks_new def compute_iou(m1, m2): m1 = m1['segmentation'] m2 = m2['segmentation'] intersection = m1*m2 union = np.float32((m1 + m2) > 0) return intersection.sum() / union.sum() to_append = [] for mask_new in masks_new: assert isinstance(mask_new, dict), f"{__file__} Got{type(mask_new)}" intersec_count = 0 for mask_in_repo in masks_repo: assert isinstance(mask_in_repo, dict), f"{__file__} Got{type(mask_in_repo)}" iou = compute_iou(mask_in_repo, mask_new) .3 if iou > iou_threshold: intersec_count += 1 if intersec_count == 0: to_append.append(mask_new) check_keys(to_append) return to_append def merge_masks(masks_repo, masks_new, iou_threshold=0.2): to_append = find_not_exist_masks(masks_repo, masks_new, iou_threshold) # print(f"DEBUG: {len(masks_new) - len(to_append)} masks are deleted, remaining {len(to_append)}. The total achieves {len(masks_repo) + len(to_append)}") return masks_repo + to_append def dilate_erode(mask): kernel = np.ones((4, 4), dtype=np.uint8) mask = cv2.morphologyEx(mask.astype(float), cv2.MORPH_CLOSE, kernel, iterations=1) return mask def get_superpixel(image, hu_mask, hu_threshold=-50): if isinstance(image, torch.Tensor): image = image.detach().cpu().numpy() if isinstance(hu_mask, torch.Tensor): hu_mask = hu_mask.detach().cpu().numpy() segments = slic(image, n_segments=100, compactness=9) mask_collect = [] image2 = image[:,:,0] for i in range(1, segments.max()+1): mask = torch.Tensor(segments==i).detach().cpu().numpy() # assert img if image2[segments==i].mean() <= hu_threshold: continue mask = mask * hu_mask mask = dilate_erode(mask) mask_data = { 'segmentation':mask, 'area':mask.sum(), 'source': "superpixel", 'mean_value':(mask * image[:,:,0]).mean(), } mask_collect.append(mask_data) check_keys(mask_collect) return mask_collect # def resize_masks(masks): # for m in masks: # m['segmentation'] = F.interpolate(torch.Tensor(m['segmentation'])[None,None,:,:], size=(512,512)).squeeze().numpy() # return masks # @tfunctime def get_masks_via_color_changing(img, label, mask_generator, predictor=None): img = repeat(img, " h w -> 1 3 h w") img = torch.Tensor(img) img = F.interpolate(img, size=(1024,1024)) label = torch.Tensor(label) if label is not None: label_masks = [] for i in range(1,label.max().int()+1): labeli = torch.Tensor(label==i).float() labeli = F.interpolate(labeli[None, None, :,:], size=(1024,1024)).squeeze().numpy() area = labeli.sum() if area <=10: continue mask = { "segmentation":labeli, "area":area, "source": "gt", } label_masks.append(mask) else: label_masks = [] mask_generator.reset_image() predictor.reset_image() masks = mask_generator.generate(img.cuda()) masks = filter_large_masks(masks) for mask in masks: mask['source'] = "sam_auto_seg" label_masks = merge_masks(label_masks, masks) del masks check_keys(label_masks) # import ipdb; ipdb.set_trace() # return img, label_masks img2 = center_clip(img, 2.5) masks = mask_generator.generate(img.cuda()) masks = filter_large_masks(masks) label_masks = merge_masks(label_masks, masks) del masks del img2 # img2 = center_clip(img, 2) # masks = mask_generator.generate(img2.cuda()) # masks = filter_large_masks(masks) # for mask in masks: # mask['source'] = "sam_auto_seg" # label_masks = merge_masks(label_masks, masks) # del masks # del img2 img2 = center_clip(img, 1) masks = mask_generator.generate(img2.cuda()) masks = filter_large_masks(masks) for mask in masks: mask['source'] = "sam_auto_seg" label_masks = merge_masks(label_masks, masks) del masks del img2 img2 = center_clip(img, 0.5) masks = mask_generator.generate(img2.cuda()) masks = filter_large_masks(masks) for mask in masks: mask['source'] = "sam_auto_seg" label_masks = merge_masks(label_masks, masks) del masks del img2 check_keys(label_masks) # import ipdb; ipdb.set_trace() return img, label_masks def check_keys(masks): for m in masks: assert m['segmentation'] is not None assert m['area'] is not None assert m['source'] is not None def filter_large_masks(masks): filtered_masks = [] for mask in masks: if mask['area'] > 0.25 *1024 * 1024: continue filtered_masks.append(mask) del masks return filtered_masks def mix_masks(masks): if len(masks) == 0: return None mixed_mask = None for item in masks: m = item['segmentation'] mixed_mask = np.zeros_like(m) if mixed_mask is None else mixed_mask mixed_mask += m mixed_mask = np.float32(mixed_mask>0) return mixed_mask def select_random_point_from_mask(gt_mask): size = gt_mask.shape assert len(size) == 2 xy = np.arange(0, size[0] * size[1]) gt_mask = np.float32(gt_mask>0) prob = rearrange(gt_mask, "h w -> (h w)") prob = prob / prob.sum() loc = np.random.choice(a=xy, size=1, replace=True, p=prob)[0] x, y = loc % size[1], loc // size[1] return x, y def select_center_point_from_mask(gt_mask): # get indices of all the foreground pixels indices = np.argwhere(gt_mask > 0) # calculate the center point by taking the mean of the foreground pixel indices center = np.mean(indices, axis=0).astype(int) y, x = center return x, y @tfunctime def get_masks_via_points_from_superpixels(img, label_masks, superpixels, predictor, hu_mask=None): if isinstance(hu_mask, torch.Tensor): hu_mask = hu_mask.detach().cpu().numpy() total_mask = mix_masks(label_masks) # superpixels = get_superpixel(rearrange(img, "1 c h w -> h w c")) points = [] ex_masks = [] for seg in superpixels: mask_collect = [] for i in range(5): mm = np.nan_to_num(seg['segmentation'], 0) if mm.sum() <= 0: continue x,y = select_center_point_from_mask(mm) # select_random_point_from_mask # print(x,y) if total_mask[y,x] == 1: continue # else: # print(total_mask[y,x]) point = torch.Tensor([x,y]) points.append(point) mask = predictor.generate(img.cuda(), point.cuda().unsqueeze(0).unsqueeze(0))[0,0].detach().cpu().numpy() mask = mask * hu_mask mask = dilate_erode(mask) mask = { "segmentation": mask, "area": mask.sum(), "source": "prompt_point_from_superpixel", } mask_collect = merge_masks(mask_collect, [mask]) ex_masks += mask_collect check_keys(ex_masks) return ex_masks def mask_to_bbox(mask): """ copied from data_engine """ # Find the indices of all non-zero elements in the mask coords = np.nonzero(mask) # Compute the minimum and maximum values of the row and column indices x_min = np.min(coords[1]) y_min = np.min(coords[0]) x_max = np.max(coords[1]) y_max = np.max(coords[0]) # Return the coordinates of the bounding box return (x_min, y_min, x_max, y_max) @tfunctime def get_masks_via_boxes_from_superpixels(img, label_masks, superpixels, predictor, hu_mask=None): if isinstance(hu_mask, torch.Tensor): hu_mask = hu_mask.detach().cpu().numpy() ex_masks = [] for seg in superpixels: if seg['segmentation'].sum() < 100: continue x_min, y_min, x_max, y_max = mask_to_bbox(seg['segmentation']) box = torch.Tensor([x_min, y_min, x_max, y_max]) mask = predictor.generate_by_box(img.cuda(), box.cuda().unsqueeze(0).unsqueeze(0))[0,0].detach().cpu().numpy() mask = mask * hu_mask mask = dilate_erode(mask) mask = { "segmentation": mask, "area": mask.sum(), "source": "prompt_box_from_superpixel", } ex_masks = merge_masks(ex_masks, [mask]) check_keys(ex_masks) return ex_masks class VariousPseudoMasksGenerator: def __init__(self, dataset, label_path:str=None) -> None: self.dataset = dataset # assert dirpath.split("/")[-1] == "cache_2d", f"{__file__} Got{dirpath.split('/')[-1]} ; dirpath:{dirpath}" self.label_path = label_path if label_path is not None else dataset.dirpath.replace("cache_2d", "cache_2d_various_pseudo_masks") def example(self, mask_generator=None, predictor=None): return self.generate(mask_generator=mask_generator, predictor=predictor, is_example=True) def generate(self, mask_generator=None, predictor=None, is_example=False): tt = timer() if mask_generator is None: mask_generator, predictor = get_predictors() self.mask_generator, self.predictor = mask_generator, predictor for img_idx in range(len(self.dataset)): tt() data = self.dataset._get_image(img_idx) masks_all_layers = [] words = data['name'].split("/") dataset_name = words[0] filename = words[-1].replace(".nii.gz","") volume_rgb = np.clip(data['img'], -200, 400) volume_rgb = (volume_rgb - volume_rgb.min()) / volume_rgb.max() for slice_idx in range(data['img'].shape[0]): path = tfilename(self.label_path, f'{dataset_name}/{filename}_s{slice_idx}_mask.npz') # if os.path.exists(path): # continue masks = self.get_various_masks(data['img'][slice_idx], data['label'][slice_idx], mask_generator, predictor) self.save_slice_mask(masks, path, slice_idx) img_path = tfilename(self.label_path, f'image_{dataset_name}/{filename}_s{slice_idx}.jpg') self.save_img_rgb(volume_rgb[slice_idx], img_path) # display # plt.figure(figsize=(6,6)) # plt.imshow(img_to_show(data['img'][slice_idx]), cmap='gray') # show_anns(masks) # plt.axis('off') # plt.show() print(f"Save to {img_path} and {path}") print(f"Processing {img_idx}, {len(masks_all_layers)} saved, time used:{tt()}", end='\r') if is_example: break def save_slice_mask(self, masks, path, slice_idx): masks_data = np.array([m['segmentation'] for m in masks]).astype(int) # if len(masks_data) <= 1: # import ipdb; ipdb.set_trace() masks_data = F.interpolate(torch.Tensor(masks_data).unsqueeze(1), size=(512,512)).squeeze().numpy() masks_data = np.int8(masks_data>0) if len(masks_data.shape) == 2: masks_data = masks_data[None,:,:] assert masks_data.shape[1:] == (512,512), f"{__file__} Got{masks_data.shape}" masks_data = rearrange(masks_data, "n h w -> n (h w)") csr = csr_matrix(masks_data) np.savez_compressed(path, data=csr.data, indices=csr.indices, indptr=csr.indptr, shape=csr.shape) def save_img_rgb(self, img, path): img = (img * 255).astype(np.uint8) cv2.imwrite(path, img) @staticmethod @tfunctime def get_various_masks(img_ori, label, mask_generator, predictor): mask_generator.reset_image() predictor.reset_image() img, label_masks = get_masks_via_color_changing(img_ori, label, mask_generator, predictor) # # return label_masks # hu_mask = img_ori>img_ori.min() + 10 # hu_mask = F.interpolate(torch.Tensor(hu_mask)[None,None,:,:], size=(1024,1024)).squeeze() # superpixels = get_superpixel(rearrange(img, "1 c h w -> h w c"), hu_mask=hu_mask) # exclu_superpixels = find_not_exist_masks(label_masks, superpixels) # # point_prompt_masks = get_masks_via_boxes_from_superpixels(img, label_masks, exclu_superpixels, predictor, hu_mask=hu_mask) # box_prompt_masks = get_masks_via_points_from_superpixels(img, label_masks, exclu_superpixels, predictor, hu_mask=hu_mask) return label_masks # + box_prompt_masks # + point_prompt_masks # if __name__ == "__main__": # CUDA_VISIBLE_DEVICES=6 python -m datasets.predict_various_masks from datasets.dataset3d import Dataset3D dataset = Dataset3D() gen = VariousPseudoMasksGenerator(dataset=dataset, label_path=tdir("/quanquan/datasets/all_datasets/various_masks_3/")) gen.generate()