Slide-SAM/datasets/predict_various_masks.py
transcendentsky e04459c6fe first commit
2023-12-05 14:58:38 +08:00

408 lines
15 KiB
Python

# from utils.predict_automasks import predict
from segment_anything import sam_model_registry, SamPredictor
from core.automask import SamAutomaticMaskGenerator
from core.trainer import SamLearner
# from datasets.dataset2d import Dataset2D
from einops import rearrange, repeat
import torch
import numpy as np
from skimage.segmentation import slic
from skimage.util import img_as_float
from skimage import io
import os
from tutils import timer, tdir
from scipy.sparse import csr_matrix
# from datasets.dataset2d import Dataset2D
import cv2
import torch.nn.functional as F
import time
from tutils import tfilename
import matplotlib.pylab as plt
from .utils import load_compressed_data, show_anns, img_to_show
import cv2
def tfunctime(func):
def run(*argv, **kargs):
t1 = time.time()
ret = func(*argv, **kargs)
t2 = time.time()
print(f"[Function <{func.__name__}>] Running time:{(t2-t1):.6f}s")
return ret
return run
def get_predictors():
sam_checkpoint = "/quanquan/code/segment-anything/segment_anything/sam_vit_h_4b8939.pth" # for A100
# sam_checkpoint = "/home1/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_h_4b8939.pth" # for server 103
device = "cuda"
model_type = "default"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
predictor = SamLearner(sam_model=sam)
return mask_generator, predictor
def center_clip(img, eta=3):
count_values = img[torch.where(img>-199)]
mean = count_values.mean()
std = count_values.std()
img3 = torch.clip(img, mean-eta*std, mean+eta*std)
return img3
def find_not_exist_masks(masks_repo, masks_new, iou_threshold=0.2):
if len(masks_repo) == 0:
return masks_new
def compute_iou(m1, m2):
m1 = m1['segmentation']
m2 = m2['segmentation']
intersection = m1*m2
union = np.float32((m1 + m2) > 0)
return intersection.sum() / union.sum()
to_append = []
for mask_new in masks_new:
assert isinstance(mask_new, dict), f"{__file__} Got{type(mask_new)}"
intersec_count = 0
for mask_in_repo in masks_repo:
assert isinstance(mask_in_repo, dict), f"{__file__} Got{type(mask_in_repo)}"
iou = compute_iou(mask_in_repo, mask_new)
.3
if iou > iou_threshold:
intersec_count += 1
if intersec_count == 0:
to_append.append(mask_new)
check_keys(to_append)
return to_append
def merge_masks(masks_repo, masks_new, iou_threshold=0.2):
to_append = find_not_exist_masks(masks_repo, masks_new, iou_threshold)
# print(f"DEBUG: {len(masks_new) - len(to_append)} masks are deleted, remaining {len(to_append)}. The total achieves {len(masks_repo) + len(to_append)}")
return masks_repo + to_append
def dilate_erode(mask):
kernel = np.ones((4, 4), dtype=np.uint8)
mask = cv2.morphologyEx(mask.astype(float), cv2.MORPH_CLOSE, kernel, iterations=1)
return mask
def get_superpixel(image, hu_mask, hu_threshold=-50):
if isinstance(image, torch.Tensor):
image = image.detach().cpu().numpy()
if isinstance(hu_mask, torch.Tensor):
hu_mask = hu_mask.detach().cpu().numpy()
segments = slic(image, n_segments=100, compactness=9)
mask_collect = []
image2 = image[:,:,0]
for i in range(1, segments.max()+1):
mask = torch.Tensor(segments==i).detach().cpu().numpy()
# assert img
if image2[segments==i].mean() <= hu_threshold:
continue
mask = mask * hu_mask
mask = dilate_erode(mask)
mask_data = {
'segmentation':mask,
'area':mask.sum(),
'source': "superpixel",
'mean_value':(mask * image[:,:,0]).mean(),
}
mask_collect.append(mask_data)
check_keys(mask_collect)
return mask_collect
# def resize_masks(masks):
# for m in masks:
# m['segmentation'] = F.interpolate(torch.Tensor(m['segmentation'])[None,None,:,:], size=(512,512)).squeeze().numpy()
# return masks
# @tfunctime
def get_masks_via_color_changing(img, label, mask_generator, predictor=None):
img = repeat(img, " h w -> 1 3 h w")
img = torch.Tensor(img)
img = F.interpolate(img, size=(1024,1024))
label = torch.Tensor(label)
if label is not None:
label_masks = []
for i in range(1,label.max().int()+1):
labeli = torch.Tensor(label==i).float()
labeli = F.interpolate(labeli[None, None, :,:], size=(1024,1024)).squeeze().numpy()
area = labeli.sum()
if area <=10:
continue
mask = {
"segmentation":labeli,
"area":area,
"source": "gt",
}
label_masks.append(mask)
else:
label_masks = []
mask_generator.reset_image()
predictor.reset_image()
masks = mask_generator.generate(img.cuda())
masks = filter_large_masks(masks)
for mask in masks:
mask['source'] = "sam_auto_seg"
label_masks = merge_masks(label_masks, masks)
del masks
check_keys(label_masks)
# import ipdb; ipdb.set_trace()
# return img, label_masks
img2 = center_clip(img, 2.5)
masks = mask_generator.generate(img.cuda())
masks = filter_large_masks(masks)
label_masks = merge_masks(label_masks, masks)
del masks
del img2
# img2 = center_clip(img, 2)
# masks = mask_generator.generate(img2.cuda())
# masks = filter_large_masks(masks)
# for mask in masks:
# mask['source'] = "sam_auto_seg"
# label_masks = merge_masks(label_masks, masks)
# del masks
# del img2
img2 = center_clip(img, 1)
masks = mask_generator.generate(img2.cuda())
masks = filter_large_masks(masks)
for mask in masks:
mask['source'] = "sam_auto_seg"
label_masks = merge_masks(label_masks, masks)
del masks
del img2
img2 = center_clip(img, 0.5)
masks = mask_generator.generate(img2.cuda())
masks = filter_large_masks(masks)
for mask in masks:
mask['source'] = "sam_auto_seg"
label_masks = merge_masks(label_masks, masks)
del masks
del img2
check_keys(label_masks)
# import ipdb; ipdb.set_trace()
return img, label_masks
def check_keys(masks):
for m in masks:
assert m['segmentation'] is not None
assert m['area'] is not None
assert m['source'] is not None
def filter_large_masks(masks):
filtered_masks = []
for mask in masks:
if mask['area'] > 0.25 *1024 * 1024:
continue
filtered_masks.append(mask)
del masks
return filtered_masks
def mix_masks(masks):
if len(masks) == 0:
return None
mixed_mask = None
for item in masks:
m = item['segmentation']
mixed_mask = np.zeros_like(m) if mixed_mask is None else mixed_mask
mixed_mask += m
mixed_mask = np.float32(mixed_mask>0)
return mixed_mask
def select_random_point_from_mask(gt_mask):
size = gt_mask.shape
assert len(size) == 2
xy = np.arange(0, size[0] * size[1])
gt_mask = np.float32(gt_mask>0)
prob = rearrange(gt_mask, "h w -> (h w)")
prob = prob / prob.sum()
loc = np.random.choice(a=xy, size=1, replace=True, p=prob)[0]
x, y = loc % size[1], loc // size[1]
return x, y
def select_center_point_from_mask(gt_mask):
# get indices of all the foreground pixels
indices = np.argwhere(gt_mask > 0)
# calculate the center point by taking the mean of the foreground pixel indices
center = np.mean(indices, axis=0).astype(int)
y, x = center
return x, y
@tfunctime
def get_masks_via_points_from_superpixels(img, label_masks, superpixels, predictor, hu_mask=None):
if isinstance(hu_mask, torch.Tensor):
hu_mask = hu_mask.detach().cpu().numpy()
total_mask = mix_masks(label_masks)
# superpixels = get_superpixel(rearrange(img, "1 c h w -> h w c"))
points = []
ex_masks = []
for seg in superpixels:
mask_collect = []
for i in range(5):
mm = np.nan_to_num(seg['segmentation'], 0)
if mm.sum() <= 0:
continue
x,y = select_center_point_from_mask(mm) # select_random_point_from_mask
# print(x,y)
if total_mask[y,x] == 1:
continue
# else:
# print(total_mask[y,x])
point = torch.Tensor([x,y])
points.append(point)
mask = predictor.generate(img.cuda(), point.cuda().unsqueeze(0).unsqueeze(0))[0,0].detach().cpu().numpy()
mask = mask * hu_mask
mask = dilate_erode(mask)
mask = {
"segmentation": mask,
"area": mask.sum(),
"source": "prompt_point_from_superpixel",
}
mask_collect = merge_masks(mask_collect, [mask])
ex_masks += mask_collect
check_keys(ex_masks)
return ex_masks
def mask_to_bbox(mask):
""" copied from data_engine """
# Find the indices of all non-zero elements in the mask
coords = np.nonzero(mask)
# Compute the minimum and maximum values of the row and column indices
x_min = np.min(coords[1])
y_min = np.min(coords[0])
x_max = np.max(coords[1])
y_max = np.max(coords[0])
# Return the coordinates of the bounding box
return (x_min, y_min, x_max, y_max)
@tfunctime
def get_masks_via_boxes_from_superpixels(img, label_masks, superpixels, predictor, hu_mask=None):
if isinstance(hu_mask, torch.Tensor):
hu_mask = hu_mask.detach().cpu().numpy()
ex_masks = []
for seg in superpixels:
if seg['segmentation'].sum() < 100:
continue
x_min, y_min, x_max, y_max = mask_to_bbox(seg['segmentation'])
box = torch.Tensor([x_min, y_min, x_max, y_max])
mask = predictor.generate_by_box(img.cuda(), box.cuda().unsqueeze(0).unsqueeze(0))[0,0].detach().cpu().numpy()
mask = mask * hu_mask
mask = dilate_erode(mask)
mask = {
"segmentation": mask,
"area": mask.sum(),
"source": "prompt_box_from_superpixel",
}
ex_masks = merge_masks(ex_masks, [mask])
check_keys(ex_masks)
return ex_masks
class VariousPseudoMasksGenerator:
def __init__(self, dataset, label_path:str=None) -> None:
self.dataset = dataset
# assert dirpath.split("/")[-1] == "cache_2d", f"{__file__} Got{dirpath.split('/')[-1]} ; dirpath:{dirpath}"
self.label_path = label_path if label_path is not None else dataset.dirpath.replace("cache_2d", "cache_2d_various_pseudo_masks")
def example(self, mask_generator=None, predictor=None):
return self.generate(mask_generator=mask_generator, predictor=predictor, is_example=True)
def generate(self, mask_generator=None, predictor=None, is_example=False):
tt = timer()
if mask_generator is None:
mask_generator, predictor = get_predictors()
self.mask_generator, self.predictor = mask_generator, predictor
for img_idx in range(len(self.dataset)):
tt()
data = self.dataset._get_image(img_idx)
masks_all_layers = []
words = data['name'].split("/")
dataset_name = words[0]
filename = words[-1].replace(".nii.gz","")
volume_rgb = np.clip(data['img'], -200, 400)
volume_rgb = (volume_rgb - volume_rgb.min()) / volume_rgb.max()
for slice_idx in range(data['img'].shape[0]):
path = tfilename(self.label_path, f'{dataset_name}/{filename}_s{slice_idx}_mask.npz')
# if os.path.exists(path):
# continue
masks = self.get_various_masks(data['img'][slice_idx], data['label'][slice_idx], mask_generator, predictor)
self.save_slice_mask(masks, path, slice_idx)
img_path = tfilename(self.label_path, f'image_{dataset_name}/{filename}_s{slice_idx}.jpg')
self.save_img_rgb(volume_rgb[slice_idx], img_path)
# display
# plt.figure(figsize=(6,6))
# plt.imshow(img_to_show(data['img'][slice_idx]), cmap='gray')
# show_anns(masks)
# plt.axis('off')
# plt.show()
print(f"Save to {img_path} and {path}")
print(f"Processing {img_idx}, {len(masks_all_layers)} saved, time used:{tt()}", end='\r')
if is_example:
break
def save_slice_mask(self, masks, path, slice_idx):
masks_data = np.array([m['segmentation'] for m in masks]).astype(int)
# if len(masks_data) <= 1:
# import ipdb; ipdb.set_trace()
masks_data = F.interpolate(torch.Tensor(masks_data).unsqueeze(1), size=(512,512)).squeeze().numpy()
masks_data = np.int8(masks_data>0)
if len(masks_data.shape) == 2:
masks_data = masks_data[None,:,:]
assert masks_data.shape[1:] == (512,512), f"{__file__} Got{masks_data.shape}"
masks_data = rearrange(masks_data, "n h w -> n (h w)")
csr = csr_matrix(masks_data)
np.savez_compressed(path, data=csr.data, indices=csr.indices, indptr=csr.indptr, shape=csr.shape)
def save_img_rgb(self, img, path):
img = (img * 255).astype(np.uint8)
cv2.imwrite(path, img)
@staticmethod
@tfunctime
def get_various_masks(img_ori, label, mask_generator, predictor):
mask_generator.reset_image()
predictor.reset_image()
img, label_masks = get_masks_via_color_changing(img_ori, label, mask_generator, predictor)
# # return label_masks
# hu_mask = img_ori>img_ori.min() + 10
# hu_mask = F.interpolate(torch.Tensor(hu_mask)[None,None,:,:], size=(1024,1024)).squeeze()
# superpixels = get_superpixel(rearrange(img, "1 c h w -> h w c"), hu_mask=hu_mask)
# exclu_superpixels = find_not_exist_masks(label_masks, superpixels)
# # point_prompt_masks = get_masks_via_boxes_from_superpixels(img, label_masks, exclu_superpixels, predictor, hu_mask=hu_mask)
# box_prompt_masks = get_masks_via_points_from_superpixels(img, label_masks, exclu_superpixels, predictor, hu_mask=hu_mask)
return label_masks # + box_prompt_masks # + point_prompt_masks #
if __name__ == "__main__":
# CUDA_VISIBLE_DEVICES=6 python -m datasets.predict_various_masks
from datasets.dataset3d import Dataset3D
dataset = Dataset3D()
gen = VariousPseudoMasksGenerator(dataset=dataset,
label_path=tdir("/quanquan/datasets/all_datasets/various_masks_3/"))
gen.generate()