This commit is contained in:
transcendentsiki 2024-04-02 16:57:10 +08:00
parent 96a2bd15a6
commit 72a73cb4b9

View File

@ -191,7 +191,7 @@ class VolumePredictor:
raise NotImplementedError
# Preprocess prompts
self.original_size = x.shape[1:]
# self.original_size = x.shape[1:]
if point_coords is not None:
assert (
point_labels is not None
@ -217,6 +217,7 @@ class VolumePredictor:
center_masks = self._predict_center_slice(center_idx, point_coords, box)
return center_masks['masks']
@torch.no_grad()
def predict_volume(
self,
x,
@ -255,8 +256,37 @@ class VolumePredictor:
else:
raise NotImplementedError
# set 3d image
self.set_image(x)
# Preprocess prompts
self.original_size = x.shape[1:]
if self.masks3d is None:
self.masks3d = np.zeros_like(x)
self.slice_count = x.shape[0]
return self.predict_with_prompt(
point_coords = point_coords,
point_labels = point_labels,
box = box,
mask_input = mask_input,
multimask_output = multimask_output,
return_logits = return_logits,
template_slice_id = template_slice_id,
return_stability = return_stability
)
@torch.no_grad()
def predict_with_prompt(
self,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
box: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
template_slice_id:int = None,
return_stability: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
if point_coords is not None:
assert (
point_labels is not None
@ -273,38 +303,39 @@ class VolumePredictor:
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
mask_input_torch = mask_input_torch[None, :, :, :]
# set 3d image
self.set_image(x)
self.all_prompts = {}
# predict center slice
center_idx = template_slice_id if template_slice_id is not None else x.shape[0] // 2
center_idx = template_slice_id if template_slice_id is not None else self.slice_count // 2
# print("Processing ", center_idx)
center_masks = self._predict_center_slice(center_idx, point_coords, box)
if center_masks._stats == {}:
print("Ends for no mask.")
raise ValueError
self.merge_to_mask3d(center_idx, center_masks)
center_idx = center_idx.item() if not isinstance(center_idx, int) else center_idx
self.all_prompts[center_idx] = box if box is not None else point_coords
previous_masks = center_masks
for i in range(center_idx+1, x.shape[0]-1):
for i in range(center_idx+1, self.slice_count-1):
# print("Processing downward", i)
previous_masks = self._predict_slice(i, previous_masks, orientation="down")
previous_masks, scaled_boxes = self._predict_slice(i, previous_masks, orientation="down")
if previous_masks._stats == {}:
print("Ends for no mask.")
break
self.merge_to_mask3d(i, previous_masks)
self.all_prompts[i] = scaled_boxes
previous_masks = center_masks
for i in np.arange(1, center_idx)[::-1]:
# print("Processing upward", i)
previous_masks = self._predict_slice(i, previous_masks, orientation="up")
previous_masks, scaled_boxes = self._predict_slice(i, previous_masks, orientation="up")
if previous_masks._stats == {}:
print("Ends for no mask.")
break
self.merge_to_mask3d(i, previous_masks)
self.all_prompts[i] = scaled_boxes
if self.masks3d is None:
self.masks3d = np.zeros_like(x)
if return_stability:
return self.postprocess_3d(self.masks3d), self.stability_score_2d
return self.postprocess_3d(self.masks3d)
@ -324,7 +355,7 @@ class VolumePredictor:
scaled_boxes, tags = self.generate_prompts_from_previous_masks(previous_masks, orientation)
masks = self.genetate_masks_from_boxes(idx, all_boxes=scaled_boxes, tags=tags)
masks.to_numpy()
return masks
return masks, scaled_boxes
def generate_prompts_from_previous_masks(self, previous_masks: MaskData, orientation):
if orientation == "down":
@ -486,10 +517,6 @@ class VolumePredictor:
data["boxes"] = batched_mask_to_box(data["masks"][:,1,:,:]>0)
return data
# @staticmethod
# def calculate_
@staticmethod
def batched_remove_noise(masks):
ori_shape = masks.shape
@ -539,26 +566,80 @@ class VolumePredictor:
)
# Upscale the masks to the original image resolution
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size[1:])
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
if not return_logits:
masks = masks > self.model.mask_threshold
return masks, iou_predictions, low_res_masks
def valid_box(self, data, batch_idx):
# valid image with box, or point prompt
assert data['img'].shape[0] == 1, f"shape {data['img'].shape}"
image = data['img']
label = data['label']
# def valid_box(self, data, batch_idx):
# # valid image with box, or point prompt
# assert data['img'].shape[0] == 1, f"shape {data['img'].shape}"
# image = data['img']
# label = data['label']
box = BoxPromptGenerator().mask_to_bbox(label)
box_mask3d = self.predict_volume(
x=image,
box=box,
)
dice = compute_dice_np(box_mask3d, label.detach().cpu().numpy())
# box = BoxPromptGenerator().mask_to_bbox(label)
# box_mask3d = self.predict_volume(
# x=image,
# box=box,
# )
# dice = compute_dice_np(box_mask3d, label.detach().cpu().numpy())
def get_confidence(self):
masks = self.postprocess_3d(self.masks3d)
conf_collect = []
for i in range(1,self.masks3d.shape[0]-1):
prompt_box = self.all_prompts.get(i, None)
if prompt_box is not None:
mask = masks[i,:,:]
if mask.sum() > 0:
bbox = BoxPromptGenerator(size=None).mask_to_bbox(mask)
bbox = self.transform.apply_boxes(np.array([bbox]), self.original_size)[0]
else:
bbox = [0,0,0,0]
prompt_box = self.all_prompts[i][0]
confidence = calculate_iou(bbox, prompt_box)
else:
confidence = 0
if i == 1:
conf_collect.append(confidence)
conf_collect.append(confidence)
assert len(conf_collect) == i+1
conf_collect.append(confidence)
print(conf_collect)
return conf_collect
def calculate_iou(box1, box2):
"""
计算两个框的IoU Intersection over Union
参数
box1 box2 是两个框每个框表示为四个值 (x1, y1, x2, y2)其中 (x1, y1) 是左上角的坐标
(x2, y2) 是右下角的坐标
返回
返回两个框的IoU
"""
# 计算交集的左上角和右下角坐标
x1_i = max(box1[0], box2[0])
y1_i = max(box1[1], box2[1])
x2_i = min(box1[2], box2[2])
y2_i = min(box1[3], box2[3])
# 计算交集的面积
intersection_area = max(0, x2_i - x1_i) * max(0, y2_i - y1_i)
# 计算并集的面积
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
union_area = box1_area + box2_area - intersection_area
# print(intersection_area, union_area)
# 计算IoU
iou = intersection_area / union_area
return iou
if __name__ == "__main__":
@ -573,7 +654,7 @@ if __name__ == "__main__":
sam = sam_model_registry[model_type](checkpoint=None)
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
learner.use_lora()
pth = "model_iter_360000.pth"
pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth"
learner.load_well_trained_model(pth)
learner.cuda()
@ -588,7 +669,7 @@ if __name__ == "__main__":
volume = itk_to_np(read(img_path)) # test several slices
label_itk = read(label_path)
spacing = label_itk.GetSpacing()
label = itk_to_np(label_itk) == 1
label = itk_to_np(label_itk) == 13
volume = np.clip(volume, -200, 400)
# Select the slice with the largest mask
@ -598,8 +679,14 @@ if __name__ == "__main__":
x_max = np.max(coords[0])
template_slice_id = s.argmax()
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id])
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id]) # (115, 207, 309, 339)
# import ipdb; ipdb.set_trace()
# box = (125, 210, 300, 310)
box = np.array([box])
box[0][0] += 10
box[0][1] += 10
box[0][2] -= 10
box[0][3] -= 10
pred = predictor.predict_volume(
x=volume,
@ -609,4 +696,9 @@ if __name__ == "__main__":
)
Data3dSolver().simple_write(pred, path="mask.nii.gz", spacing=spacing)
Data3dSolver().simple_write(label, path="gt.nii.gz", spacing=spacing)
Data3dSolver().simple_write(label, path="gt.nii.gz", spacing=spacing)
dice = compute_dice_np(pred, label)
print("Dice ", dice, " box: ", box, "slice id", template_slice_id)
print(tuple(box))
# import ipdb; ipdb.set_trace()