add mbox
This commit is contained in:
parent
96a2bd15a6
commit
72a73cb4b9
@ -191,7 +191,7 @@ class VolumePredictor:
|
||||
raise NotImplementedError
|
||||
|
||||
# Preprocess prompts
|
||||
self.original_size = x.shape[1:]
|
||||
# self.original_size = x.shape[1:]
|
||||
if point_coords is not None:
|
||||
assert (
|
||||
point_labels is not None
|
||||
@ -217,6 +217,7 @@ class VolumePredictor:
|
||||
center_masks = self._predict_center_slice(center_idx, point_coords, box)
|
||||
return center_masks['masks']
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_volume(
|
||||
self,
|
||||
x,
|
||||
@ -255,8 +256,37 @@ class VolumePredictor:
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# set 3d image
|
||||
self.set_image(x)
|
||||
# Preprocess prompts
|
||||
self.original_size = x.shape[1:]
|
||||
if self.masks3d is None:
|
||||
self.masks3d = np.zeros_like(x)
|
||||
self.slice_count = x.shape[0]
|
||||
return self.predict_with_prompt(
|
||||
point_coords = point_coords,
|
||||
point_labels = point_labels,
|
||||
box = box,
|
||||
mask_input = mask_input,
|
||||
multimask_output = multimask_output,
|
||||
return_logits = return_logits,
|
||||
template_slice_id = template_slice_id,
|
||||
return_stability = return_stability
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_with_prompt(
|
||||
self,
|
||||
point_coords: Optional[np.ndarray] = None,
|
||||
point_labels: Optional[np.ndarray] = None,
|
||||
box: Optional[np.ndarray] = None,
|
||||
mask_input: Optional[np.ndarray] = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
template_slice_id:int = None,
|
||||
return_stability: bool = False,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
|
||||
if point_coords is not None:
|
||||
assert (
|
||||
point_labels is not None
|
||||
@ -273,38 +303,39 @@ class VolumePredictor:
|
||||
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
|
||||
mask_input_torch = mask_input_torch[None, :, :, :]
|
||||
|
||||
# set 3d image
|
||||
self.set_image(x)
|
||||
|
||||
self.all_prompts = {}
|
||||
|
||||
# predict center slice
|
||||
center_idx = template_slice_id if template_slice_id is not None else x.shape[0] // 2
|
||||
center_idx = template_slice_id if template_slice_id is not None else self.slice_count // 2
|
||||
# print("Processing ", center_idx)
|
||||
center_masks = self._predict_center_slice(center_idx, point_coords, box)
|
||||
if center_masks._stats == {}:
|
||||
print("Ends for no mask.")
|
||||
raise ValueError
|
||||
self.merge_to_mask3d(center_idx, center_masks)
|
||||
center_idx = center_idx.item() if not isinstance(center_idx, int) else center_idx
|
||||
self.all_prompts[center_idx] = box if box is not None else point_coords
|
||||
|
||||
previous_masks = center_masks
|
||||
for i in range(center_idx+1, x.shape[0]-1):
|
||||
for i in range(center_idx+1, self.slice_count-1):
|
||||
# print("Processing downward", i)
|
||||
previous_masks = self._predict_slice(i, previous_masks, orientation="down")
|
||||
previous_masks, scaled_boxes = self._predict_slice(i, previous_masks, orientation="down")
|
||||
if previous_masks._stats == {}:
|
||||
print("Ends for no mask.")
|
||||
break
|
||||
self.merge_to_mask3d(i, previous_masks)
|
||||
self.all_prompts[i] = scaled_boxes
|
||||
|
||||
previous_masks = center_masks
|
||||
for i in np.arange(1, center_idx)[::-1]:
|
||||
# print("Processing upward", i)
|
||||
previous_masks = self._predict_slice(i, previous_masks, orientation="up")
|
||||
previous_masks, scaled_boxes = self._predict_slice(i, previous_masks, orientation="up")
|
||||
if previous_masks._stats == {}:
|
||||
print("Ends for no mask.")
|
||||
break
|
||||
self.merge_to_mask3d(i, previous_masks)
|
||||
self.all_prompts[i] = scaled_boxes
|
||||
|
||||
if self.masks3d is None:
|
||||
self.masks3d = np.zeros_like(x)
|
||||
if return_stability:
|
||||
return self.postprocess_3d(self.masks3d), self.stability_score_2d
|
||||
return self.postprocess_3d(self.masks3d)
|
||||
@ -324,7 +355,7 @@ class VolumePredictor:
|
||||
scaled_boxes, tags = self.generate_prompts_from_previous_masks(previous_masks, orientation)
|
||||
masks = self.genetate_masks_from_boxes(idx, all_boxes=scaled_boxes, tags=tags)
|
||||
masks.to_numpy()
|
||||
return masks
|
||||
return masks, scaled_boxes
|
||||
|
||||
def generate_prompts_from_previous_masks(self, previous_masks: MaskData, orientation):
|
||||
if orientation == "down":
|
||||
@ -486,10 +517,6 @@ class VolumePredictor:
|
||||
data["boxes"] = batched_mask_to_box(data["masks"][:,1,:,:]>0)
|
||||
return data
|
||||
|
||||
# @staticmethod
|
||||
# def calculate_
|
||||
|
||||
|
||||
@staticmethod
|
||||
def batched_remove_noise(masks):
|
||||
ori_shape = masks.shape
|
||||
@ -539,26 +566,80 @@ class VolumePredictor:
|
||||
)
|
||||
|
||||
# Upscale the masks to the original image resolution
|
||||
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size[1:])
|
||||
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
|
||||
|
||||
if not return_logits:
|
||||
masks = masks > self.model.mask_threshold
|
||||
|
||||
return masks, iou_predictions, low_res_masks
|
||||
|
||||
def valid_box(self, data, batch_idx):
|
||||
# valid image with box, or point prompt
|
||||
assert data['img'].shape[0] == 1, f"shape {data['img'].shape}"
|
||||
image = data['img']
|
||||
label = data['label']
|
||||
# def valid_box(self, data, batch_idx):
|
||||
# # valid image with box, or point prompt
|
||||
# assert data['img'].shape[0] == 1, f"shape {data['img'].shape}"
|
||||
# image = data['img']
|
||||
# label = data['label']
|
||||
|
||||
box = BoxPromptGenerator().mask_to_bbox(label)
|
||||
box_mask3d = self.predict_volume(
|
||||
x=image,
|
||||
box=box,
|
||||
)
|
||||
dice = compute_dice_np(box_mask3d, label.detach().cpu().numpy())
|
||||
# box = BoxPromptGenerator().mask_to_bbox(label)
|
||||
# box_mask3d = self.predict_volume(
|
||||
# x=image,
|
||||
# box=box,
|
||||
# )
|
||||
# dice = compute_dice_np(box_mask3d, label.detach().cpu().numpy())
|
||||
|
||||
def get_confidence(self):
|
||||
masks = self.postprocess_3d(self.masks3d)
|
||||
conf_collect = []
|
||||
for i in range(1,self.masks3d.shape[0]-1):
|
||||
prompt_box = self.all_prompts.get(i, None)
|
||||
if prompt_box is not None:
|
||||
mask = masks[i,:,:]
|
||||
if mask.sum() > 0:
|
||||
bbox = BoxPromptGenerator(size=None).mask_to_bbox(mask)
|
||||
bbox = self.transform.apply_boxes(np.array([bbox]), self.original_size)[0]
|
||||
else:
|
||||
bbox = [0,0,0,0]
|
||||
prompt_box = self.all_prompts[i][0]
|
||||
confidence = calculate_iou(bbox, prompt_box)
|
||||
else:
|
||||
confidence = 0
|
||||
if i == 1:
|
||||
conf_collect.append(confidence)
|
||||
conf_collect.append(confidence)
|
||||
assert len(conf_collect) == i+1
|
||||
conf_collect.append(confidence)
|
||||
print(conf_collect)
|
||||
return conf_collect
|
||||
|
||||
def calculate_iou(box1, box2):
|
||||
"""
|
||||
计算两个框的IoU Intersection over Union。
|
||||
|
||||
参数:
|
||||
box1 和 box2 是两个框,每个框表示为四个值 (x1, y1, x2, y2),其中 (x1, y1) 是左上角的坐标,
|
||||
(x2, y2) 是右下角的坐标。
|
||||
|
||||
返回:
|
||||
返回两个框的IoU。
|
||||
"""
|
||||
# 计算交集的左上角和右下角坐标
|
||||
x1_i = max(box1[0], box2[0])
|
||||
y1_i = max(box1[1], box2[1])
|
||||
x2_i = min(box1[2], box2[2])
|
||||
y2_i = min(box1[3], box2[3])
|
||||
|
||||
# 计算交集的面积
|
||||
intersection_area = max(0, x2_i - x1_i) * max(0, y2_i - y1_i)
|
||||
|
||||
# 计算并集的面积
|
||||
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||||
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
||||
union_area = box1_area + box2_area - intersection_area
|
||||
# print(intersection_area, union_area)
|
||||
|
||||
# 计算IoU
|
||||
iou = intersection_area / union_area
|
||||
|
||||
return iou
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -573,7 +654,7 @@ if __name__ == "__main__":
|
||||
sam = sam_model_registry[model_type](checkpoint=None)
|
||||
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
||||
learner.use_lora()
|
||||
pth = "model_iter_360000.pth"
|
||||
pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth"
|
||||
learner.load_well_trained_model(pth)
|
||||
learner.cuda()
|
||||
|
||||
@ -588,7 +669,7 @@ if __name__ == "__main__":
|
||||
volume = itk_to_np(read(img_path)) # test several slices
|
||||
label_itk = read(label_path)
|
||||
spacing = label_itk.GetSpacing()
|
||||
label = itk_to_np(label_itk) == 1
|
||||
label = itk_to_np(label_itk) == 13
|
||||
volume = np.clip(volume, -200, 400)
|
||||
|
||||
# Select the slice with the largest mask
|
||||
@ -598,8 +679,14 @@ if __name__ == "__main__":
|
||||
x_max = np.max(coords[0])
|
||||
template_slice_id = s.argmax()
|
||||
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id])
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id]) # (115, 207, 309, 339)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# box = (125, 210, 300, 310)
|
||||
box = np.array([box])
|
||||
box[0][0] += 10
|
||||
box[0][1] += 10
|
||||
box[0][2] -= 10
|
||||
box[0][3] -= 10
|
||||
|
||||
pred = predictor.predict_volume(
|
||||
x=volume,
|
||||
@ -609,4 +696,9 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
Data3dSolver().simple_write(pred, path="mask.nii.gz", spacing=spacing)
|
||||
Data3dSolver().simple_write(label, path="gt.nii.gz", spacing=spacing)
|
||||
Data3dSolver().simple_write(label, path="gt.nii.gz", spacing=spacing)
|
||||
|
||||
dice = compute_dice_np(pred, label)
|
||||
print("Dice ", dice, " box: ", box, "slice id", template_slice_id)
|
||||
print(tuple(box))
|
||||
# import ipdb; ipdb.set_trace()
|
Loading…
x
Reference in New Issue
Block a user