Slide-SAM/core/learner2.py
transcendentsky e04459c6fe first commit
2023-12-05 14:58:38 +08:00

522 lines
22 KiB
Python

import torch
import torchvision
import numpy as np
from tutils.trainer import Trainer, LearnerModule
from torch.utils.data import DataLoader
from torch import optim
import torch.optim.lr_scheduler as lr_scheduler
from einops import rearrange, repeat
from torch.nn import functional as F
import os
from typing import Optional, Tuple
import torch.optim.lr_scheduler as lr_scheduler
from modeling.sam3d import Sam
# from segment_anything.utils.transforms import ResizeLongestSide
from utils.transforms import ResizeLongestSide
from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss
from .lora_sam import LoRA_Sam
from safetensors import safe_open
from datasets.data_engine import DataEngine
# def lr_schedule(epoch):
# if epoch < 250:
# return (epoch + 1) / 250 * 0.0008 + 0.00004
# elif epoch < 500:
# return 0.0001
# else:
# return 0.0001
def lr_schedule(epoch):
if epoch < 250:
return (epoch + 1) / 250 * 0.1
elif epoch < 500:
return 0.01
else:
return 0.001
class SamLearner(LearnerModule):
def __init__(
self,
sam_model: Sam,
config=None,
logger=None,
data_engine=DataEngine(None, img_size=(1024,1024)),
lora_module=None,
) -> None:
"""
Uses SAM to calculate the image embedding for an image, and then
allow repeated, efficient mask prediction given prompts.
Arguments:
sam_model (Sam): The model to use for mask prediction.
"""
super().__init__()
self.config = config
self.logger = logger
self.model = sam_model
self.net = self.model
self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
self.reset_image()
self.data_engine = data_engine
self.features = None
self.lora_module = lora_module
def save(self, pth, *args, **kwargs):
# Default: "/model_epoch_{}.pth".format(epoch)
torch.save(self.net.state_dict(), pth)
lora_path = pth.replace(".pth", "_lora.safetensors")
self.lora_module.save_lora_parameters(lora_path)
return True
def load_pretrained_model(self, pth, *args, **kwargs):
"""
Unmatched: prompt_encoder.mask_downscaling.0.weight
their: torch.Size([4, 1, 2, 2])
our: torch.Size([4, 3, 2, 2])
Unmatched: mask_decoder.mask_tokens.weight
their: torch.Size([4, 256])
our: torch.Size([12, 256])
"""
state_dict = torch.load(pth)
model_state_dict = self.model.state_dict()
model_state_dict.update(state_dict)
model_state_dict['prompt_encoder.mask_downscaling.0.weight'] = repeat(state_dict['prompt_encoder.mask_downscaling.0.weight'], "a 1 c d -> a b c d", b=3)
model_state_dict['mask_decoder.mask_tokens.weight'] = repeat(state_dict['mask_decoder.mask_tokens.weight'], "a d -> (a 3) d")
hyper_params_names = [k for k in model_state_dict.keys() if k.startswith("mask_decoder.output_hypernetworks_mlps")]
for name in hyper_params_names:
words = name.split('.')
words[2] = str(int(words[2]) // 3)
name_to_copy = ".".join(words)
model_state_dict[name] = state_dict[name_to_copy]
# for k, v in state_dict.items():
# if model_state_dict[k].shape != state_dict[k].shape:
# print("Unmatched:", k)
self.model.load_state_dict(model_state_dict)
def load_well_trained_model(self, pth=None):
pth = self.config['training']['breakpoint_path'] + "/ckpt_v/model_latest.pth" if pth is None else pth
print("Loading from ", pth)
state_dict = torch.load(pth, map_location="cpu")
# print(state_dict.keys())
# for k in state_dict.keys():
# print(k)
# exit(0)
self.model.load_state_dict(state_dict)
# self.lora_module.load_lora_parameters(pth.replace(".pth", "_lora.safetensors"))
def use_lora(self):
lora_r = 8
lora_sam = LoRA_Sam(self.model, lora_r, freeze_prompt_encoder=True)
self.lora_module = lora_sam
def configure_optimizers(self, **kwargs):
optimizer = optim.AdamW(params=self.model.parameters(), \
lr=self.config['training']['lr'], betas=(0.9, 0.999), eps=1e-08,
weight_decay=self.config['training']['weight_decay'])
# scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule)
scheduler = None
return {'optimizer': optimizer, "scheduler": scheduler}
def load_optim(self, optimizer, pth=None, *args):
pth = self.config['training']['breakpoint_path'] + "/ckpt/optim_latest.pth"
print("Load Optimizer from ", pth)
state_dict = torch.load(pth)
optimizer.load_state_dict(state_dict['optimizer'])
start_epoch = state_dict.get('epoch', 0) + 1
return start_epoch
def training_step(self, data, batch_idx, **kwargs):
img = data['img']
gt_mask = data['label']
prompt_point = data['prompt_point'] # shape: (b, 2)
batch_size = prompt_point.shape[0]
point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device)
prompt_box = data['prompt_box']
prompt_point = rearrange(prompt_point, "b c -> b 1 c")
prompt_box = rearrange(prompt_box, "b c -> b 1 c")
assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}"
assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}"
assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}"
assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}"
self.set_torch_image(img, img.shape[2:])
if np.random.random() > 0.5:
pred_masks, iou_predictions, logits = self.predict_torch(
point_coords=prompt_point,
point_labels=point_label,
multimask_output=True,
return_logits=True,
)
else:
pred_masks, iou_predictions, logits = self.predict_torch(
point_coords=None,
point_labels=None,
boxes=prompt_box,
multimask_output=True,
return_logits=True,
)
# assert pred_masks.shape == gt_mask.shape, f"Got {pred_masks.shape}, {gt_mask.shape}"
loss_1, fl, dl = ranked_combined_loss(pred_mask=pred_masks, gt_mask=gt_mask, iou_pred=iou_predictions)
# print("Debug trainer: 2", prompt_point.shape, point_label.shape, prompt_box.shape)
# Stage 2: based on the above, add more points as prompts
return {"loss": loss_1, "fl": fl.mean(), "dl": dl.mean()}
# @torch.no_grad()
def generate(self, image, prompt_point):
orig_size = image.shape[2:]
assert image.shape[1:] == (3,1024,1024),f"{__file__} Got{image.shape}"
if not self.is_image_set:
self.set_torch_image(image, orig_size)
assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}"
# assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}"
point_label = torch.ones(prompt_point.size()[:-1])
pred_masks, scores, logits = self.predict_torch(
point_coords=prompt_point,
point_labels=point_label,
mask_input=None,
multimask_output=True,
)
return pred_masks
# @torch.no_grad()
def generate_by_box(self, image, prompt_box):
orig_size = image.shape[2:]
assert image.shape[1:] == (3,1024,1024),f"{__file__} Got{image.shape}"
if not self.is_image_set:
self.set_torch_image(image, orig_size)
assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}"
pred_masks, scores, logits = self.predict_torch(
point_coords=None,
point_labels=None,
boxes=prompt_box,
mask_input=None,
multimask_output=True,
)
return pred_masks
@staticmethod
def select_best_mask(predictions, ground_truth):
# Move tensors to the same device (if not already on the same device)
# if predictions.device != ground_truth.device:
# predictions = predictions.to(ground_truth.device)
# Compute IoU between each prediction and ground truth
if predictions.shape[1] == 9:
predictions = rearrange(predictions, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
ground_truth = repeat(ground_truth, "b d h w -> b c d h w", c=3)
else:
predictions = rearrange(predictions, "b d h w -> b 1 d h w")
ground_truth = rearrange(ground_truth, "b d h w -> b 1 d h w")
intersection = torch.sum(predictions * ground_truth, dim=(-3, -2, -1))
union = torch.sum(predictions + ground_truth, dim=(-3, -2, -1)) - intersection
iou = intersection / (union + 1e-6)
# Select the prediction with maximum IoU for each image in the batch
best_indices = torch.argmax(iou, dim=1)
best_masks = torch.gather(predictions, 1, best_indices.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, 1, predictions.shape[-3], predictions.shape[-2], predictions.shape[-1]))
return best_masks
# ===============================================
def predict_multi_prompt(
self,
point_coords: Optional[torch.Tensor],
point_labels: Optional[torch.Tensor],
mask_logits: Optional[torch.Tensor],
):
if point_coords is not None:
assert (
point_labels is not None
), "point_labels must be supplied if point_coords is supplied."
point_coords = self.transform.apply_coords(point_coords, self.original_size)
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
if point_coords is not None:
points = (coords_torch, point_labels)
else:
points = None
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=points,
boxes=None,
masks=mask_logits,
)
low_res_masks, iou_predictions = self.model.mask_decoder(
image_embeddings=self.features,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
return masks, iou_predictions, low_res_masks
def set_image(
self,
image: np.ndarray,
image_format: str = "RGB",
) -> None:
# Transform the image to the form expected by the model
input_image = self.transform.apply_image(image)
input_image_torch = torch.as_tensor(input_image, device=self.device)
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
self.set_torch_image(input_image_torch, image.shape[:2])
# @torch.no_grad()
def set_torch_image(
self,
transformed_image: torch.Tensor,
original_image_size: Tuple[int, ...],
) -> None:
"""
Calculates the image embeddings for the provided image, allowing
masks to be predicted with the 'predict' method. Expects the input
image to be already transformed to the format expected by the model.
Arguments:
transformed_image (torch.Tensor): The input image, with shape
1x3xHxW, which has been transformed with ResizeLongestSide.
original_image_size (tuple(int, int)): The size of the image
before transformation, in (H, W) format.
"""
assert (
len(transformed_image.shape) == 4
and transformed_image.shape[1] == 3
and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
self.reset_image()
self.original_size = original_image_size
self.input_size = tuple(transformed_image.shape[-2:])
input_image = self.model.preprocess(transformed_image)
self.features = self.model.image_encoder(input_image)
self.is_image_set = True
def predict(
self,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
box: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
# Transform input prompts
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
if point_coords is not None:
assert (
point_labels is not None
), "point_labels must be supplied if point_coords is supplied."
point_coords = self.transform.apply_coords(point_coords, self.original_size)
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
if box is not None:
box = self.transform.apply_boxes(box, self.original_size)
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
box_torch = box_torch[None, :]
if mask_input is not None:
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
mask_input_torch = mask_input_torch[None, :, :, :]
masks, iou_predictions, low_res_masks = self.predict_torch(
coords_torch,
labels_torch,
box_torch,
mask_input_torch,
multimask_output,
return_logits=return_logits,
)
# masks = masks[0].detach().cpu().numpy()
# iou_predictions = iou_predictions[0].detach().cpu().numpy()
# low_res_masks = low_res_masks[0].detach().cpu().numpy()
return masks, iou_predictions, low_res_masks
# @torch.no_grad()
def predict_torch(
self,
point_coords: Optional[torch.Tensor],
point_labels: Optional[torch.Tensor],
boxes: Optional[torch.Tensor] = None,
mask_input: Optional[torch.Tensor] = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
if point_coords is not None:
points = (point_coords, point_labels)
else:
points = None
sparse_embeddings, dense_embeddings = self._get_prompt_embedding(points, boxes, mask_input)
# Predict masks
low_res_masks, iou_predictions = self.model.mask_decoder(
image_embeddings=self.features,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# Upscale the masks to the original image resolution
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
# import ipdb; ipdb.set_trace()
if not return_logits:
masks = masks > self.model.mask_threshold
return masks, iou_predictions, low_res_masks
# @torch.no_grad()
def _get_prompt_embedding(self, points, boxes, mask_input):
# Embed prompts
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=points,
boxes=boxes,
masks=mask_input,
)
return sparse_embeddings, dense_embeddings
def get_image_embedding(self) -> torch.Tensor:
"""
Returns the image embeddings for the currently set image, with
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
the embedding spatial dimension of SAM (typically C=256, H=W=64).
"""
if not self.is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) to generate an embedding."
)
assert self.features is not None, "Features must exist if an image has been set."
return self.features
@property
def device(self) -> torch.device:
return self.model.device
def reset_image(self) -> None:
"""Resets the currently set image."""
self.is_image_set = False
self.features = None
self.orig_h = None
self.orig_w = None
self.input_h = None
self.input_w = None
def validation_step(self, data, batch_idx=0, **kwargs):
img = data['img']
gt_mask = data['label']
prompt_point = data['prompt_point'] # shape: (b, 2)
batch_size = prompt_point.shape[0]
point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device)
prompt_box = data['prompt_box']
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3)
prompt_point = rearrange(prompt_point, "b c -> b 1 c")
prompt_box = rearrange(prompt_box, "b c -> b 1 c")
assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}"
assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}"
assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}"
assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}"
self.set_torch_image(img, img.shape[2:])
# Stage 1: use the 1st prompt, box or point
iou_details = {}
pred_masks1, iou_predictions1, logits1 = self.predict_torch(
point_coords=prompt_point,
point_labels=point_label,
multimask_output=True,
return_logits=True,
)
pred_masks1 = rearrange(pred_masks1, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
loss_point, fl, dl, _, _ = ranked_combined_loss(pred_mask=pred_masks1, gt_mask=gt_mask, iou_pred=iou_predictions1)
iou_details['loss_point'] = loss_point.mean()
iou_details['loss_point_fl'] = fl.mean()
iou_details['loss_point_dl'] = dl.mean()
iou = compute_iou((pred_masks1>0).float(), gt_mask)
iou, _ = torch.max(iou, axis=1)
iou_details['iou_point'] = iou.mean()
pred_masks2, iou_predictions2, logits2 = self.predict_torch(
point_coords=None,
point_labels=None,
boxes=prompt_box,
multimask_output=True,
return_logits=True,
)
pred_masks2 = rearrange(pred_masks2, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
loss_box, fl, dl, _, _ = ranked_combined_loss(pred_mask=pred_masks2, gt_mask=gt_mask, iou_pred=iou_predictions2)
iou_details['loss_box'] = loss_box.mean()
iou_details['loss_box_fl'] = fl.mean()
iou_details['loss_box_dl'] = dl.mean()
iou = compute_iou((pred_masks2>0).float(), gt_mask)
iou, _ = torch.max(iou, axis=1)
iou_details['iou_box'] = iou.mean()
# import ipdb; ipdb.set_trace()
# gt_mask_np = gt_mask.detach().cpu().numpy()
# for step in range(8):
# continue
# # n
# best_pred_masks = self.select_best_mask(pred_masks, gt_mask)
# best_pred_masks_np = best_pred_masks.detach().cpu().numpy()
# # import ipdb; ipdb.set_trace()
# mask_input = logits[0, np.argmax(scores[0].detach().cpu().numpy()), :, :] # Choose the model's best mask
# sub_points, sub_labels = self.data_engine.get_subsequent_prompt_point(best_pred_masks_np, gt_mask_np)
# # sub_points, sub_labels = self.data_engine.point_prompt_generator.select_random_subsequent_point(best_pred_masks_np[0][0], gt_mask_np[0][0])
# y, x = sub_points[0][1], sub_points[0][0]
# assert gt_mask_np[0][0][y,x] + best_pred_masks_np[0][0][y,x] == 1, f"{__file__} Got{gt_mask_np[0][0][y,x], best_pred_masks_np[0][0][y,x]}"
# assert gt_mask_np[0][0][y,x] == sub_labels, f"{__file__} Got{ gt_mask_np[0][0][y,x]}, {sub_labels}"
# assert best_pred_masks_np[0][0][y,x] == (1-sub_labels), f"{__file__} Got{ gt_mask_np[0][0][y,x]}, {1-sub_labels}"
# # import ipdb; ipdb.set_trace()
# # assert sub_points
# # sub_points = np.array(sub_points)[None,...].astype(int)
# # sub_labels = np.array(sub_labels)[None,...]
# prompt_point = np.concatenate([prompt_point, sub_points], axis=0)
# point_label = np.concatenate([point_label, sub_labels], axis=0)
# # import ipdb; ipdb.set_trace()
# pred_masks2, scores, logits = model.predict(
# point_coords=prompt_point,
# point_labels=point_label,
# mask_input=mask_input[None,...],
# multimask_output=False,
# )
# iou = compute_iou(pred_masks2, gt_mask)
# iou, _ = torch.max(iou, axis=1)
# iou_details[f'point_{step+2}'] = iou
return iou_details