274 lines
10 KiB
Python
274 lines
10 KiB
Python
import numpy as np
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
import matplotlib.pyplot as plt
|
|
from einops import repeat, rearrange
|
|
|
|
def show_mask(mask, ax, random_color=False):
|
|
if random_color:
|
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
|
else:
|
|
color = np.array([30/255, 144/255, 255/255, 0.6])
|
|
h, w = mask.shape[-2:]
|
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
|
ax.imshow(mask_image)
|
|
|
|
def show_points(coords, labels, ax, marker_size=375):
|
|
pos_points = coords[labels==1]
|
|
neg_points = coords[labels==0]
|
|
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
|
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
|
|
|
def show_box(box, ax):
|
|
x0, y0 = box[0], box[1]
|
|
w, h = box[2] - box[0], box[3] - box[1]
|
|
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
|
|
|
class PointPromptGenerator(object):
|
|
def __init__(self, size=None) -> None:
|
|
pass
|
|
|
|
def get_prompt_point(self, gt_mask):
|
|
# assert gt_mask.shape == (1024,1024) or gt_mask.shape == (512,512), f"[data_engine] {__file__} Got{gt_mask.shape}"
|
|
if not (gt_mask.shape == (1024,1024) or gt_mask.shape == (512,512) or gt_mask.shape == (256,256) ):
|
|
print(f"[Warning] [data_engine] {__file__} Got{gt_mask.shape}")
|
|
assert gt_mask.sum() > 0
|
|
self.size = gt_mask.shape
|
|
self.xy = np.arange(0, self.size[0] * self.size[1])
|
|
|
|
gt_mask = np.float32(gt_mask>0)
|
|
prob = rearrange(gt_mask, "h w -> (h w)")
|
|
prob = prob / prob.sum()
|
|
loc = np.random.choice(a=self.xy, size=1, replace=True, p=prob)[0]
|
|
x, y = loc % self.size[1], loc // self.size[1]
|
|
return x, y
|
|
|
|
@staticmethod
|
|
def select_random_subsequent_point(pred_mask, gt_mask):
|
|
# union = np.float32((pred_mask + gt_mask)>0)
|
|
# diff = union - intersect
|
|
assert len(pred_mask.shape) == 2
|
|
assert len(gt_mask.shape) == 2
|
|
assert gt_mask.sum() > 0, f"[data_engine] Got {gt_mask.sum()}==0 "
|
|
diff = np.float32(np.abs(pred_mask - gt_mask)>0)
|
|
diff = np.nan_to_num(diff, nan=0)
|
|
# print(diff.shape)
|
|
xy = np.arange(0, diff.shape[0] * diff.shape[1])
|
|
|
|
if diff.sum() == 0:
|
|
prob = rearrange(gt_mask, "h w -> (h w)")
|
|
prob = prob / prob.sum()
|
|
loc = np.random.choice(a=xy, size=1, replace=True, p=prob)[0]
|
|
x, y = loc % diff.shape[1], loc // diff.shape[1]
|
|
return (x,y), 1
|
|
# Get_prompt_point
|
|
prob = rearrange(diff, "h w -> (h w)")
|
|
prob = prob / prob.sum()
|
|
loc = np.random.choice(a=xy, size=1, replace=True, p=prob)[0]
|
|
x, y = loc % diff.shape[1], loc // diff.shape[1]
|
|
|
|
if gt_mask[y, x] == 1 and pred_mask[y, x] == 0:
|
|
classification = 1
|
|
else:
|
|
classification = 0
|
|
# raise ValueError
|
|
return (x, y), classification
|
|
|
|
|
|
|
|
class BoxPromptGenerator(object):
|
|
def __init__(self, size) -> None:
|
|
self.size = size
|
|
|
|
@staticmethod
|
|
def mask_to_bbox(mask):
|
|
# Find the indices of all non-zero elements in the mask
|
|
coords = np.nonzero(mask)
|
|
|
|
# Compute the minimum and maximum values of the row and column indices
|
|
x_min = np.min(coords[1])
|
|
y_min = np.min(coords[0])
|
|
x_max = np.max(coords[1])
|
|
y_max = np.max(coords[0])
|
|
|
|
# Return the coordinates of the bounding box
|
|
return (x_min, y_min, x_max, y_max)
|
|
# return (y_min, x_min, y_max, x_max)
|
|
|
|
def add_random_noise_to_bbox(self, bbox):
|
|
bbox = list(bbox)
|
|
# Calculate the side lengths of the box in the x and y directions
|
|
x_side_length = bbox[2] - bbox[0]
|
|
y_side_length = bbox[3] - bbox[1]
|
|
|
|
# Calculate the standard deviation of the noise
|
|
std_dev = 0.01 * (x_side_length + y_side_length) / 2
|
|
|
|
# Generate random noise for each coordinate
|
|
x_noise = np.random.normal(scale=std_dev)
|
|
y_noise = np.random.normal(scale=std_dev)
|
|
|
|
# Add the random noise to each coordinate, but make sure it is not larger than 20 pixels
|
|
bbox[0] += min(int(round(x_noise)), 20)
|
|
bbox[1] += min(int(round(y_noise)), 20)
|
|
bbox[2] += min(int(round(x_noise)), 20)
|
|
bbox[3] += min(int(round(y_noise)), 20)
|
|
|
|
# Make sure the modified coordinates do not exceed the maximum possible values
|
|
bbox[0] = max(bbox[0], 0)
|
|
bbox[1] = max(bbox[1], 0)
|
|
bbox[2] = min(bbox[2], self.size[0])
|
|
bbox[3] = min(bbox[3], self.size[1])
|
|
|
|
# Return the modified bounding box
|
|
return bbox
|
|
|
|
def get_prompt_box(self, gt_mask):
|
|
""" return (x_min, y_min, x_max, y_max) """
|
|
assert gt_mask.shape == (1024,1024) or gt_mask.shape == (512,512) or gt_mask.shape == (256,256), f"[data_engine] {__file__} Got{gt_mask.shape}"
|
|
box = self.mask_to_bbox(gt_mask)
|
|
box_w_noise = self.add_random_noise_to_bbox(box)
|
|
return box_w_noise
|
|
|
|
def enlarge(self, bbox, margin=0):
|
|
x0, y0, x1, y1 = bbox[0], bbox[1], bbox[2], bbox[3]
|
|
margin_x = int((x1 - x0)*0.05)
|
|
margin_y = int((y1 - y0)*0.05)
|
|
x0 = max(x0 - margin_x, 0)
|
|
y0 = max(y0 - margin_x, 0)
|
|
x1 = min(x1 - margin_y, self.size[0]-1)
|
|
y1 = min(y1 - margin_y, self.size[1]-1)
|
|
|
|
# print("[DEBUG] , enlarge size: ", margin_x, margin_y)
|
|
# print("[DEBUG] from", bbox, "to", (x0,y0,x1,y1))
|
|
return (x0,y0,x1,y1)
|
|
|
|
class DataEngine(Dataset):
|
|
def __init__(self, dataset=None, img_size=None) -> None:
|
|
# CACHE_DISK_DIR="/home1/quanquan/code/projects/medical-guangdong/cache/data2d_3/"
|
|
super().__init__()
|
|
self.point_prompt_generator = PointPromptGenerator(img_size)
|
|
self.box_prompt_generator = BoxPromptGenerator(img_size)
|
|
# self._get_dataset(dirpath=dirpath)
|
|
self.dataset = dataset
|
|
|
|
# def _get_dataset(self, dirpath):
|
|
# self.dataset = Dataset2D(dirpath=dirpath, is_train=True)
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def _get_true_index(self, idx):
|
|
return idx
|
|
|
|
def __getitem__(self, idx):
|
|
return self.get_prompt(idx)
|
|
|
|
def get_prompt_point(self, gt_mask):
|
|
return self.point_prompt_generator.get_prompt_point(gt_mask)
|
|
|
|
def get_prompt_box(self, gt_mask):
|
|
return self.box_prompt_generator.get_prompt_box(gt_mask)
|
|
|
|
# def _get_data_from_dataset(self, idx):
|
|
|
|
def get_prompt(self, idx):
|
|
idx = self._get_true_index(idx)
|
|
data = self.dataset.__getitem__(idx)
|
|
|
|
img = data['img'] # (3,h,w) d=3
|
|
mask = data['label'] # (3,h,w) d=3
|
|
|
|
try:
|
|
gt_mask = mask[1,:,:]
|
|
except Exception as e:
|
|
import ipdb; ipdb.set_trace()
|
|
gt_mask = mask[1,:,:]
|
|
gt_mask = gt_mask.numpy() if isinstance(gt_mask, torch.Tensor) else gt_mask
|
|
|
|
# if np.random.rand() > 0.5:
|
|
prompt_point = self.get_prompt_point(gt_mask)
|
|
# else:
|
|
prompt_box = self.get_prompt_box(gt_mask)
|
|
|
|
data['prompt_point'] = np.array(prompt_point).astype(np.float32)
|
|
data['prompt_box'] = np.array(prompt_box).astype(np.float32)
|
|
data['point_label'] = np.ones((1,)).astype(np.float32)
|
|
return data
|
|
|
|
def get_subsequent_prompt_point(self, pred_mask, gt_mask):
|
|
# return self.point_prompt_generator.select_random_subsequent_point_torch(pred_mask, gt_mask)
|
|
# return self.point_prompt_generator.select_random_subsequent_point(pred_mask=pred_mask, gt_mask=gt_mask)
|
|
coord_collect = []
|
|
label_collect = []
|
|
for i in range(pred_mask.shape[0]):
|
|
coords, label = self.point_prompt_generator.select_random_subsequent_point(pred_mask[i][0], gt_mask[i][0])
|
|
if label == -1:
|
|
return None, None
|
|
coord_collect.append(coords)
|
|
label_collect.append(label)
|
|
|
|
coord_collect = np.stack(coord_collect, axis=0)
|
|
label_collect = np.stack(label_collect, axis=0)
|
|
return coord_collect, label_collect
|
|
|
|
def get_noisy_box_from_box(self, box):
|
|
# Get noisy box from labeled box
|
|
return self.box_prompt_generator.add_random_noise_to_bbox(box)
|
|
|
|
# def get_prompt_mask(self, )
|
|
|
|
# class ValidEngine(DataEngine):
|
|
# def __init__(self, dataset=None, img_size=(1024,1024), is_train=False) -> None:
|
|
# # assert dataset is not None
|
|
# self.dataset = dataset
|
|
# self.is_train = is_train
|
|
# super().__init__(dataset=dataset, img_size=img_size)
|
|
# self.expand_dataset_ratio = 1
|
|
|
|
# # def _get_dataset(self, dirpath):
|
|
# # self.dataset = Dataset3D(dirpath=dirpath, is_train=self.is_train)
|
|
|
|
# def __len__(self):
|
|
# return len(self.dataset)
|
|
|
|
# def _get_true_index(self, idx):
|
|
# return idx
|
|
|
|
class DataManager:
|
|
def __init__(self, img_size=None) -> None:
|
|
self.point_prompt_generator = PointPromptGenerator(img_size)
|
|
self.box_prompt_generator = BoxPromptGenerator(img_size)
|
|
|
|
def get_prompt_point(self, gt_mask):
|
|
return self.point_prompt_generator.get_prompt_point(gt_mask)
|
|
|
|
def get_prompt_box(self, gt_mask):
|
|
return self.box_prompt_generator.get_prompt_box(gt_mask)
|
|
|
|
def get_subsequent_prompt_point(self, pred_mask, gt_mask):
|
|
# return self.point_prompt_generator.select_random_subsequent_point_torch(pred_mask, gt_mask)
|
|
# return self.point_prompt_generator.select_random_subsequent_point(pred_mask=pred_mask, gt_mask=gt_mask)
|
|
coord_collect = []
|
|
label_collect = []
|
|
for i in range(pred_mask.shape[0]):
|
|
coords, label = self.point_prompt_generator.select_random_subsequent_point(pred_mask[i][0], gt_mask[i][0])
|
|
if label == -1:
|
|
return None, None
|
|
coord_collect.append(coords)
|
|
label_collect.append(label)
|
|
|
|
coord_collect = np.stack(coord_collect, axis=0)
|
|
label_collect = np.stack(label_collect, axis=0)
|
|
return coord_collect, label_collect
|
|
|
|
def get_noisy_box_from_box(self, box):
|
|
# Get noisy box from labeled box
|
|
return self.box_prompt_generator.add_random_noise_to_bbox(box)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
dataset = DataEngine()
|
|
data = dataset.__getitem__(0)
|
|
|
|
import ipdb; ipdb.set_trace() |