commit e04459c6fe16f19b02ad380e6066ab9998307d2d Author: transcendentsky Date: Tue Dec 5 14:58:38 2023 +0800 first commit diff --git a/configs/vit_b.yaml b/configs/vit_b.yaml new file mode 100644 index 0000000..13ce881 --- /dev/null +++ b/configs/vit_b.yaml @@ -0,0 +1,53 @@ +#### basic configs +# dataset: +# name: 'Cephalometric' +# pth: '/home1/quanquan/datasets/Cephalometric/' + + +# ---------------------- Common Configs -------------------------- +base: + base_dir: "../runs/sam/" + tag: '' + stage: '' +logger: + mode: ['tb', ] +# mode: '' + recorder_reduction: 'mean' + +training: + save_mode: ['all','best', 'latest'] # + batch_size : 8 # 20 for A100 + num_workers : 16 + num_epochs : 500 # epochs + use_amp: true + save_interval : 4 + val_check_interval: 6 + load_pretrain_model: false + + # optim: + lr: 0.0002 + decay_step: 2000 + decay_gamma: 0.8 + weight_decay: 0.0001 + alpha: 0.99 + validation_interval: 100 + +dataset: + types: ['3d'] # ['3d', '2d'] + split: 'train' + data_root_path: '/quanquan/datasets/' + dataset_list: ["alp", "word", "debug"] # ['sam', "their", "ours"] + data_txt_path: './datasets/dataset_list/' + dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/" + cache_data_path: '/home1/quanquan/datasets/cached_dataset2/' + + # sam_checkpoint: "/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server + # model_type: "vit_b" + # Continue training + # continue_training: true + # load_optimizer: true + # breakpoint_path: "/quanquan/code/segment-anything/runs/sam/ddp_b1/lora_3d_2dm" + +test: + batch_size: 1 + diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/__pycache__/__init__.cpython-38.pyc b/core/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..5318d45 Binary files /dev/null and b/core/__pycache__/__init__.cpython-38.pyc differ diff --git a/core/__pycache__/ddp.cpython-38.pyc b/core/__pycache__/ddp.cpython-38.pyc new file mode 100644 index 0000000..abfdcc4 Binary files /dev/null and b/core/__pycache__/ddp.cpython-38.pyc differ diff --git a/core/__pycache__/ddp_b10.cpython-38.pyc b/core/__pycache__/ddp_b10.cpython-38.pyc new file mode 100644 index 0000000..59d673d Binary files /dev/null and b/core/__pycache__/ddp_b10.cpython-38.pyc differ diff --git a/core/__pycache__/learner2.cpython-38.pyc b/core/__pycache__/learner2.cpython-38.pyc new file mode 100644 index 0000000..591e1fc Binary files /dev/null and b/core/__pycache__/learner2.cpython-38.pyc differ diff --git a/core/__pycache__/learner3.cpython-38.pyc b/core/__pycache__/learner3.cpython-38.pyc new file mode 100644 index 0000000..c049ba7 Binary files /dev/null and b/core/__pycache__/learner3.cpython-38.pyc differ diff --git a/core/__pycache__/learner5.cpython-38.pyc b/core/__pycache__/learner5.cpython-38.pyc new file mode 100644 index 0000000..e93f8eb Binary files /dev/null and b/core/__pycache__/learner5.cpython-38.pyc differ diff --git a/core/__pycache__/lora_sam.cpython-38.pyc b/core/__pycache__/lora_sam.cpython-38.pyc new file mode 100644 index 0000000..d3d3919 Binary files /dev/null and b/core/__pycache__/lora_sam.cpython-38.pyc differ diff --git a/core/__pycache__/loss.cpython-38.pyc b/core/__pycache__/loss.cpython-38.pyc new file mode 100644 index 0000000..26a913b Binary files /dev/null and b/core/__pycache__/loss.cpython-38.pyc differ diff --git a/core/ddp.py b/core/ddp.py new file mode 100644 index 0000000..0588528 --- /dev/null +++ b/core/ddp.py @@ -0,0 +1,142 @@ +""" + from ddp_b9.py + + Train the whole ViT +""" + +import os +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from tutils import tfilename, tdir + +from datasets.dataset3d_2dmask import Dataset2D +# from datasets.dataset3d import Dataset3D +from datasets.cache_dataset3d3 import Dataset3D +from datasets.dataset_merged import DatasetMerged, TestsetMerged +from datasets.data_engine import DataEngine +from modeling.build_sam3d2 import sam_model_registry + +from .learner5 import SamLearner +# from tutils.new.trainer.trainer_ddp import DDPTrainer +from trans_utils.trainer_ddp import DDPTrainer +# from .lora_sam import LoRA_Sam + +import warnings +warnings.filterwarnings("ignore") + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + +def cleanup(): + dist.destroy_process_group() + +def load_well_trained_model(model, lora_module, pth=None): + print("Loading from ", pth) + state_dict = torch.load(pth, map_location='cpu') + model.load_state_dict(state_dict) + +def ddp_train(rank, world_size, config): + setup(rank, world_size) + + sam_checkpoint = "/quanquan/code/segment-anything/segment_anything/sam_vit_b_01ec64.pth" # A800 server + # sam_checkpoint = "/home1/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server + model_type = "vit_b" + device = rank + + config_data = config['dataset'] + data_type = config_data.get("types", ["3d", "2d"]) + data_type = [data_type] if isinstance(data_type, str) else data_type + if '2d' in data_type: + if '3d' in data_type: + dataset = DatasetMerged(config_data, is_train=True) + else: + dataset = Dataset2D(dirpath=config_data['dataset2d_path'], is_train=True) + elif '3d' in data_type: + dataset = Dataset3D(config_data, split='train') + else: + raise NotImplementedError + # assert len(validset) > 0 + data_engine = DataEngine(dataset=dataset, img_size=(1024,1024)) + sam = sam_model_registry[model_type](checkpoint=None) + + learner = SamLearner(sam_model=sam, config=config, data_engine=data_engine) + config_train = config['training'] + if config_train.get('continue_training', False): + learner.use_lora() + learner.load_well_trained_model(config['training']['breakpoint_path']) # use preset path + else: + learner.load_pretrained_model(sam_checkpoint) + learner.use_lora() + + print('Before Setting', get_parameter_number(learner.model)) + for param in learner.model.image_encoder.parameters(): + param.requires_grad = True + print('After Setting', get_parameter_number(learner.model)) + + ddp_trainer = DDPTrainer(config=config, rank=rank, world_size=world_size) + ddp_trainer.fit(learner, trainset=data_engine, validset=None) + + cleanup() + + +def get_parameter_number(model): + total_num = sum(p.numel() for p in model.parameters()) + trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + return {'Total': total_num, 'Trainable': trainable_num} + +def run_demo(demo_fn, world_size, config): + mp.spawn(demo_fn, + args=(world_size,config), + nprocs=world_size, + join=True) + +from collections import OrderedDict +import yaml +import yamlloader +def _ordereddict_to_dict(d): + if not isinstance(d, dict): + return d + for k, v in d.items(): + if isinstance(v, OrderedDict): + v = _ordereddict_to_dict(v) + d[k] = dict(v) + elif type(v) == list: + d[k] = _ordereddict_to_dict(v) + elif isinstance(v, dict): + d[k] = _ordereddict_to_dict(v) + return d + +# CUDA_VISIBLE_DEVICES=4,5,6,7 python -m core.ddp_b3 --tag lora --config configs/vit_b_103.yaml + +if __name__ == "__main__": + import argparse + from tutils.new.manager import trans_args, trans_init, ConfigManager + + n_gpus = torch.cuda.device_count() + # assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but {__file__} Got{n_gpus}" + if n_gpus == 1: + print("Warning! Running on only 1 GPU! just for debug") + world_size = n_gpus + + parser = argparse.ArgumentParser() + parser.add_argument("--config", default="./configs/vit_b.yaml") + parser.add_argument("--func", default="train") + + args = trans_args(parser=parser) + config = ConfigManager() + config.auto_init(file=__file__, args=args, ex_config=None) + # config.save() + path = tfilename(config['base']['runs_dir'], "config.yaml") + with open(path, "w") as f: + yaml.dump(_ordereddict_to_dict(config), f) + print("Save config file to ", path) + + if n_gpus < 1: exit(0) + run_demo(ddp_train, world_size, config) diff --git a/core/learner2.py b/core/learner2.py new file mode 100644 index 0000000..d5475f7 --- /dev/null +++ b/core/learner2.py @@ -0,0 +1,522 @@ +import torch +import torchvision +import numpy as np +from tutils.trainer import Trainer, LearnerModule +from torch.utils.data import DataLoader +from torch import optim +import torch.optim.lr_scheduler as lr_scheduler +from einops import rearrange, repeat +from torch.nn import functional as F +import os +from typing import Optional, Tuple +import torch.optim.lr_scheduler as lr_scheduler + +from modeling.sam3d import Sam +# from segment_anything.utils.transforms import ResizeLongestSide +from utils.transforms import ResizeLongestSide +from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss +from .lora_sam import LoRA_Sam +from safetensors import safe_open +from datasets.data_engine import DataEngine + + +# def lr_schedule(epoch): +# if epoch < 250: +# return (epoch + 1) / 250 * 0.0008 + 0.00004 +# elif epoch < 500: +# return 0.0001 +# else: +# return 0.0001 + +def lr_schedule(epoch): + if epoch < 250: + return (epoch + 1) / 250 * 0.1 + elif epoch < 500: + return 0.01 + else: + return 0.001 + +class SamLearner(LearnerModule): + def __init__( + self, + sam_model: Sam, + config=None, + logger=None, + data_engine=DataEngine(None, img_size=(1024,1024)), + lora_module=None, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.config = config + self.logger = logger + self.model = sam_model + self.net = self.model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + self.data_engine = data_engine + self.features = None + self.lora_module = lora_module + + def save(self, pth, *args, **kwargs): + # Default: "/model_epoch_{}.pth".format(epoch) + torch.save(self.net.state_dict(), pth) + lora_path = pth.replace(".pth", "_lora.safetensors") + self.lora_module.save_lora_parameters(lora_path) + return True + + def load_pretrained_model(self, pth, *args, **kwargs): + """ + Unmatched: prompt_encoder.mask_downscaling.0.weight + their: torch.Size([4, 1, 2, 2]) + our: torch.Size([4, 3, 2, 2]) + Unmatched: mask_decoder.mask_tokens.weight + their: torch.Size([4, 256]) + our: torch.Size([12, 256]) + + """ + state_dict = torch.load(pth) + model_state_dict = self.model.state_dict() + model_state_dict.update(state_dict) + model_state_dict['prompt_encoder.mask_downscaling.0.weight'] = repeat(state_dict['prompt_encoder.mask_downscaling.0.weight'], "a 1 c d -> a b c d", b=3) + model_state_dict['mask_decoder.mask_tokens.weight'] = repeat(state_dict['mask_decoder.mask_tokens.weight'], "a d -> (a 3) d") + hyper_params_names = [k for k in model_state_dict.keys() if k.startswith("mask_decoder.output_hypernetworks_mlps")] + for name in hyper_params_names: + words = name.split('.') + words[2] = str(int(words[2]) // 3) + name_to_copy = ".".join(words) + model_state_dict[name] = state_dict[name_to_copy] + # for k, v in state_dict.items(): + # if model_state_dict[k].shape != state_dict[k].shape: + # print("Unmatched:", k) + self.model.load_state_dict(model_state_dict) + + def load_well_trained_model(self, pth=None): + pth = self.config['training']['breakpoint_path'] + "/ckpt_v/model_latest.pth" if pth is None else pth + print("Loading from ", pth) + state_dict = torch.load(pth, map_location="cpu") + # print(state_dict.keys()) + # for k in state_dict.keys(): + # print(k) + # exit(0) + self.model.load_state_dict(state_dict) + # self.lora_module.load_lora_parameters(pth.replace(".pth", "_lora.safetensors")) + + def use_lora(self): + lora_r = 8 + lora_sam = LoRA_Sam(self.model, lora_r, freeze_prompt_encoder=True) + self.lora_module = lora_sam + + def configure_optimizers(self, **kwargs): + optimizer = optim.AdamW(params=self.model.parameters(), \ + lr=self.config['training']['lr'], betas=(0.9, 0.999), eps=1e-08, + weight_decay=self.config['training']['weight_decay']) + # scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule) + scheduler = None + return {'optimizer': optimizer, "scheduler": scheduler} + + def load_optim(self, optimizer, pth=None, *args): + pth = self.config['training']['breakpoint_path'] + "/ckpt/optim_latest.pth" + print("Load Optimizer from ", pth) + state_dict = torch.load(pth) + optimizer.load_state_dict(state_dict['optimizer']) + start_epoch = state_dict.get('epoch', 0) + 1 + return start_epoch + + def training_step(self, data, batch_idx, **kwargs): + img = data['img'] + gt_mask = data['label'] + prompt_point = data['prompt_point'] # shape: (b, 2) + batch_size = prompt_point.shape[0] + point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device) + prompt_box = data['prompt_box'] + + prompt_point = rearrange(prompt_point, "b c -> b 1 c") + prompt_box = rearrange(prompt_box, "b c -> b 1 c") + assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}" + assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}" + assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}" + assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}" + + self.set_torch_image(img, img.shape[2:]) + if np.random.random() > 0.5: + pred_masks, iou_predictions, logits = self.predict_torch( + point_coords=prompt_point, + point_labels=point_label, + multimask_output=True, + return_logits=True, + ) + else: + pred_masks, iou_predictions, logits = self.predict_torch( + point_coords=None, + point_labels=None, + boxes=prompt_box, + multimask_output=True, + return_logits=True, + ) + # assert pred_masks.shape == gt_mask.shape, f"Got {pred_masks.shape}, {gt_mask.shape}" + loss_1, fl, dl = ranked_combined_loss(pred_mask=pred_masks, gt_mask=gt_mask, iou_pred=iou_predictions) + + # print("Debug trainer: 2", prompt_point.shape, point_label.shape, prompt_box.shape) + # Stage 2: based on the above, add more points as prompts + return {"loss": loss_1, "fl": fl.mean(), "dl": dl.mean()} + + # @torch.no_grad() + def generate(self, image, prompt_point): + orig_size = image.shape[2:] + assert image.shape[1:] == (3,1024,1024),f"{__file__} Got{image.shape}" + if not self.is_image_set: + self.set_torch_image(image, orig_size) + + assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}" + # assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}" + point_label = torch.ones(prompt_point.size()[:-1]) + pred_masks, scores, logits = self.predict_torch( + point_coords=prompt_point, + point_labels=point_label, + mask_input=None, + multimask_output=True, + ) + return pred_masks + + # @torch.no_grad() + def generate_by_box(self, image, prompt_box): + orig_size = image.shape[2:] + assert image.shape[1:] == (3,1024,1024),f"{__file__} Got{image.shape}" + if not self.is_image_set: + self.set_torch_image(image, orig_size) + + assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}" + pred_masks, scores, logits = self.predict_torch( + point_coords=None, + point_labels=None, + boxes=prompt_box, + mask_input=None, + multimask_output=True, + ) + return pred_masks + + @staticmethod + def select_best_mask(predictions, ground_truth): + # Move tensors to the same device (if not already on the same device) + # if predictions.device != ground_truth.device: + # predictions = predictions.to(ground_truth.device) + + # Compute IoU between each prediction and ground truth + if predictions.shape[1] == 9: + predictions = rearrange(predictions, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3) + ground_truth = repeat(ground_truth, "b d h w -> b c d h w", c=3) + else: + predictions = rearrange(predictions, "b d h w -> b 1 d h w") + ground_truth = rearrange(ground_truth, "b d h w -> b 1 d h w") + intersection = torch.sum(predictions * ground_truth, dim=(-3, -2, -1)) + union = torch.sum(predictions + ground_truth, dim=(-3, -2, -1)) - intersection + iou = intersection / (union + 1e-6) + + # Select the prediction with maximum IoU for each image in the batch + best_indices = torch.argmax(iou, dim=1) + best_masks = torch.gather(predictions, 1, best_indices.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, 1, predictions.shape[-3], predictions.shape[-2], predictions.shape[-1])) + + return best_masks + + # =============================================== + def predict_multi_prompt( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + mask_logits: Optional[torch.Tensor], + ): + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + + if point_coords is not None: + points = (coords_torch, point_labels) + else: + points = None + + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=None, + masks=mask_logits, + ) + + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=False, + ) + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + return masks, iou_predictions, low_res_masks + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + # @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + # masks = masks[0].detach().cpu().numpy() + # iou_predictions = iou_predictions[0].detach().cpu().numpy() + # low_res_masks = low_res_masks[0].detach().cpu().numpy() + return masks, iou_predictions, low_res_masks + + # @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + sparse_embeddings, dense_embeddings = self._get_prompt_embedding(points, boxes, mask_input) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + # import ipdb; ipdb.set_trace() + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + # @torch.no_grad() + def _get_prompt_embedding(self, points, boxes, mask_input): + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + return sparse_embeddings, dense_embeddings + + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None + + + def validation_step(self, data, batch_idx=0, **kwargs): + img = data['img'] + gt_mask = data['label'] + prompt_point = data['prompt_point'] # shape: (b, 2) + batch_size = prompt_point.shape[0] + point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device) + prompt_box = data['prompt_box'] + gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3) + + prompt_point = rearrange(prompt_point, "b c -> b 1 c") + prompt_box = rearrange(prompt_box, "b c -> b 1 c") + assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}" + assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}" + assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}" + assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}" + + self.set_torch_image(img, img.shape[2:]) + + # Stage 1: use the 1st prompt, box or point + iou_details = {} + pred_masks1, iou_predictions1, logits1 = self.predict_torch( + point_coords=prompt_point, + point_labels=point_label, + multimask_output=True, + return_logits=True, + ) + + pred_masks1 = rearrange(pred_masks1, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3) + loss_point, fl, dl, _, _ = ranked_combined_loss(pred_mask=pred_masks1, gt_mask=gt_mask, iou_pred=iou_predictions1) + iou_details['loss_point'] = loss_point.mean() + iou_details['loss_point_fl'] = fl.mean() + iou_details['loss_point_dl'] = dl.mean() + + iou = compute_iou((pred_masks1>0).float(), gt_mask) + iou, _ = torch.max(iou, axis=1) + iou_details['iou_point'] = iou.mean() + + pred_masks2, iou_predictions2, logits2 = self.predict_torch( + point_coords=None, + point_labels=None, + boxes=prompt_box, + multimask_output=True, + return_logits=True, + ) + pred_masks2 = rearrange(pred_masks2, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3) + loss_box, fl, dl, _, _ = ranked_combined_loss(pred_mask=pred_masks2, gt_mask=gt_mask, iou_pred=iou_predictions2) + iou_details['loss_box'] = loss_box.mean() + iou_details['loss_box_fl'] = fl.mean() + iou_details['loss_box_dl'] = dl.mean() + + iou = compute_iou((pred_masks2>0).float(), gt_mask) + iou, _ = torch.max(iou, axis=1) + iou_details['iou_box'] = iou.mean() + + # import ipdb; ipdb.set_trace() + + + # gt_mask_np = gt_mask.detach().cpu().numpy() + # for step in range(8): + # continue + # # n + # best_pred_masks = self.select_best_mask(pred_masks, gt_mask) + # best_pred_masks_np = best_pred_masks.detach().cpu().numpy() + + # # import ipdb; ipdb.set_trace() + # mask_input = logits[0, np.argmax(scores[0].detach().cpu().numpy()), :, :] # Choose the model's best mask + + # sub_points, sub_labels = self.data_engine.get_subsequent_prompt_point(best_pred_masks_np, gt_mask_np) + # # sub_points, sub_labels = self.data_engine.point_prompt_generator.select_random_subsequent_point(best_pred_masks_np[0][0], gt_mask_np[0][0]) + + # y, x = sub_points[0][1], sub_points[0][0] + # assert gt_mask_np[0][0][y,x] + best_pred_masks_np[0][0][y,x] == 1, f"{__file__} Got{gt_mask_np[0][0][y,x], best_pred_masks_np[0][0][y,x]}" + # assert gt_mask_np[0][0][y,x] == sub_labels, f"{__file__} Got{ gt_mask_np[0][0][y,x]}, {sub_labels}" + # assert best_pred_masks_np[0][0][y,x] == (1-sub_labels), f"{__file__} Got{ gt_mask_np[0][0][y,x]}, {1-sub_labels}" + # # import ipdb; ipdb.set_trace() + # # assert sub_points + + # # sub_points = np.array(sub_points)[None,...].astype(int) + # # sub_labels = np.array(sub_labels)[None,...] + # prompt_point = np.concatenate([prompt_point, sub_points], axis=0) + # point_label = np.concatenate([point_label, sub_labels], axis=0) + + # # import ipdb; ipdb.set_trace() + + # pred_masks2, scores, logits = model.predict( + # point_coords=prompt_point, + # point_labels=point_label, + # mask_input=mask_input[None,...], + # multimask_output=False, + # ) + + # iou = compute_iou(pred_masks2, gt_mask) + # iou, _ = torch.max(iou, axis=1) + # iou_details[f'point_{step+2}'] = iou + + return iou_details + \ No newline at end of file diff --git a/core/learner3.py b/core/learner3.py new file mode 100644 index 0000000..023c524 --- /dev/null +++ b/core/learner3.py @@ -0,0 +1,154 @@ +import torch +import torchvision +import numpy as np +from tutils.trainer import Trainer, LearnerModule +from einops import rearrange, repeat, reduce +import torch.optim.lr_scheduler as lr_scheduler + +from core.loss import ranked_combined_loss_with_indicators +from .learner2 import SamLearner as basic_learner +from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss + + +class SamLearner(basic_learner): + + def training_step(self, data, batch_idx, **kwargs): + img = data['img'] + gt_mask = data['label'] + prompt_point = data['prompt_point'] # shape: (b, 2) + batch_size = prompt_point.shape[0] + point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device) + prompt_box = data['prompt_box'] + indicators = data['indicators'] + # print(data['name']) + + prompt_point = rearrange(prompt_point, "b c -> b 1 c") + prompt_box = rearrange(prompt_box, "b c -> b 1 c") + assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}" + assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}" + assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}" + assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}" + + self.set_torch_image(img, img.shape[2:]) + # if np.random.random() > 0.5: + pred_masks, iou_predictions, logits = self.predict_torch( + point_coords=prompt_point, + point_labels=point_label, + multimask_output=True, + return_logits=True, + ) + loss_1, fl, dl, min_losses, selected_losses, _ = ranked_combined_loss_with_indicators(pred_mask=pred_masks, gt_mask=gt_mask, iou_pred=iou_predictions, indicators=indicators) + # else: + pred_masks, iou_predictions, logits = self.predict_torch( + point_coords=None, + point_labels=None, + boxes=prompt_box, + multimask_output=True, + return_logits=True, + ) + # assert pred_masks.shape == gt_mask.shape, f"Got {pred_masks.shape}, {gt_mask.shape}" + loss_2, fl, dl, min_losses, selected_losses, _ = ranked_combined_loss_with_indicators(pred_mask=pred_masks, gt_mask=gt_mask, iou_pred=iou_predictions, indicators=indicators) + + loss = loss_1 + loss_2 + if loss < -999 or torch.isnan(loss): + print("Warning! Loss Error! ") + print(data['name']) + + # print("Debug trainer: 2", prompt_point.shape, point_label.shape, prompt_box.shape) + # Stage 2: based on the above, add more points as prompts + return {"loss": loss_1 + loss_2, "point_loss": loss_1.mean(), "box_loss": loss_2, "fl": fl.mean(), "dl": dl.mean(), "min_losses": min_losses, "selected_losses": selected_losses} + + + def validation_step(self, data, batch_idx=0, **kwargs): + img = data['img'] + gt_mask = data['label'] + prompt_point = data['prompt_point'] # shape: (b, 2) + batch_size = prompt_point.shape[0] + point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device) + prompt_box = data['prompt_box'] + indicators = data['indicators'] + print(data['name']) + prompt_point = rearrange(prompt_point, "b c -> b 1 c") + prompt_box = rearrange(prompt_box, "b c -> b 1 c") + assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}" + assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}" + assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}" + assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}" + + self.set_torch_image(img, img.shape[2:]) + + # Stage 1: use the 1st prompt, box or point + iou_details = {} + pred_masks1, iou_predictions1, logits1 = self.predict_torch( + point_coords=prompt_point, + point_labels=point_label, + multimask_output=True, + return_logits=True, + ) + + if len(pred_masks1.shape) == 4: + pred_masks1 = rearrange(pred_masks1, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3) + + loss_point, fl, dl, min_losses, selected_losses, min_indices = ranked_combined_loss_with_indicators(pred_mask=pred_masks1, gt_mask=gt_mask, iou_pred=iou_predictions1, indicators=indicators) + iou_details['loss_point'] = loss_point.mean() + iou_details['loss_point_fl'] = fl.mean() + iou_details['loss_point_dl'] = dl.mean() + + if len(gt_mask.shape) == 4: + gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3) + + indices = iou_predictions1.argmax(axis=1) + pred_maxiou = [] + for pred, i in zip(pred_masks1, indices): + pred_maxiou.append(pred[i,:,:,:]) + pred_maxiou = torch.stack(pred_maxiou, axis=0) + iou = compute_iou2(torch.tensor(pred_maxiou>0, dtype=gt_mask.dtype), gt_mask[:,0,:,:,:]).detach() + iou_details['iou_point'] = iou.mean() + + iou = compute_iou(torch.tensor(pred_masks1>0, dtype=gt_mask.dtype), gt_mask).detach() + iou, _ = torch.max(iou, axis=1) + iou_details['iou_point_max'] = iou.mean() + + pred_masks2, iou_predictions2, logits2 = self.predict_torch( + point_coords=None, + point_labels=None, + boxes=prompt_box, + multimask_output=True, + return_logits=True, + ) + loss_box, fl, dl, min_losses, selected_losses, min_indices = ranked_combined_loss_with_indicators(pred_mask=pred_masks2, gt_mask=gt_mask, iou_pred=iou_predictions2, indicators=indicators) + iou_details['loss_box'] = loss_box.mean() + iou_details['loss_box_fl'] = fl.mean() + iou_details['loss_box_dl'] = dl.mean() + + if len(gt_mask.shape) == 4: + gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3) + if len(pred_masks2.shape) == 4: + pred_masks2 = rearrange(pred_masks2, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3) + + indices = iou_predictions2.argmax(axis=1) + pred_maxiou = [] + for pred, i in zip(pred_masks2, indices): + pred_maxiou.append(pred[i,:,:,:]) + pred_maxiou = torch.stack(pred_maxiou, axis=0) + iou = compute_iou2(torch.tensor(pred_maxiou>0, dtype=gt_mask.dtype), gt_mask[:,0,:,:,:]).detach() + iou_details['iou_box'] = iou.mean() + + iou = compute_iou(torch.tensor(pred_masks2>0, dtype=gt_mask.dtype), gt_mask).detach() + iou, _ = torch.max(iou, axis=1) + iou_details['iou_box_max'] = iou.mean() + return iou_details + + +def compute_iou2(pred_mask, gt_mask): + dtype = pred_mask.dtype + intersection = torch.logical_and(pred_mask, gt_mask) + intersection = reduce(intersection, "b d h w -> b", reduction='sum') + union = torch.logical_or(pred_mask, gt_mask) + union = reduce(union, "b d h w -> b", reduction='sum') + 1e-8 + iou = intersection / union # if union > 0 else 0 + iou = torch.tensor(iou, dtype=dtype) + # print("ranked_combined_loss: compute_iou ", intersection.dtype, union.dtype, iou.dtype) + return iou + +# def save(img, mask, mask2): diff --git a/core/learner5.py b/core/learner5.py new file mode 100644 index 0000000..5eb8201 --- /dev/null +++ b/core/learner5.py @@ -0,0 +1,55 @@ +""" + Use mask_decoder3d_2.py +""" + +import torch +import torchvision +import numpy as np +from tutils.trainer import Trainer, LearnerModule +from einops import rearrange, repeat, reduce +import torch.optim.lr_scheduler as lr_scheduler + +from core.loss import ranked_combined_loss_with_indicators +from .learner3 import SamLearner as basic_learner +from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss + + +class SamLearner(basic_learner): + + def load_pretrained_model(self, pth, *args, **kwargs): + """ + Unmatched: prompt_encoder.mask_downscaling.0.weight + their: torch.Size([4, 1, 2, 2]) + our: torch.Size([4, 3, 2, 2]) + Unmatched: mask_decoder.mask_tokens.weight + their: torch.Size([4, 256]) + our: torch.Size([12, 256]) + """ + print("Load pretrained model for mask_decoder3d_2 !!") + + state_dict = torch.load(pth) + model_state_dict = self.model.state_dict() + model_state_dict.update(state_dict) + model_state_dict['prompt_encoder.mask_downscaling.0.weight'] = repeat(state_dict['prompt_encoder.mask_downscaling.0.weight'], "a 1 c d -> a b c d", b=3) + # model_state_dict['mask_decoder.mask_tokens.weight'] = repeat(state_dict['mask_decoder.mask_tokens.weight'], "a d -> (a 3) d") + + for k, v in model_state_dict.items(): + if k.startswith("mask_decoder.output_upscaling2"): + k2 = k.replace("output_upscaling2.", "output_upscaling." ) + model_state_dict[k] = model_state_dict[k2] + print("Load weights: ", k) + if k.startswith("mask_decoder.output_upscaling3"): + k2 = k.replace("output_upscaling3.", "output_upscaling." ) + model_state_dict[k] = model_state_dict[k2] + print("Load weights: ", k) + + hyper_params_names = [k for k in model_state_dict.keys() if k.startswith("mask_decoder.output_hypernetworks_mlps")] + for name in hyper_params_names: + words = name.split('.') + words[2] = str(int(words[2]) // 3) + name_to_copy = ".".join(words) + model_state_dict[name] = state_dict[name_to_copy] + # for k, v in state_dict.items(): + # if model_state_dict[k].shape != state_dict[k].shape: + # print("Unmatched:", k) + self.model.load_state_dict(model_state_dict) \ No newline at end of file diff --git a/core/lora_sam.py b/core/lora_sam.py new file mode 100644 index 0000000..81f8595 --- /dev/null +++ b/core/lora_sam.py @@ -0,0 +1,196 @@ +# Modified from https://github.com/JamesQFreeman/Sam_LoRA + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn.parameter import Parameter +from safetensors import safe_open +from safetensors.torch import save_file +# from modeling.sam3d import Sam + + +class _LoRA_qkv(nn.Module): + """In Sam it is implemented as + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + """ + + def __init__( + self, + qkv: nn.Module, + linear_a_q: nn.Module, + linear_b_q: nn.Module, + linear_a_v: nn.Module, + linear_b_v: nn.Module, + ): + super().__init__() + self.qkv = qkv + self.linear_a_q = linear_a_q + self.linear_b_q = linear_b_q + self.linear_a_v = linear_a_v + self.linear_b_v = linear_b_v + self.dim = qkv.in_features + self.w_identity = torch.eye(qkv.in_features) + + def forward(self, x): + qkv = self.qkv(x) # B,N,N,3*org_C + new_q = self.linear_b_q(self.linear_a_q(x)) + new_v = self.linear_b_v(self.linear_a_v(x)) + qkv[:, :, :, : self.dim] += new_q + qkv[:, :, :, -self.dim :] += new_v + return qkv + +class LoRA_Sam(nn.Module): + """Applies low-rank adaptation to a Sam model's image encoder. + + Args: + sam_model: a vision transformer model, see base_vit.py + r: rank of LoRA + num_classes: how many classes the model output, default to the vit model + lora_layer: which layer we apply LoRA. + freeze_all: freeze whole sam, otherwise only image encoder (VIT) + + Examples:: + >>> model = ViT('B_16_imagenet1k') + >>> lora_model = LoRA_ViT(model, r=4) + >>> preds = lora_model(img) + >>> print(preds.shape) + torch.Size([1, 1000]) + """ + + def __init__(self, sam_model, r: int, lora_layer:[int]=None, freeze_all:bool=False, freeze_prompt_encoder=True): + super(LoRA_Sam, self).__init__() + + assert r > 0 + # base_vit_dim = sam_model.image_encoder.patch_embed.proj.out_channels + # dim = base_vit_dim + if lora_layer: + self.lora_layer = lora_layer + else: + self.lora_layer = list(range(len(sam_model.image_encoder.blocks))) + # create for storage, then we can init them or load weights + self.w_As = [] # These are linear layers + self.w_Bs = [] + + # lets freeze first + if freeze_all: + for param in sam_model.parameters(): + param.requires_grad = False + else: + for param in sam_model.image_encoder.parameters(): + param.requires_grad = False + for param in sam_model.image_encoder.patch_embed.parameters(): + param.requires_grad = True + if freeze_prompt_encoder: + for param in sam_model.prompt_encoder.parameters(): + param.requires_grad = False + + # Here, we do the surgery + for t_layer_i, blk in enumerate(sam_model.image_encoder.blocks): + # If we only want few lora layer instead of all + if t_layer_i not in self.lora_layer: + continue + w_qkv_linear = blk.attn.qkv + self.dim = w_qkv_linear.in_features + w_a_linear_q = nn.Linear(self.dim, r, bias=False) + w_b_linear_q = nn.Linear(r, self.dim, bias=False) + w_a_linear_v = nn.Linear(self.dim, r, bias=False) + w_b_linear_v = nn.Linear(r, self.dim, bias=False) + self.w_As.append(w_a_linear_q) + self.w_Bs.append(w_b_linear_q) + self.w_As.append(w_a_linear_v) + self.w_Bs.append(w_b_linear_v) + blk.attn.qkv = _LoRA_qkv( + w_qkv_linear, + w_a_linear_q, + w_b_linear_q, + w_a_linear_v, + w_b_linear_v, + ) + self.reset_parameters() + self.sam = sam_model + # with open('vit_named_para.txt', 'w') as f: + # for k, v in sam_model.image_encoder.named_parameters(): + # f.write(f'{k} {v.shape}\n') + + + def save_lora_parameters(self, filename: str) -> None: + r"""Only safetensors is supported now. + + pip install safetensor if you do not have one installed yet. + + save both lora and fc parameters. + """ + # assert filename.endswith(".safetensors") + + num_layer = len(self.w_As) # actually, it is half + a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)} + b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)} + + + merged_dict = {**a_tensors, **b_tensors} + save_file(merged_dict, filename) + + def load_lora_parameters(self, filename: str) -> None: + r"""Only safetensors is supported now. + + pip install safetensor if you do not have one installed yet.\ + + load both lora and fc parameters. + """ + + assert filename.endswith(".safetensors") + + with safe_open(filename, framework="pt") as f: + for i, w_A_linear in enumerate(self.w_As): + saved_key = f"w_a_{i:03d}" + saved_tensor = f.get_tensor(saved_key) + w_A_linear.weight = Parameter(saved_tensor) + + for i, w_B_linear in enumerate(self.w_Bs): + saved_key = f"w_b_{i:03d}" + saved_tensor = f.get_tensor(saved_key) + w_B_linear.weight = Parameter(saved_tensor) + # import ipdb; ipdb.set_trace() + + def reset_parameters(self) -> None: + for w_A in self.w_As: + nn.init.kaiming_uniform_(w_A.weight, a=5**0.5) + for w_B in self.w_Bs: + nn.init.zeros_(w_B.weight) + + +def get_parameter_number(model): + total_num = sum(p.numel() for p in model.parameters()) + trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + return {'Total': total_num, 'Trainable': trainable_num} + + +if __name__ == "__main__": + from segment_anything import sam_model_registry + from segment_anything import SamPredictor, SamAutomaticMaskGenerator # prompt and every mode + import numpy as np + lora_r = 8 + path = "../checkpoints/sam_vit_b_01ec64.pth" + sam = sam_model_registry["vit_b"](checkpoint=None) + print('before lora', get_parameter_number(sam)) + lora_sam = LoRA_Sam(sam, lora_r) + print('after lora', get_parameter_number(sam)) + x = torch.rand(size=(3,1024,1024)) + path = '../data/cache/data2d_3/0007_s0069_img.npy' + img = np.load(path) + print('img shape', img.shape) + mask_generator = SamAutomaticMaskGenerator(sam) + masks = mask_generator.generate(x) + print('mask num', len(masks)) + # loss = np.sum([mask.mean(-1).mean(-1) for mask in masks]) + # predictor = SamPredictor(sam) + # predictor.set_image(x) + # masks, _, _ = predictor.predict([50,50]) + + for f_name in ['save_lora_parameters', 'load_lora_parameters']: + print(f_name) + getattr(lora_sam, f_name)('tmp.safetensors') diff --git a/core/loss.py b/core/loss.py new file mode 100644 index 0000000..510cc8a --- /dev/null +++ b/core/loss.py @@ -0,0 +1,155 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from einops import repeat, rearrange, reduce +import numpy as np + + +def compute_dice_np(pred_mask, gt_mask): + """ numpy values + """ + pred_mask = np.array(pred_mask>0) + gt_mask = np.array(gt_mask>0) + intersection = np.array(pred_mask * gt_mask).sum() + union = pred_mask.sum() + gt_mask.sum() + dice = intersection * 2 / union # if union > 0 else 0 + return dice + + +def combined_loss(logits, targets, alpha=0.2, gamma=2.0, smooth=1e-5, reduction='mean'): + # Calculate the focal loss + fl = F.binary_cross_entropy_with_logits(logits, targets, reduction='none') + pt = torch.exp(-fl) + focal_loss = alpha * (1 - pt) ** gamma * fl + + if reduction == 'mean': + fl = torch.mean(focal_loss) + elif reduction == 'sum': + fl = torch.sum(focal_loss) + + # Calculate the Dice loss + prob = torch.sigmoid(logits) + intersection = torch.sum(prob * targets, dim=(-2, -1)) + union = torch.sum(prob + targets, dim=(-2, -1)) + dice_loss = 1 - (2 * intersection + smooth) / (union + smooth) + + return focal_loss, dice_loss + + if reduction == 'mean': + dl = torch.mean(dice_loss) + elif reduction == 'sum': + dl = torch.sum(dice_loss) + + # Combine the losses using the specified ratio + loss = 20 * fl + dl + return loss + +# Assuming your prediction and ground truth tensors are named `pred` and `gt`, respectively +def mse_loss(pred, gt): + mse_loss = nn.MSELoss(reduction='none') + loss = mse_loss(pred, gt) + return loss + +def compute_iou(pred_mask, gt_mask): + dtype = pred_mask.dtype + intersection = torch.logical_and(pred_mask, gt_mask) + intersection = reduce(intersection, "b c d h w -> b c", reduction='sum') + union = torch.logical_or(pred_mask, gt_mask) + union = reduce(union, "b c d h w -> b c", reduction='sum') + 1e-8 + iou = intersection / union # if union > 0 else 0 + iou = torch.tensor(iou, dtype=dtype) + # print("ranked_combined_loss: compute_iou ", intersection.dtype, union.dtype, iou.dtype) + return iou + +def ranked_combined_loss(pred_mask, gt_mask, iou_pred): + # (b c1 c2 h w), c1: num_prediction; c2: num_slices + if len(gt_mask.shape) == 4: + gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3) + if len(pred_mask.shape) == 4: + pred_mask = rearrange(pred_mask, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3) + fl, dl = combined_loss(pred_mask, gt_mask) + fl = reduce(fl, "b c d h w -> b c", reduction="mean") + dl = reduce(dl, "b c d-> b c", reduction="mean") + segment_loss = 20*fl + dl + min_losses, min_loss_indices = torch.min(segment_loss, dim=1) + iou = compute_iou(torch.tensor(torch.tensor(pred_mask>0, dtype=gt_mask.dtype)>0, dtype=gt_mask.dtype), gt_mask).detach().detach() + # print("ranked_combined_loss ", iou.dtype) + iou_loss = mse_loss(iou_pred, iou) + + selected_losses = torch.gather(iou_loss, 1, min_loss_indices.unsqueeze(1)) + selected_fl = torch.gather(fl, 1, min_loss_indices.unsqueeze(1)) + selected_dl = torch.gather(dl, 1, min_loss_indices.unsqueeze(1)) + # print(min_losses.shape, selected_losses.shape) + + total_loss = min_losses.mean() + selected_losses.mean() + # return total_loss, min_losses, selected_losses + return total_loss, selected_fl, selected_dl, min_losses.mean(), selected_losses.mean(), min_loss_indices + +def ranked_combined_loss_one_slice(pred_mask, gt_mask, iou_pred, mask_loc): + if len(gt_mask.shape) == 4: + # assert gt_mask.shape[1] == 1, f"Got {gt_mask.shape}" + gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3) + if len(pred_mask.shape) == 4: + pred_mask = rearrange(pred_mask, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3) + gt_mask = gt_mask[:,:,mask_loc,:,:] + pred_mask = pred_mask[:,:,mask_loc,:,:] + assert len(pred_mask.shape) == 5 + return ranked_combined_loss(pred_mask, gt_mask, iou_pred) + +def ranked_combined_loss_with_indicators(pred_mask, gt_mask, iou_pred, indicators): + # indicators: indicate which slice are with the mask + # (b c1 c2 h w), c1: num_prediction; c2: num_slices + if len(gt_mask.shape) == 4: + gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3) + if len(pred_mask.shape) == 4: + pred_mask = rearrange(pred_mask, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3) + + b, c1, c2, h, w = pred_mask.shape + indicators = torch.tensor(indicators, dtype=pred_mask.dtype) + indicators = repeat(indicators, "b d -> b c d h w", c=3, h=h, w=w) + pred_mask = pred_mask * indicators + gt_mask = gt_mask * indicators + + # Same as "ranked_combined_loss" + return ranked_combined_loss(pred_mask, gt_mask, iou_pred) + + +def compute_all_loss_with_indicators(pred_mask, gt_mask, iou_pred, indicators): + # indicators: indicate which slice are with the mask + # (b c1 c2 h w), c1: num_prediction; c2: num_slices + if len(gt_mask.shape) == 4: + gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=1) + if len(pred_mask.shape) == 4: + pred_mask = rearrange(pred_mask, "b (c1 c2) h w -> b c1 c2 h w", c1=1, c2=3) + + b, c1, c2, h, w = pred_mask.shape + indicators = torch.tensor(indicators, dtype=pred_mask.dtype) + indicators = repeat(indicators, "b d -> b c d h w", c=1, h=h, w=w) + pred_mask = pred_mask * indicators + gt_mask = gt_mask * indicators + + # Same as "compute_all_loss" + return compute_all_loss(pred_mask, gt_mask, iou_pred) + +def compute_all_loss(pred_mask, gt_mask, iou_pred): + if len(pred_mask.shape) == 4: + pred_mask = pred_mask.unsqueeze(1) + if len(gt_mask.shape) == 4: + gt_mask = gt_mask.unsqueeze(1) + # import ipdb; ipdb.set_trace() + fl, dl = combined_loss(pred_mask, gt_mask) + segment_loss = 20*fl.mean() + dl.mean() + iou_loss = mse_loss(iou_pred, compute_iou(torch.tensor(pred_mask>0, dtype=gt_mask.dtype), gt_mask)) + total_loss = segment_loss.mean() + iou_loss.mean() + return total_loss, fl, dl, iou_loss + + +# def compute_ + + +if __name__ == "__main__": + pred_mask = torch.ones((1,9,1024,1024))*9 + pred_mask[:,:,:200,:] = -1 + gt_mask = torch.ones((1,3,1024,1024)) + loss = ranked_combined_loss(pred_mask, gt_mask, iou_pred=torch.ones(gt_mask.shape[:1])) + print(loss) \ No newline at end of file diff --git a/core/volume_predictor.py b/core/volume_predictor.py new file mode 100644 index 0000000..ecfe397 --- /dev/null +++ b/core/volume_predictor.py @@ -0,0 +1,612 @@ +import torch +import numpy as np +from typing import Any, List, Dict, Tuple, Optional +from einops import rearrange +from utils.amg import MaskData, batched_mask_to_box, batched_mask3d_to_box +from torchvision.ops.boxes import batched_nms, box_area # type: ignore +from utils.amg3d import build_point_grid, calculate_stability_score_3d, MaskData3d, batch_iterator +from utils.amg import calculate_stability_score +from utils.transforms3d import ResizeLongestSide, SimpleResize +from einops import rearrange, repeat +# from datasets.data_engine import ValidEngine, BoxPromptGenerator +from datasets.data_engine import DataEngine, DataManager, BoxPromptGenerator, PointPromptGenerator +import cv2 +import torch.nn.functional as F +from tutils.new.manager import ConfigManager +from tutils.nn.data import read, itk_to_np, write, np_to_itk +from torchvision.utils import save_image +from einops import rearrange, reduce, repeat +from core.loss import compute_dice_np + + +class PseudoPredictor: + # def __init__(self) -> None: + # self.image_encoder = + def predict(self, *args, **kwargs): + mask = np.zeros((1024,1024)) + mask[:500,:500] = 1 + return mask + + +class VolumePredictor: + def __init__( + self, + model, + slice_per_batch: int = 4, + points_per_side: Optional[int] = 32, + points_per_batch: int = 16, + pred_iou_thresh: float = 0.5, # debug, standard value is 0.88, + stability_score_thresh: float = 0.6, # debug, standard value is 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + use_postprocess = True, + use_noise_remove = True, + ) -> None: + self.model = model + self.im_size = (model.image_encoder.img_size, model.image_encoder.img_size) + self.slice_per_batch = slice_per_batch + self.point_grids = build_point_grid(points_per_side, self.im_size) + self.features = None + self.is_image_set = False + self.transform = SimpleResize(model.image_encoder.img_size) + + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.masks3d = dict() + + self.box_prompt_generator = BoxPromptGenerator(size=(1024,1024)) + self.masks3d = None + self.stability_score_2d = None + self.input_size = model.image_encoder.img_size + self.use_postprocess = use_postprocess + self.use_noise_remove = use_noise_remove + if not use_postprocess: + print("Warning! No postprocess") + if not use_noise_remove: + print("Warning! No use_noise_remove") + # self.original_size = (1024,1024) + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None + self.masks3d = None + self.stability_score_2d = None + + def set_image( + self, + image: np.ndarray, + image_format: str = "nifti", + ) -> None: + # Transform the image to the form expected by the model + self.original_size = image.shape + input_image_torch = torch.as_tensor(image, device=self.device) + input_image_torch = self.transform.apply_image(input_image_torch.float()) + assert np.argmin(input_image_torch.shape) == 0, f"Got image.shape: {input_image_torch.shape}" + maxlen = input_image_torch.shape[0] + slices = [] + for i in range(1, maxlen-1): + slices.append(input_image_torch[i-1:i+2, :, :]) + input_slices = torch.stack(slices, axis=0) + self.set_torch_image(input_slices) + + def batched_to_RGB(self, images): + for i in range(images.shape[0]): + images[i] = (images[i] - images[i].min()) / (images[i].max() - images[i].min()) * 255 + return images + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + ) -> None: + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.input_size = tuple(transformed_image.shape[-2:]) + transformed_image = self.batched_to_RGB(transformed_image) + input_image = self.model.preprocess(transformed_image) + features = [] + for input_image_batch in batch_iterator(self.slice_per_batch, input_image): + # print(input_image_batch[0].shape) + features_batch = self.model.image_encoder(input_image_batch[0]).cpu() + features.append(features_batch) + self.features = torch.cat(features, axis=0) + self.is_image_set = True + + def __call__(self, x, *args: Any, **kwds: Any) -> Any: + return self.predict_volume(x) + + def merge_to_mask3d(self, idx, masks:MaskData): + if masks._stats == {} or len(masks['masks']) == 0: + print("No masks") + return + if self.masks3d is None: + self.masks3d = np.zeros(self.original_size) + if self.stability_score_2d is None: + self.stability_score_2d = np.zeros(self.original_size[0]) + masks_values = masks['masks'] + for mask_value in zip(masks_values): + old_mask = self.masks3d[idx-1:idx+2] + # self.masks3d[idx-1:idx+2] = np.where(mask_value > old_mask, mask_value, old_mask) + self.masks3d[idx-1:idx+2] = mask_value + old_mask + + self.stability_score_2d[idx] = masks['stability_score_2d'][0,0] + + def postprocess_3d(self, masks3d): + # add removing noise ? + return masks3d > 0 + + def _debug_predict_slice( + self, + x, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + template_slice_id:int = None, + return_stability: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + main entrence + predict 3d volume + + x: volume: (c h w) + box: [[x,y,x,y]] + """ + # Check Input + assert len(x.shape) == 3 + assert box is None or len(box.shape) == 2 + + # preprocess + # x = np.clip(x, -200,400) + print(f"Checking Data range: [{x.min()}, {x.max()}]" ) + + # Adjust direction + indicator = np.argmin(x.shape) + if indicator == 0: + pass + elif indicator == 1: + x = rearrange(x, "h c w -> c h w") + elif indicator == 2: + x = rearrange(x, "h w c -> c h w") + else: + raise NotImplementedError + + # Preprocess prompts + self.original_size = x.shape[1:] + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + # set 3d image + self.set_image(x) + + # predict center slice + center_idx = template_slice_id if template_slice_id is not None else x.shape[0] // 2 + # print("Processing ", center_idx) + center_masks = self._predict_center_slice(center_idx, point_coords, box) + return center_masks['masks'] + + def predict_volume( + self, + x, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + template_slice_id:int = None, + return_stability: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + main entrence + predict 3d volume + + x: volume: (c h w) + box: [[x,y,x,y]] + """ + # Check Input + assert len(x.shape) == 3 + assert box is None or len(box.shape) == 2 + + # preprocess + # x = np.clip(x, -200,400) + print(f"Checking Data range: [{x.min()}, {x.max()}]" ) + + # Adjust direction + indicator = np.argmin(x.shape) + if indicator == 0: + pass + elif indicator == 1: + x = rearrange(x, "h c w -> c h w") + elif indicator == 2: + x = rearrange(x, "h w c -> c h w") + else: + raise NotImplementedError + + # Preprocess prompts + self.original_size = x.shape[1:] + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + # set 3d image + self.set_image(x) + + # predict center slice + center_idx = template_slice_id if template_slice_id is not None else x.shape[0] // 2 + # print("Processing ", center_idx) + center_masks = self._predict_center_slice(center_idx, point_coords, box) + if center_masks._stats == {}: + print("Ends for no mask.") + raise ValueError + self.merge_to_mask3d(center_idx, center_masks) + + previous_masks = center_masks + for i in range(center_idx+1, x.shape[0]-1): + # print("Processing downward", i) + previous_masks = self._predict_slice(i, previous_masks, orientation="down") + if previous_masks._stats == {}: + print("Ends for no mask.") + break + self.merge_to_mask3d(i, previous_masks) + + previous_masks = center_masks + for i in np.arange(1, center_idx)[::-1]: + # print("Processing upward", i) + previous_masks = self._predict_slice(i, previous_masks, orientation="up") + if previous_masks._stats == {}: + print("Ends for no mask.") + break + self.merge_to_mask3d(i, previous_masks) + + if self.masks3d is None: + self.masks3d = np.zeros_like(x) + if return_stability: + return self.postprocess_3d(self.masks3d), self.stability_score_2d + return self.postprocess_3d(self.masks3d) + + def _predict_center_slice(self, idx, point_prompt=None, box_prompt=None): + if box_prompt is not None: + masks = self.genetate_masks_from_boxes(idx, all_boxes=box_prompt, tags=["center_slice"]) + masks.to_numpy() + return masks + if point_prompt is not None: + masks = self.genetate_masks_from_point_grids(idx, point_prompt) + masks.to_numpy() + return masks + raise ValueError("No prompts! ?") + + def _predict_slice(self, idx, previous_masks, orientation): + scaled_boxes, tags = self.generate_prompts_from_previous_masks(previous_masks, orientation) + masks = self.genetate_masks_from_boxes(idx, all_boxes=scaled_boxes, tags=tags) + masks.to_numpy() + return masks + + def generate_prompts_from_previous_masks(self, previous_masks: MaskData, orientation): + if orientation == "down": + masks = previous_masks['masks'][:,2,:,:] + elif orientation == "up": + masks = previous_masks['masks'][:,0,:,:] + else: + raise ValueError + raw_tags = previous_masks['tags'] + + scaled_boxes = [] + tags = [] + for mask, tag in zip(masks, raw_tags): + if mask.sum() <= 50: + continue + # mask = self.remove_mask_noise(mask) + # if mask.sum() <= 50: + # continue + mask = F.interpolate(torch.Tensor(mask).float()[None,None,:,:], self.input_size).squeeze().numpy() + # scaled_boxes.append(self.box_prompt_generator.mask_to_bbox(mask)) + scaled_boxes.append(self.box_prompt_generator.enlarge(self.box_prompt_generator.mask_to_bbox(mask))) + tags.append(tag) + scaled_boxes = np.array(scaled_boxes) + return scaled_boxes, tags + + def genetate_masks_from_point_grids(self, idx, points_for_image): + idx = idx - 1 # ignore the head and tail slices + # Get points for this crop + data = MaskData() + tags = [f"s{idx}_p{p}" for p in range(points_for_image.shape[0])] + for (points, batched_tags) in batch_iterator(self.points_per_batch, points_for_image, tags): + batch_data = self._process_batch(idx, points=points, tags=batched_tags) + data.cat(batch_data) + del batch_data + + # Remove duplicates within this crop. + # keep_by_nms = batched_nms( + # data["boxes"].float(), + # data["iou_preds"], + # torch.zeros_like(data["boxes"][:, 0]), # categories + # iou_threshold=self.box_nms_thresh, + # ) + # data.filter(keep_by_nms) + return data + + def genetate_masks_from_boxes(self, idx, all_boxes, tags): + idx = idx - 1 + data = MaskData() + for (batched_boxes, batched_tags) in batch_iterator(self.points_per_batch, all_boxes, tags): + batch_data = self._process_batch(idx, boxes=batched_boxes, tags=batched_tags) + data.cat(batch_data) + del batch_data + return data + + def _process_batch( + self, + fea_slice_idx, + points=None, + boxes=None, + multimask_output=True, + tags=None + ) -> MaskData: + """ + Process with a subset of points. (bacause so many points can not be feed in one time) + """ + if points is not None: + in_points = torch.as_tensor(points, device=self.model.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + masks, iou_preds, _ = self.predict_torch_by_sliceidx( + fea_slice_idx=fea_slice_idx, + point_coords=in_points[:, None, :], + point_labels=in_labels[:, None], + multimask_output=multimask_output, + return_logits=True, + ) + masks = rearrange(masks, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3) + elif boxes is not None: + # in_points = torch.as_tensor(points, device=self.model.device) + # in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + boxes = torch.as_tensor(boxes, device=self.model.device) + masks, iou_preds, _ = self.predict_torch_by_sliceidx( + fea_slice_idx=fea_slice_idx, + boxes=boxes, + multimask_output=multimask_output, + return_logits=True, + ) + masks = rearrange(masks, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3) + else: + raise ValueError(f"No points or boxes") + + indices = iou_preds.argmax(axis=1) + # indices = torch.tensor([2], device=iou_preds.device) + pred_maxiou = [] + for pred, i in zip(masks, indices): + pred_maxiou.append(pred[i,:,:,:]) + masks = torch.stack(pred_maxiou, axis=0) + + iou_maxiou = [] + for iou, i in zip(iou_preds, indices): + iou_maxiou.append(iou[i]) + iou_preds = torch.stack(iou_maxiou) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.detach().cpu(), + iou_preds=iou_preds.detach().cpu(), + # points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + # tags=repeat(np.array(tags), "c -> (c 3)") + tags=np.array(tags), + ) + del masks + + # Filter by area + # if True: + # keep_mask = (data['masks']>0).sum(-1).sum(-1).sum(-1) < (data['masks']<=0).sum(-1).sum(-1).sum(-1) * 0.4 + # data.filter(keep_mask) + # print("keep mask / pred", keep_mask.sum()) + + # Filter Background + # if True: + # keep_mask = + + # Filter by predicted IoU + if self.use_postprocess: + if self.pred_iou_thresh > -0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + # print("pred_iou", data["iou_preds"], (data["masks"]>0).sum(-1).sum(-1)) + data.filter(keep_mask) + # print("keep mask / pred", keep_mask.sum()) + + # Calculate stability score + data["stability_score"] = calculate_stability_score_3d( + data["masks"], self.model.mask_threshold, self.stability_score_offset + ) # .mean(axis=-1) + if self.stability_score_thresh > 0.0: + # print("stability", data["stability_score"], (data["masks"]>0).sum(-1).sum(-1)) + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + # print("keep mask / stable", keep_mask.sum()) + + data["stability_score_2d"] = calculate_stability_score( + data["masks"][:,1:2,:,:], self.model.mask_threshold, self.stability_score_offset + ) + + # Threshold masks and calculate boxes + data['logits'] = data['masks'] + data["noisy_masks"] = data["logits"] > self.model.mask_threshold + + # data['masks'] = torch.zeros_like(data['noisy_masks'], dtype=data['noisy_masks'].dtype, device=data['noisy_masks'].device) + b, c,_,_ = data["noisy_masks"].shape + data['masks'] = data["noisy_masks"].float() + + if self.use_noise_remove: + for i in range(b): + for j in range(c): + data['masks'][i,j,:,:] = torch.Tensor(self.remove_mask_noise(data['noisy_masks'][i,j,:,:])) + + # data["boxes"] = batched_mask_to_box(reduce(data["masks"], "b c h w -> b h w", reduction="sum")>0) + data["boxes"] = batched_mask_to_box(data["masks"][:,1,:,:]>0) + return data + + # @staticmethod + # def calculate_ + + + @staticmethod + def batched_remove_noise(masks): + ori_shape = masks.shape + + def remove_mask_noise(self, mask): + # mask_sum = mask.sum() + kerner_size = min(mask.sum() // 20, 8) + if isinstance(mask, torch.Tensor): + mask = mask.detach().cpu().numpy() + kernel = np.ones((kerner_size,kerner_size), dtype=np.uint8) + opening = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel, 1) + return opening + + @torch.no_grad() + def predict_torch_by_sliceidx( + self, + fea_slice_idx: int, + point_coords: Optional[torch.Tensor] = None, + point_labels: Optional[torch.Tensor] = None, + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features[fea_slice_idx].to(self.model.device), + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size[1:]) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def valid_box(self, data, batch_idx): + # valid image with box, or point prompt + assert data['img'].shape[0] == 1, f"shape {data['img'].shape}" + image = data['img'] + label = data['label'] + + box = BoxPromptGenerator().mask_to_bbox(label) + box_mask3d = self.predict_volume( + x=image, + box=box, + ) + dice = compute_dice_np(box_mask3d, label.detach().cpu().numpy()) + + + +if __name__ == "__main__": + from core.learner3 import SamLearner + from modeling.build_sam3d2 import sam_model_registry + from trans_utils.data_utils import Data3dSolver + + config = ConfigManager() + config.add_config("configs/vit_b_103.yaml") + + model_type = "vit_b" + sam = sam_model_registry[model_type](checkpoint=None) + learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024))) + learner.use_lora() + pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth" + learner.load_well_trained_model(pth) + learner.cuda() + + predictor = VolumePredictor( + model=learner.model, + use_postprocess=True, + use_noise_remove=True,) + + # Load data + img_path = "/home1/quanquan/datasets/07_WORD/WORD-V0.1.0/imagesVa/word_0001.nii.gz" # "/home1/quanquan/datasets/59_SABS/sabs_CT_normalized/image_5.nii.gz" + label_path = "/home1/quanquan/datasets/07_WORD/WORD-V0.1.0/labelsVa/word_0001.nii.gz" + volume = itk_to_np(read(img_path)) # test several slices + label_itk = read(label_path) + spacing = label_itk.GetSpacing() + label = itk_to_np(label_itk) == 1 + volume = np.clip(volume, -200, 400) + + # Select the slice with the largest mask + s = reduce(label, "c h w -> c", reduction="sum") + coords = np.nonzero(s) + x_min = np.min(coords[0]) + x_max = np.max(coords[0]) + template_slice_id = s.argmax() + + box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id]) + box = np.array([box]) + + pred = predictor.predict_volume( + x=volume, + box=box, + template_slice_id=template_slice_id, + return_stability=False, + ) + + Data3dSolver().simple_write(pred, path="mask.nii.gz", spacing=spacing) + Data3dSolver().simple_write(label, path="gt.nii.gz", spacing=spacing) \ No newline at end of file diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/datasets/__pycache__/__init__.cpython-38.pyc b/datasets/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..f7c9c27 Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-38.pyc differ diff --git a/datasets/__pycache__/cache_dataset3d.cpython-38.pyc b/datasets/__pycache__/cache_dataset3d.cpython-38.pyc new file mode 100644 index 0000000..6058f75 Binary files /dev/null and b/datasets/__pycache__/cache_dataset3d.cpython-38.pyc differ diff --git a/datasets/__pycache__/cache_dataset3d3.cpython-38.pyc b/datasets/__pycache__/cache_dataset3d3.cpython-38.pyc new file mode 100644 index 0000000..e12e44f Binary files /dev/null and b/datasets/__pycache__/cache_dataset3d3.cpython-38.pyc differ diff --git a/datasets/__pycache__/data_engine.cpython-38.pyc b/datasets/__pycache__/data_engine.cpython-38.pyc new file mode 100644 index 0000000..fe680a9 Binary files /dev/null and b/datasets/__pycache__/data_engine.cpython-38.pyc differ diff --git a/datasets/__pycache__/dataset3d.cpython-38.pyc b/datasets/__pycache__/dataset3d.cpython-38.pyc new file mode 100644 index 0000000..673ba22 Binary files /dev/null and b/datasets/__pycache__/dataset3d.cpython-38.pyc differ diff --git a/datasets/__pycache__/dataset3d_2dmask.cpython-38.pyc b/datasets/__pycache__/dataset3d_2dmask.cpython-38.pyc new file mode 100644 index 0000000..1e1a4c8 Binary files /dev/null and b/datasets/__pycache__/dataset3d_2dmask.cpython-38.pyc differ diff --git a/datasets/__pycache__/dataset_merged.cpython-38.pyc b/datasets/__pycache__/dataset_merged.cpython-38.pyc new file mode 100644 index 0000000..34766f1 Binary files /dev/null and b/datasets/__pycache__/dataset_merged.cpython-38.pyc differ diff --git a/datasets/cache_dataset3d.py b/datasets/cache_dataset3d.py new file mode 100644 index 0000000..a06d55e --- /dev/null +++ b/datasets/cache_dataset3d.py @@ -0,0 +1,267 @@ +""" + Slow Loading directly + + So we pre-precess data +""" + +import numpy as np +import os +from einops import rearrange, reduce, repeat +from tutils.nn.data import read, itk_to_np, np_to_itk, write +from tutils import tfilename +from .dataset3d import DATASET_CONFIG, Dataset3D as basic_3d_dataset +from monai import transforms +import torch +import cv2 +from scipy.sparse import csr_matrix +import torch.nn.functional as F +from torchvision import transforms +from einops import rearrange +import glob +from torchvision import transforms +from monai import transforms as monai_transforms + +# "class": ["spleen", "right kidney", "left kidney", "gallbladder", "esophagus", "liver", "stomach", "aorta", "postcava", "portal vein and splenic vein", "pancrease", "right adrenal gland", "left adrenal gland"], +# "class": ["liver", "right kidney", "left kidney", "spleen"], +TEMPLATE={ + '01': [1,2,3,4,5,6,7,8,9,10,11,12,13,14], + '02': [1,0,3,4,5,6,7,0,0,0,11,0,0,14], + '03': [6], + '04': [6,27], # post process + '05': [2,26,32], # post process + '07': [6,1,3,2,7,4,5,11,14,18,19,12,20,21,23,24], + '08': [6, 2, 1, 11], + '09': [1,2,3,4,5,6,7,8,9,11,12,13,14,21,22], + '12': [6,21,16,2], + '13': [6,2,1,11,8,9,7,4,5,12,13,25], + '14': [11,11,28,28,28], # Felix data, post process + '10_03': [6, 27], # post process + '10_06': [30], + '10_07': [11, 28], # post process + '10_08': [15, 29], # post process + '10_09': [1], + '10_10': [31], + '58': [6,2,3,1], + '59': [1,2,3,4,5,6,7,8,9,10,11,12,13,14], + '60': np.arange(200).tolist(), # for debug +} + +class Dataset3D(basic_3d_dataset): + def __init__(self, config=..., use_cache=True, *args, **kwargs) -> None: + super().__init__(config, use_cache=use_cache, *args, **kwargs) + self.basic_dir = config['data_root_path'] + self.cache_dir = config['cache_data_path'] + + def prepare_transforms(self): + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((1024,1024)), + ]) + self.test_transform = transforms.Compose([ + monai_transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)), + ]) + + # @tfunctime + # def prepare_datalist(self): + def prepare_cached_datalist(self): + raise DeprecationWarning("[Warning] Please use cache_dataset3d new version instead!") + config = self.config + data_paths = [] + for dirpath in glob.glob(config['cache_data_path'] + "/*"): + data_paths += glob.glob(dirpath + "/image/*.jpg") + print("Load ", dirpath) + print('train len {}'.format(len(data_paths))) + print('Examples: ', data_paths[:2]) + return data_paths + + def caching_data(self): + assert self.use_cache == False + for index in range(len(self)): + self.cache_one_sample(index) + + def cache_one_sample(self, index, debug=False): + # LABEL_INDICES + name = self.img_names[index]['img_path'] + img_itk = read(self.img_names[index]['img_path']) + img_ori = itk_to_np(img_itk) + + img_ori = np.clip(img_ori, -200,400) + + # spacing = img_itk.GetSpacing() + scan_orientation = np.argmin(img_ori.shape) + label_ori = itk_to_np(read(self.img_names[index]['label_path'])) + + dataset_name = self.img_names[index]['img_path'].replace(self.basic_dir,"").split("/")[0] + assert dataset_name[0] in ['0','1','2','3','4','5','6','7','8','9'], f"Got {dataset_name}" + all_labels = TEMPLATE[dataset_name[:2]] + + num = 0 + + # if min(img_ori.shape) * 1.2 < max(img_ori.shape): + # orientation_all = [scan_orientation] + # else: + # orientation_all = [0,1,2] + orientation_all = [scan_orientation] + + for orientation in orientation_all: + for slice_idx in range(2, img_ori.shape[orientation]-2): + # slice_idx = np.random.randint(2, img_ori.shape[orientation]-2) + if orientation == 0: + s = img_ori[slice_idx-1:slice_idx+2, :,:] + lb = label_ori[slice_idx-1:slice_idx+2, :,:] + # spacing = (spacing[1], spacing[2]) + if orientation == 1: + s = img_ori[:,slice_idx-1:slice_idx+2,:] + s = rearrange(s, "h c w -> c h w") + lb = label_ori[:,slice_idx-1:slice_idx+2,:] + lb = rearrange(lb, "h c w -> c h w") + # spacing = (spacing[0], spacing[2]) + if orientation == 2: + s = img_ori[:,:,slice_idx-1:slice_idx+2] + s = rearrange(s, "h w c -> c h w") + lb = label_ori[:,:,slice_idx-1:slice_idx+2] + lb = rearrange(lb, "h w c -> c h w") + # spacing = (spacing[0], spacing[1]) + assert s.shape[0] == 3 + + # if np.float32(lb[1,:,:]>0).sum() <= 200: + # # return self._get_data((index+1)%len(self)) + # continue + # Choose one label + label_num = int(lb.max()) + + masks_data = [] + meta = {"img_name": name, "slice": slice_idx, "orientation": orientation, "label_idx": [], "labels": [], "id": f"{num:08d}" } + for label_idx in range(1,label_num+1): + one_lb = np.float32(lb==label_idx) + if one_lb[1,:,:].sum() <= (one_lb.shape[-1] * one_lb.shape[-2] * 0.0014): + continue + # if one_lb[0,:,:].sum()<=50 or one_lb[2,:,:].sum()<=50: + + masks_data.append(one_lb) + meta['label_idx'].append(label_idx) + meta['labels'].append(all_labels[label_idx-1]) + + if len(masks_data) <= 0: + continue + + img_rgb = s + img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy() + img_rgb = self.to_RGB(img_rgb) + save_image_name = tfilename(self.cache_dir, dataset_name, f"image/image_{index:04d}_{num:08d}.jpg") + self.save_img_rgb(rearrange(img_rgb, "c h w -> h w c"), save_image_name) + + # Save cache data + save_label_name = tfilename(self.cache_dir, dataset_name, f"label/label_{index:04d}_{num:08d}.npz") + self.save_slice_mask(masks_data, save_label_name) + print("Save ", save_image_name) + + self.save_meta(meta, tfilename(self.cache_dir, dataset_name, f"meta/meta_{index:04d}_{num:08d}.npy")) + + num += 1 + + def save_meta(self, meta, path): + assert path.endswith(".npy") + np.save(path, meta) + + def save_slice_mask(self, masks_data, prefix): + masks_data = F.interpolate(torch.Tensor(masks_data), size=(1024,1024)).numpy() + assert masks_data.shape[1:] == (3,1024,1024), f"{__file__} Got{masks_data.shape}" + for i in range(masks_data.shape[0]): + labeli = masks_data[i].astype(np.uint8) * 255 + assert labeli.sum() > 0 + path = tfilename(prefix+f"_{i:04d}.jpg") + cv2.imwrite(path, rearrange(labeli, "c h w -> h w c")) + print("save to ", path) + + def _old_save_slice_mask(self, masks_data, path): + raise DeprecationWarning() + exit(0) + assert path.endswith(".npz") + # masks_data = np.array([m['segmentation'] for m in masks]).astype(int) + masks_data = F.interpolate(torch.Tensor(masks_data), size=(1024,1024)).numpy() + # masks_data = np.int8(masks_data>0) + assert masks_data.shape[1:] == (3,1024,1024), f"{__file__} Got{masks_data.shape}" + masks_data = rearrange(masks_data, "n c h w -> n (c h w)") + csr = csr_matrix(masks_data) + np.savez_compressed(path, data=csr.data, indices=csr.indices, indptr=csr.indptr, shape=csr.shape) + + def save_img_rgb(self, img, path): + assert path.endswith(".jpg") + assert img.shape == (1024,1024,3) + cv2.imwrite(path, img.astype(np.uint8)) + + def _get_cached_data(self, index): + name = self.img_names[index] + # print(name) + img = cv2.imread(name) + compressed = np.load(name.replace("image/image_", "label/label_").replace(".jpg", ".npz")) + csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape']) + label_ori = csr.toarray() + label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024) + meta = np.load(name.replace("image/image_", "meta/meta_").replace(".jpg", ".npy"), allow_pickle=True).tolist() + # print(meta) + pp = reduce(label_ori[:,1,:,:], "n h w -> n", reduction="sum") > 500 + if pp.sum() == 0: + return self._get_cached_data((index+1)%len(self)) + + label_idx = np.random.choice(a=np.arange(len(pp)), p=pp/pp.sum()) + # label_idx = np.random.randint(0, label_ori.shape[0]) + label_ori = label_ori[label_idx] + is_edge = meta.get('is_edge', 0) + return rearrange(img, "h w c -> c h w"), label_ori, name, meta['labels'][label_idx], meta['label_idx'][label_idx] + + # @tfunctime + def __getitem__(self, index, debug=False): + # print("Dataset warning", index, len(self)) + index = index % len(self) + img_rgb, label_ori, name, label_idx, local_idx = self._get_cached_data(index) + + if label_ori.sum() <= 0: + print("[Label Error] ", name) + return self.__getitem__(index+1) + + # assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}" + # img_rgb = self.transform((img_rgb[None,:,:,:])) + img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy() + + vector = np.ones(3) + ret_dict = { + "name": name, + "img": img_rgb, + "label": label_ori, + "indicators": vector, + "class": label_idx, + "local_idx": local_idx, + } + return ret_dict + + def _convert_one_mask_from_npz_to_jpg(self, path1=None): + # path1 = "/home1/quanquan/datasets/cached_dataset2/01_BCV-Abdomen/label/label_0129_00000043.npz" # 32K + prefix = path1.replace(".npz", "").replace("/label/", "/label_jpg/") + compressed = np.load(path1) + csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape']) + label_ori = csr.toarray() + label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024) + # print(label_ori.shape) + for i in range(label_ori.shape[0]): + labeli = label_ori[i] + path = tfilename(prefix+f"_{i:04d}.jpg") + cv2.imwrite(path, rearrange(labeli, "c h w -> h w c").astype(np.uint8)) + print("save to ", path) + + def convert_masks_types(self): + assert self.use_cache == True + for index in range(len(self)): + name = self.img_names[index] + label_path = name.replace("image/image_", "label/label_").replace(".jpg", ".npz") + self._convert_one_mask_from_npz_to_jpg(label_path) + +if __name__ == "__main__": + # def go_cache(): + from tutils.new.manager import ConfigManager + config = ConfigManager() + config.add_config("configs/vit_b.yaml") + dataset = Dataset3D(config=config['dataset'], use_cache=True) + dataset.caching_data() + # dataset.convert_masks_types() diff --git a/datasets/cache_dataset3d3.py b/datasets/cache_dataset3d3.py new file mode 100644 index 0000000..2f5c972 --- /dev/null +++ b/datasets/cache_dataset3d3.py @@ -0,0 +1,97 @@ +""" + re-index by masks! not images +""" + +import numpy as np +import os +from einops import rearrange, reduce, repeat +from tutils.nn.data import read, itk_to_np, np_to_itk, write +from tutils import tfilename +from .cache_dataset3d import Dataset3D as basic_3d_dataset +from monai import transforms +import torch +import cv2 +from scipy.sparse import csr_matrix +import torch.nn.functional as F +from torchvision import transforms +from einops import rearrange +import glob +from torchvision import transforms +from monai import transforms as monai_transforms + + +class Dataset3D(basic_3d_dataset): + def __init__(self, config=..., use_cache=True, *args, **kwargs) -> None: + super().__init__(config, use_cache=use_cache, *args, **kwargs) + self.basic_dir = config['data_root_path'] + self.cache_dir = config['cache_data_path'] + + def prepare_cached_datalist(self): + config = self.config + data_paths = [] + for dirpath in glob.glob(config['cache_data_path'] + "/*"): + if not os.path.isdir(dirpath): + continue + prefix = dirpath.split("/")[-1] + if prefix[:2] in config['cache_prefix']: + data_paths += glob.glob(dirpath + "/label_jpg/*.jpg") + print("Load ", dirpath) + print('Masks len {}'.format(len(data_paths))) + print('Examples: ', data_paths[:2]) + return data_paths + + def _get_cached_data(self, index): + mask_path = self.img_names[index] + # print(name) + mask = np.int32(cv2.imread(mask_path) > 0) + + prefix = mask_path[:-9] + img_path = prefix.replace("/label_jpg/label_", "/image/image_") + ".jpg" + img = cv2.imread(img_path) + number = int(mask_path[-8:-4]) + + meta = np.load(prefix.replace("/label_jpg/label_", "/meta/meta_")+".npy", allow_pickle=True).tolist() + # label_idx = np.random.randint(0, label_ori.shape[0]) + + return rearrange(img, "h w c -> c h w"), rearrange(mask, "h w c -> c h w"), mask_path, meta['labels'][number], meta['label_idx'][number] + + # @tfunctime + def __getitem__(self, index, debug=False): + index = index % len(self) + img_rgb, label_ori, name, label_idx, local_idx = self._get_cached_data(index) + + # assert label_ori.sum() > 0 + if label_ori.sum() <= 0: + print("[Label Error] ", name) + return self.__getitem__(index+1) + + # assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}" + # img_rgb = self.transform((img_rgb[None,:,:,:])) + img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy() + + vector = np.ones(3) + ret_dict = { + "name": name, + "img": img_rgb, + "label": label_ori, + "indicators": vector, + "class": label_idx, + "local_idx": local_idx, + "is_problem": label_ori.sum() <= 30, + } + return ret_dict + +if __name__ == "__main__": + # def go_cache(): + from tutils.new.manager import ConfigManager + from tqdm import tqdm + config = ConfigManager() + config.add_config("configs/vit_b.yaml") + config.print() + dataset = Dataset3D(config=config['dataset'], use_cache=True) + # dataset.caching_data() + # dataset.convert_masks_types() + for i in tqdm(range(len(dataset))): + data = dataset.__getitem__(i) + # assert + diff --git a/datasets/data_engine.py b/datasets/data_engine.py new file mode 100644 index 0000000..e6ce2ca --- /dev/null +++ b/datasets/data_engine.py @@ -0,0 +1,274 @@ +import numpy as np +import torch +from torch.utils.data import Dataset +import matplotlib.pyplot as plt +from einops import repeat, rearrange + +def show_mask(mask, ax, random_color=False): + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + color = np.array([30/255, 144/255, 255/255, 0.6]) + h, w = mask.shape[-2:] + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + ax.imshow(mask_image) + +def show_points(coords, labels, ax, marker_size=375): + pos_points = coords[labels==1] + neg_points = coords[labels==0] + ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) + ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) + +def show_box(box, ax): + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) + +class PointPromptGenerator(object): + def __init__(self, size=None) -> None: + pass + + def get_prompt_point(self, gt_mask): + # assert gt_mask.shape == (1024,1024) or gt_mask.shape == (512,512), f"[data_engine] {__file__} Got{gt_mask.shape}" + if not (gt_mask.shape == (1024,1024) or gt_mask.shape == (512,512) or gt_mask.shape == (256,256) ): + print(f"[Warning] [data_engine] {__file__} Got{gt_mask.shape}") + assert gt_mask.sum() > 0 + self.size = gt_mask.shape + self.xy = np.arange(0, self.size[0] * self.size[1]) + + gt_mask = np.float32(gt_mask>0) + prob = rearrange(gt_mask, "h w -> (h w)") + prob = prob / prob.sum() + loc = np.random.choice(a=self.xy, size=1, replace=True, p=prob)[0] + x, y = loc % self.size[1], loc // self.size[1] + return x, y + + @staticmethod + def select_random_subsequent_point(pred_mask, gt_mask): + # union = np.float32((pred_mask + gt_mask)>0) + # diff = union - intersect + assert len(pred_mask.shape) == 2 + assert len(gt_mask.shape) == 2 + assert gt_mask.sum() > 0, f"[data_engine] Got {gt_mask.sum()}==0 " + diff = np.float32(np.abs(pred_mask - gt_mask)>0) + diff = np.nan_to_num(diff, nan=0) + # print(diff.shape) + xy = np.arange(0, diff.shape[0] * diff.shape[1]) + + if diff.sum() == 0: + prob = rearrange(gt_mask, "h w -> (h w)") + prob = prob / prob.sum() + loc = np.random.choice(a=xy, size=1, replace=True, p=prob)[0] + x, y = loc % diff.shape[1], loc // diff.shape[1] + return (x,y), 1 + # Get_prompt_point + prob = rearrange(diff, "h w -> (h w)") + prob = prob / prob.sum() + loc = np.random.choice(a=xy, size=1, replace=True, p=prob)[0] + x, y = loc % diff.shape[1], loc // diff.shape[1] + + if gt_mask[y, x] == 1 and pred_mask[y, x] == 0: + classification = 1 + else: + classification = 0 + # raise ValueError + return (x, y), classification + + + +class BoxPromptGenerator(object): + def __init__(self, size) -> None: + self.size = size + + @staticmethod + def mask_to_bbox(mask): + # Find the indices of all non-zero elements in the mask + coords = np.nonzero(mask) + + # Compute the minimum and maximum values of the row and column indices + x_min = np.min(coords[1]) + y_min = np.min(coords[0]) + x_max = np.max(coords[1]) + y_max = np.max(coords[0]) + + # Return the coordinates of the bounding box + return (x_min, y_min, x_max, y_max) + # return (y_min, x_min, y_max, x_max) + + def add_random_noise_to_bbox(self, bbox): + bbox = list(bbox) + # Calculate the side lengths of the box in the x and y directions + x_side_length = bbox[2] - bbox[0] + y_side_length = bbox[3] - bbox[1] + + # Calculate the standard deviation of the noise + std_dev = 0.01 * (x_side_length + y_side_length) / 2 + + # Generate random noise for each coordinate + x_noise = np.random.normal(scale=std_dev) + y_noise = np.random.normal(scale=std_dev) + + # Add the random noise to each coordinate, but make sure it is not larger than 20 pixels + bbox[0] += min(int(round(x_noise)), 20) + bbox[1] += min(int(round(y_noise)), 20) + bbox[2] += min(int(round(x_noise)), 20) + bbox[3] += min(int(round(y_noise)), 20) + + # Make sure the modified coordinates do not exceed the maximum possible values + bbox[0] = max(bbox[0], 0) + bbox[1] = max(bbox[1], 0) + bbox[2] = min(bbox[2], self.size[0]) + bbox[3] = min(bbox[3], self.size[1]) + + # Return the modified bounding box + return bbox + + def get_prompt_box(self, gt_mask): + """ return (x_min, y_min, x_max, y_max) """ + assert gt_mask.shape == (1024,1024) or gt_mask.shape == (512,512) or gt_mask.shape == (256,256), f"[data_engine] {__file__} Got{gt_mask.shape}" + box = self.mask_to_bbox(gt_mask) + box_w_noise = self.add_random_noise_to_bbox(box) + return box_w_noise + + def enlarge(self, bbox, margin=0): + x0, y0, x1, y1 = bbox[0], bbox[1], bbox[2], bbox[3] + margin_x = int((x1 - x0)*0.05) + margin_y = int((y1 - y0)*0.05) + x0 = max(x0 - margin_x, 0) + y0 = max(y0 - margin_x, 0) + x1 = min(x1 - margin_y, self.size[0]-1) + y1 = min(y1 - margin_y, self.size[1]-1) + + # print("[DEBUG] , enlarge size: ", margin_x, margin_y) + # print("[DEBUG] from", bbox, "to", (x0,y0,x1,y1)) + return (x0,y0,x1,y1) + +class DataEngine(Dataset): + def __init__(self, dataset=None, img_size=None) -> None: + # CACHE_DISK_DIR="/home1/quanquan/code/projects/medical-guangdong/cache/data2d_3/" + super().__init__() + self.point_prompt_generator = PointPromptGenerator(img_size) + self.box_prompt_generator = BoxPromptGenerator(img_size) + # self._get_dataset(dirpath=dirpath) + self.dataset = dataset + + # def _get_dataset(self, dirpath): + # self.dataset = Dataset2D(dirpath=dirpath, is_train=True) + + def __len__(self): + return len(self.dataset) + + def _get_true_index(self, idx): + return idx + + def __getitem__(self, idx): + return self.get_prompt(idx) + + def get_prompt_point(self, gt_mask): + return self.point_prompt_generator.get_prompt_point(gt_mask) + + def get_prompt_box(self, gt_mask): + return self.box_prompt_generator.get_prompt_box(gt_mask) + + # def _get_data_from_dataset(self, idx): + + def get_prompt(self, idx): + idx = self._get_true_index(idx) + data = self.dataset.__getitem__(idx) + + img = data['img'] # (3,h,w) d=3 + mask = data['label'] # (3,h,w) d=3 + + try: + gt_mask = mask[1,:,:] + except Exception as e: + import ipdb; ipdb.set_trace() + gt_mask = mask[1,:,:] + gt_mask = gt_mask.numpy() if isinstance(gt_mask, torch.Tensor) else gt_mask + + # if np.random.rand() > 0.5: + prompt_point = self.get_prompt_point(gt_mask) + # else: + prompt_box = self.get_prompt_box(gt_mask) + + data['prompt_point'] = np.array(prompt_point).astype(np.float32) + data['prompt_box'] = np.array(prompt_box).astype(np.float32) + data['point_label'] = np.ones((1,)).astype(np.float32) + return data + + def get_subsequent_prompt_point(self, pred_mask, gt_mask): + # return self.point_prompt_generator.select_random_subsequent_point_torch(pred_mask, gt_mask) + # return self.point_prompt_generator.select_random_subsequent_point(pred_mask=pred_mask, gt_mask=gt_mask) + coord_collect = [] + label_collect = [] + for i in range(pred_mask.shape[0]): + coords, label = self.point_prompt_generator.select_random_subsequent_point(pred_mask[i][0], gt_mask[i][0]) + if label == -1: + return None, None + coord_collect.append(coords) + label_collect.append(label) + + coord_collect = np.stack(coord_collect, axis=0) + label_collect = np.stack(label_collect, axis=0) + return coord_collect, label_collect + + def get_noisy_box_from_box(self, box): + # Get noisy box from labeled box + return self.box_prompt_generator.add_random_noise_to_bbox(box) + + # def get_prompt_mask(self, ) + +# class ValidEngine(DataEngine): +# def __init__(self, dataset=None, img_size=(1024,1024), is_train=False) -> None: +# # assert dataset is not None +# self.dataset = dataset +# self.is_train = is_train +# super().__init__(dataset=dataset, img_size=img_size) +# self.expand_dataset_ratio = 1 + +# # def _get_dataset(self, dirpath): +# # self.dataset = Dataset3D(dirpath=dirpath, is_train=self.is_train) + +# def __len__(self): +# return len(self.dataset) + +# def _get_true_index(self, idx): +# return idx + +class DataManager: + def __init__(self, img_size=None) -> None: + self.point_prompt_generator = PointPromptGenerator(img_size) + self.box_prompt_generator = BoxPromptGenerator(img_size) + + def get_prompt_point(self, gt_mask): + return self.point_prompt_generator.get_prompt_point(gt_mask) + + def get_prompt_box(self, gt_mask): + return self.box_prompt_generator.get_prompt_box(gt_mask) + + def get_subsequent_prompt_point(self, pred_mask, gt_mask): + # return self.point_prompt_generator.select_random_subsequent_point_torch(pred_mask, gt_mask) + # return self.point_prompt_generator.select_random_subsequent_point(pred_mask=pred_mask, gt_mask=gt_mask) + coord_collect = [] + label_collect = [] + for i in range(pred_mask.shape[0]): + coords, label = self.point_prompt_generator.select_random_subsequent_point(pred_mask[i][0], gt_mask[i][0]) + if label == -1: + return None, None + coord_collect.append(coords) + label_collect.append(label) + + coord_collect = np.stack(coord_collect, axis=0) + label_collect = np.stack(label_collect, axis=0) + return coord_collect, label_collect + + def get_noisy_box_from_box(self, box): + # Get noisy box from labeled box + return self.box_prompt_generator.add_random_noise_to_bbox(box) + + +if __name__ == "__main__": + dataset = DataEngine() + data = dataset.__getitem__(0) + + import ipdb; ipdb.set_trace() \ No newline at end of file diff --git a/datasets/dataset3d.py b/datasets/dataset3d.py new file mode 100644 index 0000000..9983c01 --- /dev/null +++ b/datasets/dataset3d.py @@ -0,0 +1,239 @@ +from torchvision import transforms +from monai import transforms as monai_transforms +import numpy as np +import SimpleITK as sitk +import torch +from torch.utils.data import Dataset as dataset +import torch.nn.functional as F +import glob +import os +from einops import rearrange + +from tutils.nn.data import read, itk_to_np + +from tqdm import tqdm +# from monai.transforms import SpatialPadd, CenterSpatialCropd, Resized, NormalizeIntensityd +# from monai.transforms import RandAdjustContrastd, RandShiftIntensityd, Rotated, RandAffined +# from datasets.common_2d_aug import RandomRotation, RandomResizedCrop, RandomHorizontalFlip, ColorJitter, ToTensor, Normalize +from tutils import tfilename, tdir +import random +import time +import cv2 +from scipy.sparse import csr_matrix + +def tfunctime(func): + def run(*argv, **kargs): + t1 = time.time() + ret = func(*argv, **kargs) + t2 = time.time() + print(f"[Function {func.__name__}] Running time:{(t2-t1):.6f}s") + return ret + return run + +# DEFAULT_PATH='/home1/quanquan/datasets/KiTS/' +DEFAULT_PATH="/home1/quanquan/datasets/BCV-Abdomen/Training/" + +LABEL_INDICES={ + "t2sag": ["bg","kidney", "label 2", "label 3", "rectum", "tumor", "other"], + } + +# CACHE_DISK_DIR="/home1/quanquan/code/projects/medical-guangdong/cache/data2d_3/" +CACHE_DISK_DIR=None +# DEFAULT_CONFIG={ +# "pad": (512,512), +# "crop": (384,384), +# "resize": (512,512), +# } + +DATASET_CONFIG={ + 'split': 'train', + 'data_root_path':'/quanquan/datasets/', + 'dataset_list': ['sam', "their", "ours"], + 'data_txt_path':'./datasets/dataset_list/', +} + +DATASET_METAINFO={ + "WORD": {0:"Background", 1:"Liver", 2:"Spleen", 3:"Left Kidney", 4:"Right Kidney", 5:"Stomach", 6:"Gallbladder", 7:"Esophagus", 8:"Pancreas", 9:"Duodenum", 10:"Colon", 11:"Intestine", 12:"Adrenal", 13:"Rectum", 14:"Bladder", 15:"left head of femur", 16:"right head of femur"} +} + + +class Dataset3D(dataset): + def __init__(self, config=DATASET_CONFIG, is_train=True, split='train', getting_multi_mask=False, use_cache=False) -> None: + super().__init__() + self.config = config + self.is_train = is_train + self.split = split + self.getting_multi_mask = getting_multi_mask + self.use_cache = use_cache + self.img_names = self.prepare_cached_datalist() if use_cache else self.prepare_datalist() + # self.img_names = self.prepare_datalist() + self.prepare_transforms() + + def prepare_cached_datalist(self): + raise NotImplementedError + + def prepare_transforms(self): + self.transform = monai_transforms.Compose([ + # transforms.Resized(keys=['img', 'label'], spatial_size=(3,512,512)), + # transforms.RandSpatialCropd(keys=["img", 'label'], roi_size=(3,448,448)), + # transforms.RandAffined(keys=['img', 'label'], prob=0.5, shear_range=(0.2,0.2)), + # transforms.RandCropByPosNegLabeld(keys=['img', 'label'], spatial_size=(3,448,448), label_key='label', neg=0), + # transforms.RandSmoothFieldAdjustContrastd(keys=['img', 'label'], ) + monai_transforms.RandAdjustContrastd(keys=['img'], ), + # transforms.RandShiftIntensityd(keys=['img'], prob=0.8, offsets=(0, 20)), + monai_transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)), + ]) + self.test_transform = transforms.Compose([ + monai_transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)), + ]) + + def _get_image(self, index): + name = self.img_names[index]['img_path'] + if not os.path.exists(name): + print("Path not exists!", name) + return self._get_image(index+1%len(self)) + img_itk = read(self.img_names[index]['img_path']) + img_ori = itk_to_np(img_itk) + img = np.clip(img_ori, -200, 400).astype(np.float32) + img = (img - img.min()) / img.max() * 255 + label_ori = itk_to_np(read(self.img_names[index]['label_path'])) + return {"img":img, "name":name.replace(self.config['data_root_path'], ""), "label":label_ori} + + def __len__(self): + return len(self.img_names) + + # @tfunctime + def prepare_datalist(self): + config = self.config + data_paths = [] + for item in config['dataset_list']: + print("Load datalist from ", item) + for line in open(config["data_txt_path"]+ item + f"_{self.split}.txt"): + name = line.strip().split()[1].split('.')[0] + img_path = config['data_root_path'] + line.strip().split()[0] + label_path = config['data_root_path'] + line.strip().split()[1] + data_paths.append({'img_path': img_path, 'label_path': label_path, 'name': name}) + print('train len {}'.format(len(data_paths))) + return data_paths + + # @tfunctime + def _get_data(self, index, debug=False): + # LABEL_INDICES + name = self.img_names[index]['img_path'] + img_itk = read(self.img_names[index]['img_path']) + img_ori = itk_to_np(img_itk) + # spacing = img_itk.GetSpacing() + scan_orientation = np.argmin(img_ori.shape) + label_ori = itk_to_np(read(self.img_names[index]['label_path'])) + + if min(img_ori.shape) * 2 < max(img_ori.shape): + orientation = scan_orientation + else: + orientation = np.random.randint(3) + slice_idx = np.random.randint(2, img_ori.shape[orientation]-2) + if orientation == 0: + s = img_ori[slice_idx-1:slice_idx+2, :,:] + lb = label_ori[slice_idx-1:slice_idx+2, :,:] + # spacing = (spacing[1], spacing[2]) + if orientation == 1: + s = img_ori[:,slice_idx-1:slice_idx+2,:] + s = rearrange(s, "h c w -> c h w") + lb = label_ori[:,slice_idx-1:slice_idx+2,:] + lb = rearrange(lb, "h c w -> c h w") + # spacing = (spacing[0], spacing[2]) + if orientation == 2: + s = img_ori[:,:,slice_idx-1:slice_idx+2] + s = rearrange(s, "h w c -> c h w") + lb = label_ori[:,:,slice_idx-1:slice_idx+2] + lb = rearrange(lb, "h w c -> c h w") + # spacing = (spacing[0], spacing[1]) + assert s.shape[0] == 3 + + if np.float32(lb[1,:,:]>0).sum() <= 200: + return self._get_data((index+1)%len(self)) + # Choose one label + label_num = int(lb.max()) + is_good_mask = [] + for label_idx in range(1,label_num+1): + one_lb = np.float32(lb==label_idx) + is_good_mask.append(one_lb.sum()>=50) + label_idx = np.random.choice(range(1,label_num+1), p=np.array(is_good_mask)/np.sum(is_good_mask)) + lb = np.float32(lb==label_idx) + return s, lb, name, label_idx + + # @tfunctime + def _get_cached_data(self, index): + name = self.img_names[index] + img = cv2.imread(name) + + compressed = np.load(name.replace("image/image_", "label/label_").replace(".jpg", ".npz")) + csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape']) + label_ori = csr.toarray() + label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024) + + label_idx = np.random.randint(0, label_ori.shape[0]) + label_ori = label_ori[label_idx] + return rearrange(img, "h w c -> c h w"), label_ori, name, -1 + + def to_RGB(self, img): + # transform images to RGB style + img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(int) + return img + + # @tfunctime + def __getitem__(self, index, debug=False): + img_ori, label_ori, name, label_idx = self._get_data(index) + img_ori = np.clip(img_ori, -200,400) + img_rgb = self.to_RGB(img_ori) + + assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}" + bundle_ori = {"img":torch.Tensor(img_rgb).unsqueeze(0), "label":torch.Tensor(label_ori).unsqueeze(0)} + # import ipdb; ipdb.set_trace() + if self.is_train: + # bundle = self.transform(bundle_ori)[0] # use with transforms.RandCropByPosNegLabeld + bundle = self.transform(bundle_ori) + else: + bundle = self.test_transform(bundle_ori) + + if not self.use_cache: + bundle['label'] = (bundle['label']>0.5).float() + vector = np.ones(3) + if debug: + ret_dict = { + "name": name, + "img": bundle['img'], + "label": bundle['label'], + "img_ori":img_ori, + "label_ori":label_ori, + "label_idx": label_idx, + "indicators": vector, + # "label_name": + } + return ret_dict + + ret_dict = { + "name": name, + "img": bundle['img'][0].float(), + "label": bundle['label'][0].float(), + "indicators": vector, + } + if bundle['label'][0][1,:,:].sum() <= 0: + return self.__getitem__(index+1 % len(self)) + return ret_dict + + +if __name__ == "__main__": + from tutils.new.manager import ConfigManager + config = ConfigManager() + config.add_config("configs/vit_b_103.yaml") + dataset = Dataset3D(config=config['dataset']) # , use_cache=True + data = dataset.__getitem__(0) + + import ipdb; ipdb.set_trace() + from torch.utils.data import DataLoader + loader = DataLoader(dataset, batch_size=8) + for batch in loader: + print(batch['img'].shape, batch['label'].shape) + # print(data['label'].max()) + # import ipdb; ipdb.set_trace() + diff --git a/datasets/dataset3d_2dmask.py b/datasets/dataset3d_2dmask.py new file mode 100644 index 0000000..808ec24 --- /dev/null +++ b/datasets/dataset3d_2dmask.py @@ -0,0 +1,178 @@ +# from torchvision import transforms +from monai import transforms +import numpy as np +import SimpleITK as sitk +import torch +from torch.utils.data import Dataset as dataset +import glob +import os +from einops import rearrange, repeat +from tutils.nn.data.tsitk import read +from tqdm import tqdm +# from monai.transforms import SpatialPadd, CenterSpatialCropd, Resized, NormalizeIntensityd +# from monai.transforms import RandAdjustContrastd, RandShiftIntensityd, Rotated, RandAffined +# from datasets.common_2d_aug import RandomRotation, RandomResizedCrop, RandomHorizontalFlip, ColorJitter, ToTensor, Normalize +from tutils import tfilename, tdir +import random +from tutils.nn.data import itk_to_np +from scipy.sparse import csr_matrix +import cv2 + + +DEFAULT_PATH="/quanquan/datasets/08_AbdomenCT-1K/" + + +class Dataset2D(dataset): + def __init__(self, dirpath=None, is_train=True, getting_multi_mask=False) -> None: + super().__init__() + self.dirpath = dirpath + self.is_train = is_train + self.getting_multi_mask = getting_multi_mask + self.img_names = self.prepare_datalist() + self.prepare_transforms() + self.weights_dict = {"gt":2, "sam_auto_seg":2, "prompt_point_from_superpixel":1, "prompt_box_from_superpixel":1, "superpixel":0} + + def prepare_transforms(self): + self.transform = transforms.Compose([ + transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)), + # transforms.RandSpatialCropd(keys=["img"], roi_size=(448,448,1)), + transforms.RandAffined(keys=['img', 'label'], prob=0.5, shear_range=(0.2,0.2)), + transforms.RandCropByPosNegLabeld(keys=['img', 'label'], spatial_size=(3,960,960), label_key='label', neg=0), + # transforms.RandSmoothFieldAdjustContrastd(keys=['img', 'label'], ) + transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)), + transforms.RandAdjustContrastd(keys=['img'], ), + transforms.RandShiftIntensityd(keys=['img'], prob=0.8, offsets=(-5, 5)), + ]) + self.test_transform = transforms.Compose([ + transforms.Resized(keys=['img'], spatial_size=(3,1024,1024)), + ]) + + def __len__(self): + return len(self.img_names) + + def to_RGB(self, img): + # transform images to RGB style + img = ((img - img.min()) / img.max() * 255).astype(int) + return img + + def prepare_datalist(self): + dirpath_img = os.path.join(self.dirpath, 'preprocessed', 'cache_2d_various_pseudo_masks') + names = glob.glob(os.path.join(dirpath_img, "*_mask.npz")) + names = [os.path.split(name)[-1].replace("_mask.npz", "") for name in names] + names.sort() + # names = names[:15000] + print(f"[Dataset2d] Load {len(names)} paths.") + assert len(names) > 0, f"{__file__} Gotdirpath: {self.dirpath}" + return names + + + def _get_data(self, index, debug=False, iternum=0): + img_names = self.img_names + img_info = os.path.split(img_names[index])[-1].split('_s') + filename, slice_idx = img_info[0], int(img_info[-1][:4]) + mask_loc = np.random.randint(0,3) + if mask_loc == 0: + slices_indices = [slice_idx, slice_idx+1, slice_idx+2] + elif mask_loc == 1: + slices_indices = [slice_idx-1, slice_idx, slice_idx+1] + elif mask_loc == 2: + slices_indices = [slice_idx-2, slice_idx-1, slice_idx] + + # Load .npy data + filenames = [os.path.join(self.dirpath, "preprocessed", "cache_jpg", f"{filename}_s{i:04}_img.jpg") for i in slices_indices] + for name in filenames: + if not os.path.exists(name): + return self._get_data(index+1 % len(self)) + + imgs = [cv2.imread(name, cv2.IMREAD_GRAYSCALE) for name in filenames] + img_rgb = np.stack(imgs, axis=0) + + # Load RGB data + + compressed = np.load(os.path.join(self.dirpath, "preprocessed", "cache_2d_various_pseudo_masks", img_names[index]+"_mask.npz")) + csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape']) + label_ori = csr.toarray() + label_ori = rearrange(label_ori, "c (h w) -> c h w", h=1024, w=1024) + metadata = np.load(os.path.join(self.dirpath, "preprocessed", "cache_2d_various_pseudo_masks", img_names[index]+"_metadata.npy"), allow_pickle=True) + + label_prob = np.array([self.weights_dict[item['source']] for item in metadata]).astype(float) + label_prob = label_prob / label_prob.sum() + + label_idx = np.random.choice(a=np.arange(len(metadata)), p=label_prob) + label_ori = label_ori[label_idx] + metadata = metadata[label_idx] + assert metadata['source'] != 'superpixel' + + assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}" + bundle_ori = {"img":torch.Tensor(rearrange(img_rgb, "c h w -> 1 c h w")), "label":torch.Tensor(repeat(label_ori, "h w -> 1 3 h w"))} + # import ipdb; ipdb.set_trace() + if self.is_train: + bundle = self.transform(bundle_ori)[0] + else: + bundle = self.test_transform(bundle_ori) + + bundle['label'] = (bundle['label']>0.5).float() + if bundle['label'][0].sum() < 100: + return self._get_data((index+1)%len(self), iternum=iternum+1) + + vector = np.zeros(3) + vector[mask_loc] = 1 + + if debug: + ret_dict = { + "name": img_names[index], + "img": bundle['img'][0], + "label": bundle['label'][0], + "img_ori":img_rgb, + "label_ori":label_ori, + "weight": self.weights_dict[metadata['source']], + "iternum": iternum, + "mask_loc": mask_loc, + "indicators": vector, + } + return ret_dict + + ret_dict = { + "name": img_names[index], + "img": bundle['img'][0], + "label": bundle['label'][0], + "mask_loc": mask_loc, + "indicators": vector, + } + return ret_dict + + + def __getitem__(self, index): + return self._get_data(index) + + +class Testset2d(Dataset2D): + def __init__(self, dirpath=None, is_train=False, getting_multi_mask=False) -> None: + super().__init__(dirpath, is_train, getting_multi_mask) + self.test_names = self.prepare_datalist() + + def prepare_datalist(self): + dirpath_img = os.path.join(self.dirpath, 'cache_2d_various_pseudo_masks') + names = glob.glob(os.path.join(dirpath_img, "*_mask.npz")) + names = [os.path.split(name)[-1].replace("_mask.npz", "") for name in names] + names.sort() + names = names[15000:] + print(f"[Dataset2d] Load {len(names)} paths.") + assert len(names) > 0, f"{__file__} Gotdirpath: {self.dirpath}" + return names + + +if __name__ == "__main__": + from torch.utils.data import DataLoader + + dataset = Dataset2D(dirpath=DEFAULT_PATH) + loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False) + iternums = 0 + for i, data in enumerate(loader): + # iternums += data['iternum'].item() + print(i, iternums / (i+1), data['img'].shape, data['label'].shape) + assert data['label'].sum() >= 100, f"{__file__} Got{data['label'].sum()}" + assert torch.Tensor(data['label']==1).sum() >= 100, f"{__file__} Got {torch.Tensor(data['label']==1).sum().sum()}" + + import ipdb; ipdb.set_trace() + diff --git a/datasets/dataset_merged.py b/datasets/dataset_merged.py new file mode 100644 index 0000000..01e3804 --- /dev/null +++ b/datasets/dataset_merged.py @@ -0,0 +1,74 @@ +# from torchvision import transforms +from monai import transforms +import numpy as np +import SimpleITK as sitk +import torch +from torch.utils.data import Dataset as dataset +import torch.nn.functional as F +import glob +import os +from einops import rearrange + +from tutils.nn.data import read, itk_to_np + +from tqdm import tqdm +from tutils import tfilename, tdir +import random +# from .dataset2d import Dataset2D +from .dataset3d_2dmask import Dataset2D +from .dataset3d import Dataset3D + + +class DatasetMerged(dataset): + def __init__(self, config=None, is_train=True, getting_multi_mask=False) -> None: + super().__init__() + self.dataset2d = Dataset2D(dirpath="/quanquan/datasets/08_AbdomenCT-1K/", is_train=True) + self.dataset3d = Dataset3D(config=config, is_train=True) + self.len_2d = len(self.dataset2d) + self.len_3d = len(self.dataset3d) + + def __getitem__(self, index, debug=False): + index = index % len(self) + # print("DEBUG! is_2d:", index < self.len_2d) + if index < self.len_2d: + return self.dataset2d.__getitem__(index) + else: + index = (index - self.len_2d) % self.len_3d + return self.dataset3d.__getitem__(index) + + def __len__(self): + return len(self.dataset2d) + len(self.dataset3d) * 200 + + + +class TestsetMerged(dataset): + def __init__(self, config=None, is_train=False) -> None: + super().__init__() + self.dataset2d = Dataset2D(dirpath="/quanquan/datasets/08_AbdomenCT-1K/preprocessed/", is_train=False) + self.dataset3d = Dataset3D(config=config, is_train=False, split='val') + self.len_2d = len(self.dataset2d) + self.len_3d = len(self.dataset3d) + + def __getitem__(self, index, debug=False): + index = index % len(self) + if index < self.len_2d: + return self.dataset2d.__getitem__(index) + else: + index = (index - self.len_2d) % self.len_3d + return self.dataset3d.__getitem__(index) + + def __len__(self): + return len(self.dataset2d) + len(self.dataset3d) * 2 + + +if __name__ == "__main__": + from tutils import timer + from tutils.new.manager import trans_args, trans_init, ConfigManager + config = ConfigManager() + config.add_basic_config() + config.add_config("configs/vit_b.yaml") + dataset = DatasetMerged(config['dataset']) + tt = timer() + for i in range(20000,len(dataset)): + data = dataset.__getitem__(i) + print("time: ", tt()) \ No newline at end of file diff --git a/datasets/generate_txt.py b/datasets/generate_txt.py new file mode 100644 index 0000000..655059a --- /dev/null +++ b/datasets/generate_txt.py @@ -0,0 +1,435 @@ +import os +import glob +import shutil +from tutils import tfilename + + +HOME_PATH="/quanquan/datasets/" + +def check_existing(img_path, label_path): + if os.path.exists(img_path) and os.path.exists(label_path): + return True + else: + if not os.path.exists(img_path): + print("IMAGE Not exist: ", img_path) + if not os.path.exists(label_path): + print("LABEL Not exist: ", label_path) + return False + +def get_availabel_files(names, label_names): + llist = [[n,n2] for n,n2 in zip(names, label_names) if check_existing(n, n2)] + names = [n[0] for n in llist] + label_names = [n[1] for n in llist] + return names, label_names + +def write_txt(img_paths, label_paths, meta_info, split, writing_mode='a+'): + dataset_id = meta_info['dataset_id'] + assert split in ['train', 'val', 'test'], f" split in ['train', 'val', 'test'] , but Got {split}" + save_path = meta_info["save_txt_path"].replace("_train.txt", f"_{split}.txt") + + count = 0 + with open(save_path, writing_mode) as f: + for p1, p2 in zip(img_paths, label_paths): + p1 = p1.replace(meta_info['home_path'], "") + p2 = p2.replace(meta_info['home_path'], "") + line = f"{p1}\t{p2}\n" + # print(line, end=" ") + f.write(line) + count += 1 + if count <= 0: + raise ValueError(f"ID: {meta_info['dataset_id']}, \tTask: {meta_info['dataset_name']}\t, {count} files are writen.") + print(f"ID: {meta_info['dataset_id']}, \tTask: {meta_info['dataset_name']}\t, {count} files are writen.\t Writing Over! write into ", save_path) + + +def organize_in_nnunet_style(meta_info): + dirpath = os.path.join(meta_info['home_path'], meta_info['dirpath']) + if os.path.exists(os.path.join(dirpath, "imagesTr")) and os.path.exists(os.path.join(dirpath, "labelsTr")): + img_paths = glob.glob(os.path.join(dirpath, "imagesTr", "*.nii.gz")) + img_paths.sort() + label_paths = [p.replace("imagesTr", "labelsTr")[:-12]+".nii.gz" for p in img_paths] + img_paths, label_paths = get_availabel_files(img_paths, label_paths) + write_txt(img_paths, label_paths, meta_info=meta_info, split='train', writing_mode="a+") + + if os.path.exists(os.path.join(dirpath, "imagesVa")) and os.path.exists(os.path.join(dirpath, "labelsVa")): + img_paths = glob.glob(os.path.join(dirpath, "imagesVa", "*.nii.gz")) + img_paths.sort() + label_paths = [p.replace("imagesVa", "labelsVa")[:-12]+".nii.gz" for p in img_paths] + img_paths, label_paths = get_availabel_files(img_paths, label_paths) + write_txt(img_paths, label_paths, meta_info=meta_info, split='val', writing_mode="a+") + + if os.path.exists(os.path.join(dirpath, "imagesTs")) and os.path.exists(os.path.join(dirpath, "labelsTs")): + img_paths = glob.glob(os.path.join(dirpath, "imagesTs", "*.nii.gz")) + img_paths.sort() + label_paths = [p.replace("imagesTs", "labelsTs")[:-12]+".nii.gz" for p in img_paths] + img_paths, label_paths = get_availabel_files(img_paths, label_paths) + write_txt(img_paths, label_paths, meta_info=meta_info, split='test', writing_mode="a+") + + +def organize_in_style2(meta_info): + dirpath = os.path.join(meta_info['home_path'], meta_info['dirpath']) + if os.path.exists(os.path.join(dirpath, "imagesTr")) and os.path.exists(os.path.join(dirpath, "labelsTr")): + img_paths = glob.glob(os.path.join(dirpath, "imagesTr", "*.nii.gz")) + img_paths.sort() + label_paths = [p.replace("imagesTr", "labelsTr") for p in img_paths] + img_paths, label_paths = get_availabel_files(img_paths, label_paths) + write_txt(img_paths, label_paths, meta_info=meta_info, split='train', writing_mode="a+") + + if os.path.exists(os.path.join(dirpath, "imagesVa")) and os.path.exists(os.path.join(dirpath, "labelsVa")): + img_paths = glob.glob(os.path.join(dirpath, "imagesVa", "*.nii.gz")) + img_paths.sort() + label_paths = [p.replace("imagesVa", "labelsVa") for p in img_paths] + img_paths, label_paths = get_availabel_files(img_paths, label_paths) + write_txt(img_paths, label_paths, meta_info=meta_info, split='val', writing_mode="a+") + + if os.path.exists(os.path.join(dirpath, "imagesTs")) and os.path.exists(os.path.join(dirpath, "labelsTs")): + img_paths = glob.glob(os.path.join(dirpath, "imagesTs", "*.nii.gz")) + img_paths.sort() + label_paths = [p.replace("imagesTs", "labelsTs") for p in img_paths] + img_paths, label_paths = get_availabel_files(img_paths, label_paths) + write_txt(img_paths, label_paths, meta_info=meta_info, split='test', writing_mode="a+") + + +def organize_by_names(names_in, label_names_in, meta_info): + assert len(names_in) > 0, f"Meta info: {meta_info}" + names, label_names = get_availabel_files(names_in, label_names_in) + assert len(names) > 0, f"Meta info: {meta_info}, \n {names_in[:2]} \n {label_names_in[:2]}" + assert len(label_names) > 0, f"Meta info: {meta_info}, \n {names_in[:2]} \n {label_names_in[:2]}" + + # print("debug files", len(names)) + if len(names) > 10: + num_valid = min(int(len(names) // 10), 10) + # print("num valid", num_valid) + train_names = names[:-num_valid*2] + valid_names = names[-num_valid*2:-num_valid] + test_names = names[-num_valid:] + + train_labels = label_names[:-num_valid*2] + valid_labels = label_names[-num_valid*2:-num_valid] + test_labels = label_names[-num_valid:] + + write_txt(train_names, train_labels, meta_info=meta_info, split="train") + write_txt(valid_names, valid_labels, meta_info=meta_info, split="val") + write_txt(test_names, test_labels, meta_info=meta_info, split="test") + else: + write_txt(names, label_names, meta_info=meta_info, split="train") + + +def clear_files(train_path): + if os.path.exists(train_path): + parent, name = os.path.split(train_path) + shutil.move(train_path, tfilename(parent, "misc", name)) + + val_path = train_path.replace("_train.txt", "_val.txt") + if os.path.exists(val_path): + parent, name = os.path.split(val_path) + shutil.move(val_path, os.path.join(parent, "misc", name)) + + test_path = train_path.replace("_train.txt", "_test.txt") + if os.path.exists(test_path): + parent, name = os.path.split(test_path) + shutil.move(test_path, os.path.join(parent, "misc", name)) + print("Files cleared!") + +# from tutils.nn.data import read +# def convert_to_nii(paths): + + +################################################################################### + +################################################################################### + +def get_BCV_Abdomen(save_path=None): + meta_info = { + "dataset_name": "BTCV", + "dataset_id": "01", + "modality": "CT", + "home_path": HOME_PATH, + "dirpath": "01_BCV-Abdomen/Training/", + "save_txt_path": save_path, + } + names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "img/*.nii.gz")) + names.sort() + label_names = [p.replace("img", "label") for p in names] + organize_by_names(names, label_names, meta_info=meta_info) + +def get_AbdomenCT_1K(save_path): + meta_info = { + "dataset_name": "AbdomenCT-1K", + "dataset_id": "08", + "modality": "CT", + "home_path": HOME_PATH, + "dirpath": "08_AbdomenCT-1K", + "save_txt_path": save_path, + } + # print(names) + organize_in_nnunet_style(meta_info=meta_info) + +def get_AMOS(save_path): + meta_info = { + "dataset_name": "AMOS", + "dataset_id": "09", + "modality": "CT", + "home_path": HOME_PATH, + "dirpath": "09_AMOS", + "save_txt_path": save_path, + } + organize_in_style2(meta_info) + +def get_MSD(save_path): + meta_info = { + "dataset_name": "MSD", # Decathlon + "dataset_id": "10", + "modality": "CT", + "home_path": HOME_PATH, + "parent_dirpath": "10_Decathlon", + "dirpath": "", + "save_txt_path": save_path, + } + subtasks = ["Task06_Lung", "Task08_HepaticVessel", "Task09_Spleen", "Task10_Colon"] + for task in subtasks: + # print("Processing ", task) + meta_info_subtask = { + "dataset_name":task, + "dataset_id": f"{meta_info['dataset_id']}_{task[4:6]}", + "home_path":HOME_PATH, + "dirpath": f"{meta_info['parent_dirpath']}/{task}", + "save_txt_path": save_path, + } + # print(meta_info_subtask) + organize_in_style2(meta_info=meta_info_subtask) + + +def get_MSD_MRI(save_path): + meta_info = { + "dataset_name": "MSD", # Decathlon + "dataset_id": "10", + "modality": "MRI", + "home_path": HOME_PATH, + "parent_dirpath": "10_Decathlon", + "dirpath": "", + "save_txt_path": save_path, + } + subtasks = ["Task02_Heart", "Task05_Prostate"] + for task in subtasks: + # print("Processing ", task) + meta_info_subtask = { + "dataset_name":task, + "dataset_id": f"{meta_info['dataset_id']}_{task[4:6]}", + "home_path":HOME_PATH, + "dirpath": f"{meta_info['parent_dirpath']}/{task}", + "save_txt_path": save_path, + } + # print(meta_info_subtask) + organize_in_style2(meta_info=meta_info_subtask) + + +def get_ASOCA(save_path): + meta_info = { + "dataset_name": "ASOCA", + "dataset_id": "51", + "modality": "CT", + "home_path": HOME_PATH, + "dirpath": "51_ASOCA", + "save_txt_path": save_path, + } + names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "image/*.nii.gz")) + names.sort() + label_names = [p.replace("/image", "/label") for p in names] + # print(os.path.join(meta_info['home_path'], meta_info['dirpath'], "image/*.nii.gz") + # print("debug ,", names) + organize_by_names(names, label_names, meta_info=meta_info) + +def get_BCV_Cervix(save_path): + meta_info = { + "dataset_name": "BCV-Cervix", + "dataset_id": "52", + "modality": "CT", + "home_path": HOME_PATH, + "dirpath": "52_BCV-Cervix/Training/", + "save_txt_path": save_path, + } + names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "img/*.nii.gz")) + names.sort() + label_names = [p.replace("/img/", "/label/").replace("-Image", "-Mask") for p in names] + organize_by_names(names, label_names, meta_info=meta_info) + +def get_NIHPancrease(save_path): + meta_info = { + "dataset_name": "NIHPancrease", + "dataset_id": "53", + "modality": "CT", + "home_path": HOME_PATH, + "dirpath": "53_NIHPancrease", + "save_txt_path": save_path, + } + names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "data/*.nii.gz") ) + names.sort() + label_names = [p.replace("/data/PANCREAS_", "/label/label") for p in names] + organize_by_names(names, label_names, meta_info=meta_info) + +def get_CTPelvic(save_path): + meta_info = { + "dataset_name": "CTPelvic1K", + "dataset_id": "54", + "modality": "CT", + "home_path": HOME_PATH, + "dirpath": "54_CTPelvic1K", + "save_txt_path": save_path, + } + names = [] + names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset1_data/*.nii.gz")) + names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset2_data/*.nii.gz")) + names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset3_data/*.nii.gz")) + names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset4_data/*.nii.gz")) + names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset5_data/*.nii.gz")) + names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset7_data/*.nii.gz")) + names.sort() + # xx_data.nii.gz xx_mask_4label.nii.gz + label_names = [p.replace("_data/", "_mask/").replace("_data.nii.gz", "_mask_4label.nii.gz") for p in names] + organize_by_names(names, label_names, meta_info=meta_info) + +def get_FLARE(save_path): + meta_info = { + "dataset_name": "FLARE", + "dataset_id": "55", + "modality": "CT", + "home_path": HOME_PATH, + "dirpath": "55_FLARE22Train", + "save_txt_path": save_path, + "class": ['liver', 'right kidney', 'spleen', 'pancrease', 'aorta','postcava','right adrenal gland','left darenal gland','gallbladder','esophagus','stomach','duodenum','left kidney'], + } + organize_in_nnunet_style(meta_info=meta_info) + +# def get_HAN(save_path): +# meta_info = { +# "dataset_name": "Head-and-neck", +# "dataset_id": "56", +# "modality": "CT", +# "home_path": HOME_PATH, +# "dirpath": "56_Head-and-Neck-challenge", +# "save_txt_path": save_path, +# } +# names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "data/*.nii.gz")) +# names.sort() +# label_names = [p.replace("/data/", "/label/") for p in names] +# organize_by_names(names, label_names, meta_info=meta_info) + +# def get_StructSeg(save_path): +# meta_info = { +# "dataset_name": "StructSeg2019", +# "dataset_id": "57", +# "modality": "CT", +# "home_path": HOME_PATH, +# "dirpath": "57_StructSeg", +# "save_txt_path": save_path, +# } +# names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "HaN_OAR/data/*")) +# names = [f"{name}/data.nii.gz" for name in names] +# names.sort() +# label_names = [p.replace("/data.nii.gz", "/label.nii.gz") for p in names] +# organize_by_names(names, label_names, meta_info=meta_info) + +def get_CHAOS(save_path): + meta_info = { + "dataset_name": "CHAOS", + "dataset_id": "58", + "modality": "MRI", + "home_path": HOME_PATH, + "dirpath": "58_CHAOST2/chaos_MR_T2_normalized/", + "save_txt_path": save_path, + "class": ["liver", "right kidney", "left kidney", "spleen"], + } + names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "image*.nii.gz")) + names.sort() + label_names = [p.replace("/image_", "/label_") for p in names] + organize_by_names(names, label_names, meta_info=meta_info) + +def get_SABS(save_path): + meta_info = { + "dataset_name": "SABS", # BTCV ? + "dataset_id": "59", + "modality": "CT", + "home_path": HOME_PATH, + "dirpath": "59_SABS/sabs_CT_normalized/", + "save_txt_path": save_path, + "class": ["spleen", "right kidney", "left kidney", "gallbladder", "esophagus", "liver", "stomach", "aorta", "postcava", "portal vein and splenic vein", "pancrease", "right adrenal gland", "left adrenal gland"], + } + names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "image_*.nii.gz")) + names.sort() + label_names = [p.replace("/image_", "/label_") for p in names] + organize_by_names(names, label_names, meta_info=meta_info) + + +def get_Totalseg(save_path): + meta_info = { + "dataset_name": "Totalseg", + "dataset_id": "60", + "modality": "CT", + "home_path": HOME_PATH, + # "dirpath": "nnUNet_raw/Dataset101_Totalseg", + "dirpath": "60_Totalseg", + "save_txt_path": save_path, + "class": [], + } + organize_in_nnunet_style(meta_info=meta_info) + + +def get_WORD(save_path): + meta_info = { + "dataset_name": "WORDs", # BTCV ? + "dataset_id": "07", + "modality": "CT", + "home_path": HOME_PATH, + "dirpath": "07_WORD/WORD-V0.1.0/", + "save_txt_path": save_path, + } + organize_in_style2(meta_info=meta_info) + +def generate_all(): + save_path="./datasets/dataset_list/all_train.txt" + clear_files(save_path) + get_BCV_Abdomen(save_path) + get_AbdomenCT_1K(save_path) + get_AMOS(save_path) + get_MSD(save_path) + # get_ASOCA() + # get_BCV_Cervix() + # # get_NIHPancrease() # bug in data ? + # get_CTPelvic() + # get_FLARE() + # get_SABS() + +def generate_their(): + save_path="./datasets/dataset_list/their_train.txt" + clear_files(save_path) + save_path="./datasets/dataset_list/their_train.txt" + get_BCV_Abdomen(save_path) + get_AbdomenCT_1K(save_path) + get_AMOS(save_path) + get_MSD(save_path) + +def generate_ours(): + save_path="./datasets/dataset_list/ours_train.txt" + get_ASOCA(save_path) + get_BCV_Cervix(save_path) + # get_NIHPancrease() # bug in data ? + get_CTPelvic(save_path) + get_FLARE(save_path) + get_SABS(save_path) + +def generate_alp_dataset(): + save_path = "./datasets/dataset_list/alp_train.txt" + clear_files(save_path) + get_SABS(save_path) + get_CHAOS(save_path) + + +if __name__ == "__main__": + print(__file__) + # generate_alp_dataset() + save_path ="./datasets/dataset_list/totalseg_train.txt" + clear_files(save_path) + get_Totalseg(save_path) + + # save_path="./datasets/dataset_list/word_train.txt" + print("Over") \ No newline at end of file diff --git a/datasets/predict_various_masks.py b/datasets/predict_various_masks.py new file mode 100644 index 0000000..2ecdd75 --- /dev/null +++ b/datasets/predict_various_masks.py @@ -0,0 +1,408 @@ +# from utils.predict_automasks import predict +from segment_anything import sam_model_registry, SamPredictor +from core.automask import SamAutomaticMaskGenerator +from core.trainer import SamLearner +# from datasets.dataset2d import Dataset2D +from einops import rearrange, repeat +import torch +import numpy as np + +from skimage.segmentation import slic +from skimage.util import img_as_float +from skimage import io +import os +from tutils import timer, tdir +from scipy.sparse import csr_matrix +# from datasets.dataset2d import Dataset2D +import cv2 +import torch.nn.functional as F +import time +from tutils import tfilename + +import matplotlib.pylab as plt +from .utils import load_compressed_data, show_anns, img_to_show +import cv2 + +def tfunctime(func): + def run(*argv, **kargs): + t1 = time.time() + ret = func(*argv, **kargs) + t2 = time.time() + print(f"[Function <{func.__name__}>] Running time:{(t2-t1):.6f}s") + return ret + return run + +def get_predictors(): + sam_checkpoint = "/quanquan/code/segment-anything/segment_anything/sam_vit_h_4b8939.pth" # for A100 + # sam_checkpoint = "/home1/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_h_4b8939.pth" # for server 103 + device = "cuda" + model_type = "default" + + sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) + sam.to(device=device) + + mask_generator = SamAutomaticMaskGenerator(sam) + predictor = SamLearner(sam_model=sam) + return mask_generator, predictor + +def center_clip(img, eta=3): + count_values = img[torch.where(img>-199)] + mean = count_values.mean() + std = count_values.std() + img3 = torch.clip(img, mean-eta*std, mean+eta*std) + + return img3 + + +def find_not_exist_masks(masks_repo, masks_new, iou_threshold=0.2): + if len(masks_repo) == 0: + return masks_new + def compute_iou(m1, m2): + m1 = m1['segmentation'] + m2 = m2['segmentation'] + intersection = m1*m2 + union = np.float32((m1 + m2) > 0) + return intersection.sum() / union.sum() + to_append = [] + for mask_new in masks_new: + assert isinstance(mask_new, dict), f"{__file__} Got{type(mask_new)}" + intersec_count = 0 + for mask_in_repo in masks_repo: + assert isinstance(mask_in_repo, dict), f"{__file__} Got{type(mask_in_repo)}" + iou = compute_iou(mask_in_repo, mask_new) + .3 + if iou > iou_threshold: + intersec_count += 1 + if intersec_count == 0: + to_append.append(mask_new) + check_keys(to_append) + return to_append + + +def merge_masks(masks_repo, masks_new, iou_threshold=0.2): + to_append = find_not_exist_masks(masks_repo, masks_new, iou_threshold) + # print(f"DEBUG: {len(masks_new) - len(to_append)} masks are deleted, remaining {len(to_append)}. The total achieves {len(masks_repo) + len(to_append)}") + return masks_repo + to_append + + +def dilate_erode(mask): + kernel = np.ones((4, 4), dtype=np.uint8) + mask = cv2.morphologyEx(mask.astype(float), cv2.MORPH_CLOSE, kernel, iterations=1) + return mask + + +def get_superpixel(image, hu_mask, hu_threshold=-50): + if isinstance(image, torch.Tensor): + image = image.detach().cpu().numpy() + if isinstance(hu_mask, torch.Tensor): + hu_mask = hu_mask.detach().cpu().numpy() + segments = slic(image, n_segments=100, compactness=9) + mask_collect = [] + image2 = image[:,:,0] + for i in range(1, segments.max()+1): + mask = torch.Tensor(segments==i).detach().cpu().numpy() + # assert img + if image2[segments==i].mean() <= hu_threshold: + continue + mask = mask * hu_mask + mask = dilate_erode(mask) + mask_data = { + 'segmentation':mask, + 'area':mask.sum(), + 'source': "superpixel", + 'mean_value':(mask * image[:,:,0]).mean(), + } + mask_collect.append(mask_data) + check_keys(mask_collect) + return mask_collect + +# def resize_masks(masks): +# for m in masks: +# m['segmentation'] = F.interpolate(torch.Tensor(m['segmentation'])[None,None,:,:], size=(512,512)).squeeze().numpy() +# return masks + +# @tfunctime +def get_masks_via_color_changing(img, label, mask_generator, predictor=None): + img = repeat(img, " h w -> 1 3 h w") + img = torch.Tensor(img) + img = F.interpolate(img, size=(1024,1024)) + + label = torch.Tensor(label) + if label is not None: + label_masks = [] + for i in range(1,label.max().int()+1): + labeli = torch.Tensor(label==i).float() + labeli = F.interpolate(labeli[None, None, :,:], size=(1024,1024)).squeeze().numpy() + area = labeli.sum() + if area <=10: + continue + mask = { + "segmentation":labeli, + "area":area, + "source": "gt", + } + label_masks.append(mask) + else: + label_masks = [] + + mask_generator.reset_image() + predictor.reset_image() + + masks = mask_generator.generate(img.cuda()) + masks = filter_large_masks(masks) + for mask in masks: + mask['source'] = "sam_auto_seg" + label_masks = merge_masks(label_masks, masks) + del masks + + check_keys(label_masks) + # import ipdb; ipdb.set_trace() + # return img, label_masks + + img2 = center_clip(img, 2.5) + masks = mask_generator.generate(img.cuda()) + masks = filter_large_masks(masks) + label_masks = merge_masks(label_masks, masks) + del masks + del img2 + + # img2 = center_clip(img, 2) + # masks = mask_generator.generate(img2.cuda()) + # masks = filter_large_masks(masks) + # for mask in masks: + # mask['source'] = "sam_auto_seg" + # label_masks = merge_masks(label_masks, masks) + # del masks + # del img2 + + img2 = center_clip(img, 1) + masks = mask_generator.generate(img2.cuda()) + masks = filter_large_masks(masks) + for mask in masks: + mask['source'] = "sam_auto_seg" + label_masks = merge_masks(label_masks, masks) + del masks + del img2 + + img2 = center_clip(img, 0.5) + masks = mask_generator.generate(img2.cuda()) + masks = filter_large_masks(masks) + for mask in masks: + mask['source'] = "sam_auto_seg" + label_masks = merge_masks(label_masks, masks) + del masks + del img2 + + check_keys(label_masks) + # import ipdb; ipdb.set_trace() + return img, label_masks + +def check_keys(masks): + for m in masks: + assert m['segmentation'] is not None + assert m['area'] is not None + assert m['source'] is not None + +def filter_large_masks(masks): + filtered_masks = [] + for mask in masks: + if mask['area'] > 0.25 *1024 * 1024: + continue + filtered_masks.append(mask) + del masks + return filtered_masks + + +def mix_masks(masks): + if len(masks) == 0: + return None + mixed_mask = None + for item in masks: + m = item['segmentation'] + mixed_mask = np.zeros_like(m) if mixed_mask is None else mixed_mask + mixed_mask += m + mixed_mask = np.float32(mixed_mask>0) + return mixed_mask + + +def select_random_point_from_mask(gt_mask): + size = gt_mask.shape + assert len(size) == 2 + xy = np.arange(0, size[0] * size[1]) + gt_mask = np.float32(gt_mask>0) + prob = rearrange(gt_mask, "h w -> (h w)") + prob = prob / prob.sum() + loc = np.random.choice(a=xy, size=1, replace=True, p=prob)[0] + x, y = loc % size[1], loc // size[1] + return x, y + + +def select_center_point_from_mask(gt_mask): + # get indices of all the foreground pixels + indices = np.argwhere(gt_mask > 0) + # calculate the center point by taking the mean of the foreground pixel indices + center = np.mean(indices, axis=0).astype(int) + y, x = center + return x, y + +@tfunctime +def get_masks_via_points_from_superpixels(img, label_masks, superpixels, predictor, hu_mask=None): + if isinstance(hu_mask, torch.Tensor): + hu_mask = hu_mask.detach().cpu().numpy() + total_mask = mix_masks(label_masks) + # superpixels = get_superpixel(rearrange(img, "1 c h w -> h w c")) + points = [] + ex_masks = [] + for seg in superpixels: + mask_collect = [] + for i in range(5): + mm = np.nan_to_num(seg['segmentation'], 0) + if mm.sum() <= 0: + continue + x,y = select_center_point_from_mask(mm) # select_random_point_from_mask + # print(x,y) + if total_mask[y,x] == 1: + continue + # else: + # print(total_mask[y,x]) + point = torch.Tensor([x,y]) + points.append(point) + mask = predictor.generate(img.cuda(), point.cuda().unsqueeze(0).unsqueeze(0))[0,0].detach().cpu().numpy() + mask = mask * hu_mask + mask = dilate_erode(mask) + mask = { + "segmentation": mask, + "area": mask.sum(), + "source": "prompt_point_from_superpixel", + } + mask_collect = merge_masks(mask_collect, [mask]) + ex_masks += mask_collect + check_keys(ex_masks) + return ex_masks + +def mask_to_bbox(mask): + """ copied from data_engine """ + # Find the indices of all non-zero elements in the mask + coords = np.nonzero(mask) + + # Compute the minimum and maximum values of the row and column indices + x_min = np.min(coords[1]) + y_min = np.min(coords[0]) + x_max = np.max(coords[1]) + y_max = np.max(coords[0]) + + # Return the coordinates of the bounding box + return (x_min, y_min, x_max, y_max) + +@tfunctime +def get_masks_via_boxes_from_superpixels(img, label_masks, superpixels, predictor, hu_mask=None): + if isinstance(hu_mask, torch.Tensor): + hu_mask = hu_mask.detach().cpu().numpy() + ex_masks = [] + for seg in superpixels: + if seg['segmentation'].sum() < 100: + continue + x_min, y_min, x_max, y_max = mask_to_bbox(seg['segmentation']) + box = torch.Tensor([x_min, y_min, x_max, y_max]) + mask = predictor.generate_by_box(img.cuda(), box.cuda().unsqueeze(0).unsqueeze(0))[0,0].detach().cpu().numpy() + mask = mask * hu_mask + mask = dilate_erode(mask) + mask = { + "segmentation": mask, + "area": mask.sum(), + "source": "prompt_box_from_superpixel", + } + ex_masks = merge_masks(ex_masks, [mask]) + check_keys(ex_masks) + return ex_masks + + +class VariousPseudoMasksGenerator: + def __init__(self, dataset, label_path:str=None) -> None: + self.dataset = dataset + # assert dirpath.split("/")[-1] == "cache_2d", f"{__file__} Got{dirpath.split('/')[-1]} ; dirpath:{dirpath}" + self.label_path = label_path if label_path is not None else dataset.dirpath.replace("cache_2d", "cache_2d_various_pseudo_masks") + + def example(self, mask_generator=None, predictor=None): + return self.generate(mask_generator=mask_generator, predictor=predictor, is_example=True) + + def generate(self, mask_generator=None, predictor=None, is_example=False): + tt = timer() + if mask_generator is None: + mask_generator, predictor = get_predictors() + self.mask_generator, self.predictor = mask_generator, predictor + for img_idx in range(len(self.dataset)): + tt() + data = self.dataset._get_image(img_idx) + masks_all_layers = [] + words = data['name'].split("/") + dataset_name = words[0] + filename = words[-1].replace(".nii.gz","") + volume_rgb = np.clip(data['img'], -200, 400) + volume_rgb = (volume_rgb - volume_rgb.min()) / volume_rgb.max() + for slice_idx in range(data['img'].shape[0]): + path = tfilename(self.label_path, f'{dataset_name}/{filename}_s{slice_idx}_mask.npz') + # if os.path.exists(path): + # continue + masks = self.get_various_masks(data['img'][slice_idx], data['label'][slice_idx], mask_generator, predictor) + self.save_slice_mask(masks, path, slice_idx) + img_path = tfilename(self.label_path, f'image_{dataset_name}/{filename}_s{slice_idx}.jpg') + self.save_img_rgb(volume_rgb[slice_idx], img_path) + # display + # plt.figure(figsize=(6,6)) + # plt.imshow(img_to_show(data['img'][slice_idx]), cmap='gray') + # show_anns(masks) + # plt.axis('off') + # plt.show() + print(f"Save to {img_path} and {path}") + + print(f"Processing {img_idx}, {len(masks_all_layers)} saved, time used:{tt()}", end='\r') + if is_example: + break + + def save_slice_mask(self, masks, path, slice_idx): + masks_data = np.array([m['segmentation'] for m in masks]).astype(int) + # if len(masks_data) <= 1: + # import ipdb; ipdb.set_trace() + masks_data = F.interpolate(torch.Tensor(masks_data).unsqueeze(1), size=(512,512)).squeeze().numpy() + masks_data = np.int8(masks_data>0) + if len(masks_data.shape) == 2: + masks_data = masks_data[None,:,:] + assert masks_data.shape[1:] == (512,512), f"{__file__} Got{masks_data.shape}" + masks_data = rearrange(masks_data, "n h w -> n (h w)") + csr = csr_matrix(masks_data) + np.savez_compressed(path, data=csr.data, indices=csr.indices, indptr=csr.indptr, shape=csr.shape) + + def save_img_rgb(self, img, path): + img = (img * 255).astype(np.uint8) + cv2.imwrite(path, img) + + + @staticmethod + @tfunctime + def get_various_masks(img_ori, label, mask_generator, predictor): + mask_generator.reset_image() + predictor.reset_image() + img, label_masks = get_masks_via_color_changing(img_ori, label, mask_generator, predictor) + + # # return label_masks + # hu_mask = img_ori>img_ori.min() + 10 + # hu_mask = F.interpolate(torch.Tensor(hu_mask)[None,None,:,:], size=(1024,1024)).squeeze() + + # superpixels = get_superpixel(rearrange(img, "1 c h w -> h w c"), hu_mask=hu_mask) + # exclu_superpixels = find_not_exist_masks(label_masks, superpixels) + # # point_prompt_masks = get_masks_via_boxes_from_superpixels(img, label_masks, exclu_superpixels, predictor, hu_mask=hu_mask) + # box_prompt_masks = get_masks_via_points_from_superpixels(img, label_masks, exclu_superpixels, predictor, hu_mask=hu_mask) + + return label_masks # + box_prompt_masks # + point_prompt_masks # + + +if __name__ == "__main__": + # CUDA_VISIBLE_DEVICES=6 python -m datasets.predict_various_masks + from datasets.dataset3d import Dataset3D + dataset = Dataset3D() + gen = VariousPseudoMasksGenerator(dataset=dataset, + label_path=tdir("/quanquan/datasets/all_datasets/various_masks_3/")) + gen.generate() + + \ No newline at end of file diff --git a/modeling/__init__.py b/modeling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modeling/__pycache__/__init__.cpython-38.pyc b/modeling/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..e1ebe84 Binary files /dev/null and b/modeling/__pycache__/__init__.cpython-38.pyc differ diff --git a/modeling/__pycache__/build_sam3d2.cpython-38.pyc b/modeling/__pycache__/build_sam3d2.cpython-38.pyc new file mode 100644 index 0000000..97d9ba2 Binary files /dev/null and b/modeling/__pycache__/build_sam3d2.cpython-38.pyc differ diff --git a/modeling/__pycache__/common.cpython-38.pyc b/modeling/__pycache__/common.cpython-38.pyc new file mode 100644 index 0000000..ea313d8 Binary files /dev/null and b/modeling/__pycache__/common.cpython-38.pyc differ diff --git a/modeling/__pycache__/image_encoder.cpython-38.pyc b/modeling/__pycache__/image_encoder.cpython-38.pyc new file mode 100644 index 0000000..e4ed32f Binary files /dev/null and b/modeling/__pycache__/image_encoder.cpython-38.pyc differ diff --git a/modeling/__pycache__/mask_decoder3d_2.cpython-38.pyc b/modeling/__pycache__/mask_decoder3d_2.cpython-38.pyc new file mode 100644 index 0000000..957ccf0 Binary files /dev/null and b/modeling/__pycache__/mask_decoder3d_2.cpython-38.pyc differ diff --git a/modeling/__pycache__/prompt_encoder3d.cpython-38.pyc b/modeling/__pycache__/prompt_encoder3d.cpython-38.pyc new file mode 100644 index 0000000..d1f317f Binary files /dev/null and b/modeling/__pycache__/prompt_encoder3d.cpython-38.pyc differ diff --git a/modeling/__pycache__/sam3d.cpython-38.pyc b/modeling/__pycache__/sam3d.cpython-38.pyc new file mode 100644 index 0000000..4fed89b Binary files /dev/null and b/modeling/__pycache__/sam3d.cpython-38.pyc differ diff --git a/modeling/__pycache__/transformer.cpython-38.pyc b/modeling/__pycache__/transformer.cpython-38.pyc new file mode 100644 index 0000000..0f4fcdc Binary files /dev/null and b/modeling/__pycache__/transformer.cpython-38.pyc differ diff --git a/modeling/build_sam3d2.py b/modeling/build_sam3d2.py new file mode 100644 index 0000000..beced3d --- /dev/null +++ b/modeling/build_sam3d2.py @@ -0,0 +1,119 @@ +""" + Differences with build_sam3d.py + use mask_decoder3d_2 instead of mask_decoder3d + +""" + +import torch + +from functools import partial + +# from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer +from .image_encoder import ImageEncoderViT +from .mask_decoder3d_2 import MaskDecoder +from .prompt_encoder3d import PromptEncoder +from .sam3d import Sam +from .transformer import TwoWayTransformer + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam + + +if __name__ == "__main__": + model = build_sam_vit_l() + data = torch.ones((1,3,1024,1024)) + out = model(data, multimask_output=True) + print(out.shape) \ No newline at end of file diff --git a/modeling/common.py b/modeling/common.py new file mode 100644 index 0000000..2bf1523 --- /dev/null +++ b/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/modeling/image_encoder.py b/modeling/image_encoder.py new file mode 100644 index 0000000..8aced9e --- /dev/null +++ b/modeling/image_encoder.py @@ -0,0 +1,402 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x + + +if __name__ == "__main__": + model = ImageEncoderViT() + data = torch.ones(1,3,1024,1024) + print(model) + model(data) + \ No newline at end of file diff --git a/modeling/mask_decoder3d_2.py b/modeling/mask_decoder3d_2.py new file mode 100644 index 0000000..a4f25e2 --- /dev/null +++ b/modeling/mask_decoder3d_2.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F +from einops import rearrange +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + num_slices: int = 3, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_upscaling2 = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_upscaling3 = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + # mask_slice = slice(1, None) + masks = masks[:, 3:, :, :] + iou_pred = iou_pred[:, 1:] + else: + # mask_slice = slice(0, 1) + masks = masks[:, :3, :, :] + iou_pred = iou_pred[:, 0:1] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) # (1,5,256) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # (1,7,256) + + # Expand per-image data in batch direction to be per-mask + # TODO: Why this code??? + # src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = image_embeddings + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + upscaled_embedding2 = self.output_upscaling2(src) + upscaled_embedding3 = self.output_upscaling3(src) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks1 = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + masks2 = (hyper_in @ upscaled_embedding2.view(b, c, h * w)).view(b, -1, h, w) + masks3 = (hyper_in @ upscaled_embedding3.view(b, c, h * w)).view(b, -1, h, w) + masks = torch.stack([masks1, masks2, masks3], axis=2) + assert masks.shape[1] == 4 and masks.shape[2] == 3, f"Got {masks.shape}" + masks = rearrange(masks, "b c1 c2 h w -> b (c1 c2) h w") + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + # masks = rearrange(masks, "b (c1 c2) h w -> b c1 c2 h w", c1=4, c2=3) + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/modeling/prompt_encoder3d.py b/modeling/prompt_encoder3d.py new file mode 100644 index 0000000..e4df03e --- /dev/null +++ b/modeling/prompt_encoder3d.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + # nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + nn.Conv2d(3, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/modeling/sam3d.py b/modeling/sam3d.py new file mode 100644 index 0000000..c388c22 --- /dev/null +++ b/modeling/sam3d.py @@ -0,0 +1,174 @@ + # Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder3d_2 import MaskDecoder +from .prompt_encoder3d import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/modeling/transformer.py b/modeling/transformer.py new file mode 100644 index 0000000..28fafea --- /dev/null +++ b/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..2f03a92 --- /dev/null +++ b/readme.md @@ -0,0 +1,19 @@ + +# Slide-SAM: Medical SAM meets sliding window + + +## Training +prepare datasets +``` +python -m datasets.generate_txt +``` + +cache 3d data into slices +``` +python -m datasets.cache_datasets3d +``` + +run training +``` +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m core.ddp --tag debug +``` \ No newline at end of file diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/volume_eval.py b/test/volume_eval.py new file mode 100644 index 0000000..d18ea18 --- /dev/null +++ b/test/volume_eval.py @@ -0,0 +1,173 @@ +""" + Volume evalutaion + +""" +import torch +import numpy as np +from torch.utils.data import DataLoader +# from datasets.dataset3d import Dataset3D +from tutils.new.manager import ConfigManager +from datasets.eval_dataloader.loader_abstract import AbstractLoader + +from core.volume_predictor import VolumePredictor +from datasets.data_engine import DataManager, BoxPromptGenerator, PointPromptGenerator + +from tutils import tfilename +from tutils.new.trainer.recorder import Recorder +from trans_utils.metrics import compute_dice_np +from trans_utils.data_utils import Data3dSolver + + + +class Evaluater: + def __init__(self, config) -> None: + self.config = config + self.recorder = Recorder() + + def solve(self, model, dataset): + # model.eval() + self.predictor = model + dataloader = DataLoader(dataset, batch_size=1, shuffle=False) + + for i, data in enumerate(dataloader): + # if i <4: + # print + # continue + # for k, v in data.items(): + # if isinstance(v, torch.Tensor): + # data[k] = v.to(self.rank) + if self.config['dataset']['prompt'] == 'box': + res = self.eval_step(data, batch_idx=i) + if self.config['dataset']['prompt'] == 'point': + res = self.eval_step_point(data, batch_idx=i) + self.recorder.record(res) + res = self.recorder.cal_metrics() + print(res) + print("prompt:", self.config['dataset']['prompt'], " class_idx:", self.config['dataset']['label_idx']) + + def eval_step(self, data, batch_idx=0): + name = data['name'] + dataset_name = data['dataset_name'][0] + label_idx = data['label_idx'][0] + template_slice_id = data['template_slice_id'][0] + + assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}" + if template_slice_id == 0: + template_slice_id += 1 + elif template_slice_id == (data['img'].shape[0] - 1): + template_slice_id -= 1 + + spacing = data['spacing'].numpy().tolist()[0] + if data['img'].shape[-1] < 260: + # assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}" + img = data['img'][0][:,:256,:256] + label = data['label'][0][:,:256,:256] + else: + img = data['img'][0] + label = data['label'][0] + # img = torch.clip(img, -200, 600) + box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy()) + box = np.array([box]) + pred, stability = self.predictor.predict_volume( + x=img, + box=box, + template_slice_id=template_slice_id, + return_stability=True, + ) + prompt_type = 'box' + dice = compute_dice_np(pred, label.detach().cpu().numpy()) + Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing) + Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz")) + # Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz")) + # np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability) + print(dataset_name, name, dice) + return {"dice": dice} + + def eval_step_point(self, data, batch_idx=0): + name = data['name'] + dataset_name = data['dataset_name'][0] + label_idx = data['label_idx'][0] + template_slice_id = data['template_slice_id'][0] + spacing = data['spacing'].numpy().tolist()[0] + + assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}" + if template_slice_id == 0: + template_slice_id += 1 + elif template_slice_id == (data['img'].shape[0] - 1): + template_slice_id -= 1 + + if data['img'].shape[-1] < 260: + # assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}" + img = data['img'][0][:,:256,:256] + label = data['label'][0][:,:256,:256] + else: + img = data['img'][0] + label = data['label'][0] + + box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy()) + point = (box[0]+box[2])*0.5 , (box[1]+box[3])*0.5 + point = np.array([point]).astype(int) + if label[template_slice_id][point[0,1], point[0,0]] == 0: + print("Use random point instead !!!") + point = PointPromptGenerator().get_prompt_point(label[template_slice_id]) + point = np.array([point]).astype(int) + # box = np.array([box]) + pred = self.predictor.predict_volume( + x=img, + point_coords=point, + point_labels=np.ones_like(point)[:,:1], + template_slice_id=template_slice_id, + ) + dice = compute_dice_np(pred, label.detach().cpu().numpy()) + prompt_type = 'point' + Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing) + # Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}.nii.gz")) + print(dataset_name, name, dice) + return {"dice": dice} + +def to_RGB(img): + pass + +if __name__ == "__main__": + # from core.learner3 import SamLearner + # from modeling.build_sam3d import sam_model_registry + + from core.learner3 import SamLearner + from modeling.build_sam3d2 import sam_model_registry + + EX_CONFIG = { + 'dataset':{ + 'prompt': 'box', + 'dataset_list': ['word'], # ["sabs"], chaos, word + 'label_idx': 1, + } + } + + config = ConfigManager() + config.add_config("configs/vit_b_103.yaml") + config.add_config(EX_CONFIG) + + # Init Model + model_type = "vit_b" + sam = sam_model_registry[model_type](checkpoint=None) + learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024))) + learner.use_lora() + # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt_v/model_latest.pth" + # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_20.pth" + # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_16.pth" + # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora_small/ckpt/model_epoch_6.pth" + # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora/ckpt/model_epoch_50.pth" + # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_8/ckpt_v/model_latest.pth" + # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_5/ckpt/model_epoch_100.pth" + pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth" + # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_500000.pth" + learner.load_well_trained_model(pth) + learner.cuda() + predictor = VolumePredictor( + model=learner.model, + use_postprocess=True, + use_noise_remove=True,) + + solver = Evaluater(config) + dataset = AbstractLoader(config['dataset'], split="test") + solver.solve(predictor, dataset) \ No newline at end of file diff --git a/trans_utils/__init__.py b/trans_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trans_utils/__pycache__/__init__.cpython-38.pyc b/trans_utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..2ee66e7 Binary files /dev/null and b/trans_utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/trans_utils/__pycache__/trainer_ddp.cpython-38.pyc b/trans_utils/__pycache__/trainer_ddp.cpython-38.pyc new file mode 100644 index 0000000..f69322e Binary files /dev/null and b/trans_utils/__pycache__/trainer_ddp.cpython-38.pyc differ diff --git a/trans_utils/data_utils.py b/trans_utils/data_utils.py new file mode 100644 index 0000000..f8e8ce6 --- /dev/null +++ b/trans_utils/data_utils.py @@ -0,0 +1,97 @@ +import torch +import numpy as np +from tutils import tfilename +from tutils.nn.data import read, itk_to_np, np_to_itk, write +from torchvision.utils import save_image +import SimpleITK as sitk +from tutils.nn.data.tsitk.preprocess import resampleImage + + +class Data3dSolver: + def __init__(self) -> None: + pass + + def simple_write(self, data_np, path="tmp.nii.gz", spacing=None): + assert len(data_np.shape) == 3, f"Got {data_np.shape}" + data_np = data_np.astype(np.int16) + data_itk = np_to_itk(data_np) + if spacing is not None: + data_itk.SetSpacing(spacing) + write(data_itk, path=tfilename(path)) + print("Save to ", path) + + def write_slices(self, data, path="tmp_masks.jpg"): + if isinstance(data, torch.Tensor): + pass + if isinstance(data, np.ndarray): + data = torch.Tensor(data) + assert len(data.shape) == 4, f"Shape should be (b c h w) c=1/3, Got {data.shape}" + assert data.shape[1] == 1 or data.shape[1] == 3, f"Shape should be (b c h w) c=1/3, Got {data.shape}" + assert path.endswith(".jpg") or path.endswith(".png") + save_image(torch.Tensor(data).unsqueeze(1), tfilename(path)) + print("Save to ", path) + + def write_multilabel_nii(self, data, path, meta=None): + if isinstance(data, dict): + data_all = [v for k,v in data.items()] + data = np.stack(data_all, axis=0) + assert len(data.shape) == 4, f"Shape should be (b c h w) , Got {data.shape}" + # Merge labels to one + merged = np.zeros_like(data[0]) + for i, datai in enumerate(data): + merged = np.where(datai > 0, datai * (i+1), merged) + + merged = merged.astype(np.int16) + data_itk = np_to_itk(merged) + if meta is not None: + data_itk = formalize(data_itk, meta) + write(data_itk, path=tfilename(path)) + print("Save to ", path) + + def fwrite(self, data, path, meta): + data = data.astype(np.int16) + data_itk = np_to_itk(data) + data_itk = formalize(data_itk, meta) + write(data_itk, path=tfilename(path)) + + def read(self, path, spacing_norm=True): + data_itk = read(path) + if spacing_norm: + ori_size = data_itk.GetSize() + ori_spacing = data_itk.GetSpacing() + data_itk = self.normalize_spacing(data_itk) + new_size = data_itk.GetSize() + new_spacing = data_itk.GetSpacing() + print("Change size from ", ori_size, new_size) + print("Change spacing from ", ori_spacing, new_spacing) + data_np = itk_to_np(data_itk) + print("[data_utils.DEBUG]", data_np.shape) + return data_np, data_itk.GetSpacing() + + def normalize_spacing(self, data_itk): + spacing = data_itk.GetSpacing() + new_spacing = (min(spacing),min(spacing),min(spacing)) + data_itk = resampleImage(data_itk, NewSpacing=new_spacing) + return data_itk + + +def formalize(img:sitk.SimpleITK.Image, meta:sitk.SimpleITK.Image): + # Size = meta.GetSize() + Spacing = meta.GetSpacing() + Origin = meta.GetOrigin() + Direction = meta.GetDirection() + + img.SetSpacing(Spacing) + img.SetOrigin(Origin) + img.SetDirection(Direction) + return img + + +def write(img:sitk.SimpleITK.Image, path:str, mode:str="nifti"): + """ + Path: (example) os.path.join(jpg_dir, f"trans_{random_name}.nii.gz") + """ + mode = mode.lower() + writer = sitk.ImageFileWriter() + writer.SetFileName(path) + writer.Execute(img) \ No newline at end of file diff --git a/trans_utils/metrics.py b/trans_utils/metrics.py new file mode 100644 index 0000000..104aae2 --- /dev/null +++ b/trans_utils/metrics.py @@ -0,0 +1,24 @@ +import numpy as np + + + +def compute_dice_np(pred_mask, gt_mask): + """ numpy values + """ + assert gt_mask.max() == 1, f"Got gt_mask.max():{gt_mask.max()} Error!!" + pred_mask = np.array(pred_mask>0) + gt_mask = np.array(gt_mask>0) + intersection = np.array(pred_mask * gt_mask).sum() + union = pred_mask.sum() + gt_mask.sum() + dice = intersection * 2 / union # if union > 0 else 0 + return dice + + +def compute_prec_np(pred_mask, gt_mask): + true_pos = (np.int32(pred_mask>0) * np.int32(gt_mask>0)).sum() + return true_pos / np.int32(pred_mask>0).sum() + +def compute_recall_np(pred_mask, gt_mask): + true_pos = (np.int32(pred_mask>0) * np.int32(gt_mask>0)).sum() + false_neg = ((gt_mask - pred_mask)>0).sum() + return true_pos / (true_pos + false_neg) \ No newline at end of file diff --git a/trans_utils/trainer_ddp.py b/trans_utils/trainer_ddp.py new file mode 100644 index 0000000..cf8bab0 --- /dev/null +++ b/trans_utils/trainer_ddp.py @@ -0,0 +1,418 @@ +import os +import torch +import torch.nn as nn +import torch.optim as optim +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader, Dataset + +import torch.multiprocessing as mp +import torch.distributed as dist +import tempfile +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.cuda.amp import autocast, GradScaler + +from tutils.new.trainer.trainer_abstract import AbstractTrainer +from tutils.new.manager.loggers import MultiLogger +from tutils.new.manager.csv_recorder import CSVLogger +from tutils.new.trainer.recorder import Recorder +from tutils.new.utils.core_utils import _get_time_str +from tutils.new.utils.public_utils import dict_to_str + +# Waiting for update +from tutils import tfilename, tenum + +# export MASTER_ADDR=192.168.1.100 +# export MASTER_PORT=12345 + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + +def cleanup(): + dist.destroy_process_group() + + +def ddp(rank, world_size): + print(f"Running basic DDP example on rank {rank}.") + + # create model and move it to GPU with id rank + model = model.to(rank) + ddp_model = DDP(model, device_ids=[rank]) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) + + optimizer.zero_grad() + outputs = ddp_model(torch.randn(20, 10)) + labels = torch.randn(20, 5).to(rank) + loss_fn(outputs, labels).backward() + optimizer.step() + + cleanup() + + +def get_logger(config): + config_base = config['base'] + config_logger = config['logger'] + logger = MultiLogger(logdir=config_base['runs_dir'], + record_mode=config_logger.get('record_mode', None), + tag=config_base['tag'], + extag=config_base.get('experiment', None), + action=config_logger.get('action', 'k')) # backup config.yaml + return logger + +class DDPTrainer(AbstractTrainer): + def __init__(self, config, tester=None, monitor=None, rank='cuda', world_size=0, logger=None): + super().__init__(config, tester, monitor, rank, world_size) + self.rank = rank + self.logger = logger + self.logging_available = (self.rank == 0 or self.rank == 'cuda') + print("Running on ", rank) + self.global_iteration = 0 + if self.logging_available: + print(f"Logger at Process(rank={rank})") + self.recorder = Recorder(reduction=self.recorder_mode) + self.recorder_valid = Recorder(reduction=self.recorder_mode) + self.recorder_test = Recorder(reduction=self.recorder_mode) + self.logger = None + self.csvlogger = CSVLogger(tfilename(self.runs_dir, "best_record")) + self.csvlogger_all = CSVLogger(tfilename(self.runs_dir, "all_record")) + self.monitor = monitor + self.tester = tester + + self.logger = get_logger(config) + assert self.logger is not None, f"{__file__} Gotrank {self.rank}" + + if self.use_amp: + self.scalar = GradScaler() + print("Debug settings: use amp=",self.use_amp) + + def init_model(self, model, trainset, validset=None, **kwargs): + # Use CacheDataset + # trainset = CacheDataset(trainset, num_workers=12, cache_rate=0.5) + if trainset is not None: + assert len(trainset) > 0 , f"{__file__} Got{len(trainset)}" + self.trainloader = DataLoader(dataset=trainset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + drop_last=True, + pin_memory=True) + if validset is not None: + self.validloader = DataLoader(dataset=validset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + drop_last=True, + pin_memory=True) + if self.load_pretrain_model: + model.module.load() + rank = self.rank + model = model.to(rank) + ddp_model = DDP(model, device_ids=[rank]) + return ddp_model + + def configure_optim(self, model, **kwargs): + # Set optimizer and scheduler + optim_configs = model.module.configure_optimizers() + assert isinstance(optim_configs, dict) + optimizer = optim_configs['optimizer'] + scheduler = optim_configs['scheduler'] + + if self.load_optimizer: + start_epoch = model.module.load_optim(optimizer) + print(f"[DDPTrainer] Continue training, from epoch {start_epoch}") + else: + start_epoch = self.start_epoch + return optimizer, scheduler, start_epoch + + def fit(self, model, trainset, validset=None): + model = self.init_model(model, trainset, validset=validset, rank=self.rank) + self.init_timers() + optimizer, scheduler, start_epoch = self.configure_optim(model) + + for epoch in range(start_epoch, self.max_epochs): + self.on_before_zero_grad() + # Training + self.timer_epoch() + do_training_log = (epoch % self.training_log_interval == 0) + if self.validloader is not None and self.validation_interval > 0 and epoch % self.validation_interval == 0: + self.valid(model, self.validloader, epoch, do_training_log) + + if self.trainloader is not None: + self.train(model, self.trainloader, epoch, optimizer, scheduler, do_training_log) + + if self.logging_available: + self.test(model, epoch=epoch) + + if epoch % self.save_interval == 0 and self.logging_available: + if 'latest' in self.save_mode: + self.save(model, epoch, 'latest', optimizer) + if 'all' in self.save_mode: + self.save(model, epoch, None, optimizer) + # time_save_model = self.timer_epoch() + + print("Training is Over for GPU rank ", self.rank) + self.cleanup() + + def test(self, model, epoch): + # Evaluation + if epoch % self.val_check_interval == 0 and self.logging_available: + print("Note: Tester runs on only") + if self.tester is not None: + out = self.tester.test(model=model, epoch=epoch, rank=self.rank) + if self.monitor is not None: + best_dict = self.monitor.record(out, epoch) + self.recorder_test.record({**best_dict, **out}) + if best_dict['isbest']: + if 'best' in self.save_mode: + self.save(model, epoch, type='best') + self.csvlogger.record({**best_dict, **out, "time": _get_time_str()}) + if self.save_all_records: + self.csvlogger_all.record({**best_dict, **out, "time": _get_time_str()}) + self.logger.info(f"\n[*] {dict_to_str(best_dict)}[*] Epoch {epoch}: \n{dict_to_str(out)}") + self.logger.add_scalars(out, step=epoch, tag='test') + # if '' + else: + self.logger.info(f"\n[*] Epoch {epoch}: {dict_to_str(out)}") + self.on_after_testing(d=out) + + def save(self, model, epoch, type=None, optimizer=None, **kwargs): + if self.logging_available: + if type is None: + # if self.save_interval > 0 and epoch % self.save_interval == 0: + save_name = "/ckpt/model_epoch_{}.pth".format(epoch) + model.module.save(tfilename(self.runs_dir, save_name), epoch=epoch) + self.logger.info(f"Epoch {epoch}: Save model to ``{save_name}``! ") + elif type == 'best': + # save_name = "/ckpt/best_model_epoch_{}.pth".format(epoch) + save_name2 = "/ckpt_v/model_best.pth" + # model.save(tfilename(self.runs_dir, save_name), epoch=epoch, is_best=True) + model.module.save(tfilename(self.runs_dir, save_name2), epoch=epoch, is_best=True) + self.logger.info(f"[Best model] Epoch {epoch}: Save model to ``{save_name2}``! ") + elif type == 'latest': + if self.save_interval > 0 and epoch % self.save_interval == 0: + save_name = "/ckpt_v/model_latest.pth" + model.module.save(tfilename(self.runs_dir, save_name), epoch=epoch, is_latest=True) + save_optim_name = "/ckpt/optim_latest.pth" + model.module.save_optim(tfilename(self.runs_dir, save_optim_name), optimizer=optimizer, epoch=epoch) + self.logger.info(f"Epoch {epoch}: Save checkpoint to ``{save_name}``") + elif type == "iteration": + save_name = "/ckpt/model_iter_{}.pth".format(self.global_iteration) + model.module.save(tfilename(self.runs_dir, save_name), epoch=self.global_iteration) + self.logger.info(f"Epoch {epoch}: Save model to ``{save_name}``! ") + + + def train(self, model, trainloader, epoch, optimizer, scheduler=None, do_training_log=True): + model.train() + out = {} + if do_training_log and self.logging_available: + self.recorder.clear() + time_record = 0.1111 + self.timer_batch() + + success_count = 0 + failed_count = 0 + for load_time, batch_idx, data in tenum(trainloader): + optimizer.zero_grad() + self.timer_data() + # training steps + for k, v in data.items(): + if isinstance(v, torch.Tensor): + data[k] = v.to(self.rank) + time_data_cuda = self.timer_data() + if self.use_amp: + with autocast(): + self.timer_net() + out = model.module.training_step(data, batch_idx, epoch=epoch) + # try: + # out = model.module.training_step(data, batch_idx, epoch=epoch) + # except Exception as e: + # msg = f"Ignore Error! {e}" + # if self.logging_available: + # self.logger.info(msg) + # else: + # print(msg) + # continue + assert isinstance(out, dict) + time_fd = self.timer_net() + loss = out['loss'] + self.scalar.scale(loss).backward() + self.scalar.step(optimizer) + self.scalar.update() + time_bp = self.timer_net() + else: + self.timer_net() + try: + out = model.module.training_step(data, batch_idx, epoch=epoch) + except Exception as e: + msg = f"Ignore Error! {e}" + if self.logging_available: + self.logger.info(msg) + else: + print(msg) + continue + if out['loss'] is None: + failed_count += 1 + continue + if torch.isnan(out['loss']): + print("Ignore Nan Value: ", out['loss']) + failed_count += 1 + # raise ValueError(f"Get loss: {out['loss']}") + assert isinstance(out, dict) + time_fd = self.timer_net() + loss = out['loss'] + loss.backward() + optimizer.step() + time_bp = self.timer_net() + success_count += 1 + + time_batch = self.timer_batch() + # batch logger ! + if self.logging_available and do_training_log: + out['time_load'] = load_time + out['time_cuda'] = time_data_cuda + out['time_forward'] = time_fd + out['time_bp'] = time_bp + out['time_record'] = time_record + out['time_batch'] = time_batch + self.timer_data() + self.recorder.record(out) + time_record = self.timer_data() + + # for debug ! + if epoch == 0: + if self.logging_available: + self.logger.info("[*] Debug Checking Pipeline !!!") + del out + return + if self.global_iteration % 100 == 0: + print(f"Epoch: {epoch} | batch:{batch_idx}/{len(trainloader)}, Iteration:{self.global_iteration}, results: {to_item(out)}", end='\n') + if self.global_iteration % 5000 == 0: + self.save(model, epoch, "iteration", optimizer) + # print("") + self.global_iteration += 1 + + if scheduler is not None: + scheduler.step() + + # epoch logger ! + if self.logging_available: + if do_training_log : + _dict = self.recorder.cal_metrics() + _dict['time_total'] = self.timer_epoch() + + # print(_dict) + # assert isinstance(lr, float), f"Got lr={lr}, type: {type(lr)}" + loss_str = "" + for k, v in _dict.items(): + loss_str += "{}:{:.4f} ".format(k, v) + # lr = optimizer.param_groups[0]['lr'] + lr = self.get_lr(optimizer) + _dict['lr'] = lr + loss_str += "{}:{:.6e} ".format('lr', lr) + self.logger.info(f"Epoch {epoch}: {loss_str}") + # _dict_with_train_tag = {f"train/{k}":v for k,v in _dict.items()} + self.logger.add_scalars(_dict, step=epoch, tag='train') + time_log_scalars = self.timer_epoch() + self.on_after_training(d=_dict) + # Clear + del out + del data + + def valid(self, model, validloader, epoch, do_training_log=True): + model.eval() + out = {} + if do_training_log and self.logging_available: + self.recorder_valid.clear() + time_record = 0.1 + self.timer_batch() + + success_count = 1 + failed_count = 1 + for load_time, batch_idx, data in tenum(validloader): + # model.on_before_zero_grad() + self.timer_data() + # training steps + for k, v in data.items(): + if isinstance(v, torch.Tensor): + data[k] = v.to(self.rank) + time_data_cuda = self.timer_data() + if self.use_amp: + with autocast(): + self.timer_net() + out = model.module.validation_step(data, batch_idx, epoch=epoch) + assert isinstance(out, dict) + time_fd = self.timer_net() + else: + self.timer_net() + out = model.module.validation_step(data, batch_idx, epoch=epoch) + if out['loss'] is None: + failed_count += 1 + continue + if torch.isnan(out['loss']): + self.logger.info("Nan Value: ", out['loss']) + failed_count += 1 + raise ValueError(f"Get loss: {out['loss']}") + assert isinstance(out, dict) + time_fd = self.timer_net() + success_count += 1 + + time_batch = self.timer_batch() + # batch logger ! + if do_training_log and self.logging_available: + out['time_load'] = load_time + out['time_cuda'] = time_data_cuda + out['time_forward'] = time_fd + out['time_record'] = time_record + out['time_batch'] = time_batch + self.timer_data() + self.recorder_valid.record(out) + time_record = self.timer_data() + + if batch_idx % 2 == 0: + print(f"Valid Epoch: {epoch}. Processing batch_idx:{batch_idx} / {len(validloader)}, time_load: {load_time}, results: {to_item(out)}", end='\r') + + if epoch == 0: + if self.logging_available: + self.logger.info("[*] Debug Checking validation Pipeline !!!") + del out + return + # model.on_after_zero_grad(d=out) + if self.logging_available: + self.logger.info(f"Training Success Ratio: {success_count / (success_count + failed_count)}") + + # epoch logger ! + if self.logging_available: + if do_training_log : + _dict = self.recorder_valid.cal_metrics() + _dict['time_total'] = self.timer_epoch() + + # print(_dict) + # assert isinstance(lr, float), f"Got lr={lr}, type: {type(lr)}" + loss_str = "" + for k, v in _dict.items(): + loss_str += "{}:{:.4f} ".format(k, v) + self.logger.info(f"Epoch {epoch}: {loss_str}") + # _dict_with_val_tag = {f"val/{k}":v for k,v in _dict.items()} + self.logger.add_scalars(_dict, step=epoch, tag='val') + time_log_scalars = self.timer_epoch() + self.on_after_training(d=_dict) + # Clear + del out + del data + + + +def to_item(tensors): + for k,v in tensors.items(): + if isinstance(v, torch.Tensor): + tensors[k] = v.detach().cpu().item() + return tensors \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..5277f46 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..c0d0778 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..45a429b Binary files /dev/null and b/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/__pycache__/amg.cpython-38.pyc b/utils/__pycache__/amg.cpython-38.pyc new file mode 100644 index 0000000..2ac6acd Binary files /dev/null and b/utils/__pycache__/amg.cpython-38.pyc differ diff --git a/utils/__pycache__/amg3d.cpython-38.pyc b/utils/__pycache__/amg3d.cpython-38.pyc new file mode 100644 index 0000000..d942b76 Binary files /dev/null and b/utils/__pycache__/amg3d.cpython-38.pyc differ diff --git a/utils/__pycache__/masks3d_utils.cpython-38.pyc b/utils/__pycache__/masks3d_utils.cpython-38.pyc new file mode 100644 index 0000000..6b84347 Binary files /dev/null and b/utils/__pycache__/masks3d_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/transforms.cpython-310.pyc b/utils/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000..ab49768 Binary files /dev/null and b/utils/__pycache__/transforms.cpython-310.pyc differ diff --git a/utils/__pycache__/transforms.cpython-38.pyc b/utils/__pycache__/transforms.cpython-38.pyc new file mode 100644 index 0000000..e94db31 Binary files /dev/null and b/utils/__pycache__/transforms.cpython-38.pyc differ diff --git a/utils/__pycache__/transforms3d.cpython-38.pyc b/utils/__pycache__/transforms3d.cpython-38.pyc new file mode 100644 index 0000000..f04488b Binary files /dev/null and b/utils/__pycache__/transforms3d.cpython-38.pyc differ diff --git a/utils/amg.py b/utils/amg.py new file mode 100644 index 0000000..c0124c2 --- /dev/null +++ b/utils/amg.py @@ -0,0 +1,399 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + if k == "boxes" and len(v.shape) == 1: + v = v[None,:] + if k == "boxes" and len(self._stats[k].shape) == 1: + self._stats[k] = self._stats[k][None, :] + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + masks = masks.squeeze() + assert len(masks.shape) <= 3 + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out + +def batched_mask3d_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + assert masks.shape[1] == 3 + masks = masks[:,1,:,:] + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/utils/amg3d.py b/utils/amg3d.py new file mode 100644 index 0000000..94a5dab --- /dev/null +++ b/utils/amg3d.py @@ -0,0 +1,204 @@ + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData3d: + def __init__(self, size, **kwargs) -> None: + self.size = size + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, slice_id, num, item): + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + key = str(slice_id) + str(num) + if self._stats.get(key, None) is None: + self._stats[key] = np.zeros(self.size) + self._stats[key][slice_id-1:slice_id+2] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def merge(self, slice_ids, num, item): + pass + + +def build_all_layer_point_grids( + n_per_side: int = 32, n_layers: int = 0, scale_per_layer: int = 1) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def build_point_grid(n_per_side: int, size) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points * np.array(size) + +# def calculate_stability_score( +# masks: torch.Tensor, mask_threshold: float, threshold_offset: float +# ) -> torch.Tensor: +# """ +# Computes the stability score for a batch of masks. The stability +# score is the IoU between the binary masks obtained by thresholding +# the predicted mask logits at high and low values. +# """ +# # One mask is always contained inside the other. +# # Save memory by preventing unnecessary cast to torch.int64 +# intersections = ( +# (masks > (mask_threshold + threshold_offset)) +# .sum(-1, dtype=torch.int16) +# .sum(-1, dtype=torch.int32) +# ) +# unions = ( +# (masks > (mask_threshold - threshold_offset)) +# .sum(-1, dtype=torch.int16) +# .sum(-1, dtype=torch.int32) +# ) +# return intersections / unions + + +def calculate_stability_score_3d( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + .sum(-1, dtype=torch.int32) + ) + # intersections = intersections2d.sum(-1, dtype=torch.int32) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + .sum(-1, dtype=torch.int32) + ) + return (intersections / unions) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() diff --git a/utils/masks3d_utils.py b/utils/masks3d_utils.py new file mode 100644 index 0000000..7eb5772 --- /dev/null +++ b/utils/masks3d_utils.py @@ -0,0 +1,59 @@ +from tutils.nn.data import write, np_to_itk +import numpy as np + + +class Masks3D: + def __init__(self) -> None: + pass + + def from_dict(self, masks): + self.masks = masks + # self.tags = masks + + def to_2dmasks(self): + pass + + def filter_by_bg(self, volume, threshold=None): + threshold = (volume.max() - volume.min()) * 0.1 + volume.min() + keys = self.masks.keys() + for k in keys: + v = self.masks[k] + assert v.shape == volume.shape, f"Got shape ERROR, {v.shape, volume.shape}" + if (v * volume).mean() <= threshold: + self.masks.pop(k) + + # def filter_by_area(self,): + + + def sort_by_logits(self): + self.confidences = [] + self.tags_by_conf = [] + for k, v in self.masks.items(): + confidence = v[v>0].mean() + self.confidences.append(confidence) + self.tags_by_conf.append(k) + indices = np.argsort(self.confidences)[::-1] + self.tags_by_conf = np.array(self.tags_by_conf)[indices].tolist() + self.confidences = np.array(self.confidences)[indices] + + def to_nii(self, path="tmp.nii.gz"): + self.sort_by_logits() + total = None + for i, k in enumerate(self.tags_by_conf): + mask = np.int32(self.masks[k]>0) + if total is None: + total = mask * i + else: + total = np.where(total>0, total, mask * i) + mask_itk = np_to_itk(total) + write(mask_itk, path) + + +if __name__ == "__main__": + p = "/home1/quanquan/code/projects/finetune_large/segment_anything/tmp.npy" + data = np.load(p, allow_pickle=True).tolist() + # import ipdb; ipdb.set_trace() + print(data.keys()) + mm = Masks3D() + mm.from_dict(data) + mm.to_nii() \ No newline at end of file diff --git a/utils/onnx.py b/utils/onnx.py new file mode 100644 index 0000000..3196bdf --- /dev/null +++ b/utils/onnx.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from typing import Tuple + +from ..modeling import Sam +from .amg import calculate_stability_score + + +class SamOnnxModel(nn.Module): + """ + This model should not be called directly, but is used in ONNX export. + It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, + with some functions modified to enable model tracing. Also supports extra + options controlling what information. See the ONNX export script for details. + """ + + def __init__( + self, + model: Sam, + return_single_mask: bool, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + ) -> None: + super().__init__() + self.mask_decoder = model.mask_decoder + self.model = model + self.img_size = model.image_encoder.img_size + self.return_single_mask = return_single_mask + self.use_stability_score = use_stability_score + self.stability_score_offset = 1.0 + self.return_extra_metrics = return_extra_metrics + + @staticmethod + def resize_longest_image_size( + input_image_size: torch.Tensor, longest_side: int + ) -> torch.Tensor: + input_image_size = input_image_size.to(torch.float32) + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + point_coords = point_coords / self.img_size + point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( + point_labels == -1 + ) + + for i in range(self.model.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ + i + ].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding + + def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore + + orig_im_size = orig_im_size.to(torch.int64) + h, w = orig_im_size[0], orig_im_size[1] + masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + return masks + + def select_masks( + self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Determine if we should return the multiclick mask or not from the number of points. + # The reweighting is used to avoid control flow. + score_reweight = torch.tensor( + [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] + ).to(iou_preds.device) + score = iou_preds + (num_points - 2.5) * score_reweight + best_idx = torch.argmax(score, dim=1) + masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) + iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) + + return masks, iou_preds + + @torch.no_grad() + def forward( + self, + image_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + orig_im_size: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + masks, scores = self.model.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, + ) + + if self.use_stability_score: + scores = calculate_stability_score( + masks, self.model.mask_threshold, self.stability_score_offset + ) + + if self.return_single_mask: + masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) + + upscaled_masks = self.mask_postprocessing(masks, orig_im_size) + + if self.return_extra_metrics: + stability_scores = calculate_stability_score( + upscaled_masks, self.model.mask_threshold, self.stability_score_offset + ) + areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks + + return upscaled_masks, scores, masks diff --git a/utils/transforms.py b/utils/transforms.py new file mode 100644 index 0000000..c08ba1e --- /dev/null +++ b/utils/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/utils/transforms3d.py b/utils/transforms3d.py new file mode 100644 index 0000000..d20f8c8 --- /dev/null +++ b/utils/transforms3d.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore, +from torch.nn.functional import interpolate + +from copy import deepcopy +from typing import Tuple + + + +class SimpleResize: + """ + Keep the same with training, + maybe fixed in future. + """ + def __init__(self, target_length: int = 1024) -> None: + self.target_length = (target_length, target_length) + + def apply_image(self, image: torch.tensor) -> torch.tensor: + target_size = self.target_length + return interpolate(image.unsqueeze(1), size=(target_size[0], target_size[1])).squeeze() + # return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + old_h, old_w = original_size + new_h, new_w = self.target_length + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.target_length + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + old_h, old_w = original_size + new_h, new_w = self.target_length + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + + + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int = 1024) -> None: + self.target_length = target_length + + def apply_image(self, image: torch.tensor) -> torch.tensor: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return interpolate(image, size=(target_size[0], target_size[1], image.shape[-1])) + # return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int): + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww)