Slide-SAM/datasets/data_engine.py
transcendentsky e04459c6fe first commit
2023-12-05 14:58:38 +08:00

274 lines
10 KiB
Python

import numpy as np
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from einops import repeat, rearrange
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
class PointPromptGenerator(object):
def __init__(self, size=None) -> None:
pass
def get_prompt_point(self, gt_mask):
# assert gt_mask.shape == (1024,1024) or gt_mask.shape == (512,512), f"[data_engine] {__file__} Got{gt_mask.shape}"
if not (gt_mask.shape == (1024,1024) or gt_mask.shape == (512,512) or gt_mask.shape == (256,256) ):
print(f"[Warning] [data_engine] {__file__} Got{gt_mask.shape}")
assert gt_mask.sum() > 0
self.size = gt_mask.shape
self.xy = np.arange(0, self.size[0] * self.size[1])
gt_mask = np.float32(gt_mask>0)
prob = rearrange(gt_mask, "h w -> (h w)")
prob = prob / prob.sum()
loc = np.random.choice(a=self.xy, size=1, replace=True, p=prob)[0]
x, y = loc % self.size[1], loc // self.size[1]
return x, y
@staticmethod
def select_random_subsequent_point(pred_mask, gt_mask):
# union = np.float32((pred_mask + gt_mask)>0)
# diff = union - intersect
assert len(pred_mask.shape) == 2
assert len(gt_mask.shape) == 2
assert gt_mask.sum() > 0, f"[data_engine] Got {gt_mask.sum()}==0 "
diff = np.float32(np.abs(pred_mask - gt_mask)>0)
diff = np.nan_to_num(diff, nan=0)
# print(diff.shape)
xy = np.arange(0, diff.shape[0] * diff.shape[1])
if diff.sum() == 0:
prob = rearrange(gt_mask, "h w -> (h w)")
prob = prob / prob.sum()
loc = np.random.choice(a=xy, size=1, replace=True, p=prob)[0]
x, y = loc % diff.shape[1], loc // diff.shape[1]
return (x,y), 1
# Get_prompt_point
prob = rearrange(diff, "h w -> (h w)")
prob = prob / prob.sum()
loc = np.random.choice(a=xy, size=1, replace=True, p=prob)[0]
x, y = loc % diff.shape[1], loc // diff.shape[1]
if gt_mask[y, x] == 1 and pred_mask[y, x] == 0:
classification = 1
else:
classification = 0
# raise ValueError
return (x, y), classification
class BoxPromptGenerator(object):
def __init__(self, size) -> None:
self.size = size
@staticmethod
def mask_to_bbox(mask):
# Find the indices of all non-zero elements in the mask
coords = np.nonzero(mask)
# Compute the minimum and maximum values of the row and column indices
x_min = np.min(coords[1])
y_min = np.min(coords[0])
x_max = np.max(coords[1])
y_max = np.max(coords[0])
# Return the coordinates of the bounding box
return (x_min, y_min, x_max, y_max)
# return (y_min, x_min, y_max, x_max)
def add_random_noise_to_bbox(self, bbox):
bbox = list(bbox)
# Calculate the side lengths of the box in the x and y directions
x_side_length = bbox[2] - bbox[0]
y_side_length = bbox[3] - bbox[1]
# Calculate the standard deviation of the noise
std_dev = 0.01 * (x_side_length + y_side_length) / 2
# Generate random noise for each coordinate
x_noise = np.random.normal(scale=std_dev)
y_noise = np.random.normal(scale=std_dev)
# Add the random noise to each coordinate, but make sure it is not larger than 20 pixels
bbox[0] += min(int(round(x_noise)), 20)
bbox[1] += min(int(round(y_noise)), 20)
bbox[2] += min(int(round(x_noise)), 20)
bbox[3] += min(int(round(y_noise)), 20)
# Make sure the modified coordinates do not exceed the maximum possible values
bbox[0] = max(bbox[0], 0)
bbox[1] = max(bbox[1], 0)
bbox[2] = min(bbox[2], self.size[0])
bbox[3] = min(bbox[3], self.size[1])
# Return the modified bounding box
return bbox
def get_prompt_box(self, gt_mask):
""" return (x_min, y_min, x_max, y_max) """
assert gt_mask.shape == (1024,1024) or gt_mask.shape == (512,512) or gt_mask.shape == (256,256), f"[data_engine] {__file__} Got{gt_mask.shape}"
box = self.mask_to_bbox(gt_mask)
box_w_noise = self.add_random_noise_to_bbox(box)
return box_w_noise
def enlarge(self, bbox, margin=0):
x0, y0, x1, y1 = bbox[0], bbox[1], bbox[2], bbox[3]
margin_x = int((x1 - x0)*0.05)
margin_y = int((y1 - y0)*0.05)
x0 = max(x0 - margin_x, 0)
y0 = max(y0 - margin_x, 0)
x1 = min(x1 - margin_y, self.size[0]-1)
y1 = min(y1 - margin_y, self.size[1]-1)
# print("[DEBUG] , enlarge size: ", margin_x, margin_y)
# print("[DEBUG] from", bbox, "to", (x0,y0,x1,y1))
return (x0,y0,x1,y1)
class DataEngine(Dataset):
def __init__(self, dataset=None, img_size=None) -> None:
# CACHE_DISK_DIR="/home1/quanquan/code/projects/medical-guangdong/cache/data2d_3/"
super().__init__()
self.point_prompt_generator = PointPromptGenerator(img_size)
self.box_prompt_generator = BoxPromptGenerator(img_size)
# self._get_dataset(dirpath=dirpath)
self.dataset = dataset
# def _get_dataset(self, dirpath):
# self.dataset = Dataset2D(dirpath=dirpath, is_train=True)
def __len__(self):
return len(self.dataset)
def _get_true_index(self, idx):
return idx
def __getitem__(self, idx):
return self.get_prompt(idx)
def get_prompt_point(self, gt_mask):
return self.point_prompt_generator.get_prompt_point(gt_mask)
def get_prompt_box(self, gt_mask):
return self.box_prompt_generator.get_prompt_box(gt_mask)
# def _get_data_from_dataset(self, idx):
def get_prompt(self, idx):
idx = self._get_true_index(idx)
data = self.dataset.__getitem__(idx)
img = data['img'] # (3,h,w) d=3
mask = data['label'] # (3,h,w) d=3
try:
gt_mask = mask[1,:,:]
except Exception as e:
import ipdb; ipdb.set_trace()
gt_mask = mask[1,:,:]
gt_mask = gt_mask.numpy() if isinstance(gt_mask, torch.Tensor) else gt_mask
# if np.random.rand() > 0.5:
prompt_point = self.get_prompt_point(gt_mask)
# else:
prompt_box = self.get_prompt_box(gt_mask)
data['prompt_point'] = np.array(prompt_point).astype(np.float32)
data['prompt_box'] = np.array(prompt_box).astype(np.float32)
data['point_label'] = np.ones((1,)).astype(np.float32)
return data
def get_subsequent_prompt_point(self, pred_mask, gt_mask):
# return self.point_prompt_generator.select_random_subsequent_point_torch(pred_mask, gt_mask)
# return self.point_prompt_generator.select_random_subsequent_point(pred_mask=pred_mask, gt_mask=gt_mask)
coord_collect = []
label_collect = []
for i in range(pred_mask.shape[0]):
coords, label = self.point_prompt_generator.select_random_subsequent_point(pred_mask[i][0], gt_mask[i][0])
if label == -1:
return None, None
coord_collect.append(coords)
label_collect.append(label)
coord_collect = np.stack(coord_collect, axis=0)
label_collect = np.stack(label_collect, axis=0)
return coord_collect, label_collect
def get_noisy_box_from_box(self, box):
# Get noisy box from labeled box
return self.box_prompt_generator.add_random_noise_to_bbox(box)
# def get_prompt_mask(self, )
# class ValidEngine(DataEngine):
# def __init__(self, dataset=None, img_size=(1024,1024), is_train=False) -> None:
# # assert dataset is not None
# self.dataset = dataset
# self.is_train = is_train
# super().__init__(dataset=dataset, img_size=img_size)
# self.expand_dataset_ratio = 1
# # def _get_dataset(self, dirpath):
# # self.dataset = Dataset3D(dirpath=dirpath, is_train=self.is_train)
# def __len__(self):
# return len(self.dataset)
# def _get_true_index(self, idx):
# return idx
class DataManager:
def __init__(self, img_size=None) -> None:
self.point_prompt_generator = PointPromptGenerator(img_size)
self.box_prompt_generator = BoxPromptGenerator(img_size)
def get_prompt_point(self, gt_mask):
return self.point_prompt_generator.get_prompt_point(gt_mask)
def get_prompt_box(self, gt_mask):
return self.box_prompt_generator.get_prompt_box(gt_mask)
def get_subsequent_prompt_point(self, pred_mask, gt_mask):
# return self.point_prompt_generator.select_random_subsequent_point_torch(pred_mask, gt_mask)
# return self.point_prompt_generator.select_random_subsequent_point(pred_mask=pred_mask, gt_mask=gt_mask)
coord_collect = []
label_collect = []
for i in range(pred_mask.shape[0]):
coords, label = self.point_prompt_generator.select_random_subsequent_point(pred_mask[i][0], gt_mask[i][0])
if label == -1:
return None, None
coord_collect.append(coords)
label_collect.append(label)
coord_collect = np.stack(coord_collect, axis=0)
label_collect = np.stack(label_collect, axis=0)
return coord_collect, label_collect
def get_noisy_box_from_box(self, box):
# Get noisy box from labeled box
return self.box_prompt_generator.add_random_noise_to_bbox(box)
if __name__ == "__main__":
dataset = DataEngine()
data = dataset.__getitem__(0)
import ipdb; ipdb.set_trace()