155 lines
7.1 KiB
Python
155 lines
7.1 KiB
Python
import torch
|
|
import torchvision
|
|
import numpy as np
|
|
from tutils.trainer import Trainer, LearnerModule
|
|
from einops import rearrange, repeat, reduce
|
|
import torch.optim.lr_scheduler as lr_scheduler
|
|
|
|
from core.loss import ranked_combined_loss_with_indicators
|
|
from .learner2 import SamLearner as basic_learner
|
|
from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss
|
|
|
|
|
|
class SamLearner(basic_learner):
|
|
|
|
def training_step(self, data, batch_idx, **kwargs):
|
|
img = data['img']
|
|
gt_mask = data['label']
|
|
prompt_point = data['prompt_point'] # shape: (b, 2)
|
|
batch_size = prompt_point.shape[0]
|
|
point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device)
|
|
prompt_box = data['prompt_box']
|
|
indicators = data['indicators']
|
|
# print(data['name'])
|
|
|
|
prompt_point = rearrange(prompt_point, "b c -> b 1 c")
|
|
prompt_box = rearrange(prompt_box, "b c -> b 1 c")
|
|
assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}"
|
|
assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}"
|
|
assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}"
|
|
assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}"
|
|
|
|
self.set_torch_image(img, img.shape[2:])
|
|
# if np.random.random() > 0.5:
|
|
pred_masks, iou_predictions, logits = self.predict_torch(
|
|
point_coords=prompt_point,
|
|
point_labels=point_label,
|
|
multimask_output=True,
|
|
return_logits=True,
|
|
)
|
|
loss_1, fl, dl, min_losses, selected_losses, _ = ranked_combined_loss_with_indicators(pred_mask=pred_masks, gt_mask=gt_mask, iou_pred=iou_predictions, indicators=indicators)
|
|
# else:
|
|
pred_masks, iou_predictions, logits = self.predict_torch(
|
|
point_coords=None,
|
|
point_labels=None,
|
|
boxes=prompt_box,
|
|
multimask_output=True,
|
|
return_logits=True,
|
|
)
|
|
# assert pred_masks.shape == gt_mask.shape, f"Got {pred_masks.shape}, {gt_mask.shape}"
|
|
loss_2, fl, dl, min_losses, selected_losses, _ = ranked_combined_loss_with_indicators(pred_mask=pred_masks, gt_mask=gt_mask, iou_pred=iou_predictions, indicators=indicators)
|
|
|
|
loss = loss_1 + loss_2
|
|
if loss < -999 or torch.isnan(loss):
|
|
print("Warning! Loss Error! ")
|
|
print(data['name'])
|
|
|
|
# print("Debug trainer: 2", prompt_point.shape, point_label.shape, prompt_box.shape)
|
|
# Stage 2: based on the above, add more points as prompts
|
|
return {"loss": loss_1 + loss_2, "point_loss": loss_1.mean(), "box_loss": loss_2, "fl": fl.mean(), "dl": dl.mean(), "min_losses": min_losses, "selected_losses": selected_losses}
|
|
|
|
|
|
def validation_step(self, data, batch_idx=0, **kwargs):
|
|
img = data['img']
|
|
gt_mask = data['label']
|
|
prompt_point = data['prompt_point'] # shape: (b, 2)
|
|
batch_size = prompt_point.shape[0]
|
|
point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device)
|
|
prompt_box = data['prompt_box']
|
|
indicators = data['indicators']
|
|
print(data['name'])
|
|
prompt_point = rearrange(prompt_point, "b c -> b 1 c")
|
|
prompt_box = rearrange(prompt_box, "b c -> b 1 c")
|
|
assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}"
|
|
assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}"
|
|
assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}"
|
|
assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}"
|
|
|
|
self.set_torch_image(img, img.shape[2:])
|
|
|
|
# Stage 1: use the 1st prompt, box or point
|
|
iou_details = {}
|
|
pred_masks1, iou_predictions1, logits1 = self.predict_torch(
|
|
point_coords=prompt_point,
|
|
point_labels=point_label,
|
|
multimask_output=True,
|
|
return_logits=True,
|
|
)
|
|
|
|
if len(pred_masks1.shape) == 4:
|
|
pred_masks1 = rearrange(pred_masks1, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
|
|
|
|
loss_point, fl, dl, min_losses, selected_losses, min_indices = ranked_combined_loss_with_indicators(pred_mask=pred_masks1, gt_mask=gt_mask, iou_pred=iou_predictions1, indicators=indicators)
|
|
iou_details['loss_point'] = loss_point.mean()
|
|
iou_details['loss_point_fl'] = fl.mean()
|
|
iou_details['loss_point_dl'] = dl.mean()
|
|
|
|
if len(gt_mask.shape) == 4:
|
|
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3)
|
|
|
|
indices = iou_predictions1.argmax(axis=1)
|
|
pred_maxiou = []
|
|
for pred, i in zip(pred_masks1, indices):
|
|
pred_maxiou.append(pred[i,:,:,:])
|
|
pred_maxiou = torch.stack(pred_maxiou, axis=0)
|
|
iou = compute_iou2(torch.tensor(pred_maxiou>0, dtype=gt_mask.dtype), gt_mask[:,0,:,:,:]).detach()
|
|
iou_details['iou_point'] = iou.mean()
|
|
|
|
iou = compute_iou(torch.tensor(pred_masks1>0, dtype=gt_mask.dtype), gt_mask).detach()
|
|
iou, _ = torch.max(iou, axis=1)
|
|
iou_details['iou_point_max'] = iou.mean()
|
|
|
|
pred_masks2, iou_predictions2, logits2 = self.predict_torch(
|
|
point_coords=None,
|
|
point_labels=None,
|
|
boxes=prompt_box,
|
|
multimask_output=True,
|
|
return_logits=True,
|
|
)
|
|
loss_box, fl, dl, min_losses, selected_losses, min_indices = ranked_combined_loss_with_indicators(pred_mask=pred_masks2, gt_mask=gt_mask, iou_pred=iou_predictions2, indicators=indicators)
|
|
iou_details['loss_box'] = loss_box.mean()
|
|
iou_details['loss_box_fl'] = fl.mean()
|
|
iou_details['loss_box_dl'] = dl.mean()
|
|
|
|
if len(gt_mask.shape) == 4:
|
|
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3)
|
|
if len(pred_masks2.shape) == 4:
|
|
pred_masks2 = rearrange(pred_masks2, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
|
|
|
|
indices = iou_predictions2.argmax(axis=1)
|
|
pred_maxiou = []
|
|
for pred, i in zip(pred_masks2, indices):
|
|
pred_maxiou.append(pred[i,:,:,:])
|
|
pred_maxiou = torch.stack(pred_maxiou, axis=0)
|
|
iou = compute_iou2(torch.tensor(pred_maxiou>0, dtype=gt_mask.dtype), gt_mask[:,0,:,:,:]).detach()
|
|
iou_details['iou_box'] = iou.mean()
|
|
|
|
iou = compute_iou(torch.tensor(pred_masks2>0, dtype=gt_mask.dtype), gt_mask).detach()
|
|
iou, _ = torch.max(iou, axis=1)
|
|
iou_details['iou_box_max'] = iou.mean()
|
|
return iou_details
|
|
|
|
|
|
def compute_iou2(pred_mask, gt_mask):
|
|
dtype = pred_mask.dtype
|
|
intersection = torch.logical_and(pred_mask, gt_mask)
|
|
intersection = reduce(intersection, "b d h w -> b", reduction='sum')
|
|
union = torch.logical_or(pred_mask, gt_mask)
|
|
union = reduce(union, "b d h w -> b", reduction='sum') + 1e-8
|
|
iou = intersection / union # if union > 0 else 0
|
|
iou = torch.tensor(iou, dtype=dtype)
|
|
# print("ranked_combined_loss: compute_iou ", intersection.dtype, union.dtype, iou.dtype)
|
|
return iou
|
|
|
|
# def save(img, mask, mask2):
|