From 72a73cb4b973d4fb4a506ee2f73da52c445c99a9 Mon Sep 17 00:00:00 2001 From: transcendentsiki Date: Tue, 2 Apr 2024 16:57:10 +0800 Subject: [PATCH] add mbox --- core/volume_predictor.py | 154 +++++++++++++++++++++++++++++++-------- 1 file changed, 123 insertions(+), 31 deletions(-) diff --git a/core/volume_predictor.py b/core/volume_predictor.py index 88e77cc..677862b 100644 --- a/core/volume_predictor.py +++ b/core/volume_predictor.py @@ -191,7 +191,7 @@ class VolumePredictor: raise NotImplementedError # Preprocess prompts - self.original_size = x.shape[1:] + # self.original_size = x.shape[1:] if point_coords is not None: assert ( point_labels is not None @@ -217,6 +217,7 @@ class VolumePredictor: center_masks = self._predict_center_slice(center_idx, point_coords, box) return center_masks['masks'] + @torch.no_grad() def predict_volume( self, x, @@ -255,8 +256,37 @@ class VolumePredictor: else: raise NotImplementedError + # set 3d image + self.set_image(x) # Preprocess prompts self.original_size = x.shape[1:] + if self.masks3d is None: + self.masks3d = np.zeros_like(x) + self.slice_count = x.shape[0] + return self.predict_with_prompt( + point_coords = point_coords, + point_labels = point_labels, + box = box, + mask_input = mask_input, + multimask_output = multimask_output, + return_logits = return_logits, + template_slice_id = template_slice_id, + return_stability = return_stability + ) + + @torch.no_grad() + def predict_with_prompt( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + template_slice_id:int = None, + return_stability: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + if point_coords is not None: assert ( point_labels is not None @@ -273,38 +303,39 @@ class VolumePredictor: mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) mask_input_torch = mask_input_torch[None, :, :, :] - # set 3d image - self.set_image(x) - + self.all_prompts = {} + # predict center slice - center_idx = template_slice_id if template_slice_id is not None else x.shape[0] // 2 + center_idx = template_slice_id if template_slice_id is not None else self.slice_count // 2 # print("Processing ", center_idx) center_masks = self._predict_center_slice(center_idx, point_coords, box) if center_masks._stats == {}: print("Ends for no mask.") raise ValueError self.merge_to_mask3d(center_idx, center_masks) + center_idx = center_idx.item() if not isinstance(center_idx, int) else center_idx + self.all_prompts[center_idx] = box if box is not None else point_coords previous_masks = center_masks - for i in range(center_idx+1, x.shape[0]-1): + for i in range(center_idx+1, self.slice_count-1): # print("Processing downward", i) - previous_masks = self._predict_slice(i, previous_masks, orientation="down") + previous_masks, scaled_boxes = self._predict_slice(i, previous_masks, orientation="down") if previous_masks._stats == {}: print("Ends for no mask.") break self.merge_to_mask3d(i, previous_masks) + self.all_prompts[i] = scaled_boxes previous_masks = center_masks for i in np.arange(1, center_idx)[::-1]: # print("Processing upward", i) - previous_masks = self._predict_slice(i, previous_masks, orientation="up") + previous_masks, scaled_boxes = self._predict_slice(i, previous_masks, orientation="up") if previous_masks._stats == {}: print("Ends for no mask.") break self.merge_to_mask3d(i, previous_masks) + self.all_prompts[i] = scaled_boxes - if self.masks3d is None: - self.masks3d = np.zeros_like(x) if return_stability: return self.postprocess_3d(self.masks3d), self.stability_score_2d return self.postprocess_3d(self.masks3d) @@ -324,7 +355,7 @@ class VolumePredictor: scaled_boxes, tags = self.generate_prompts_from_previous_masks(previous_masks, orientation) masks = self.genetate_masks_from_boxes(idx, all_boxes=scaled_boxes, tags=tags) masks.to_numpy() - return masks + return masks, scaled_boxes def generate_prompts_from_previous_masks(self, previous_masks: MaskData, orientation): if orientation == "down": @@ -486,10 +517,6 @@ class VolumePredictor: data["boxes"] = batched_mask_to_box(data["masks"][:,1,:,:]>0) return data - # @staticmethod - # def calculate_ - - @staticmethod def batched_remove_noise(masks): ori_shape = masks.shape @@ -539,26 +566,80 @@ class VolumePredictor: ) # Upscale the masks to the original image resolution - masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size[1:]) + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) if not return_logits: masks = masks > self.model.mask_threshold return masks, iou_predictions, low_res_masks - def valid_box(self, data, batch_idx): - # valid image with box, or point prompt - assert data['img'].shape[0] == 1, f"shape {data['img'].shape}" - image = data['img'] - label = data['label'] + # def valid_box(self, data, batch_idx): + # # valid image with box, or point prompt + # assert data['img'].shape[0] == 1, f"shape {data['img'].shape}" + # image = data['img'] + # label = data['label'] - box = BoxPromptGenerator().mask_to_bbox(label) - box_mask3d = self.predict_volume( - x=image, - box=box, - ) - dice = compute_dice_np(box_mask3d, label.detach().cpu().numpy()) + # box = BoxPromptGenerator().mask_to_bbox(label) + # box_mask3d = self.predict_volume( + # x=image, + # box=box, + # ) + # dice = compute_dice_np(box_mask3d, label.detach().cpu().numpy()) + def get_confidence(self): + masks = self.postprocess_3d(self.masks3d) + conf_collect = [] + for i in range(1,self.masks3d.shape[0]-1): + prompt_box = self.all_prompts.get(i, None) + if prompt_box is not None: + mask = masks[i,:,:] + if mask.sum() > 0: + bbox = BoxPromptGenerator(size=None).mask_to_bbox(mask) + bbox = self.transform.apply_boxes(np.array([bbox]), self.original_size)[0] + else: + bbox = [0,0,0,0] + prompt_box = self.all_prompts[i][0] + confidence = calculate_iou(bbox, prompt_box) + else: + confidence = 0 + if i == 1: + conf_collect.append(confidence) + conf_collect.append(confidence) + assert len(conf_collect) == i+1 + conf_collect.append(confidence) + print(conf_collect) + return conf_collect + +def calculate_iou(box1, box2): + """ + 计算两个框的IoU Intersection over Union。 + + 参数: + box1 和 box2 是两个框,每个框表示为四个值 (x1, y1, x2, y2),其中 (x1, y1) 是左上角的坐标, + (x2, y2) 是右下角的坐标。 + + 返回: + 返回两个框的IoU。 + """ + # 计算交集的左上角和右下角坐标 + x1_i = max(box1[0], box2[0]) + y1_i = max(box1[1], box2[1]) + x2_i = min(box1[2], box2[2]) + y2_i = min(box1[3], box2[3]) + + # 计算交集的面积 + intersection_area = max(0, x2_i - x1_i) * max(0, y2_i - y1_i) + + # 计算并集的面积 + box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) + box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) + union_area = box1_area + box2_area - intersection_area + # print(intersection_area, union_area) + + # 计算IoU + iou = intersection_area / union_area + + return iou if __name__ == "__main__": @@ -573,7 +654,7 @@ if __name__ == "__main__": sam = sam_model_registry[model_type](checkpoint=None) learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024))) learner.use_lora() - pth = "model_iter_360000.pth" + pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth" learner.load_well_trained_model(pth) learner.cuda() @@ -588,7 +669,7 @@ if __name__ == "__main__": volume = itk_to_np(read(img_path)) # test several slices label_itk = read(label_path) spacing = label_itk.GetSpacing() - label = itk_to_np(label_itk) == 1 + label = itk_to_np(label_itk) == 13 volume = np.clip(volume, -200, 400) # Select the slice with the largest mask @@ -598,8 +679,14 @@ if __name__ == "__main__": x_max = np.max(coords[0]) template_slice_id = s.argmax() - box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id]) + box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id]) # (115, 207, 309, 339) + # import ipdb; ipdb.set_trace() + # box = (125, 210, 300, 310) box = np.array([box]) + box[0][0] += 10 + box[0][1] += 10 + box[0][2] -= 10 + box[0][3] -= 10 pred = predictor.predict_volume( x=volume, @@ -609,4 +696,9 @@ if __name__ == "__main__": ) Data3dSolver().simple_write(pred, path="mask.nii.gz", spacing=spacing) - Data3dSolver().simple_write(label, path="gt.nii.gz", spacing=spacing) \ No newline at end of file + Data3dSolver().simple_write(label, path="gt.nii.gz", spacing=spacing) + + dice = compute_dice_np(pred, label) + print("Dice ", dice, " box: ", box, "slice id", template_slice_id) + print(tuple(box)) + # import ipdb; ipdb.set_trace() \ No newline at end of file