first commit
This commit is contained in:
commit
e04459c6fe
53
configs/vit_b.yaml
Normal file
53
configs/vit_b.yaml
Normal file
@ -0,0 +1,53 @@
|
||||
#### basic configs
|
||||
# dataset:
|
||||
# name: 'Cephalometric'
|
||||
# pth: '/home1/quanquan/datasets/Cephalometric/'
|
||||
|
||||
|
||||
# ---------------------- Common Configs --------------------------
|
||||
base:
|
||||
base_dir: "../runs/sam/"
|
||||
tag: ''
|
||||
stage: ''
|
||||
logger:
|
||||
mode: ['tb', ]
|
||||
# mode: ''
|
||||
recorder_reduction: 'mean'
|
||||
|
||||
training:
|
||||
save_mode: ['all','best', 'latest'] #
|
||||
batch_size : 8 # 20 for A100
|
||||
num_workers : 16
|
||||
num_epochs : 500 # epochs
|
||||
use_amp: true
|
||||
save_interval : 4
|
||||
val_check_interval: 6
|
||||
load_pretrain_model: false
|
||||
|
||||
# optim:
|
||||
lr: 0.0002
|
||||
decay_step: 2000
|
||||
decay_gamma: 0.8
|
||||
weight_decay: 0.0001
|
||||
alpha: 0.99
|
||||
validation_interval: 100
|
||||
|
||||
dataset:
|
||||
types: ['3d'] # ['3d', '2d']
|
||||
split: 'train'
|
||||
data_root_path: '/quanquan/datasets/'
|
||||
dataset_list: ["alp", "word", "debug"] # ['sam', "their", "ours"]
|
||||
data_txt_path: './datasets/dataset_list/'
|
||||
dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/"
|
||||
cache_data_path: '/home1/quanquan/datasets/cached_dataset2/'
|
||||
|
||||
# sam_checkpoint: "/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server
|
||||
# model_type: "vit_b"
|
||||
# Continue training
|
||||
# continue_training: true
|
||||
# load_optimizer: true
|
||||
# breakpoint_path: "/quanquan/code/segment-anything/runs/sam/ddp_b1/lora_3d_2dm"
|
||||
|
||||
test:
|
||||
batch_size: 1
|
||||
|
0
core/__init__.py
Normal file
0
core/__init__.py
Normal file
BIN
core/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
core/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
core/__pycache__/ddp.cpython-38.pyc
Normal file
BIN
core/__pycache__/ddp.cpython-38.pyc
Normal file
Binary file not shown.
BIN
core/__pycache__/ddp_b10.cpython-38.pyc
Normal file
BIN
core/__pycache__/ddp_b10.cpython-38.pyc
Normal file
Binary file not shown.
BIN
core/__pycache__/learner2.cpython-38.pyc
Normal file
BIN
core/__pycache__/learner2.cpython-38.pyc
Normal file
Binary file not shown.
BIN
core/__pycache__/learner3.cpython-38.pyc
Normal file
BIN
core/__pycache__/learner3.cpython-38.pyc
Normal file
Binary file not shown.
BIN
core/__pycache__/learner5.cpython-38.pyc
Normal file
BIN
core/__pycache__/learner5.cpython-38.pyc
Normal file
Binary file not shown.
BIN
core/__pycache__/lora_sam.cpython-38.pyc
Normal file
BIN
core/__pycache__/lora_sam.cpython-38.pyc
Normal file
Binary file not shown.
BIN
core/__pycache__/loss.cpython-38.pyc
Normal file
BIN
core/__pycache__/loss.cpython-38.pyc
Normal file
Binary file not shown.
142
core/ddp.py
Normal file
142
core/ddp.py
Normal file
@ -0,0 +1,142 @@
|
||||
"""
|
||||
from ddp_b9.py
|
||||
|
||||
Train the whole ViT
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from tutils import tfilename, tdir
|
||||
|
||||
from datasets.dataset3d_2dmask import Dataset2D
|
||||
# from datasets.dataset3d import Dataset3D
|
||||
from datasets.cache_dataset3d3 import Dataset3D
|
||||
from datasets.dataset_merged import DatasetMerged, TestsetMerged
|
||||
from datasets.data_engine import DataEngine
|
||||
from modeling.build_sam3d2 import sam_model_registry
|
||||
|
||||
from .learner5 import SamLearner
|
||||
# from tutils.new.trainer.trainer_ddp import DDPTrainer
|
||||
from trans_utils.trainer_ddp import DDPTrainer
|
||||
# from .lora_sam import LoRA_Sam
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def setup(rank, world_size):
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
|
||||
# initialize the process group
|
||||
dist.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
def load_well_trained_model(model, lora_module, pth=None):
|
||||
print("Loading from ", pth)
|
||||
state_dict = torch.load(pth, map_location='cpu')
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
def ddp_train(rank, world_size, config):
|
||||
setup(rank, world_size)
|
||||
|
||||
sam_checkpoint = "/quanquan/code/segment-anything/segment_anything/sam_vit_b_01ec64.pth" # A800 server
|
||||
# sam_checkpoint = "/home1/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server
|
||||
model_type = "vit_b"
|
||||
device = rank
|
||||
|
||||
config_data = config['dataset']
|
||||
data_type = config_data.get("types", ["3d", "2d"])
|
||||
data_type = [data_type] if isinstance(data_type, str) else data_type
|
||||
if '2d' in data_type:
|
||||
if '3d' in data_type:
|
||||
dataset = DatasetMerged(config_data, is_train=True)
|
||||
else:
|
||||
dataset = Dataset2D(dirpath=config_data['dataset2d_path'], is_train=True)
|
||||
elif '3d' in data_type:
|
||||
dataset = Dataset3D(config_data, split='train')
|
||||
else:
|
||||
raise NotImplementedError
|
||||
# assert len(validset) > 0
|
||||
data_engine = DataEngine(dataset=dataset, img_size=(1024,1024))
|
||||
sam = sam_model_registry[model_type](checkpoint=None)
|
||||
|
||||
learner = SamLearner(sam_model=sam, config=config, data_engine=data_engine)
|
||||
config_train = config['training']
|
||||
if config_train.get('continue_training', False):
|
||||
learner.use_lora()
|
||||
learner.load_well_trained_model(config['training']['breakpoint_path']) # use preset path
|
||||
else:
|
||||
learner.load_pretrained_model(sam_checkpoint)
|
||||
learner.use_lora()
|
||||
|
||||
print('Before Setting', get_parameter_number(learner.model))
|
||||
for param in learner.model.image_encoder.parameters():
|
||||
param.requires_grad = True
|
||||
print('After Setting', get_parameter_number(learner.model))
|
||||
|
||||
ddp_trainer = DDPTrainer(config=config, rank=rank, world_size=world_size)
|
||||
ddp_trainer.fit(learner, trainset=data_engine, validset=None)
|
||||
|
||||
cleanup()
|
||||
|
||||
|
||||
def get_parameter_number(model):
|
||||
total_num = sum(p.numel() for p in model.parameters())
|
||||
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
return {'Total': total_num, 'Trainable': trainable_num}
|
||||
|
||||
def run_demo(demo_fn, world_size, config):
|
||||
mp.spawn(demo_fn,
|
||||
args=(world_size,config),
|
||||
nprocs=world_size,
|
||||
join=True)
|
||||
|
||||
from collections import OrderedDict
|
||||
import yaml
|
||||
import yamlloader
|
||||
def _ordereddict_to_dict(d):
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
for k, v in d.items():
|
||||
if isinstance(v, OrderedDict):
|
||||
v = _ordereddict_to_dict(v)
|
||||
d[k] = dict(v)
|
||||
elif type(v) == list:
|
||||
d[k] = _ordereddict_to_dict(v)
|
||||
elif isinstance(v, dict):
|
||||
d[k] = _ordereddict_to_dict(v)
|
||||
return d
|
||||
|
||||
# CUDA_VISIBLE_DEVICES=4,5,6,7 python -m core.ddp_b3 --tag lora --config configs/vit_b_103.yaml
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from tutils.new.manager import trans_args, trans_init, ConfigManager
|
||||
|
||||
n_gpus = torch.cuda.device_count()
|
||||
# assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but {__file__} Got{n_gpus}"
|
||||
if n_gpus == 1:
|
||||
print("Warning! Running on only 1 GPU! just for debug")
|
||||
world_size = n_gpus
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", default="./configs/vit_b.yaml")
|
||||
parser.add_argument("--func", default="train")
|
||||
|
||||
args = trans_args(parser=parser)
|
||||
config = ConfigManager()
|
||||
config.auto_init(file=__file__, args=args, ex_config=None)
|
||||
# config.save()
|
||||
path = tfilename(config['base']['runs_dir'], "config.yaml")
|
||||
with open(path, "w") as f:
|
||||
yaml.dump(_ordereddict_to_dict(config), f)
|
||||
print("Save config file to ", path)
|
||||
|
||||
if n_gpus < 1: exit(0)
|
||||
run_demo(ddp_train, world_size, config)
|
522
core/learner2.py
Normal file
522
core/learner2.py
Normal file
@ -0,0 +1,522 @@
|
||||
import torch
|
||||
import torchvision
|
||||
import numpy as np
|
||||
from tutils.trainer import Trainer, LearnerModule
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import optim
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
from einops import rearrange, repeat
|
||||
from torch.nn import functional as F
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
|
||||
from modeling.sam3d import Sam
|
||||
# from segment_anything.utils.transforms import ResizeLongestSide
|
||||
from utils.transforms import ResizeLongestSide
|
||||
from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss
|
||||
from .lora_sam import LoRA_Sam
|
||||
from safetensors import safe_open
|
||||
from datasets.data_engine import DataEngine
|
||||
|
||||
|
||||
# def lr_schedule(epoch):
|
||||
# if epoch < 250:
|
||||
# return (epoch + 1) / 250 * 0.0008 + 0.00004
|
||||
# elif epoch < 500:
|
||||
# return 0.0001
|
||||
# else:
|
||||
# return 0.0001
|
||||
|
||||
def lr_schedule(epoch):
|
||||
if epoch < 250:
|
||||
return (epoch + 1) / 250 * 0.1
|
||||
elif epoch < 500:
|
||||
return 0.01
|
||||
else:
|
||||
return 0.001
|
||||
|
||||
class SamLearner(LearnerModule):
|
||||
def __init__(
|
||||
self,
|
||||
sam_model: Sam,
|
||||
config=None,
|
||||
logger=None,
|
||||
data_engine=DataEngine(None, img_size=(1024,1024)),
|
||||
lora_module=None,
|
||||
) -> None:
|
||||
"""
|
||||
Uses SAM to calculate the image embedding for an image, and then
|
||||
allow repeated, efficient mask prediction given prompts.
|
||||
|
||||
Arguments:
|
||||
sam_model (Sam): The model to use for mask prediction.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
self.model = sam_model
|
||||
self.net = self.model
|
||||
self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
|
||||
self.reset_image()
|
||||
self.data_engine = data_engine
|
||||
self.features = None
|
||||
self.lora_module = lora_module
|
||||
|
||||
def save(self, pth, *args, **kwargs):
|
||||
# Default: "/model_epoch_{}.pth".format(epoch)
|
||||
torch.save(self.net.state_dict(), pth)
|
||||
lora_path = pth.replace(".pth", "_lora.safetensors")
|
||||
self.lora_module.save_lora_parameters(lora_path)
|
||||
return True
|
||||
|
||||
def load_pretrained_model(self, pth, *args, **kwargs):
|
||||
"""
|
||||
Unmatched: prompt_encoder.mask_downscaling.0.weight
|
||||
their: torch.Size([4, 1, 2, 2])
|
||||
our: torch.Size([4, 3, 2, 2])
|
||||
Unmatched: mask_decoder.mask_tokens.weight
|
||||
their: torch.Size([4, 256])
|
||||
our: torch.Size([12, 256])
|
||||
|
||||
"""
|
||||
state_dict = torch.load(pth)
|
||||
model_state_dict = self.model.state_dict()
|
||||
model_state_dict.update(state_dict)
|
||||
model_state_dict['prompt_encoder.mask_downscaling.0.weight'] = repeat(state_dict['prompt_encoder.mask_downscaling.0.weight'], "a 1 c d -> a b c d", b=3)
|
||||
model_state_dict['mask_decoder.mask_tokens.weight'] = repeat(state_dict['mask_decoder.mask_tokens.weight'], "a d -> (a 3) d")
|
||||
hyper_params_names = [k for k in model_state_dict.keys() if k.startswith("mask_decoder.output_hypernetworks_mlps")]
|
||||
for name in hyper_params_names:
|
||||
words = name.split('.')
|
||||
words[2] = str(int(words[2]) // 3)
|
||||
name_to_copy = ".".join(words)
|
||||
model_state_dict[name] = state_dict[name_to_copy]
|
||||
# for k, v in state_dict.items():
|
||||
# if model_state_dict[k].shape != state_dict[k].shape:
|
||||
# print("Unmatched:", k)
|
||||
self.model.load_state_dict(model_state_dict)
|
||||
|
||||
def load_well_trained_model(self, pth=None):
|
||||
pth = self.config['training']['breakpoint_path'] + "/ckpt_v/model_latest.pth" if pth is None else pth
|
||||
print("Loading from ", pth)
|
||||
state_dict = torch.load(pth, map_location="cpu")
|
||||
# print(state_dict.keys())
|
||||
# for k in state_dict.keys():
|
||||
# print(k)
|
||||
# exit(0)
|
||||
self.model.load_state_dict(state_dict)
|
||||
# self.lora_module.load_lora_parameters(pth.replace(".pth", "_lora.safetensors"))
|
||||
|
||||
def use_lora(self):
|
||||
lora_r = 8
|
||||
lora_sam = LoRA_Sam(self.model, lora_r, freeze_prompt_encoder=True)
|
||||
self.lora_module = lora_sam
|
||||
|
||||
def configure_optimizers(self, **kwargs):
|
||||
optimizer = optim.AdamW(params=self.model.parameters(), \
|
||||
lr=self.config['training']['lr'], betas=(0.9, 0.999), eps=1e-08,
|
||||
weight_decay=self.config['training']['weight_decay'])
|
||||
# scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule)
|
||||
scheduler = None
|
||||
return {'optimizer': optimizer, "scheduler": scheduler}
|
||||
|
||||
def load_optim(self, optimizer, pth=None, *args):
|
||||
pth = self.config['training']['breakpoint_path'] + "/ckpt/optim_latest.pth"
|
||||
print("Load Optimizer from ", pth)
|
||||
state_dict = torch.load(pth)
|
||||
optimizer.load_state_dict(state_dict['optimizer'])
|
||||
start_epoch = state_dict.get('epoch', 0) + 1
|
||||
return start_epoch
|
||||
|
||||
def training_step(self, data, batch_idx, **kwargs):
|
||||
img = data['img']
|
||||
gt_mask = data['label']
|
||||
prompt_point = data['prompt_point'] # shape: (b, 2)
|
||||
batch_size = prompt_point.shape[0]
|
||||
point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device)
|
||||
prompt_box = data['prompt_box']
|
||||
|
||||
prompt_point = rearrange(prompt_point, "b c -> b 1 c")
|
||||
prompt_box = rearrange(prompt_box, "b c -> b 1 c")
|
||||
assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}"
|
||||
assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}"
|
||||
assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}"
|
||||
assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}"
|
||||
|
||||
self.set_torch_image(img, img.shape[2:])
|
||||
if np.random.random() > 0.5:
|
||||
pred_masks, iou_predictions, logits = self.predict_torch(
|
||||
point_coords=prompt_point,
|
||||
point_labels=point_label,
|
||||
multimask_output=True,
|
||||
return_logits=True,
|
||||
)
|
||||
else:
|
||||
pred_masks, iou_predictions, logits = self.predict_torch(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
boxes=prompt_box,
|
||||
multimask_output=True,
|
||||
return_logits=True,
|
||||
)
|
||||
# assert pred_masks.shape == gt_mask.shape, f"Got {pred_masks.shape}, {gt_mask.shape}"
|
||||
loss_1, fl, dl = ranked_combined_loss(pred_mask=pred_masks, gt_mask=gt_mask, iou_pred=iou_predictions)
|
||||
|
||||
# print("Debug trainer: 2", prompt_point.shape, point_label.shape, prompt_box.shape)
|
||||
# Stage 2: based on the above, add more points as prompts
|
||||
return {"loss": loss_1, "fl": fl.mean(), "dl": dl.mean()}
|
||||
|
||||
# @torch.no_grad()
|
||||
def generate(self, image, prompt_point):
|
||||
orig_size = image.shape[2:]
|
||||
assert image.shape[1:] == (3,1024,1024),f"{__file__} Got{image.shape}"
|
||||
if not self.is_image_set:
|
||||
self.set_torch_image(image, orig_size)
|
||||
|
||||
assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}"
|
||||
# assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}"
|
||||
point_label = torch.ones(prompt_point.size()[:-1])
|
||||
pred_masks, scores, logits = self.predict_torch(
|
||||
point_coords=prompt_point,
|
||||
point_labels=point_label,
|
||||
mask_input=None,
|
||||
multimask_output=True,
|
||||
)
|
||||
return pred_masks
|
||||
|
||||
# @torch.no_grad()
|
||||
def generate_by_box(self, image, prompt_box):
|
||||
orig_size = image.shape[2:]
|
||||
assert image.shape[1:] == (3,1024,1024),f"{__file__} Got{image.shape}"
|
||||
if not self.is_image_set:
|
||||
self.set_torch_image(image, orig_size)
|
||||
|
||||
assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}"
|
||||
pred_masks, scores, logits = self.predict_torch(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
boxes=prompt_box,
|
||||
mask_input=None,
|
||||
multimask_output=True,
|
||||
)
|
||||
return pred_masks
|
||||
|
||||
@staticmethod
|
||||
def select_best_mask(predictions, ground_truth):
|
||||
# Move tensors to the same device (if not already on the same device)
|
||||
# if predictions.device != ground_truth.device:
|
||||
# predictions = predictions.to(ground_truth.device)
|
||||
|
||||
# Compute IoU between each prediction and ground truth
|
||||
if predictions.shape[1] == 9:
|
||||
predictions = rearrange(predictions, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
|
||||
ground_truth = repeat(ground_truth, "b d h w -> b c d h w", c=3)
|
||||
else:
|
||||
predictions = rearrange(predictions, "b d h w -> b 1 d h w")
|
||||
ground_truth = rearrange(ground_truth, "b d h w -> b 1 d h w")
|
||||
intersection = torch.sum(predictions * ground_truth, dim=(-3, -2, -1))
|
||||
union = torch.sum(predictions + ground_truth, dim=(-3, -2, -1)) - intersection
|
||||
iou = intersection / (union + 1e-6)
|
||||
|
||||
# Select the prediction with maximum IoU for each image in the batch
|
||||
best_indices = torch.argmax(iou, dim=1)
|
||||
best_masks = torch.gather(predictions, 1, best_indices.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, 1, predictions.shape[-3], predictions.shape[-2], predictions.shape[-1]))
|
||||
|
||||
return best_masks
|
||||
|
||||
# ===============================================
|
||||
def predict_multi_prompt(
|
||||
self,
|
||||
point_coords: Optional[torch.Tensor],
|
||||
point_labels: Optional[torch.Tensor],
|
||||
mask_logits: Optional[torch.Tensor],
|
||||
):
|
||||
if point_coords is not None:
|
||||
assert (
|
||||
point_labels is not None
|
||||
), "point_labels must be supplied if point_coords is supplied."
|
||||
point_coords = self.transform.apply_coords(point_coords, self.original_size)
|
||||
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
|
||||
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
||||
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
|
||||
|
||||
if point_coords is not None:
|
||||
points = (coords_torch, point_labels)
|
||||
else:
|
||||
points = None
|
||||
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
||||
points=points,
|
||||
boxes=None,
|
||||
masks=mask_logits,
|
||||
)
|
||||
|
||||
low_res_masks, iou_predictions = self.model.mask_decoder(
|
||||
image_embeddings=self.features,
|
||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=False,
|
||||
)
|
||||
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
|
||||
return masks, iou_predictions, low_res_masks
|
||||
|
||||
def set_image(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
image_format: str = "RGB",
|
||||
) -> None:
|
||||
# Transform the image to the form expected by the model
|
||||
input_image = self.transform.apply_image(image)
|
||||
input_image_torch = torch.as_tensor(input_image, device=self.device)
|
||||
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
|
||||
|
||||
self.set_torch_image(input_image_torch, image.shape[:2])
|
||||
|
||||
# @torch.no_grad()
|
||||
def set_torch_image(
|
||||
self,
|
||||
transformed_image: torch.Tensor,
|
||||
original_image_size: Tuple[int, ...],
|
||||
) -> None:
|
||||
"""
|
||||
Calculates the image embeddings for the provided image, allowing
|
||||
masks to be predicted with the 'predict' method. Expects the input
|
||||
image to be already transformed to the format expected by the model.
|
||||
|
||||
Arguments:
|
||||
transformed_image (torch.Tensor): The input image, with shape
|
||||
1x3xHxW, which has been transformed with ResizeLongestSide.
|
||||
original_image_size (tuple(int, int)): The size of the image
|
||||
before transformation, in (H, W) format.
|
||||
"""
|
||||
assert (
|
||||
len(transformed_image.shape) == 4
|
||||
and transformed_image.shape[1] == 3
|
||||
and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
|
||||
), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
|
||||
self.reset_image()
|
||||
|
||||
self.original_size = original_image_size
|
||||
self.input_size = tuple(transformed_image.shape[-2:])
|
||||
input_image = self.model.preprocess(transformed_image)
|
||||
self.features = self.model.image_encoder(input_image)
|
||||
self.is_image_set = True
|
||||
|
||||
def predict(
|
||||
self,
|
||||
point_coords: Optional[np.ndarray] = None,
|
||||
point_labels: Optional[np.ndarray] = None,
|
||||
box: Optional[np.ndarray] = None,
|
||||
mask_input: Optional[np.ndarray] = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if not self.is_image_set:
|
||||
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
||||
|
||||
# Transform input prompts
|
||||
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
|
||||
if point_coords is not None:
|
||||
assert (
|
||||
point_labels is not None
|
||||
), "point_labels must be supplied if point_coords is supplied."
|
||||
point_coords = self.transform.apply_coords(point_coords, self.original_size)
|
||||
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
|
||||
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
||||
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
|
||||
if box is not None:
|
||||
box = self.transform.apply_boxes(box, self.original_size)
|
||||
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
||||
box_torch = box_torch[None, :]
|
||||
if mask_input is not None:
|
||||
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
|
||||
mask_input_torch = mask_input_torch[None, :, :, :]
|
||||
|
||||
masks, iou_predictions, low_res_masks = self.predict_torch(
|
||||
coords_torch,
|
||||
labels_torch,
|
||||
box_torch,
|
||||
mask_input_torch,
|
||||
multimask_output,
|
||||
return_logits=return_logits,
|
||||
)
|
||||
|
||||
# masks = masks[0].detach().cpu().numpy()
|
||||
# iou_predictions = iou_predictions[0].detach().cpu().numpy()
|
||||
# low_res_masks = low_res_masks[0].detach().cpu().numpy()
|
||||
return masks, iou_predictions, low_res_masks
|
||||
|
||||
# @torch.no_grad()
|
||||
def predict_torch(
|
||||
self,
|
||||
point_coords: Optional[torch.Tensor],
|
||||
point_labels: Optional[torch.Tensor],
|
||||
boxes: Optional[torch.Tensor] = None,
|
||||
mask_input: Optional[torch.Tensor] = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if not self.is_image_set:
|
||||
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
||||
|
||||
if point_coords is not None:
|
||||
points = (point_coords, point_labels)
|
||||
else:
|
||||
points = None
|
||||
|
||||
sparse_embeddings, dense_embeddings = self._get_prompt_embedding(points, boxes, mask_input)
|
||||
|
||||
# Predict masks
|
||||
low_res_masks, iou_predictions = self.model.mask_decoder(
|
||||
image_embeddings=self.features,
|
||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
|
||||
# Upscale the masks to the original image resolution
|
||||
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
if not return_logits:
|
||||
masks = masks > self.model.mask_threshold
|
||||
|
||||
return masks, iou_predictions, low_res_masks
|
||||
|
||||
# @torch.no_grad()
|
||||
def _get_prompt_embedding(self, points, boxes, mask_input):
|
||||
# Embed prompts
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
||||
points=points,
|
||||
boxes=boxes,
|
||||
masks=mask_input,
|
||||
)
|
||||
return sparse_embeddings, dense_embeddings
|
||||
|
||||
|
||||
def get_image_embedding(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns the image embeddings for the currently set image, with
|
||||
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
|
||||
the embedding spatial dimension of SAM (typically C=256, H=W=64).
|
||||
"""
|
||||
if not self.is_image_set:
|
||||
raise RuntimeError(
|
||||
"An image must be set with .set_image(...) to generate an embedding."
|
||||
)
|
||||
assert self.features is not None, "Features must exist if an image has been set."
|
||||
return self.features
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.model.device
|
||||
|
||||
def reset_image(self) -> None:
|
||||
"""Resets the currently set image."""
|
||||
self.is_image_set = False
|
||||
self.features = None
|
||||
self.orig_h = None
|
||||
self.orig_w = None
|
||||
self.input_h = None
|
||||
self.input_w = None
|
||||
|
||||
|
||||
def validation_step(self, data, batch_idx=0, **kwargs):
|
||||
img = data['img']
|
||||
gt_mask = data['label']
|
||||
prompt_point = data['prompt_point'] # shape: (b, 2)
|
||||
batch_size = prompt_point.shape[0]
|
||||
point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device)
|
||||
prompt_box = data['prompt_box']
|
||||
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3)
|
||||
|
||||
prompt_point = rearrange(prompt_point, "b c -> b 1 c")
|
||||
prompt_box = rearrange(prompt_box, "b c -> b 1 c")
|
||||
assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}"
|
||||
assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}"
|
||||
assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}"
|
||||
assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}"
|
||||
|
||||
self.set_torch_image(img, img.shape[2:])
|
||||
|
||||
# Stage 1: use the 1st prompt, box or point
|
||||
iou_details = {}
|
||||
pred_masks1, iou_predictions1, logits1 = self.predict_torch(
|
||||
point_coords=prompt_point,
|
||||
point_labels=point_label,
|
||||
multimask_output=True,
|
||||
return_logits=True,
|
||||
)
|
||||
|
||||
pred_masks1 = rearrange(pred_masks1, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
|
||||
loss_point, fl, dl, _, _ = ranked_combined_loss(pred_mask=pred_masks1, gt_mask=gt_mask, iou_pred=iou_predictions1)
|
||||
iou_details['loss_point'] = loss_point.mean()
|
||||
iou_details['loss_point_fl'] = fl.mean()
|
||||
iou_details['loss_point_dl'] = dl.mean()
|
||||
|
||||
iou = compute_iou((pred_masks1>0).float(), gt_mask)
|
||||
iou, _ = torch.max(iou, axis=1)
|
||||
iou_details['iou_point'] = iou.mean()
|
||||
|
||||
pred_masks2, iou_predictions2, logits2 = self.predict_torch(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
boxes=prompt_box,
|
||||
multimask_output=True,
|
||||
return_logits=True,
|
||||
)
|
||||
pred_masks2 = rearrange(pred_masks2, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
|
||||
loss_box, fl, dl, _, _ = ranked_combined_loss(pred_mask=pred_masks2, gt_mask=gt_mask, iou_pred=iou_predictions2)
|
||||
iou_details['loss_box'] = loss_box.mean()
|
||||
iou_details['loss_box_fl'] = fl.mean()
|
||||
iou_details['loss_box_dl'] = dl.mean()
|
||||
|
||||
iou = compute_iou((pred_masks2>0).float(), gt_mask)
|
||||
iou, _ = torch.max(iou, axis=1)
|
||||
iou_details['iou_box'] = iou.mean()
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
|
||||
# gt_mask_np = gt_mask.detach().cpu().numpy()
|
||||
# for step in range(8):
|
||||
# continue
|
||||
# # n
|
||||
# best_pred_masks = self.select_best_mask(pred_masks, gt_mask)
|
||||
# best_pred_masks_np = best_pred_masks.detach().cpu().numpy()
|
||||
|
||||
# # import ipdb; ipdb.set_trace()
|
||||
# mask_input = logits[0, np.argmax(scores[0].detach().cpu().numpy()), :, :] # Choose the model's best mask
|
||||
|
||||
# sub_points, sub_labels = self.data_engine.get_subsequent_prompt_point(best_pred_masks_np, gt_mask_np)
|
||||
# # sub_points, sub_labels = self.data_engine.point_prompt_generator.select_random_subsequent_point(best_pred_masks_np[0][0], gt_mask_np[0][0])
|
||||
|
||||
# y, x = sub_points[0][1], sub_points[0][0]
|
||||
# assert gt_mask_np[0][0][y,x] + best_pred_masks_np[0][0][y,x] == 1, f"{__file__} Got{gt_mask_np[0][0][y,x], best_pred_masks_np[0][0][y,x]}"
|
||||
# assert gt_mask_np[0][0][y,x] == sub_labels, f"{__file__} Got{ gt_mask_np[0][0][y,x]}, {sub_labels}"
|
||||
# assert best_pred_masks_np[0][0][y,x] == (1-sub_labels), f"{__file__} Got{ gt_mask_np[0][0][y,x]}, {1-sub_labels}"
|
||||
# # import ipdb; ipdb.set_trace()
|
||||
# # assert sub_points
|
||||
|
||||
# # sub_points = np.array(sub_points)[None,...].astype(int)
|
||||
# # sub_labels = np.array(sub_labels)[None,...]
|
||||
# prompt_point = np.concatenate([prompt_point, sub_points], axis=0)
|
||||
# point_label = np.concatenate([point_label, sub_labels], axis=0)
|
||||
|
||||
# # import ipdb; ipdb.set_trace()
|
||||
|
||||
# pred_masks2, scores, logits = model.predict(
|
||||
# point_coords=prompt_point,
|
||||
# point_labels=point_label,
|
||||
# mask_input=mask_input[None,...],
|
||||
# multimask_output=False,
|
||||
# )
|
||||
|
||||
# iou = compute_iou(pred_masks2, gt_mask)
|
||||
# iou, _ = torch.max(iou, axis=1)
|
||||
# iou_details[f'point_{step+2}'] = iou
|
||||
|
||||
return iou_details
|
||||
|
154
core/learner3.py
Normal file
154
core/learner3.py
Normal file
@ -0,0 +1,154 @@
|
||||
import torch
|
||||
import torchvision
|
||||
import numpy as np
|
||||
from tutils.trainer import Trainer, LearnerModule
|
||||
from einops import rearrange, repeat, reduce
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
|
||||
from core.loss import ranked_combined_loss_with_indicators
|
||||
from .learner2 import SamLearner as basic_learner
|
||||
from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss
|
||||
|
||||
|
||||
class SamLearner(basic_learner):
|
||||
|
||||
def training_step(self, data, batch_idx, **kwargs):
|
||||
img = data['img']
|
||||
gt_mask = data['label']
|
||||
prompt_point = data['prompt_point'] # shape: (b, 2)
|
||||
batch_size = prompt_point.shape[0]
|
||||
point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device)
|
||||
prompt_box = data['prompt_box']
|
||||
indicators = data['indicators']
|
||||
# print(data['name'])
|
||||
|
||||
prompt_point = rearrange(prompt_point, "b c -> b 1 c")
|
||||
prompt_box = rearrange(prompt_box, "b c -> b 1 c")
|
||||
assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}"
|
||||
assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}"
|
||||
assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}"
|
||||
assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}"
|
||||
|
||||
self.set_torch_image(img, img.shape[2:])
|
||||
# if np.random.random() > 0.5:
|
||||
pred_masks, iou_predictions, logits = self.predict_torch(
|
||||
point_coords=prompt_point,
|
||||
point_labels=point_label,
|
||||
multimask_output=True,
|
||||
return_logits=True,
|
||||
)
|
||||
loss_1, fl, dl, min_losses, selected_losses, _ = ranked_combined_loss_with_indicators(pred_mask=pred_masks, gt_mask=gt_mask, iou_pred=iou_predictions, indicators=indicators)
|
||||
# else:
|
||||
pred_masks, iou_predictions, logits = self.predict_torch(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
boxes=prompt_box,
|
||||
multimask_output=True,
|
||||
return_logits=True,
|
||||
)
|
||||
# assert pred_masks.shape == gt_mask.shape, f"Got {pred_masks.shape}, {gt_mask.shape}"
|
||||
loss_2, fl, dl, min_losses, selected_losses, _ = ranked_combined_loss_with_indicators(pred_mask=pred_masks, gt_mask=gt_mask, iou_pred=iou_predictions, indicators=indicators)
|
||||
|
||||
loss = loss_1 + loss_2
|
||||
if loss < -999 or torch.isnan(loss):
|
||||
print("Warning! Loss Error! ")
|
||||
print(data['name'])
|
||||
|
||||
# print("Debug trainer: 2", prompt_point.shape, point_label.shape, prompt_box.shape)
|
||||
# Stage 2: based on the above, add more points as prompts
|
||||
return {"loss": loss_1 + loss_2, "point_loss": loss_1.mean(), "box_loss": loss_2, "fl": fl.mean(), "dl": dl.mean(), "min_losses": min_losses, "selected_losses": selected_losses}
|
||||
|
||||
|
||||
def validation_step(self, data, batch_idx=0, **kwargs):
|
||||
img = data['img']
|
||||
gt_mask = data['label']
|
||||
prompt_point = data['prompt_point'] # shape: (b, 2)
|
||||
batch_size = prompt_point.shape[0]
|
||||
point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device)
|
||||
prompt_box = data['prompt_box']
|
||||
indicators = data['indicators']
|
||||
print(data['name'])
|
||||
prompt_point = rearrange(prompt_point, "b c -> b 1 c")
|
||||
prompt_box = rearrange(prompt_box, "b c -> b 1 c")
|
||||
assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}"
|
||||
assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}"
|
||||
assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}"
|
||||
assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}"
|
||||
|
||||
self.set_torch_image(img, img.shape[2:])
|
||||
|
||||
# Stage 1: use the 1st prompt, box or point
|
||||
iou_details = {}
|
||||
pred_masks1, iou_predictions1, logits1 = self.predict_torch(
|
||||
point_coords=prompt_point,
|
||||
point_labels=point_label,
|
||||
multimask_output=True,
|
||||
return_logits=True,
|
||||
)
|
||||
|
||||
if len(pred_masks1.shape) == 4:
|
||||
pred_masks1 = rearrange(pred_masks1, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
|
||||
|
||||
loss_point, fl, dl, min_losses, selected_losses, min_indices = ranked_combined_loss_with_indicators(pred_mask=pred_masks1, gt_mask=gt_mask, iou_pred=iou_predictions1, indicators=indicators)
|
||||
iou_details['loss_point'] = loss_point.mean()
|
||||
iou_details['loss_point_fl'] = fl.mean()
|
||||
iou_details['loss_point_dl'] = dl.mean()
|
||||
|
||||
if len(gt_mask.shape) == 4:
|
||||
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3)
|
||||
|
||||
indices = iou_predictions1.argmax(axis=1)
|
||||
pred_maxiou = []
|
||||
for pred, i in zip(pred_masks1, indices):
|
||||
pred_maxiou.append(pred[i,:,:,:])
|
||||
pred_maxiou = torch.stack(pred_maxiou, axis=0)
|
||||
iou = compute_iou2(torch.tensor(pred_maxiou>0, dtype=gt_mask.dtype), gt_mask[:,0,:,:,:]).detach()
|
||||
iou_details['iou_point'] = iou.mean()
|
||||
|
||||
iou = compute_iou(torch.tensor(pred_masks1>0, dtype=gt_mask.dtype), gt_mask).detach()
|
||||
iou, _ = torch.max(iou, axis=1)
|
||||
iou_details['iou_point_max'] = iou.mean()
|
||||
|
||||
pred_masks2, iou_predictions2, logits2 = self.predict_torch(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
boxes=prompt_box,
|
||||
multimask_output=True,
|
||||
return_logits=True,
|
||||
)
|
||||
loss_box, fl, dl, min_losses, selected_losses, min_indices = ranked_combined_loss_with_indicators(pred_mask=pred_masks2, gt_mask=gt_mask, iou_pred=iou_predictions2, indicators=indicators)
|
||||
iou_details['loss_box'] = loss_box.mean()
|
||||
iou_details['loss_box_fl'] = fl.mean()
|
||||
iou_details['loss_box_dl'] = dl.mean()
|
||||
|
||||
if len(gt_mask.shape) == 4:
|
||||
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3)
|
||||
if len(pred_masks2.shape) == 4:
|
||||
pred_masks2 = rearrange(pred_masks2, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
|
||||
|
||||
indices = iou_predictions2.argmax(axis=1)
|
||||
pred_maxiou = []
|
||||
for pred, i in zip(pred_masks2, indices):
|
||||
pred_maxiou.append(pred[i,:,:,:])
|
||||
pred_maxiou = torch.stack(pred_maxiou, axis=0)
|
||||
iou = compute_iou2(torch.tensor(pred_maxiou>0, dtype=gt_mask.dtype), gt_mask[:,0,:,:,:]).detach()
|
||||
iou_details['iou_box'] = iou.mean()
|
||||
|
||||
iou = compute_iou(torch.tensor(pred_masks2>0, dtype=gt_mask.dtype), gt_mask).detach()
|
||||
iou, _ = torch.max(iou, axis=1)
|
||||
iou_details['iou_box_max'] = iou.mean()
|
||||
return iou_details
|
||||
|
||||
|
||||
def compute_iou2(pred_mask, gt_mask):
|
||||
dtype = pred_mask.dtype
|
||||
intersection = torch.logical_and(pred_mask, gt_mask)
|
||||
intersection = reduce(intersection, "b d h w -> b", reduction='sum')
|
||||
union = torch.logical_or(pred_mask, gt_mask)
|
||||
union = reduce(union, "b d h w -> b", reduction='sum') + 1e-8
|
||||
iou = intersection / union # if union > 0 else 0
|
||||
iou = torch.tensor(iou, dtype=dtype)
|
||||
# print("ranked_combined_loss: compute_iou ", intersection.dtype, union.dtype, iou.dtype)
|
||||
return iou
|
||||
|
||||
# def save(img, mask, mask2):
|
55
core/learner5.py
Normal file
55
core/learner5.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""
|
||||
Use mask_decoder3d_2.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import numpy as np
|
||||
from tutils.trainer import Trainer, LearnerModule
|
||||
from einops import rearrange, repeat, reduce
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
|
||||
from core.loss import ranked_combined_loss_with_indicators
|
||||
from .learner3 import SamLearner as basic_learner
|
||||
from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss
|
||||
|
||||
|
||||
class SamLearner(basic_learner):
|
||||
|
||||
def load_pretrained_model(self, pth, *args, **kwargs):
|
||||
"""
|
||||
Unmatched: prompt_encoder.mask_downscaling.0.weight
|
||||
their: torch.Size([4, 1, 2, 2])
|
||||
our: torch.Size([4, 3, 2, 2])
|
||||
Unmatched: mask_decoder.mask_tokens.weight
|
||||
their: torch.Size([4, 256])
|
||||
our: torch.Size([12, 256])
|
||||
"""
|
||||
print("Load pretrained model for mask_decoder3d_2 !!")
|
||||
|
||||
state_dict = torch.load(pth)
|
||||
model_state_dict = self.model.state_dict()
|
||||
model_state_dict.update(state_dict)
|
||||
model_state_dict['prompt_encoder.mask_downscaling.0.weight'] = repeat(state_dict['prompt_encoder.mask_downscaling.0.weight'], "a 1 c d -> a b c d", b=3)
|
||||
# model_state_dict['mask_decoder.mask_tokens.weight'] = repeat(state_dict['mask_decoder.mask_tokens.weight'], "a d -> (a 3) d")
|
||||
|
||||
for k, v in model_state_dict.items():
|
||||
if k.startswith("mask_decoder.output_upscaling2"):
|
||||
k2 = k.replace("output_upscaling2.", "output_upscaling." )
|
||||
model_state_dict[k] = model_state_dict[k2]
|
||||
print("Load weights: ", k)
|
||||
if k.startswith("mask_decoder.output_upscaling3"):
|
||||
k2 = k.replace("output_upscaling3.", "output_upscaling." )
|
||||
model_state_dict[k] = model_state_dict[k2]
|
||||
print("Load weights: ", k)
|
||||
|
||||
hyper_params_names = [k for k in model_state_dict.keys() if k.startswith("mask_decoder.output_hypernetworks_mlps")]
|
||||
for name in hyper_params_names:
|
||||
words = name.split('.')
|
||||
words[2] = str(int(words[2]) // 3)
|
||||
name_to_copy = ".".join(words)
|
||||
model_state_dict[name] = state_dict[name_to_copy]
|
||||
# for k, v in state_dict.items():
|
||||
# if model_state_dict[k].shape != state_dict[k].shape:
|
||||
# print("Unmatched:", k)
|
||||
self.model.load_state_dict(model_state_dict)
|
196
core/lora_sam.py
Normal file
196
core/lora_sam.py
Normal file
@ -0,0 +1,196 @@
|
||||
# Modified from https://github.com/JamesQFreeman/Sam_LoRA
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
# from modeling.sam3d import Sam
|
||||
|
||||
|
||||
class _LoRA_qkv(nn.Module):
|
||||
"""In Sam it is implemented as
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qkv: nn.Module,
|
||||
linear_a_q: nn.Module,
|
||||
linear_b_q: nn.Module,
|
||||
linear_a_v: nn.Module,
|
||||
linear_b_v: nn.Module,
|
||||
):
|
||||
super().__init__()
|
||||
self.qkv = qkv
|
||||
self.linear_a_q = linear_a_q
|
||||
self.linear_b_q = linear_b_q
|
||||
self.linear_a_v = linear_a_v
|
||||
self.linear_b_v = linear_b_v
|
||||
self.dim = qkv.in_features
|
||||
self.w_identity = torch.eye(qkv.in_features)
|
||||
|
||||
def forward(self, x):
|
||||
qkv = self.qkv(x) # B,N,N,3*org_C
|
||||
new_q = self.linear_b_q(self.linear_a_q(x))
|
||||
new_v = self.linear_b_v(self.linear_a_v(x))
|
||||
qkv[:, :, :, : self.dim] += new_q
|
||||
qkv[:, :, :, -self.dim :] += new_v
|
||||
return qkv
|
||||
|
||||
class LoRA_Sam(nn.Module):
|
||||
"""Applies low-rank adaptation to a Sam model's image encoder.
|
||||
|
||||
Args:
|
||||
sam_model: a vision transformer model, see base_vit.py
|
||||
r: rank of LoRA
|
||||
num_classes: how many classes the model output, default to the vit model
|
||||
lora_layer: which layer we apply LoRA.
|
||||
freeze_all: freeze whole sam, otherwise only image encoder (VIT)
|
||||
|
||||
Examples::
|
||||
>>> model = ViT('B_16_imagenet1k')
|
||||
>>> lora_model = LoRA_ViT(model, r=4)
|
||||
>>> preds = lora_model(img)
|
||||
>>> print(preds.shape)
|
||||
torch.Size([1, 1000])
|
||||
"""
|
||||
|
||||
def __init__(self, sam_model, r: int, lora_layer:[int]=None, freeze_all:bool=False, freeze_prompt_encoder=True):
|
||||
super(LoRA_Sam, self).__init__()
|
||||
|
||||
assert r > 0
|
||||
# base_vit_dim = sam_model.image_encoder.patch_embed.proj.out_channels
|
||||
# dim = base_vit_dim
|
||||
if lora_layer:
|
||||
self.lora_layer = lora_layer
|
||||
else:
|
||||
self.lora_layer = list(range(len(sam_model.image_encoder.blocks)))
|
||||
# create for storage, then we can init them or load weights
|
||||
self.w_As = [] # These are linear layers
|
||||
self.w_Bs = []
|
||||
|
||||
# lets freeze first
|
||||
if freeze_all:
|
||||
for param in sam_model.parameters():
|
||||
param.requires_grad = False
|
||||
else:
|
||||
for param in sam_model.image_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
for param in sam_model.image_encoder.patch_embed.parameters():
|
||||
param.requires_grad = True
|
||||
if freeze_prompt_encoder:
|
||||
for param in sam_model.prompt_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Here, we do the surgery
|
||||
for t_layer_i, blk in enumerate(sam_model.image_encoder.blocks):
|
||||
# If we only want few lora layer instead of all
|
||||
if t_layer_i not in self.lora_layer:
|
||||
continue
|
||||
w_qkv_linear = blk.attn.qkv
|
||||
self.dim = w_qkv_linear.in_features
|
||||
w_a_linear_q = nn.Linear(self.dim, r, bias=False)
|
||||
w_b_linear_q = nn.Linear(r, self.dim, bias=False)
|
||||
w_a_linear_v = nn.Linear(self.dim, r, bias=False)
|
||||
w_b_linear_v = nn.Linear(r, self.dim, bias=False)
|
||||
self.w_As.append(w_a_linear_q)
|
||||
self.w_Bs.append(w_b_linear_q)
|
||||
self.w_As.append(w_a_linear_v)
|
||||
self.w_Bs.append(w_b_linear_v)
|
||||
blk.attn.qkv = _LoRA_qkv(
|
||||
w_qkv_linear,
|
||||
w_a_linear_q,
|
||||
w_b_linear_q,
|
||||
w_a_linear_v,
|
||||
w_b_linear_v,
|
||||
)
|
||||
self.reset_parameters()
|
||||
self.sam = sam_model
|
||||
# with open('vit_named_para.txt', 'w') as f:
|
||||
# for k, v in sam_model.image_encoder.named_parameters():
|
||||
# f.write(f'{k} {v.shape}\n')
|
||||
|
||||
|
||||
def save_lora_parameters(self, filename: str) -> None:
|
||||
r"""Only safetensors is supported now.
|
||||
|
||||
pip install safetensor if you do not have one installed yet.
|
||||
|
||||
save both lora and fc parameters.
|
||||
"""
|
||||
# assert filename.endswith(".safetensors")
|
||||
|
||||
num_layer = len(self.w_As) # actually, it is half
|
||||
a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)}
|
||||
b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)}
|
||||
|
||||
|
||||
merged_dict = {**a_tensors, **b_tensors}
|
||||
save_file(merged_dict, filename)
|
||||
|
||||
def load_lora_parameters(self, filename: str) -> None:
|
||||
r"""Only safetensors is supported now.
|
||||
|
||||
pip install safetensor if you do not have one installed yet.\
|
||||
|
||||
load both lora and fc parameters.
|
||||
"""
|
||||
|
||||
assert filename.endswith(".safetensors")
|
||||
|
||||
with safe_open(filename, framework="pt") as f:
|
||||
for i, w_A_linear in enumerate(self.w_As):
|
||||
saved_key = f"w_a_{i:03d}"
|
||||
saved_tensor = f.get_tensor(saved_key)
|
||||
w_A_linear.weight = Parameter(saved_tensor)
|
||||
|
||||
for i, w_B_linear in enumerate(self.w_Bs):
|
||||
saved_key = f"w_b_{i:03d}"
|
||||
saved_tensor = f.get_tensor(saved_key)
|
||||
w_B_linear.weight = Parameter(saved_tensor)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
for w_A in self.w_As:
|
||||
nn.init.kaiming_uniform_(w_A.weight, a=5**0.5)
|
||||
for w_B in self.w_Bs:
|
||||
nn.init.zeros_(w_B.weight)
|
||||
|
||||
|
||||
def get_parameter_number(model):
|
||||
total_num = sum(p.numel() for p in model.parameters())
|
||||
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
return {'Total': total_num, 'Trainable': trainable_num}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from segment_anything import sam_model_registry
|
||||
from segment_anything import SamPredictor, SamAutomaticMaskGenerator # prompt and every mode
|
||||
import numpy as np
|
||||
lora_r = 8
|
||||
path = "../checkpoints/sam_vit_b_01ec64.pth"
|
||||
sam = sam_model_registry["vit_b"](checkpoint=None)
|
||||
print('before lora', get_parameter_number(sam))
|
||||
lora_sam = LoRA_Sam(sam, lora_r)
|
||||
print('after lora', get_parameter_number(sam))
|
||||
x = torch.rand(size=(3,1024,1024))
|
||||
path = '../data/cache/data2d_3/0007_s0069_img.npy'
|
||||
img = np.load(path)
|
||||
print('img shape', img.shape)
|
||||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||
masks = mask_generator.generate(x)
|
||||
print('mask num', len(masks))
|
||||
# loss = np.sum([mask.mean(-1).mean(-1) for mask in masks])
|
||||
# predictor = SamPredictor(sam)
|
||||
# predictor.set_image(x)
|
||||
# masks, _, _ = predictor.predict([50,50])
|
||||
|
||||
for f_name in ['save_lora_parameters', 'load_lora_parameters']:
|
||||
print(f_name)
|
||||
getattr(lora_sam, f_name)('tmp.safetensors')
|
155
core/loss.py
Normal file
155
core/loss.py
Normal file
@ -0,0 +1,155 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from einops import repeat, rearrange, reduce
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_dice_np(pred_mask, gt_mask):
|
||||
""" numpy values
|
||||
"""
|
||||
pred_mask = np.array(pred_mask>0)
|
||||
gt_mask = np.array(gt_mask>0)
|
||||
intersection = np.array(pred_mask * gt_mask).sum()
|
||||
union = pred_mask.sum() + gt_mask.sum()
|
||||
dice = intersection * 2 / union # if union > 0 else 0
|
||||
return dice
|
||||
|
||||
|
||||
def combined_loss(logits, targets, alpha=0.2, gamma=2.0, smooth=1e-5, reduction='mean'):
|
||||
# Calculate the focal loss
|
||||
fl = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
|
||||
pt = torch.exp(-fl)
|
||||
focal_loss = alpha * (1 - pt) ** gamma * fl
|
||||
|
||||
if reduction == 'mean':
|
||||
fl = torch.mean(focal_loss)
|
||||
elif reduction == 'sum':
|
||||
fl = torch.sum(focal_loss)
|
||||
|
||||
# Calculate the Dice loss
|
||||
prob = torch.sigmoid(logits)
|
||||
intersection = torch.sum(prob * targets, dim=(-2, -1))
|
||||
union = torch.sum(prob + targets, dim=(-2, -1))
|
||||
dice_loss = 1 - (2 * intersection + smooth) / (union + smooth)
|
||||
|
||||
return focal_loss, dice_loss
|
||||
|
||||
if reduction == 'mean':
|
||||
dl = torch.mean(dice_loss)
|
||||
elif reduction == 'sum':
|
||||
dl = torch.sum(dice_loss)
|
||||
|
||||
# Combine the losses using the specified ratio
|
||||
loss = 20 * fl + dl
|
||||
return loss
|
||||
|
||||
# Assuming your prediction and ground truth tensors are named `pred` and `gt`, respectively
|
||||
def mse_loss(pred, gt):
|
||||
mse_loss = nn.MSELoss(reduction='none')
|
||||
loss = mse_loss(pred, gt)
|
||||
return loss
|
||||
|
||||
def compute_iou(pred_mask, gt_mask):
|
||||
dtype = pred_mask.dtype
|
||||
intersection = torch.logical_and(pred_mask, gt_mask)
|
||||
intersection = reduce(intersection, "b c d h w -> b c", reduction='sum')
|
||||
union = torch.logical_or(pred_mask, gt_mask)
|
||||
union = reduce(union, "b c d h w -> b c", reduction='sum') + 1e-8
|
||||
iou = intersection / union # if union > 0 else 0
|
||||
iou = torch.tensor(iou, dtype=dtype)
|
||||
# print("ranked_combined_loss: compute_iou ", intersection.dtype, union.dtype, iou.dtype)
|
||||
return iou
|
||||
|
||||
def ranked_combined_loss(pred_mask, gt_mask, iou_pred):
|
||||
# (b c1 c2 h w), c1: num_prediction; c2: num_slices
|
||||
if len(gt_mask.shape) == 4:
|
||||
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3)
|
||||
if len(pred_mask.shape) == 4:
|
||||
pred_mask = rearrange(pred_mask, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
|
||||
fl, dl = combined_loss(pred_mask, gt_mask)
|
||||
fl = reduce(fl, "b c d h w -> b c", reduction="mean")
|
||||
dl = reduce(dl, "b c d-> b c", reduction="mean")
|
||||
segment_loss = 20*fl + dl
|
||||
min_losses, min_loss_indices = torch.min(segment_loss, dim=1)
|
||||
iou = compute_iou(torch.tensor(torch.tensor(pred_mask>0, dtype=gt_mask.dtype)>0, dtype=gt_mask.dtype), gt_mask).detach().detach()
|
||||
# print("ranked_combined_loss ", iou.dtype)
|
||||
iou_loss = mse_loss(iou_pred, iou)
|
||||
|
||||
selected_losses = torch.gather(iou_loss, 1, min_loss_indices.unsqueeze(1))
|
||||
selected_fl = torch.gather(fl, 1, min_loss_indices.unsqueeze(1))
|
||||
selected_dl = torch.gather(dl, 1, min_loss_indices.unsqueeze(1))
|
||||
# print(min_losses.shape, selected_losses.shape)
|
||||
|
||||
total_loss = min_losses.mean() + selected_losses.mean()
|
||||
# return total_loss, min_losses, selected_losses
|
||||
return total_loss, selected_fl, selected_dl, min_losses.mean(), selected_losses.mean(), min_loss_indices
|
||||
|
||||
def ranked_combined_loss_one_slice(pred_mask, gt_mask, iou_pred, mask_loc):
|
||||
if len(gt_mask.shape) == 4:
|
||||
# assert gt_mask.shape[1] == 1, f"Got {gt_mask.shape}"
|
||||
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3)
|
||||
if len(pred_mask.shape) == 4:
|
||||
pred_mask = rearrange(pred_mask, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
|
||||
gt_mask = gt_mask[:,:,mask_loc,:,:]
|
||||
pred_mask = pred_mask[:,:,mask_loc,:,:]
|
||||
assert len(pred_mask.shape) == 5
|
||||
return ranked_combined_loss(pred_mask, gt_mask, iou_pred)
|
||||
|
||||
def ranked_combined_loss_with_indicators(pred_mask, gt_mask, iou_pred, indicators):
|
||||
# indicators: indicate which slice are with the mask
|
||||
# (b c1 c2 h w), c1: num_prediction; c2: num_slices
|
||||
if len(gt_mask.shape) == 4:
|
||||
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3)
|
||||
if len(pred_mask.shape) == 4:
|
||||
pred_mask = rearrange(pred_mask, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
|
||||
|
||||
b, c1, c2, h, w = pred_mask.shape
|
||||
indicators = torch.tensor(indicators, dtype=pred_mask.dtype)
|
||||
indicators = repeat(indicators, "b d -> b c d h w", c=3, h=h, w=w)
|
||||
pred_mask = pred_mask * indicators
|
||||
gt_mask = gt_mask * indicators
|
||||
|
||||
# Same as "ranked_combined_loss"
|
||||
return ranked_combined_loss(pred_mask, gt_mask, iou_pred)
|
||||
|
||||
|
||||
def compute_all_loss_with_indicators(pred_mask, gt_mask, iou_pred, indicators):
|
||||
# indicators: indicate which slice are with the mask
|
||||
# (b c1 c2 h w), c1: num_prediction; c2: num_slices
|
||||
if len(gt_mask.shape) == 4:
|
||||
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=1)
|
||||
if len(pred_mask.shape) == 4:
|
||||
pred_mask = rearrange(pred_mask, "b (c1 c2) h w -> b c1 c2 h w", c1=1, c2=3)
|
||||
|
||||
b, c1, c2, h, w = pred_mask.shape
|
||||
indicators = torch.tensor(indicators, dtype=pred_mask.dtype)
|
||||
indicators = repeat(indicators, "b d -> b c d h w", c=1, h=h, w=w)
|
||||
pred_mask = pred_mask * indicators
|
||||
gt_mask = gt_mask * indicators
|
||||
|
||||
# Same as "compute_all_loss"
|
||||
return compute_all_loss(pred_mask, gt_mask, iou_pred)
|
||||
|
||||
def compute_all_loss(pred_mask, gt_mask, iou_pred):
|
||||
if len(pred_mask.shape) == 4:
|
||||
pred_mask = pred_mask.unsqueeze(1)
|
||||
if len(gt_mask.shape) == 4:
|
||||
gt_mask = gt_mask.unsqueeze(1)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
fl, dl = combined_loss(pred_mask, gt_mask)
|
||||
segment_loss = 20*fl.mean() + dl.mean()
|
||||
iou_loss = mse_loss(iou_pred, compute_iou(torch.tensor(pred_mask>0, dtype=gt_mask.dtype), gt_mask))
|
||||
total_loss = segment_loss.mean() + iou_loss.mean()
|
||||
return total_loss, fl, dl, iou_loss
|
||||
|
||||
|
||||
# def compute_
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pred_mask = torch.ones((1,9,1024,1024))*9
|
||||
pred_mask[:,:,:200,:] = -1
|
||||
gt_mask = torch.ones((1,3,1024,1024))
|
||||
loss = ranked_combined_loss(pred_mask, gt_mask, iou_pred=torch.ones(gt_mask.shape[:1]))
|
||||
print(loss)
|
612
core/volume_predictor.py
Normal file
612
core/volume_predictor.py
Normal file
@ -0,0 +1,612 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Any, List, Dict, Tuple, Optional
|
||||
from einops import rearrange
|
||||
from utils.amg import MaskData, batched_mask_to_box, batched_mask3d_to_box
|
||||
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
||||
from utils.amg3d import build_point_grid, calculate_stability_score_3d, MaskData3d, batch_iterator
|
||||
from utils.amg import calculate_stability_score
|
||||
from utils.transforms3d import ResizeLongestSide, SimpleResize
|
||||
from einops import rearrange, repeat
|
||||
# from datasets.data_engine import ValidEngine, BoxPromptGenerator
|
||||
from datasets.data_engine import DataEngine, DataManager, BoxPromptGenerator, PointPromptGenerator
|
||||
import cv2
|
||||
import torch.nn.functional as F
|
||||
from tutils.new.manager import ConfigManager
|
||||
from tutils.nn.data import read, itk_to_np, write, np_to_itk
|
||||
from torchvision.utils import save_image
|
||||
from einops import rearrange, reduce, repeat
|
||||
from core.loss import compute_dice_np
|
||||
|
||||
|
||||
class PseudoPredictor:
|
||||
# def __init__(self) -> None:
|
||||
# self.image_encoder =
|
||||
def predict(self, *args, **kwargs):
|
||||
mask = np.zeros((1024,1024))
|
||||
mask[:500,:500] = 1
|
||||
return mask
|
||||
|
||||
|
||||
class VolumePredictor:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
slice_per_batch: int = 4,
|
||||
points_per_side: Optional[int] = 32,
|
||||
points_per_batch: int = 16,
|
||||
pred_iou_thresh: float = 0.5, # debug, standard value is 0.88,
|
||||
stability_score_thresh: float = 0.6, # debug, standard value is 0.95,
|
||||
stability_score_offset: float = 1.0,
|
||||
box_nms_thresh: float = 0.7,
|
||||
use_postprocess = True,
|
||||
use_noise_remove = True,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.im_size = (model.image_encoder.img_size, model.image_encoder.img_size)
|
||||
self.slice_per_batch = slice_per_batch
|
||||
self.point_grids = build_point_grid(points_per_side, self.im_size)
|
||||
self.features = None
|
||||
self.is_image_set = False
|
||||
self.transform = SimpleResize(model.image_encoder.img_size)
|
||||
|
||||
self.points_per_batch = points_per_batch
|
||||
self.pred_iou_thresh = pred_iou_thresh
|
||||
self.stability_score_thresh = stability_score_thresh
|
||||
self.stability_score_offset = stability_score_offset
|
||||
self.box_nms_thresh = box_nms_thresh
|
||||
self.masks3d = dict()
|
||||
|
||||
self.box_prompt_generator = BoxPromptGenerator(size=(1024,1024))
|
||||
self.masks3d = None
|
||||
self.stability_score_2d = None
|
||||
self.input_size = model.image_encoder.img_size
|
||||
self.use_postprocess = use_postprocess
|
||||
self.use_noise_remove = use_noise_remove
|
||||
if not use_postprocess:
|
||||
print("Warning! No postprocess")
|
||||
if not use_noise_remove:
|
||||
print("Warning! No use_noise_remove")
|
||||
# self.original_size = (1024,1024)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.model.device
|
||||
|
||||
def reset_image(self) -> None:
|
||||
"""Resets the currently set image."""
|
||||
self.is_image_set = False
|
||||
self.features = None
|
||||
self.orig_h = None
|
||||
self.orig_w = None
|
||||
self.input_h = None
|
||||
self.input_w = None
|
||||
self.masks3d = None
|
||||
self.stability_score_2d = None
|
||||
|
||||
def set_image(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
image_format: str = "nifti",
|
||||
) -> None:
|
||||
# Transform the image to the form expected by the model
|
||||
self.original_size = image.shape
|
||||
input_image_torch = torch.as_tensor(image, device=self.device)
|
||||
input_image_torch = self.transform.apply_image(input_image_torch.float())
|
||||
assert np.argmin(input_image_torch.shape) == 0, f"Got image.shape: {input_image_torch.shape}"
|
||||
maxlen = input_image_torch.shape[0]
|
||||
slices = []
|
||||
for i in range(1, maxlen-1):
|
||||
slices.append(input_image_torch[i-1:i+2, :, :])
|
||||
input_slices = torch.stack(slices, axis=0)
|
||||
self.set_torch_image(input_slices)
|
||||
|
||||
def batched_to_RGB(self, images):
|
||||
for i in range(images.shape[0]):
|
||||
images[i] = (images[i] - images[i].min()) / (images[i].max() - images[i].min()) * 255
|
||||
return images
|
||||
|
||||
@torch.no_grad()
|
||||
def set_torch_image(
|
||||
self,
|
||||
transformed_image: torch.Tensor,
|
||||
) -> None:
|
||||
assert (
|
||||
len(transformed_image.shape) == 4
|
||||
and transformed_image.shape[1] == 3
|
||||
and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
|
||||
), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
|
||||
self.reset_image()
|
||||
|
||||
self.input_size = tuple(transformed_image.shape[-2:])
|
||||
transformed_image = self.batched_to_RGB(transformed_image)
|
||||
input_image = self.model.preprocess(transformed_image)
|
||||
features = []
|
||||
for input_image_batch in batch_iterator(self.slice_per_batch, input_image):
|
||||
# print(input_image_batch[0].shape)
|
||||
features_batch = self.model.image_encoder(input_image_batch[0]).cpu()
|
||||
features.append(features_batch)
|
||||
self.features = torch.cat(features, axis=0)
|
||||
self.is_image_set = True
|
||||
|
||||
def __call__(self, x, *args: Any, **kwds: Any) -> Any:
|
||||
return self.predict_volume(x)
|
||||
|
||||
def merge_to_mask3d(self, idx, masks:MaskData):
|
||||
if masks._stats == {} or len(masks['masks']) == 0:
|
||||
print("No masks")
|
||||
return
|
||||
if self.masks3d is None:
|
||||
self.masks3d = np.zeros(self.original_size)
|
||||
if self.stability_score_2d is None:
|
||||
self.stability_score_2d = np.zeros(self.original_size[0])
|
||||
masks_values = masks['masks']
|
||||
for mask_value in zip(masks_values):
|
||||
old_mask = self.masks3d[idx-1:idx+2]
|
||||
# self.masks3d[idx-1:idx+2] = np.where(mask_value > old_mask, mask_value, old_mask)
|
||||
self.masks3d[idx-1:idx+2] = mask_value + old_mask
|
||||
|
||||
self.stability_score_2d[idx] = masks['stability_score_2d'][0,0]
|
||||
|
||||
def postprocess_3d(self, masks3d):
|
||||
# add removing noise ?
|
||||
return masks3d > 0
|
||||
|
||||
def _debug_predict_slice(
|
||||
self,
|
||||
x,
|
||||
point_coords: Optional[np.ndarray] = None,
|
||||
point_labels: Optional[np.ndarray] = None,
|
||||
box: Optional[np.ndarray] = None,
|
||||
mask_input: Optional[np.ndarray] = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
template_slice_id:int = None,
|
||||
return_stability: bool = False,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
main entrence
|
||||
predict 3d volume
|
||||
|
||||
x: volume: (c h w)
|
||||
box: [[x,y,x,y]]
|
||||
"""
|
||||
# Check Input
|
||||
assert len(x.shape) == 3
|
||||
assert box is None or len(box.shape) == 2
|
||||
|
||||
# preprocess
|
||||
# x = np.clip(x, -200,400)
|
||||
print(f"Checking Data range: [{x.min()}, {x.max()}]" )
|
||||
|
||||
# Adjust direction
|
||||
indicator = np.argmin(x.shape)
|
||||
if indicator == 0:
|
||||
pass
|
||||
elif indicator == 1:
|
||||
x = rearrange(x, "h c w -> c h w")
|
||||
elif indicator == 2:
|
||||
x = rearrange(x, "h w c -> c h w")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Preprocess prompts
|
||||
self.original_size = x.shape[1:]
|
||||
if point_coords is not None:
|
||||
assert (
|
||||
point_labels is not None
|
||||
), "point_labels must be supplied if point_coords is supplied."
|
||||
point_coords = self.transform.apply_coords(point_coords, self.original_size)
|
||||
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
|
||||
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
||||
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
|
||||
if box is not None:
|
||||
box = self.transform.apply_boxes(box, self.original_size)
|
||||
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
||||
box_torch = box_torch[None, :]
|
||||
if mask_input is not None:
|
||||
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
|
||||
mask_input_torch = mask_input_torch[None, :, :, :]
|
||||
|
||||
# set 3d image
|
||||
self.set_image(x)
|
||||
|
||||
# predict center slice
|
||||
center_idx = template_slice_id if template_slice_id is not None else x.shape[0] // 2
|
||||
# print("Processing ", center_idx)
|
||||
center_masks = self._predict_center_slice(center_idx, point_coords, box)
|
||||
return center_masks['masks']
|
||||
|
||||
def predict_volume(
|
||||
self,
|
||||
x,
|
||||
point_coords: Optional[np.ndarray] = None,
|
||||
point_labels: Optional[np.ndarray] = None,
|
||||
box: Optional[np.ndarray] = None,
|
||||
mask_input: Optional[np.ndarray] = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
template_slice_id:int = None,
|
||||
return_stability: bool = False,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
main entrence
|
||||
predict 3d volume
|
||||
|
||||
x: volume: (c h w)
|
||||
box: [[x,y,x,y]]
|
||||
"""
|
||||
# Check Input
|
||||
assert len(x.shape) == 3
|
||||
assert box is None or len(box.shape) == 2
|
||||
|
||||
# preprocess
|
||||
# x = np.clip(x, -200,400)
|
||||
print(f"Checking Data range: [{x.min()}, {x.max()}]" )
|
||||
|
||||
# Adjust direction
|
||||
indicator = np.argmin(x.shape)
|
||||
if indicator == 0:
|
||||
pass
|
||||
elif indicator == 1:
|
||||
x = rearrange(x, "h c w -> c h w")
|
||||
elif indicator == 2:
|
||||
x = rearrange(x, "h w c -> c h w")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Preprocess prompts
|
||||
self.original_size = x.shape[1:]
|
||||
if point_coords is not None:
|
||||
assert (
|
||||
point_labels is not None
|
||||
), "point_labels must be supplied if point_coords is supplied."
|
||||
point_coords = self.transform.apply_coords(point_coords, self.original_size)
|
||||
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
|
||||
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
||||
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
|
||||
if box is not None:
|
||||
box = self.transform.apply_boxes(box, self.original_size)
|
||||
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
||||
box_torch = box_torch[None, :]
|
||||
if mask_input is not None:
|
||||
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
|
||||
mask_input_torch = mask_input_torch[None, :, :, :]
|
||||
|
||||
# set 3d image
|
||||
self.set_image(x)
|
||||
|
||||
# predict center slice
|
||||
center_idx = template_slice_id if template_slice_id is not None else x.shape[0] // 2
|
||||
# print("Processing ", center_idx)
|
||||
center_masks = self._predict_center_slice(center_idx, point_coords, box)
|
||||
if center_masks._stats == {}:
|
||||
print("Ends for no mask.")
|
||||
raise ValueError
|
||||
self.merge_to_mask3d(center_idx, center_masks)
|
||||
|
||||
previous_masks = center_masks
|
||||
for i in range(center_idx+1, x.shape[0]-1):
|
||||
# print("Processing downward", i)
|
||||
previous_masks = self._predict_slice(i, previous_masks, orientation="down")
|
||||
if previous_masks._stats == {}:
|
||||
print("Ends for no mask.")
|
||||
break
|
||||
self.merge_to_mask3d(i, previous_masks)
|
||||
|
||||
previous_masks = center_masks
|
||||
for i in np.arange(1, center_idx)[::-1]:
|
||||
# print("Processing upward", i)
|
||||
previous_masks = self._predict_slice(i, previous_masks, orientation="up")
|
||||
if previous_masks._stats == {}:
|
||||
print("Ends for no mask.")
|
||||
break
|
||||
self.merge_to_mask3d(i, previous_masks)
|
||||
|
||||
if self.masks3d is None:
|
||||
self.masks3d = np.zeros_like(x)
|
||||
if return_stability:
|
||||
return self.postprocess_3d(self.masks3d), self.stability_score_2d
|
||||
return self.postprocess_3d(self.masks3d)
|
||||
|
||||
def _predict_center_slice(self, idx, point_prompt=None, box_prompt=None):
|
||||
if box_prompt is not None:
|
||||
masks = self.genetate_masks_from_boxes(idx, all_boxes=box_prompt, tags=["center_slice"])
|
||||
masks.to_numpy()
|
||||
return masks
|
||||
if point_prompt is not None:
|
||||
masks = self.genetate_masks_from_point_grids(idx, point_prompt)
|
||||
masks.to_numpy()
|
||||
return masks
|
||||
raise ValueError("No prompts! ?")
|
||||
|
||||
def _predict_slice(self, idx, previous_masks, orientation):
|
||||
scaled_boxes, tags = self.generate_prompts_from_previous_masks(previous_masks, orientation)
|
||||
masks = self.genetate_masks_from_boxes(idx, all_boxes=scaled_boxes, tags=tags)
|
||||
masks.to_numpy()
|
||||
return masks
|
||||
|
||||
def generate_prompts_from_previous_masks(self, previous_masks: MaskData, orientation):
|
||||
if orientation == "down":
|
||||
masks = previous_masks['masks'][:,2,:,:]
|
||||
elif orientation == "up":
|
||||
masks = previous_masks['masks'][:,0,:,:]
|
||||
else:
|
||||
raise ValueError
|
||||
raw_tags = previous_masks['tags']
|
||||
|
||||
scaled_boxes = []
|
||||
tags = []
|
||||
for mask, tag in zip(masks, raw_tags):
|
||||
if mask.sum() <= 50:
|
||||
continue
|
||||
# mask = self.remove_mask_noise(mask)
|
||||
# if mask.sum() <= 50:
|
||||
# continue
|
||||
mask = F.interpolate(torch.Tensor(mask).float()[None,None,:,:], self.input_size).squeeze().numpy()
|
||||
# scaled_boxes.append(self.box_prompt_generator.mask_to_bbox(mask))
|
||||
scaled_boxes.append(self.box_prompt_generator.enlarge(self.box_prompt_generator.mask_to_bbox(mask)))
|
||||
tags.append(tag)
|
||||
scaled_boxes = np.array(scaled_boxes)
|
||||
return scaled_boxes, tags
|
||||
|
||||
def genetate_masks_from_point_grids(self, idx, points_for_image):
|
||||
idx = idx - 1 # ignore the head and tail slices
|
||||
# Get points for this crop
|
||||
data = MaskData()
|
||||
tags = [f"s{idx}_p{p}" for p in range(points_for_image.shape[0])]
|
||||
for (points, batched_tags) in batch_iterator(self.points_per_batch, points_for_image, tags):
|
||||
batch_data = self._process_batch(idx, points=points, tags=batched_tags)
|
||||
data.cat(batch_data)
|
||||
del batch_data
|
||||
|
||||
# Remove duplicates within this crop.
|
||||
# keep_by_nms = batched_nms(
|
||||
# data["boxes"].float(),
|
||||
# data["iou_preds"],
|
||||
# torch.zeros_like(data["boxes"][:, 0]), # categories
|
||||
# iou_threshold=self.box_nms_thresh,
|
||||
# )
|
||||
# data.filter(keep_by_nms)
|
||||
return data
|
||||
|
||||
def genetate_masks_from_boxes(self, idx, all_boxes, tags):
|
||||
idx = idx - 1
|
||||
data = MaskData()
|
||||
for (batched_boxes, batched_tags) in batch_iterator(self.points_per_batch, all_boxes, tags):
|
||||
batch_data = self._process_batch(idx, boxes=batched_boxes, tags=batched_tags)
|
||||
data.cat(batch_data)
|
||||
del batch_data
|
||||
return data
|
||||
|
||||
def _process_batch(
|
||||
self,
|
||||
fea_slice_idx,
|
||||
points=None,
|
||||
boxes=None,
|
||||
multimask_output=True,
|
||||
tags=None
|
||||
) -> MaskData:
|
||||
"""
|
||||
Process with a subset of points. (bacause so many points can not be feed in one time)
|
||||
"""
|
||||
if points is not None:
|
||||
in_points = torch.as_tensor(points, device=self.model.device)
|
||||
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
||||
masks, iou_preds, _ = self.predict_torch_by_sliceidx(
|
||||
fea_slice_idx=fea_slice_idx,
|
||||
point_coords=in_points[:, None, :],
|
||||
point_labels=in_labels[:, None],
|
||||
multimask_output=multimask_output,
|
||||
return_logits=True,
|
||||
)
|
||||
masks = rearrange(masks, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
|
||||
elif boxes is not None:
|
||||
# in_points = torch.as_tensor(points, device=self.model.device)
|
||||
# in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
||||
boxes = torch.as_tensor(boxes, device=self.model.device)
|
||||
masks, iou_preds, _ = self.predict_torch_by_sliceidx(
|
||||
fea_slice_idx=fea_slice_idx,
|
||||
boxes=boxes,
|
||||
multimask_output=multimask_output,
|
||||
return_logits=True,
|
||||
)
|
||||
masks = rearrange(masks, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
|
||||
else:
|
||||
raise ValueError(f"No points or boxes")
|
||||
|
||||
indices = iou_preds.argmax(axis=1)
|
||||
# indices = torch.tensor([2], device=iou_preds.device)
|
||||
pred_maxiou = []
|
||||
for pred, i in zip(masks, indices):
|
||||
pred_maxiou.append(pred[i,:,:,:])
|
||||
masks = torch.stack(pred_maxiou, axis=0)
|
||||
|
||||
iou_maxiou = []
|
||||
for iou, i in zip(iou_preds, indices):
|
||||
iou_maxiou.append(iou[i])
|
||||
iou_preds = torch.stack(iou_maxiou)
|
||||
|
||||
# Serialize predictions and store in MaskData
|
||||
data = MaskData(
|
||||
masks=masks.detach().cpu(),
|
||||
iou_preds=iou_preds.detach().cpu(),
|
||||
# points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
||||
# tags=repeat(np.array(tags), "c -> (c 3)")
|
||||
tags=np.array(tags),
|
||||
)
|
||||
del masks
|
||||
|
||||
# Filter by area
|
||||
# if True:
|
||||
# keep_mask = (data['masks']>0).sum(-1).sum(-1).sum(-1) < (data['masks']<=0).sum(-1).sum(-1).sum(-1) * 0.4
|
||||
# data.filter(keep_mask)
|
||||
# print("keep mask / pred", keep_mask.sum())
|
||||
|
||||
# Filter Background
|
||||
# if True:
|
||||
# keep_mask =
|
||||
|
||||
# Filter by predicted IoU
|
||||
if self.use_postprocess:
|
||||
if self.pred_iou_thresh > -0.0:
|
||||
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
||||
# print("pred_iou", data["iou_preds"], (data["masks"]>0).sum(-1).sum(-1))
|
||||
data.filter(keep_mask)
|
||||
# print("keep mask / pred", keep_mask.sum())
|
||||
|
||||
# Calculate stability score
|
||||
data["stability_score"] = calculate_stability_score_3d(
|
||||
data["masks"], self.model.mask_threshold, self.stability_score_offset
|
||||
) # .mean(axis=-1)
|
||||
if self.stability_score_thresh > 0.0:
|
||||
# print("stability", data["stability_score"], (data["masks"]>0).sum(-1).sum(-1))
|
||||
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
||||
data.filter(keep_mask)
|
||||
# print("keep mask / stable", keep_mask.sum())
|
||||
|
||||
data["stability_score_2d"] = calculate_stability_score(
|
||||
data["masks"][:,1:2,:,:], self.model.mask_threshold, self.stability_score_offset
|
||||
)
|
||||
|
||||
# Threshold masks and calculate boxes
|
||||
data['logits'] = data['masks']
|
||||
data["noisy_masks"] = data["logits"] > self.model.mask_threshold
|
||||
|
||||
# data['masks'] = torch.zeros_like(data['noisy_masks'], dtype=data['noisy_masks'].dtype, device=data['noisy_masks'].device)
|
||||
b, c,_,_ = data["noisy_masks"].shape
|
||||
data['masks'] = data["noisy_masks"].float()
|
||||
|
||||
if self.use_noise_remove:
|
||||
for i in range(b):
|
||||
for j in range(c):
|
||||
data['masks'][i,j,:,:] = torch.Tensor(self.remove_mask_noise(data['noisy_masks'][i,j,:,:]))
|
||||
|
||||
# data["boxes"] = batched_mask_to_box(reduce(data["masks"], "b c h w -> b h w", reduction="sum")>0)
|
||||
data["boxes"] = batched_mask_to_box(data["masks"][:,1,:,:]>0)
|
||||
return data
|
||||
|
||||
# @staticmethod
|
||||
# def calculate_
|
||||
|
||||
|
||||
@staticmethod
|
||||
def batched_remove_noise(masks):
|
||||
ori_shape = masks.shape
|
||||
|
||||
def remove_mask_noise(self, mask):
|
||||
# mask_sum = mask.sum()
|
||||
kerner_size = min(mask.sum() // 20, 8)
|
||||
if isinstance(mask, torch.Tensor):
|
||||
mask = mask.detach().cpu().numpy()
|
||||
kernel = np.ones((kerner_size,kerner_size), dtype=np.uint8)
|
||||
opening = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel, 1)
|
||||
return opening
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_torch_by_sliceidx(
|
||||
self,
|
||||
fea_slice_idx: int,
|
||||
point_coords: Optional[torch.Tensor] = None,
|
||||
point_labels: Optional[torch.Tensor] = None,
|
||||
boxes: Optional[torch.Tensor] = None,
|
||||
mask_input: Optional[torch.Tensor] = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if not self.is_image_set:
|
||||
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
||||
|
||||
if point_coords is not None:
|
||||
points = (point_coords, point_labels)
|
||||
else:
|
||||
points = None
|
||||
|
||||
# Embed prompts
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
||||
points=points,
|
||||
boxes=boxes,
|
||||
masks=mask_input,
|
||||
)
|
||||
|
||||
# Predict masks
|
||||
low_res_masks, iou_predictions = self.model.mask_decoder(
|
||||
image_embeddings=self.features[fea_slice_idx].to(self.model.device),
|
||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
|
||||
# Upscale the masks to the original image resolution
|
||||
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size[1:])
|
||||
|
||||
if not return_logits:
|
||||
masks = masks > self.model.mask_threshold
|
||||
|
||||
return masks, iou_predictions, low_res_masks
|
||||
|
||||
def valid_box(self, data, batch_idx):
|
||||
# valid image with box, or point prompt
|
||||
assert data['img'].shape[0] == 1, f"shape {data['img'].shape}"
|
||||
image = data['img']
|
||||
label = data['label']
|
||||
|
||||
box = BoxPromptGenerator().mask_to_bbox(label)
|
||||
box_mask3d = self.predict_volume(
|
||||
x=image,
|
||||
box=box,
|
||||
)
|
||||
dice = compute_dice_np(box_mask3d, label.detach().cpu().numpy())
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from core.learner3 import SamLearner
|
||||
from modeling.build_sam3d2 import sam_model_registry
|
||||
from trans_utils.data_utils import Data3dSolver
|
||||
|
||||
config = ConfigManager()
|
||||
config.add_config("configs/vit_b_103.yaml")
|
||||
|
||||
model_type = "vit_b"
|
||||
sam = sam_model_registry[model_type](checkpoint=None)
|
||||
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
||||
learner.use_lora()
|
||||
pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth"
|
||||
learner.load_well_trained_model(pth)
|
||||
learner.cuda()
|
||||
|
||||
predictor = VolumePredictor(
|
||||
model=learner.model,
|
||||
use_postprocess=True,
|
||||
use_noise_remove=True,)
|
||||
|
||||
# Load data
|
||||
img_path = "/home1/quanquan/datasets/07_WORD/WORD-V0.1.0/imagesVa/word_0001.nii.gz" # "/home1/quanquan/datasets/59_SABS/sabs_CT_normalized/image_5.nii.gz"
|
||||
label_path = "/home1/quanquan/datasets/07_WORD/WORD-V0.1.0/labelsVa/word_0001.nii.gz"
|
||||
volume = itk_to_np(read(img_path)) # test several slices
|
||||
label_itk = read(label_path)
|
||||
spacing = label_itk.GetSpacing()
|
||||
label = itk_to_np(label_itk) == 1
|
||||
volume = np.clip(volume, -200, 400)
|
||||
|
||||
# Select the slice with the largest mask
|
||||
s = reduce(label, "c h w -> c", reduction="sum")
|
||||
coords = np.nonzero(s)
|
||||
x_min = np.min(coords[0])
|
||||
x_max = np.max(coords[0])
|
||||
template_slice_id = s.argmax()
|
||||
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id])
|
||||
box = np.array([box])
|
||||
|
||||
pred = predictor.predict_volume(
|
||||
x=volume,
|
||||
box=box,
|
||||
template_slice_id=template_slice_id,
|
||||
return_stability=False,
|
||||
)
|
||||
|
||||
Data3dSolver().simple_write(pred, path="mask.nii.gz", spacing=spacing)
|
||||
Data3dSolver().simple_write(label, path="gt.nii.gz", spacing=spacing)
|
0
datasets/__init__.py
Normal file
0
datasets/__init__.py
Normal file
BIN
datasets/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
datasets/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
datasets/__pycache__/cache_dataset3d.cpython-38.pyc
Normal file
BIN
datasets/__pycache__/cache_dataset3d.cpython-38.pyc
Normal file
Binary file not shown.
BIN
datasets/__pycache__/cache_dataset3d3.cpython-38.pyc
Normal file
BIN
datasets/__pycache__/cache_dataset3d3.cpython-38.pyc
Normal file
Binary file not shown.
BIN
datasets/__pycache__/data_engine.cpython-38.pyc
Normal file
BIN
datasets/__pycache__/data_engine.cpython-38.pyc
Normal file
Binary file not shown.
BIN
datasets/__pycache__/dataset3d.cpython-38.pyc
Normal file
BIN
datasets/__pycache__/dataset3d.cpython-38.pyc
Normal file
Binary file not shown.
BIN
datasets/__pycache__/dataset3d_2dmask.cpython-38.pyc
Normal file
BIN
datasets/__pycache__/dataset3d_2dmask.cpython-38.pyc
Normal file
Binary file not shown.
BIN
datasets/__pycache__/dataset_merged.cpython-38.pyc
Normal file
BIN
datasets/__pycache__/dataset_merged.cpython-38.pyc
Normal file
Binary file not shown.
267
datasets/cache_dataset3d.py
Normal file
267
datasets/cache_dataset3d.py
Normal file
@ -0,0 +1,267 @@
|
||||
"""
|
||||
Slow Loading directly
|
||||
|
||||
So we pre-precess data
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
from einops import rearrange, reduce, repeat
|
||||
from tutils.nn.data import read, itk_to_np, np_to_itk, write
|
||||
from tutils import tfilename
|
||||
from .dataset3d import DATASET_CONFIG, Dataset3D as basic_3d_dataset
|
||||
from monai import transforms
|
||||
import torch
|
||||
import cv2
|
||||
from scipy.sparse import csr_matrix
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
from einops import rearrange
|
||||
import glob
|
||||
from torchvision import transforms
|
||||
from monai import transforms as monai_transforms
|
||||
|
||||
# "class": ["spleen", "right kidney", "left kidney", "gallbladder", "esophagus", "liver", "stomach", "aorta", "postcava", "portal vein and splenic vein", "pancrease", "right adrenal gland", "left adrenal gland"],
|
||||
# "class": ["liver", "right kidney", "left kidney", "spleen"],
|
||||
TEMPLATE={
|
||||
'01': [1,2,3,4,5,6,7,8,9,10,11,12,13,14],
|
||||
'02': [1,0,3,4,5,6,7,0,0,0,11,0,0,14],
|
||||
'03': [6],
|
||||
'04': [6,27], # post process
|
||||
'05': [2,26,32], # post process
|
||||
'07': [6,1,3,2,7,4,5,11,14,18,19,12,20,21,23,24],
|
||||
'08': [6, 2, 1, 11],
|
||||
'09': [1,2,3,4,5,6,7,8,9,11,12,13,14,21,22],
|
||||
'12': [6,21,16,2],
|
||||
'13': [6,2,1,11,8,9,7,4,5,12,13,25],
|
||||
'14': [11,11,28,28,28], # Felix data, post process
|
||||
'10_03': [6, 27], # post process
|
||||
'10_06': [30],
|
||||
'10_07': [11, 28], # post process
|
||||
'10_08': [15, 29], # post process
|
||||
'10_09': [1],
|
||||
'10_10': [31],
|
||||
'58': [6,2,3,1],
|
||||
'59': [1,2,3,4,5,6,7,8,9,10,11,12,13,14],
|
||||
'60': np.arange(200).tolist(), # for debug
|
||||
}
|
||||
|
||||
class Dataset3D(basic_3d_dataset):
|
||||
def __init__(self, config=..., use_cache=True, *args, **kwargs) -> None:
|
||||
super().__init__(config, use_cache=use_cache, *args, **kwargs)
|
||||
self.basic_dir = config['data_root_path']
|
||||
self.cache_dir = config['cache_data_path']
|
||||
|
||||
def prepare_transforms(self):
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Resize((1024,1024)),
|
||||
])
|
||||
self.test_transform = transforms.Compose([
|
||||
monai_transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
|
||||
])
|
||||
|
||||
# @tfunctime
|
||||
# def prepare_datalist(self):
|
||||
def prepare_cached_datalist(self):
|
||||
raise DeprecationWarning("[Warning] Please use cache_dataset3d new version instead!")
|
||||
config = self.config
|
||||
data_paths = []
|
||||
for dirpath in glob.glob(config['cache_data_path'] + "/*"):
|
||||
data_paths += glob.glob(dirpath + "/image/*.jpg")
|
||||
print("Load ", dirpath)
|
||||
print('train len {}'.format(len(data_paths)))
|
||||
print('Examples: ', data_paths[:2])
|
||||
return data_paths
|
||||
|
||||
def caching_data(self):
|
||||
assert self.use_cache == False
|
||||
for index in range(len(self)):
|
||||
self.cache_one_sample(index)
|
||||
|
||||
def cache_one_sample(self, index, debug=False):
|
||||
# LABEL_INDICES
|
||||
name = self.img_names[index]['img_path']
|
||||
img_itk = read(self.img_names[index]['img_path'])
|
||||
img_ori = itk_to_np(img_itk)
|
||||
|
||||
img_ori = np.clip(img_ori, -200,400)
|
||||
|
||||
# spacing = img_itk.GetSpacing()
|
||||
scan_orientation = np.argmin(img_ori.shape)
|
||||
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
|
||||
|
||||
dataset_name = self.img_names[index]['img_path'].replace(self.basic_dir,"").split("/")[0]
|
||||
assert dataset_name[0] in ['0','1','2','3','4','5','6','7','8','9'], f"Got {dataset_name}"
|
||||
all_labels = TEMPLATE[dataset_name[:2]]
|
||||
|
||||
num = 0
|
||||
|
||||
# if min(img_ori.shape) * 1.2 < max(img_ori.shape):
|
||||
# orientation_all = [scan_orientation]
|
||||
# else:
|
||||
# orientation_all = [0,1,2]
|
||||
orientation_all = [scan_orientation]
|
||||
|
||||
for orientation in orientation_all:
|
||||
for slice_idx in range(2, img_ori.shape[orientation]-2):
|
||||
# slice_idx = np.random.randint(2, img_ori.shape[orientation]-2)
|
||||
if orientation == 0:
|
||||
s = img_ori[slice_idx-1:slice_idx+2, :,:]
|
||||
lb = label_ori[slice_idx-1:slice_idx+2, :,:]
|
||||
# spacing = (spacing[1], spacing[2])
|
||||
if orientation == 1:
|
||||
s = img_ori[:,slice_idx-1:slice_idx+2,:]
|
||||
s = rearrange(s, "h c w -> c h w")
|
||||
lb = label_ori[:,slice_idx-1:slice_idx+2,:]
|
||||
lb = rearrange(lb, "h c w -> c h w")
|
||||
# spacing = (spacing[0], spacing[2])
|
||||
if orientation == 2:
|
||||
s = img_ori[:,:,slice_idx-1:slice_idx+2]
|
||||
s = rearrange(s, "h w c -> c h w")
|
||||
lb = label_ori[:,:,slice_idx-1:slice_idx+2]
|
||||
lb = rearrange(lb, "h w c -> c h w")
|
||||
# spacing = (spacing[0], spacing[1])
|
||||
assert s.shape[0] == 3
|
||||
|
||||
# if np.float32(lb[1,:,:]>0).sum() <= 200:
|
||||
# # return self._get_data((index+1)%len(self))
|
||||
# continue
|
||||
# Choose one label
|
||||
label_num = int(lb.max())
|
||||
|
||||
masks_data = []
|
||||
meta = {"img_name": name, "slice": slice_idx, "orientation": orientation, "label_idx": [], "labels": [], "id": f"{num:08d}" }
|
||||
for label_idx in range(1,label_num+1):
|
||||
one_lb = np.float32(lb==label_idx)
|
||||
if one_lb[1,:,:].sum() <= (one_lb.shape[-1] * one_lb.shape[-2] * 0.0014):
|
||||
continue
|
||||
# if one_lb[0,:,:].sum()<=50 or one_lb[2,:,:].sum()<=50:
|
||||
|
||||
masks_data.append(one_lb)
|
||||
meta['label_idx'].append(label_idx)
|
||||
meta['labels'].append(all_labels[label_idx-1])
|
||||
|
||||
if len(masks_data) <= 0:
|
||||
continue
|
||||
|
||||
img_rgb = s
|
||||
img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy()
|
||||
img_rgb = self.to_RGB(img_rgb)
|
||||
save_image_name = tfilename(self.cache_dir, dataset_name, f"image/image_{index:04d}_{num:08d}.jpg")
|
||||
self.save_img_rgb(rearrange(img_rgb, "c h w -> h w c"), save_image_name)
|
||||
|
||||
# Save cache data
|
||||
save_label_name = tfilename(self.cache_dir, dataset_name, f"label/label_{index:04d}_{num:08d}.npz")
|
||||
self.save_slice_mask(masks_data, save_label_name)
|
||||
print("Save ", save_image_name)
|
||||
|
||||
self.save_meta(meta, tfilename(self.cache_dir, dataset_name, f"meta/meta_{index:04d}_{num:08d}.npy"))
|
||||
|
||||
num += 1
|
||||
|
||||
def save_meta(self, meta, path):
|
||||
assert path.endswith(".npy")
|
||||
np.save(path, meta)
|
||||
|
||||
def save_slice_mask(self, masks_data, prefix):
|
||||
masks_data = F.interpolate(torch.Tensor(masks_data), size=(1024,1024)).numpy()
|
||||
assert masks_data.shape[1:] == (3,1024,1024), f"{__file__} Got{masks_data.shape}"
|
||||
for i in range(masks_data.shape[0]):
|
||||
labeli = masks_data[i].astype(np.uint8) * 255
|
||||
assert labeli.sum() > 0
|
||||
path = tfilename(prefix+f"_{i:04d}.jpg")
|
||||
cv2.imwrite(path, rearrange(labeli, "c h w -> h w c"))
|
||||
print("save to ", path)
|
||||
|
||||
def _old_save_slice_mask(self, masks_data, path):
|
||||
raise DeprecationWarning()
|
||||
exit(0)
|
||||
assert path.endswith(".npz")
|
||||
# masks_data = np.array([m['segmentation'] for m in masks]).astype(int)
|
||||
masks_data = F.interpolate(torch.Tensor(masks_data), size=(1024,1024)).numpy()
|
||||
# masks_data = np.int8(masks_data>0)
|
||||
assert masks_data.shape[1:] == (3,1024,1024), f"{__file__} Got{masks_data.shape}"
|
||||
masks_data = rearrange(masks_data, "n c h w -> n (c h w)")
|
||||
csr = csr_matrix(masks_data)
|
||||
np.savez_compressed(path, data=csr.data, indices=csr.indices, indptr=csr.indptr, shape=csr.shape)
|
||||
|
||||
def save_img_rgb(self, img, path):
|
||||
assert path.endswith(".jpg")
|
||||
assert img.shape == (1024,1024,3)
|
||||
cv2.imwrite(path, img.astype(np.uint8))
|
||||
|
||||
def _get_cached_data(self, index):
|
||||
name = self.img_names[index]
|
||||
# print(name)
|
||||
img = cv2.imread(name)
|
||||
compressed = np.load(name.replace("image/image_", "label/label_").replace(".jpg", ".npz"))
|
||||
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
|
||||
label_ori = csr.toarray()
|
||||
label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024)
|
||||
meta = np.load(name.replace("image/image_", "meta/meta_").replace(".jpg", ".npy"), allow_pickle=True).tolist()
|
||||
# print(meta)
|
||||
pp = reduce(label_ori[:,1,:,:], "n h w -> n", reduction="sum") > 500
|
||||
if pp.sum() == 0:
|
||||
return self._get_cached_data((index+1)%len(self))
|
||||
|
||||
label_idx = np.random.choice(a=np.arange(len(pp)), p=pp/pp.sum())
|
||||
# label_idx = np.random.randint(0, label_ori.shape[0])
|
||||
label_ori = label_ori[label_idx]
|
||||
is_edge = meta.get('is_edge', 0)
|
||||
return rearrange(img, "h w c -> c h w"), label_ori, name, meta['labels'][label_idx], meta['label_idx'][label_idx]
|
||||
|
||||
# @tfunctime
|
||||
def __getitem__(self, index, debug=False):
|
||||
# print("Dataset warning", index, len(self))
|
||||
index = index % len(self)
|
||||
img_rgb, label_ori, name, label_idx, local_idx = self._get_cached_data(index)
|
||||
|
||||
if label_ori.sum() <= 0:
|
||||
print("[Label Error] ", name)
|
||||
return self.__getitem__(index+1)
|
||||
|
||||
# assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}"
|
||||
# img_rgb = self.transform((img_rgb[None,:,:,:]))
|
||||
img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy()
|
||||
|
||||
vector = np.ones(3)
|
||||
ret_dict = {
|
||||
"name": name,
|
||||
"img": img_rgb,
|
||||
"label": label_ori,
|
||||
"indicators": vector,
|
||||
"class": label_idx,
|
||||
"local_idx": local_idx,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
def _convert_one_mask_from_npz_to_jpg(self, path1=None):
|
||||
# path1 = "/home1/quanquan/datasets/cached_dataset2/01_BCV-Abdomen/label/label_0129_00000043.npz" # 32K
|
||||
prefix = path1.replace(".npz", "").replace("/label/", "/label_jpg/")
|
||||
compressed = np.load(path1)
|
||||
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
|
||||
label_ori = csr.toarray()
|
||||
label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024)
|
||||
# print(label_ori.shape)
|
||||
for i in range(label_ori.shape[0]):
|
||||
labeli = label_ori[i]
|
||||
path = tfilename(prefix+f"_{i:04d}.jpg")
|
||||
cv2.imwrite(path, rearrange(labeli, "c h w -> h w c").astype(np.uint8))
|
||||
print("save to ", path)
|
||||
|
||||
def convert_masks_types(self):
|
||||
assert self.use_cache == True
|
||||
for index in range(len(self)):
|
||||
name = self.img_names[index]
|
||||
label_path = name.replace("image/image_", "label/label_").replace(".jpg", ".npz")
|
||||
self._convert_one_mask_from_npz_to_jpg(label_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# def go_cache():
|
||||
from tutils.new.manager import ConfigManager
|
||||
config = ConfigManager()
|
||||
config.add_config("configs/vit_b.yaml")
|
||||
dataset = Dataset3D(config=config['dataset'], use_cache=True)
|
||||
dataset.caching_data()
|
||||
# dataset.convert_masks_types()
|
97
datasets/cache_dataset3d3.py
Normal file
97
datasets/cache_dataset3d3.py
Normal file
@ -0,0 +1,97 @@
|
||||
"""
|
||||
re-index by masks! not images
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
from einops import rearrange, reduce, repeat
|
||||
from tutils.nn.data import read, itk_to_np, np_to_itk, write
|
||||
from tutils import tfilename
|
||||
from .cache_dataset3d import Dataset3D as basic_3d_dataset
|
||||
from monai import transforms
|
||||
import torch
|
||||
import cv2
|
||||
from scipy.sparse import csr_matrix
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
from einops import rearrange
|
||||
import glob
|
||||
from torchvision import transforms
|
||||
from monai import transforms as monai_transforms
|
||||
|
||||
|
||||
class Dataset3D(basic_3d_dataset):
|
||||
def __init__(self, config=..., use_cache=True, *args, **kwargs) -> None:
|
||||
super().__init__(config, use_cache=use_cache, *args, **kwargs)
|
||||
self.basic_dir = config['data_root_path']
|
||||
self.cache_dir = config['cache_data_path']
|
||||
|
||||
def prepare_cached_datalist(self):
|
||||
config = self.config
|
||||
data_paths = []
|
||||
for dirpath in glob.glob(config['cache_data_path'] + "/*"):
|
||||
if not os.path.isdir(dirpath):
|
||||
continue
|
||||
prefix = dirpath.split("/")[-1]
|
||||
if prefix[:2] in config['cache_prefix']:
|
||||
data_paths += glob.glob(dirpath + "/label_jpg/*.jpg")
|
||||
print("Load ", dirpath)
|
||||
print('Masks len {}'.format(len(data_paths)))
|
||||
print('Examples: ', data_paths[:2])
|
||||
return data_paths
|
||||
|
||||
def _get_cached_data(self, index):
|
||||
mask_path = self.img_names[index]
|
||||
# print(name)
|
||||
mask = np.int32(cv2.imread(mask_path) > 0)
|
||||
|
||||
prefix = mask_path[:-9]
|
||||
img_path = prefix.replace("/label_jpg/label_", "/image/image_") + ".jpg"
|
||||
img = cv2.imread(img_path)
|
||||
number = int(mask_path[-8:-4])
|
||||
|
||||
meta = np.load(prefix.replace("/label_jpg/label_", "/meta/meta_")+".npy", allow_pickle=True).tolist()
|
||||
# label_idx = np.random.randint(0, label_ori.shape[0])
|
||||
|
||||
return rearrange(img, "h w c -> c h w"), rearrange(mask, "h w c -> c h w"), mask_path, meta['labels'][number], meta['label_idx'][number]
|
||||
|
||||
# @tfunctime
|
||||
def __getitem__(self, index, debug=False):
|
||||
index = index % len(self)
|
||||
img_rgb, label_ori, name, label_idx, local_idx = self._get_cached_data(index)
|
||||
|
||||
# assert label_ori.sum() > 0
|
||||
if label_ori.sum() <= 0:
|
||||
print("[Label Error] ", name)
|
||||
return self.__getitem__(index+1)
|
||||
|
||||
# assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}"
|
||||
# img_rgb = self.transform((img_rgb[None,:,:,:]))
|
||||
img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy()
|
||||
|
||||
vector = np.ones(3)
|
||||
ret_dict = {
|
||||
"name": name,
|
||||
"img": img_rgb,
|
||||
"label": label_ori,
|
||||
"indicators": vector,
|
||||
"class": label_idx,
|
||||
"local_idx": local_idx,
|
||||
"is_problem": label_ori.sum() <= 30,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
if __name__ == "__main__":
|
||||
# def go_cache():
|
||||
from tutils.new.manager import ConfigManager
|
||||
from tqdm import tqdm
|
||||
config = ConfigManager()
|
||||
config.add_config("configs/vit_b.yaml")
|
||||
config.print()
|
||||
dataset = Dataset3D(config=config['dataset'], use_cache=True)
|
||||
# dataset.caching_data()
|
||||
# dataset.convert_masks_types()
|
||||
for i in tqdm(range(len(dataset))):
|
||||
data = dataset.__getitem__(i)
|
||||
# assert
|
||||
|
274
datasets/data_engine.py
Normal file
274
datasets/data_engine.py
Normal file
@ -0,0 +1,274 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import matplotlib.pyplot as plt
|
||||
from einops import repeat, rearrange
|
||||
|
||||
def show_mask(mask, ax, random_color=False):
|
||||
if random_color:
|
||||
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
||||
else:
|
||||
color = np.array([30/255, 144/255, 255/255, 0.6])
|
||||
h, w = mask.shape[-2:]
|
||||
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
||||
ax.imshow(mask_image)
|
||||
|
||||
def show_points(coords, labels, ax, marker_size=375):
|
||||
pos_points = coords[labels==1]
|
||||
neg_points = coords[labels==0]
|
||||
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
||||
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
||||
|
||||
def show_box(box, ax):
|
||||
x0, y0 = box[0], box[1]
|
||||
w, h = box[2] - box[0], box[3] - box[1]
|
||||
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
||||
|
||||
class PointPromptGenerator(object):
|
||||
def __init__(self, size=None) -> None:
|
||||
pass
|
||||
|
||||
def get_prompt_point(self, gt_mask):
|
||||
# assert gt_mask.shape == (1024,1024) or gt_mask.shape == (512,512), f"[data_engine] {__file__} Got{gt_mask.shape}"
|
||||
if not (gt_mask.shape == (1024,1024) or gt_mask.shape == (512,512) or gt_mask.shape == (256,256) ):
|
||||
print(f"[Warning] [data_engine] {__file__} Got{gt_mask.shape}")
|
||||
assert gt_mask.sum() > 0
|
||||
self.size = gt_mask.shape
|
||||
self.xy = np.arange(0, self.size[0] * self.size[1])
|
||||
|
||||
gt_mask = np.float32(gt_mask>0)
|
||||
prob = rearrange(gt_mask, "h w -> (h w)")
|
||||
prob = prob / prob.sum()
|
||||
loc = np.random.choice(a=self.xy, size=1, replace=True, p=prob)[0]
|
||||
x, y = loc % self.size[1], loc // self.size[1]
|
||||
return x, y
|
||||
|
||||
@staticmethod
|
||||
def select_random_subsequent_point(pred_mask, gt_mask):
|
||||
# union = np.float32((pred_mask + gt_mask)>0)
|
||||
# diff = union - intersect
|
||||
assert len(pred_mask.shape) == 2
|
||||
assert len(gt_mask.shape) == 2
|
||||
assert gt_mask.sum() > 0, f"[data_engine] Got {gt_mask.sum()}==0 "
|
||||
diff = np.float32(np.abs(pred_mask - gt_mask)>0)
|
||||
diff = np.nan_to_num(diff, nan=0)
|
||||
# print(diff.shape)
|
||||
xy = np.arange(0, diff.shape[0] * diff.shape[1])
|
||||
|
||||
if diff.sum() == 0:
|
||||
prob = rearrange(gt_mask, "h w -> (h w)")
|
||||
prob = prob / prob.sum()
|
||||
loc = np.random.choice(a=xy, size=1, replace=True, p=prob)[0]
|
||||
x, y = loc % diff.shape[1], loc // diff.shape[1]
|
||||
return (x,y), 1
|
||||
# Get_prompt_point
|
||||
prob = rearrange(diff, "h w -> (h w)")
|
||||
prob = prob / prob.sum()
|
||||
loc = np.random.choice(a=xy, size=1, replace=True, p=prob)[0]
|
||||
x, y = loc % diff.shape[1], loc // diff.shape[1]
|
||||
|
||||
if gt_mask[y, x] == 1 and pred_mask[y, x] == 0:
|
||||
classification = 1
|
||||
else:
|
||||
classification = 0
|
||||
# raise ValueError
|
||||
return (x, y), classification
|
||||
|
||||
|
||||
|
||||
class BoxPromptGenerator(object):
|
||||
def __init__(self, size) -> None:
|
||||
self.size = size
|
||||
|
||||
@staticmethod
|
||||
def mask_to_bbox(mask):
|
||||
# Find the indices of all non-zero elements in the mask
|
||||
coords = np.nonzero(mask)
|
||||
|
||||
# Compute the minimum and maximum values of the row and column indices
|
||||
x_min = np.min(coords[1])
|
||||
y_min = np.min(coords[0])
|
||||
x_max = np.max(coords[1])
|
||||
y_max = np.max(coords[0])
|
||||
|
||||
# Return the coordinates of the bounding box
|
||||
return (x_min, y_min, x_max, y_max)
|
||||
# return (y_min, x_min, y_max, x_max)
|
||||
|
||||
def add_random_noise_to_bbox(self, bbox):
|
||||
bbox = list(bbox)
|
||||
# Calculate the side lengths of the box in the x and y directions
|
||||
x_side_length = bbox[2] - bbox[0]
|
||||
y_side_length = bbox[3] - bbox[1]
|
||||
|
||||
# Calculate the standard deviation of the noise
|
||||
std_dev = 0.01 * (x_side_length + y_side_length) / 2
|
||||
|
||||
# Generate random noise for each coordinate
|
||||
x_noise = np.random.normal(scale=std_dev)
|
||||
y_noise = np.random.normal(scale=std_dev)
|
||||
|
||||
# Add the random noise to each coordinate, but make sure it is not larger than 20 pixels
|
||||
bbox[0] += min(int(round(x_noise)), 20)
|
||||
bbox[1] += min(int(round(y_noise)), 20)
|
||||
bbox[2] += min(int(round(x_noise)), 20)
|
||||
bbox[3] += min(int(round(y_noise)), 20)
|
||||
|
||||
# Make sure the modified coordinates do not exceed the maximum possible values
|
||||
bbox[0] = max(bbox[0], 0)
|
||||
bbox[1] = max(bbox[1], 0)
|
||||
bbox[2] = min(bbox[2], self.size[0])
|
||||
bbox[3] = min(bbox[3], self.size[1])
|
||||
|
||||
# Return the modified bounding box
|
||||
return bbox
|
||||
|
||||
def get_prompt_box(self, gt_mask):
|
||||
""" return (x_min, y_min, x_max, y_max) """
|
||||
assert gt_mask.shape == (1024,1024) or gt_mask.shape == (512,512) or gt_mask.shape == (256,256), f"[data_engine] {__file__} Got{gt_mask.shape}"
|
||||
box = self.mask_to_bbox(gt_mask)
|
||||
box_w_noise = self.add_random_noise_to_bbox(box)
|
||||
return box_w_noise
|
||||
|
||||
def enlarge(self, bbox, margin=0):
|
||||
x0, y0, x1, y1 = bbox[0], bbox[1], bbox[2], bbox[3]
|
||||
margin_x = int((x1 - x0)*0.05)
|
||||
margin_y = int((y1 - y0)*0.05)
|
||||
x0 = max(x0 - margin_x, 0)
|
||||
y0 = max(y0 - margin_x, 0)
|
||||
x1 = min(x1 - margin_y, self.size[0]-1)
|
||||
y1 = min(y1 - margin_y, self.size[1]-1)
|
||||
|
||||
# print("[DEBUG] , enlarge size: ", margin_x, margin_y)
|
||||
# print("[DEBUG] from", bbox, "to", (x0,y0,x1,y1))
|
||||
return (x0,y0,x1,y1)
|
||||
|
||||
class DataEngine(Dataset):
|
||||
def __init__(self, dataset=None, img_size=None) -> None:
|
||||
# CACHE_DISK_DIR="/home1/quanquan/code/projects/medical-guangdong/cache/data2d_3/"
|
||||
super().__init__()
|
||||
self.point_prompt_generator = PointPromptGenerator(img_size)
|
||||
self.box_prompt_generator = BoxPromptGenerator(img_size)
|
||||
# self._get_dataset(dirpath=dirpath)
|
||||
self.dataset = dataset
|
||||
|
||||
# def _get_dataset(self, dirpath):
|
||||
# self.dataset = Dataset2D(dirpath=dirpath, is_train=True)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def _get_true_index(self, idx):
|
||||
return idx
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.get_prompt(idx)
|
||||
|
||||
def get_prompt_point(self, gt_mask):
|
||||
return self.point_prompt_generator.get_prompt_point(gt_mask)
|
||||
|
||||
def get_prompt_box(self, gt_mask):
|
||||
return self.box_prompt_generator.get_prompt_box(gt_mask)
|
||||
|
||||
# def _get_data_from_dataset(self, idx):
|
||||
|
||||
def get_prompt(self, idx):
|
||||
idx = self._get_true_index(idx)
|
||||
data = self.dataset.__getitem__(idx)
|
||||
|
||||
img = data['img'] # (3,h,w) d=3
|
||||
mask = data['label'] # (3,h,w) d=3
|
||||
|
||||
try:
|
||||
gt_mask = mask[1,:,:]
|
||||
except Exception as e:
|
||||
import ipdb; ipdb.set_trace()
|
||||
gt_mask = mask[1,:,:]
|
||||
gt_mask = gt_mask.numpy() if isinstance(gt_mask, torch.Tensor) else gt_mask
|
||||
|
||||
# if np.random.rand() > 0.5:
|
||||
prompt_point = self.get_prompt_point(gt_mask)
|
||||
# else:
|
||||
prompt_box = self.get_prompt_box(gt_mask)
|
||||
|
||||
data['prompt_point'] = np.array(prompt_point).astype(np.float32)
|
||||
data['prompt_box'] = np.array(prompt_box).astype(np.float32)
|
||||
data['point_label'] = np.ones((1,)).astype(np.float32)
|
||||
return data
|
||||
|
||||
def get_subsequent_prompt_point(self, pred_mask, gt_mask):
|
||||
# return self.point_prompt_generator.select_random_subsequent_point_torch(pred_mask, gt_mask)
|
||||
# return self.point_prompt_generator.select_random_subsequent_point(pred_mask=pred_mask, gt_mask=gt_mask)
|
||||
coord_collect = []
|
||||
label_collect = []
|
||||
for i in range(pred_mask.shape[0]):
|
||||
coords, label = self.point_prompt_generator.select_random_subsequent_point(pred_mask[i][0], gt_mask[i][0])
|
||||
if label == -1:
|
||||
return None, None
|
||||
coord_collect.append(coords)
|
||||
label_collect.append(label)
|
||||
|
||||
coord_collect = np.stack(coord_collect, axis=0)
|
||||
label_collect = np.stack(label_collect, axis=0)
|
||||
return coord_collect, label_collect
|
||||
|
||||
def get_noisy_box_from_box(self, box):
|
||||
# Get noisy box from labeled box
|
||||
return self.box_prompt_generator.add_random_noise_to_bbox(box)
|
||||
|
||||
# def get_prompt_mask(self, )
|
||||
|
||||
# class ValidEngine(DataEngine):
|
||||
# def __init__(self, dataset=None, img_size=(1024,1024), is_train=False) -> None:
|
||||
# # assert dataset is not None
|
||||
# self.dataset = dataset
|
||||
# self.is_train = is_train
|
||||
# super().__init__(dataset=dataset, img_size=img_size)
|
||||
# self.expand_dataset_ratio = 1
|
||||
|
||||
# # def _get_dataset(self, dirpath):
|
||||
# # self.dataset = Dataset3D(dirpath=dirpath, is_train=self.is_train)
|
||||
|
||||
# def __len__(self):
|
||||
# return len(self.dataset)
|
||||
|
||||
# def _get_true_index(self, idx):
|
||||
# return idx
|
||||
|
||||
class DataManager:
|
||||
def __init__(self, img_size=None) -> None:
|
||||
self.point_prompt_generator = PointPromptGenerator(img_size)
|
||||
self.box_prompt_generator = BoxPromptGenerator(img_size)
|
||||
|
||||
def get_prompt_point(self, gt_mask):
|
||||
return self.point_prompt_generator.get_prompt_point(gt_mask)
|
||||
|
||||
def get_prompt_box(self, gt_mask):
|
||||
return self.box_prompt_generator.get_prompt_box(gt_mask)
|
||||
|
||||
def get_subsequent_prompt_point(self, pred_mask, gt_mask):
|
||||
# return self.point_prompt_generator.select_random_subsequent_point_torch(pred_mask, gt_mask)
|
||||
# return self.point_prompt_generator.select_random_subsequent_point(pred_mask=pred_mask, gt_mask=gt_mask)
|
||||
coord_collect = []
|
||||
label_collect = []
|
||||
for i in range(pred_mask.shape[0]):
|
||||
coords, label = self.point_prompt_generator.select_random_subsequent_point(pred_mask[i][0], gt_mask[i][0])
|
||||
if label == -1:
|
||||
return None, None
|
||||
coord_collect.append(coords)
|
||||
label_collect.append(label)
|
||||
|
||||
coord_collect = np.stack(coord_collect, axis=0)
|
||||
label_collect = np.stack(label_collect, axis=0)
|
||||
return coord_collect, label_collect
|
||||
|
||||
def get_noisy_box_from_box(self, box):
|
||||
# Get noisy box from labeled box
|
||||
return self.box_prompt_generator.add_random_noise_to_bbox(box)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = DataEngine()
|
||||
data = dataset.__getitem__(0)
|
||||
|
||||
import ipdb; ipdb.set_trace()
|
239
datasets/dataset3d.py
Normal file
239
datasets/dataset3d.py
Normal file
@ -0,0 +1,239 @@
|
||||
from torchvision import transforms
|
||||
from monai import transforms as monai_transforms
|
||||
import numpy as np
|
||||
import SimpleITK as sitk
|
||||
import torch
|
||||
from torch.utils.data import Dataset as dataset
|
||||
import torch.nn.functional as F
|
||||
import glob
|
||||
import os
|
||||
from einops import rearrange
|
||||
|
||||
from tutils.nn.data import read, itk_to_np
|
||||
|
||||
from tqdm import tqdm
|
||||
# from monai.transforms import SpatialPadd, CenterSpatialCropd, Resized, NormalizeIntensityd
|
||||
# from monai.transforms import RandAdjustContrastd, RandShiftIntensityd, Rotated, RandAffined
|
||||
# from datasets.common_2d_aug import RandomRotation, RandomResizedCrop, RandomHorizontalFlip, ColorJitter, ToTensor, Normalize
|
||||
from tutils import tfilename, tdir
|
||||
import random
|
||||
import time
|
||||
import cv2
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
def tfunctime(func):
|
||||
def run(*argv, **kargs):
|
||||
t1 = time.time()
|
||||
ret = func(*argv, **kargs)
|
||||
t2 = time.time()
|
||||
print(f"[Function {func.__name__}] Running time:{(t2-t1):.6f}s")
|
||||
return ret
|
||||
return run
|
||||
|
||||
# DEFAULT_PATH='/home1/quanquan/datasets/KiTS/'
|
||||
DEFAULT_PATH="/home1/quanquan/datasets/BCV-Abdomen/Training/"
|
||||
|
||||
LABEL_INDICES={
|
||||
"t2sag": ["bg","kidney", "label 2", "label 3", "rectum", "tumor", "other"],
|
||||
}
|
||||
|
||||
# CACHE_DISK_DIR="/home1/quanquan/code/projects/medical-guangdong/cache/data2d_3/"
|
||||
CACHE_DISK_DIR=None
|
||||
# DEFAULT_CONFIG={
|
||||
# "pad": (512,512),
|
||||
# "crop": (384,384),
|
||||
# "resize": (512,512),
|
||||
# }
|
||||
|
||||
DATASET_CONFIG={
|
||||
'split': 'train',
|
||||
'data_root_path':'/quanquan/datasets/',
|
||||
'dataset_list': ['sam', "their", "ours"],
|
||||
'data_txt_path':'./datasets/dataset_list/',
|
||||
}
|
||||
|
||||
DATASET_METAINFO={
|
||||
"WORD": {0:"Background", 1:"Liver", 2:"Spleen", 3:"Left Kidney", 4:"Right Kidney", 5:"Stomach", 6:"Gallbladder", 7:"Esophagus", 8:"Pancreas", 9:"Duodenum", 10:"Colon", 11:"Intestine", 12:"Adrenal", 13:"Rectum", 14:"Bladder", 15:"left head of femur", 16:"right head of femur"}
|
||||
}
|
||||
|
||||
|
||||
class Dataset3D(dataset):
|
||||
def __init__(self, config=DATASET_CONFIG, is_train=True, split='train', getting_multi_mask=False, use_cache=False) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.is_train = is_train
|
||||
self.split = split
|
||||
self.getting_multi_mask = getting_multi_mask
|
||||
self.use_cache = use_cache
|
||||
self.img_names = self.prepare_cached_datalist() if use_cache else self.prepare_datalist()
|
||||
# self.img_names = self.prepare_datalist()
|
||||
self.prepare_transforms()
|
||||
|
||||
def prepare_cached_datalist(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def prepare_transforms(self):
|
||||
self.transform = monai_transforms.Compose([
|
||||
# transforms.Resized(keys=['img', 'label'], spatial_size=(3,512,512)),
|
||||
# transforms.RandSpatialCropd(keys=["img", 'label'], roi_size=(3,448,448)),
|
||||
# transforms.RandAffined(keys=['img', 'label'], prob=0.5, shear_range=(0.2,0.2)),
|
||||
# transforms.RandCropByPosNegLabeld(keys=['img', 'label'], spatial_size=(3,448,448), label_key='label', neg=0),
|
||||
# transforms.RandSmoothFieldAdjustContrastd(keys=['img', 'label'], )
|
||||
monai_transforms.RandAdjustContrastd(keys=['img'], ),
|
||||
# transforms.RandShiftIntensityd(keys=['img'], prob=0.8, offsets=(0, 20)),
|
||||
monai_transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
|
||||
])
|
||||
self.test_transform = transforms.Compose([
|
||||
monai_transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
|
||||
])
|
||||
|
||||
def _get_image(self, index):
|
||||
name = self.img_names[index]['img_path']
|
||||
if not os.path.exists(name):
|
||||
print("Path not exists!", name)
|
||||
return self._get_image(index+1%len(self))
|
||||
img_itk = read(self.img_names[index]['img_path'])
|
||||
img_ori = itk_to_np(img_itk)
|
||||
img = np.clip(img_ori, -200, 400).astype(np.float32)
|
||||
img = (img - img.min()) / img.max() * 255
|
||||
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
|
||||
return {"img":img, "name":name.replace(self.config['data_root_path'], ""), "label":label_ori}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
||||
# @tfunctime
|
||||
def prepare_datalist(self):
|
||||
config = self.config
|
||||
data_paths = []
|
||||
for item in config['dataset_list']:
|
||||
print("Load datalist from ", item)
|
||||
for line in open(config["data_txt_path"]+ item + f"_{self.split}.txt"):
|
||||
name = line.strip().split()[1].split('.')[0]
|
||||
img_path = config['data_root_path'] + line.strip().split()[0]
|
||||
label_path = config['data_root_path'] + line.strip().split()[1]
|
||||
data_paths.append({'img_path': img_path, 'label_path': label_path, 'name': name})
|
||||
print('train len {}'.format(len(data_paths)))
|
||||
return data_paths
|
||||
|
||||
# @tfunctime
|
||||
def _get_data(self, index, debug=False):
|
||||
# LABEL_INDICES
|
||||
name = self.img_names[index]['img_path']
|
||||
img_itk = read(self.img_names[index]['img_path'])
|
||||
img_ori = itk_to_np(img_itk)
|
||||
# spacing = img_itk.GetSpacing()
|
||||
scan_orientation = np.argmin(img_ori.shape)
|
||||
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
|
||||
|
||||
if min(img_ori.shape) * 2 < max(img_ori.shape):
|
||||
orientation = scan_orientation
|
||||
else:
|
||||
orientation = np.random.randint(3)
|
||||
slice_idx = np.random.randint(2, img_ori.shape[orientation]-2)
|
||||
if orientation == 0:
|
||||
s = img_ori[slice_idx-1:slice_idx+2, :,:]
|
||||
lb = label_ori[slice_idx-1:slice_idx+2, :,:]
|
||||
# spacing = (spacing[1], spacing[2])
|
||||
if orientation == 1:
|
||||
s = img_ori[:,slice_idx-1:slice_idx+2,:]
|
||||
s = rearrange(s, "h c w -> c h w")
|
||||
lb = label_ori[:,slice_idx-1:slice_idx+2,:]
|
||||
lb = rearrange(lb, "h c w -> c h w")
|
||||
# spacing = (spacing[0], spacing[2])
|
||||
if orientation == 2:
|
||||
s = img_ori[:,:,slice_idx-1:slice_idx+2]
|
||||
s = rearrange(s, "h w c -> c h w")
|
||||
lb = label_ori[:,:,slice_idx-1:slice_idx+2]
|
||||
lb = rearrange(lb, "h w c -> c h w")
|
||||
# spacing = (spacing[0], spacing[1])
|
||||
assert s.shape[0] == 3
|
||||
|
||||
if np.float32(lb[1,:,:]>0).sum() <= 200:
|
||||
return self._get_data((index+1)%len(self))
|
||||
# Choose one label
|
||||
label_num = int(lb.max())
|
||||
is_good_mask = []
|
||||
for label_idx in range(1,label_num+1):
|
||||
one_lb = np.float32(lb==label_idx)
|
||||
is_good_mask.append(one_lb.sum()>=50)
|
||||
label_idx = np.random.choice(range(1,label_num+1), p=np.array(is_good_mask)/np.sum(is_good_mask))
|
||||
lb = np.float32(lb==label_idx)
|
||||
return s, lb, name, label_idx
|
||||
|
||||
# @tfunctime
|
||||
def _get_cached_data(self, index):
|
||||
name = self.img_names[index]
|
||||
img = cv2.imread(name)
|
||||
|
||||
compressed = np.load(name.replace("image/image_", "label/label_").replace(".jpg", ".npz"))
|
||||
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
|
||||
label_ori = csr.toarray()
|
||||
label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024)
|
||||
|
||||
label_idx = np.random.randint(0, label_ori.shape[0])
|
||||
label_ori = label_ori[label_idx]
|
||||
return rearrange(img, "h w c -> c h w"), label_ori, name, -1
|
||||
|
||||
def to_RGB(self, img):
|
||||
# transform images to RGB style
|
||||
img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(int)
|
||||
return img
|
||||
|
||||
# @tfunctime
|
||||
def __getitem__(self, index, debug=False):
|
||||
img_ori, label_ori, name, label_idx = self._get_data(index)
|
||||
img_ori = np.clip(img_ori, -200,400)
|
||||
img_rgb = self.to_RGB(img_ori)
|
||||
|
||||
assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}"
|
||||
bundle_ori = {"img":torch.Tensor(img_rgb).unsqueeze(0), "label":torch.Tensor(label_ori).unsqueeze(0)}
|
||||
# import ipdb; ipdb.set_trace()
|
||||
if self.is_train:
|
||||
# bundle = self.transform(bundle_ori)[0] # use with transforms.RandCropByPosNegLabeld
|
||||
bundle = self.transform(bundle_ori)
|
||||
else:
|
||||
bundle = self.test_transform(bundle_ori)
|
||||
|
||||
if not self.use_cache:
|
||||
bundle['label'] = (bundle['label']>0.5).float()
|
||||
vector = np.ones(3)
|
||||
if debug:
|
||||
ret_dict = {
|
||||
"name": name,
|
||||
"img": bundle['img'],
|
||||
"label": bundle['label'],
|
||||
"img_ori":img_ori,
|
||||
"label_ori":label_ori,
|
||||
"label_idx": label_idx,
|
||||
"indicators": vector,
|
||||
# "label_name":
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
ret_dict = {
|
||||
"name": name,
|
||||
"img": bundle['img'][0].float(),
|
||||
"label": bundle['label'][0].float(),
|
||||
"indicators": vector,
|
||||
}
|
||||
if bundle['label'][0][1,:,:].sum() <= 0:
|
||||
return self.__getitem__(index+1 % len(self))
|
||||
return ret_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from tutils.new.manager import ConfigManager
|
||||
config = ConfigManager()
|
||||
config.add_config("configs/vit_b_103.yaml")
|
||||
dataset = Dataset3D(config=config['dataset']) # , use_cache=True
|
||||
data = dataset.__getitem__(0)
|
||||
|
||||
import ipdb; ipdb.set_trace()
|
||||
from torch.utils.data import DataLoader
|
||||
loader = DataLoader(dataset, batch_size=8)
|
||||
for batch in loader:
|
||||
print(batch['img'].shape, batch['label'].shape)
|
||||
# print(data['label'].max())
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
178
datasets/dataset3d_2dmask.py
Normal file
178
datasets/dataset3d_2dmask.py
Normal file
@ -0,0 +1,178 @@
|
||||
# from torchvision import transforms
|
||||
from monai import transforms
|
||||
import numpy as np
|
||||
import SimpleITK as sitk
|
||||
import torch
|
||||
from torch.utils.data import Dataset as dataset
|
||||
import glob
|
||||
import os
|
||||
from einops import rearrange, repeat
|
||||
from tutils.nn.data.tsitk import read
|
||||
from tqdm import tqdm
|
||||
# from monai.transforms import SpatialPadd, CenterSpatialCropd, Resized, NormalizeIntensityd
|
||||
# from monai.transforms import RandAdjustContrastd, RandShiftIntensityd, Rotated, RandAffined
|
||||
# from datasets.common_2d_aug import RandomRotation, RandomResizedCrop, RandomHorizontalFlip, ColorJitter, ToTensor, Normalize
|
||||
from tutils import tfilename, tdir
|
||||
import random
|
||||
from tutils.nn.data import itk_to_np
|
||||
from scipy.sparse import csr_matrix
|
||||
import cv2
|
||||
|
||||
|
||||
DEFAULT_PATH="/quanquan/datasets/08_AbdomenCT-1K/"
|
||||
|
||||
|
||||
class Dataset2D(dataset):
|
||||
def __init__(self, dirpath=None, is_train=True, getting_multi_mask=False) -> None:
|
||||
super().__init__()
|
||||
self.dirpath = dirpath
|
||||
self.is_train = is_train
|
||||
self.getting_multi_mask = getting_multi_mask
|
||||
self.img_names = self.prepare_datalist()
|
||||
self.prepare_transforms()
|
||||
self.weights_dict = {"gt":2, "sam_auto_seg":2, "prompt_point_from_superpixel":1, "prompt_box_from_superpixel":1, "superpixel":0}
|
||||
|
||||
def prepare_transforms(self):
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
|
||||
# transforms.RandSpatialCropd(keys=["img"], roi_size=(448,448,1)),
|
||||
transforms.RandAffined(keys=['img', 'label'], prob=0.5, shear_range=(0.2,0.2)),
|
||||
transforms.RandCropByPosNegLabeld(keys=['img', 'label'], spatial_size=(3,960,960), label_key='label', neg=0),
|
||||
# transforms.RandSmoothFieldAdjustContrastd(keys=['img', 'label'], )
|
||||
transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
|
||||
transforms.RandAdjustContrastd(keys=['img'], ),
|
||||
transforms.RandShiftIntensityd(keys=['img'], prob=0.8, offsets=(-5, 5)),
|
||||
])
|
||||
self.test_transform = transforms.Compose([
|
||||
transforms.Resized(keys=['img'], spatial_size=(3,1024,1024)),
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
||||
def to_RGB(self, img):
|
||||
# transform images to RGB style
|
||||
img = ((img - img.min()) / img.max() * 255).astype(int)
|
||||
return img
|
||||
|
||||
def prepare_datalist(self):
|
||||
dirpath_img = os.path.join(self.dirpath, 'preprocessed', 'cache_2d_various_pseudo_masks')
|
||||
names = glob.glob(os.path.join(dirpath_img, "*_mask.npz"))
|
||||
names = [os.path.split(name)[-1].replace("_mask.npz", "") for name in names]
|
||||
names.sort()
|
||||
# names = names[:15000]
|
||||
print(f"[Dataset2d] Load {len(names)} paths.")
|
||||
assert len(names) > 0, f"{__file__} Gotdirpath: {self.dirpath}"
|
||||
return names
|
||||
|
||||
|
||||
def _get_data(self, index, debug=False, iternum=0):
|
||||
img_names = self.img_names
|
||||
img_info = os.path.split(img_names[index])[-1].split('_s')
|
||||
filename, slice_idx = img_info[0], int(img_info[-1][:4])
|
||||
mask_loc = np.random.randint(0,3)
|
||||
if mask_loc == 0:
|
||||
slices_indices = [slice_idx, slice_idx+1, slice_idx+2]
|
||||
elif mask_loc == 1:
|
||||
slices_indices = [slice_idx-1, slice_idx, slice_idx+1]
|
||||
elif mask_loc == 2:
|
||||
slices_indices = [slice_idx-2, slice_idx-1, slice_idx]
|
||||
|
||||
# Load .npy data
|
||||
filenames = [os.path.join(self.dirpath, "preprocessed", "cache_jpg", f"{filename}_s{i:04}_img.jpg") for i in slices_indices]
|
||||
for name in filenames:
|
||||
if not os.path.exists(name):
|
||||
return self._get_data(index+1 % len(self))
|
||||
|
||||
imgs = [cv2.imread(name, cv2.IMREAD_GRAYSCALE) for name in filenames]
|
||||
img_rgb = np.stack(imgs, axis=0)
|
||||
|
||||
# Load RGB data
|
||||
|
||||
compressed = np.load(os.path.join(self.dirpath, "preprocessed", "cache_2d_various_pseudo_masks", img_names[index]+"_mask.npz"))
|
||||
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
|
||||
label_ori = csr.toarray()
|
||||
label_ori = rearrange(label_ori, "c (h w) -> c h w", h=1024, w=1024)
|
||||
metadata = np.load(os.path.join(self.dirpath, "preprocessed", "cache_2d_various_pseudo_masks", img_names[index]+"_metadata.npy"), allow_pickle=True)
|
||||
|
||||
label_prob = np.array([self.weights_dict[item['source']] for item in metadata]).astype(float)
|
||||
label_prob = label_prob / label_prob.sum()
|
||||
|
||||
label_idx = np.random.choice(a=np.arange(len(metadata)), p=label_prob)
|
||||
label_ori = label_ori[label_idx]
|
||||
metadata = metadata[label_idx]
|
||||
assert metadata['source'] != 'superpixel'
|
||||
|
||||
assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}"
|
||||
bundle_ori = {"img":torch.Tensor(rearrange(img_rgb, "c h w -> 1 c h w")), "label":torch.Tensor(repeat(label_ori, "h w -> 1 3 h w"))}
|
||||
# import ipdb; ipdb.set_trace()
|
||||
if self.is_train:
|
||||
bundle = self.transform(bundle_ori)[0]
|
||||
else:
|
||||
bundle = self.test_transform(bundle_ori)
|
||||
|
||||
bundle['label'] = (bundle['label']>0.5).float()
|
||||
if bundle['label'][0].sum() < 100:
|
||||
return self._get_data((index+1)%len(self), iternum=iternum+1)
|
||||
|
||||
vector = np.zeros(3)
|
||||
vector[mask_loc] = 1
|
||||
|
||||
if debug:
|
||||
ret_dict = {
|
||||
"name": img_names[index],
|
||||
"img": bundle['img'][0],
|
||||
"label": bundle['label'][0],
|
||||
"img_ori":img_rgb,
|
||||
"label_ori":label_ori,
|
||||
"weight": self.weights_dict[metadata['source']],
|
||||
"iternum": iternum,
|
||||
"mask_loc": mask_loc,
|
||||
"indicators": vector,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
ret_dict = {
|
||||
"name": img_names[index],
|
||||
"img": bundle['img'][0],
|
||||
"label": bundle['label'][0],
|
||||
"mask_loc": mask_loc,
|
||||
"indicators": vector,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self._get_data(index)
|
||||
|
||||
|
||||
class Testset2d(Dataset2D):
|
||||
def __init__(self, dirpath=None, is_train=False, getting_multi_mask=False) -> None:
|
||||
super().__init__(dirpath, is_train, getting_multi_mask)
|
||||
self.test_names = self.prepare_datalist()
|
||||
|
||||
def prepare_datalist(self):
|
||||
dirpath_img = os.path.join(self.dirpath, 'cache_2d_various_pseudo_masks')
|
||||
names = glob.glob(os.path.join(dirpath_img, "*_mask.npz"))
|
||||
names = [os.path.split(name)[-1].replace("_mask.npz", "") for name in names]
|
||||
names.sort()
|
||||
names = names[15000:]
|
||||
print(f"[Dataset2d] Load {len(names)} paths.")
|
||||
assert len(names) > 0, f"{__file__} Gotdirpath: {self.dirpath}"
|
||||
return names
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
dataset = Dataset2D(dirpath=DEFAULT_PATH)
|
||||
loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)
|
||||
iternums = 0
|
||||
for i, data in enumerate(loader):
|
||||
# iternums += data['iternum'].item()
|
||||
print(i, iternums / (i+1), data['img'].shape, data['label'].shape)
|
||||
assert data['label'].sum() >= 100, f"{__file__} Got{data['label'].sum()}"
|
||||
assert torch.Tensor(data['label']==1).sum() >= 100, f"{__file__} Got {torch.Tensor(data['label']==1).sum().sum()}"
|
||||
|
||||
import ipdb; ipdb.set_trace()
|
||||
|
74
datasets/dataset_merged.py
Normal file
74
datasets/dataset_merged.py
Normal file
@ -0,0 +1,74 @@
|
||||
# from torchvision import transforms
|
||||
from monai import transforms
|
||||
import numpy as np
|
||||
import SimpleITK as sitk
|
||||
import torch
|
||||
from torch.utils.data import Dataset as dataset
|
||||
import torch.nn.functional as F
|
||||
import glob
|
||||
import os
|
||||
from einops import rearrange
|
||||
|
||||
from tutils.nn.data import read, itk_to_np
|
||||
|
||||
from tqdm import tqdm
|
||||
from tutils import tfilename, tdir
|
||||
import random
|
||||
# from .dataset2d import Dataset2D
|
||||
from .dataset3d_2dmask import Dataset2D
|
||||
from .dataset3d import Dataset3D
|
||||
|
||||
|
||||
class DatasetMerged(dataset):
|
||||
def __init__(self, config=None, is_train=True, getting_multi_mask=False) -> None:
|
||||
super().__init__()
|
||||
self.dataset2d = Dataset2D(dirpath="/quanquan/datasets/08_AbdomenCT-1K/", is_train=True)
|
||||
self.dataset3d = Dataset3D(config=config, is_train=True)
|
||||
self.len_2d = len(self.dataset2d)
|
||||
self.len_3d = len(self.dataset3d)
|
||||
|
||||
def __getitem__(self, index, debug=False):
|
||||
index = index % len(self)
|
||||
# print("DEBUG! is_2d:", index < self.len_2d)
|
||||
if index < self.len_2d:
|
||||
return self.dataset2d.__getitem__(index)
|
||||
else:
|
||||
index = (index - self.len_2d) % self.len_3d
|
||||
return self.dataset3d.__getitem__(index)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset2d) + len(self.dataset3d) * 200
|
||||
|
||||
|
||||
|
||||
class TestsetMerged(dataset):
|
||||
def __init__(self, config=None, is_train=False) -> None:
|
||||
super().__init__()
|
||||
self.dataset2d = Dataset2D(dirpath="/quanquan/datasets/08_AbdomenCT-1K/preprocessed/", is_train=False)
|
||||
self.dataset3d = Dataset3D(config=config, is_train=False, split='val')
|
||||
self.len_2d = len(self.dataset2d)
|
||||
self.len_3d = len(self.dataset3d)
|
||||
|
||||
def __getitem__(self, index, debug=False):
|
||||
index = index % len(self)
|
||||
if index < self.len_2d:
|
||||
return self.dataset2d.__getitem__(index)
|
||||
else:
|
||||
index = (index - self.len_2d) % self.len_3d
|
||||
return self.dataset3d.__getitem__(index)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset2d) + len(self.dataset3d) * 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from tutils import timer
|
||||
from tutils.new.manager import trans_args, trans_init, ConfigManager
|
||||
config = ConfigManager()
|
||||
config.add_basic_config()
|
||||
config.add_config("configs/vit_b.yaml")
|
||||
dataset = DatasetMerged(config['dataset'])
|
||||
tt = timer()
|
||||
for i in range(20000,len(dataset)):
|
||||
data = dataset.__getitem__(i)
|
||||
print("time: ", tt())
|
435
datasets/generate_txt.py
Normal file
435
datasets/generate_txt.py
Normal file
@ -0,0 +1,435 @@
|
||||
import os
|
||||
import glob
|
||||
import shutil
|
||||
from tutils import tfilename
|
||||
|
||||
|
||||
HOME_PATH="/quanquan/datasets/"
|
||||
|
||||
def check_existing(img_path, label_path):
|
||||
if os.path.exists(img_path) and os.path.exists(label_path):
|
||||
return True
|
||||
else:
|
||||
if not os.path.exists(img_path):
|
||||
print("IMAGE Not exist: ", img_path)
|
||||
if not os.path.exists(label_path):
|
||||
print("LABEL Not exist: ", label_path)
|
||||
return False
|
||||
|
||||
def get_availabel_files(names, label_names):
|
||||
llist = [[n,n2] for n,n2 in zip(names, label_names) if check_existing(n, n2)]
|
||||
names = [n[0] for n in llist]
|
||||
label_names = [n[1] for n in llist]
|
||||
return names, label_names
|
||||
|
||||
def write_txt(img_paths, label_paths, meta_info, split, writing_mode='a+'):
|
||||
dataset_id = meta_info['dataset_id']
|
||||
assert split in ['train', 'val', 'test'], f" split in ['train', 'val', 'test'] , but Got {split}"
|
||||
save_path = meta_info["save_txt_path"].replace("_train.txt", f"_{split}.txt")
|
||||
|
||||
count = 0
|
||||
with open(save_path, writing_mode) as f:
|
||||
for p1, p2 in zip(img_paths, label_paths):
|
||||
p1 = p1.replace(meta_info['home_path'], "")
|
||||
p2 = p2.replace(meta_info['home_path'], "")
|
||||
line = f"{p1}\t{p2}\n"
|
||||
# print(line, end=" ")
|
||||
f.write(line)
|
||||
count += 1
|
||||
if count <= 0:
|
||||
raise ValueError(f"ID: {meta_info['dataset_id']}, \tTask: {meta_info['dataset_name']}\t, {count} files are writen.")
|
||||
print(f"ID: {meta_info['dataset_id']}, \tTask: {meta_info['dataset_name']}\t, {count} files are writen.\t Writing Over! write into ", save_path)
|
||||
|
||||
|
||||
def organize_in_nnunet_style(meta_info):
|
||||
dirpath = os.path.join(meta_info['home_path'], meta_info['dirpath'])
|
||||
if os.path.exists(os.path.join(dirpath, "imagesTr")) and os.path.exists(os.path.join(dirpath, "labelsTr")):
|
||||
img_paths = glob.glob(os.path.join(dirpath, "imagesTr", "*.nii.gz"))
|
||||
img_paths.sort()
|
||||
label_paths = [p.replace("imagesTr", "labelsTr")[:-12]+".nii.gz" for p in img_paths]
|
||||
img_paths, label_paths = get_availabel_files(img_paths, label_paths)
|
||||
write_txt(img_paths, label_paths, meta_info=meta_info, split='train', writing_mode="a+")
|
||||
|
||||
if os.path.exists(os.path.join(dirpath, "imagesVa")) and os.path.exists(os.path.join(dirpath, "labelsVa")):
|
||||
img_paths = glob.glob(os.path.join(dirpath, "imagesVa", "*.nii.gz"))
|
||||
img_paths.sort()
|
||||
label_paths = [p.replace("imagesVa", "labelsVa")[:-12]+".nii.gz" for p in img_paths]
|
||||
img_paths, label_paths = get_availabel_files(img_paths, label_paths)
|
||||
write_txt(img_paths, label_paths, meta_info=meta_info, split='val', writing_mode="a+")
|
||||
|
||||
if os.path.exists(os.path.join(dirpath, "imagesTs")) and os.path.exists(os.path.join(dirpath, "labelsTs")):
|
||||
img_paths = glob.glob(os.path.join(dirpath, "imagesTs", "*.nii.gz"))
|
||||
img_paths.sort()
|
||||
label_paths = [p.replace("imagesTs", "labelsTs")[:-12]+".nii.gz" for p in img_paths]
|
||||
img_paths, label_paths = get_availabel_files(img_paths, label_paths)
|
||||
write_txt(img_paths, label_paths, meta_info=meta_info, split='test', writing_mode="a+")
|
||||
|
||||
|
||||
def organize_in_style2(meta_info):
|
||||
dirpath = os.path.join(meta_info['home_path'], meta_info['dirpath'])
|
||||
if os.path.exists(os.path.join(dirpath, "imagesTr")) and os.path.exists(os.path.join(dirpath, "labelsTr")):
|
||||
img_paths = glob.glob(os.path.join(dirpath, "imagesTr", "*.nii.gz"))
|
||||
img_paths.sort()
|
||||
label_paths = [p.replace("imagesTr", "labelsTr") for p in img_paths]
|
||||
img_paths, label_paths = get_availabel_files(img_paths, label_paths)
|
||||
write_txt(img_paths, label_paths, meta_info=meta_info, split='train', writing_mode="a+")
|
||||
|
||||
if os.path.exists(os.path.join(dirpath, "imagesVa")) and os.path.exists(os.path.join(dirpath, "labelsVa")):
|
||||
img_paths = glob.glob(os.path.join(dirpath, "imagesVa", "*.nii.gz"))
|
||||
img_paths.sort()
|
||||
label_paths = [p.replace("imagesVa", "labelsVa") for p in img_paths]
|
||||
img_paths, label_paths = get_availabel_files(img_paths, label_paths)
|
||||
write_txt(img_paths, label_paths, meta_info=meta_info, split='val', writing_mode="a+")
|
||||
|
||||
if os.path.exists(os.path.join(dirpath, "imagesTs")) and os.path.exists(os.path.join(dirpath, "labelsTs")):
|
||||
img_paths = glob.glob(os.path.join(dirpath, "imagesTs", "*.nii.gz"))
|
||||
img_paths.sort()
|
||||
label_paths = [p.replace("imagesTs", "labelsTs") for p in img_paths]
|
||||
img_paths, label_paths = get_availabel_files(img_paths, label_paths)
|
||||
write_txt(img_paths, label_paths, meta_info=meta_info, split='test', writing_mode="a+")
|
||||
|
||||
|
||||
def organize_by_names(names_in, label_names_in, meta_info):
|
||||
assert len(names_in) > 0, f"Meta info: {meta_info}"
|
||||
names, label_names = get_availabel_files(names_in, label_names_in)
|
||||
assert len(names) > 0, f"Meta info: {meta_info}, \n {names_in[:2]} \n {label_names_in[:2]}"
|
||||
assert len(label_names) > 0, f"Meta info: {meta_info}, \n {names_in[:2]} \n {label_names_in[:2]}"
|
||||
|
||||
# print("debug files", len(names))
|
||||
if len(names) > 10:
|
||||
num_valid = min(int(len(names) // 10), 10)
|
||||
# print("num valid", num_valid)
|
||||
train_names = names[:-num_valid*2]
|
||||
valid_names = names[-num_valid*2:-num_valid]
|
||||
test_names = names[-num_valid:]
|
||||
|
||||
train_labels = label_names[:-num_valid*2]
|
||||
valid_labels = label_names[-num_valid*2:-num_valid]
|
||||
test_labels = label_names[-num_valid:]
|
||||
|
||||
write_txt(train_names, train_labels, meta_info=meta_info, split="train")
|
||||
write_txt(valid_names, valid_labels, meta_info=meta_info, split="val")
|
||||
write_txt(test_names, test_labels, meta_info=meta_info, split="test")
|
||||
else:
|
||||
write_txt(names, label_names, meta_info=meta_info, split="train")
|
||||
|
||||
|
||||
def clear_files(train_path):
|
||||
if os.path.exists(train_path):
|
||||
parent, name = os.path.split(train_path)
|
||||
shutil.move(train_path, tfilename(parent, "misc", name))
|
||||
|
||||
val_path = train_path.replace("_train.txt", "_val.txt")
|
||||
if os.path.exists(val_path):
|
||||
parent, name = os.path.split(val_path)
|
||||
shutil.move(val_path, os.path.join(parent, "misc", name))
|
||||
|
||||
test_path = train_path.replace("_train.txt", "_test.txt")
|
||||
if os.path.exists(test_path):
|
||||
parent, name = os.path.split(test_path)
|
||||
shutil.move(test_path, os.path.join(parent, "misc", name))
|
||||
print("Files cleared!")
|
||||
|
||||
# from tutils.nn.data import read
|
||||
# def convert_to_nii(paths):
|
||||
|
||||
|
||||
###################################################################################
|
||||
|
||||
###################################################################################
|
||||
|
||||
def get_BCV_Abdomen(save_path=None):
|
||||
meta_info = {
|
||||
"dataset_name": "BTCV",
|
||||
"dataset_id": "01",
|
||||
"modality": "CT",
|
||||
"home_path": HOME_PATH,
|
||||
"dirpath": "01_BCV-Abdomen/Training/",
|
||||
"save_txt_path": save_path,
|
||||
}
|
||||
names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "img/*.nii.gz"))
|
||||
names.sort()
|
||||
label_names = [p.replace("img", "label") for p in names]
|
||||
organize_by_names(names, label_names, meta_info=meta_info)
|
||||
|
||||
def get_AbdomenCT_1K(save_path):
|
||||
meta_info = {
|
||||
"dataset_name": "AbdomenCT-1K",
|
||||
"dataset_id": "08",
|
||||
"modality": "CT",
|
||||
"home_path": HOME_PATH,
|
||||
"dirpath": "08_AbdomenCT-1K",
|
||||
"save_txt_path": save_path,
|
||||
}
|
||||
# print(names)
|
||||
organize_in_nnunet_style(meta_info=meta_info)
|
||||
|
||||
def get_AMOS(save_path):
|
||||
meta_info = {
|
||||
"dataset_name": "AMOS",
|
||||
"dataset_id": "09",
|
||||
"modality": "CT",
|
||||
"home_path": HOME_PATH,
|
||||
"dirpath": "09_AMOS",
|
||||
"save_txt_path": save_path,
|
||||
}
|
||||
organize_in_style2(meta_info)
|
||||
|
||||
def get_MSD(save_path):
|
||||
meta_info = {
|
||||
"dataset_name": "MSD", # Decathlon
|
||||
"dataset_id": "10",
|
||||
"modality": "CT",
|
||||
"home_path": HOME_PATH,
|
||||
"parent_dirpath": "10_Decathlon",
|
||||
"dirpath": "",
|
||||
"save_txt_path": save_path,
|
||||
}
|
||||
subtasks = ["Task06_Lung", "Task08_HepaticVessel", "Task09_Spleen", "Task10_Colon"]
|
||||
for task in subtasks:
|
||||
# print("Processing ", task)
|
||||
meta_info_subtask = {
|
||||
"dataset_name":task,
|
||||
"dataset_id": f"{meta_info['dataset_id']}_{task[4:6]}",
|
||||
"home_path":HOME_PATH,
|
||||
"dirpath": f"{meta_info['parent_dirpath']}/{task}",
|
||||
"save_txt_path": save_path,
|
||||
}
|
||||
# print(meta_info_subtask)
|
||||
organize_in_style2(meta_info=meta_info_subtask)
|
||||
|
||||
|
||||
def get_MSD_MRI(save_path):
|
||||
meta_info = {
|
||||
"dataset_name": "MSD", # Decathlon
|
||||
"dataset_id": "10",
|
||||
"modality": "MRI",
|
||||
"home_path": HOME_PATH,
|
||||
"parent_dirpath": "10_Decathlon",
|
||||
"dirpath": "",
|
||||
"save_txt_path": save_path,
|
||||
}
|
||||
subtasks = ["Task02_Heart", "Task05_Prostate"]
|
||||
for task in subtasks:
|
||||
# print("Processing ", task)
|
||||
meta_info_subtask = {
|
||||
"dataset_name":task,
|
||||
"dataset_id": f"{meta_info['dataset_id']}_{task[4:6]}",
|
||||
"home_path":HOME_PATH,
|
||||
"dirpath": f"{meta_info['parent_dirpath']}/{task}",
|
||||
"save_txt_path": save_path,
|
||||
}
|
||||
# print(meta_info_subtask)
|
||||
organize_in_style2(meta_info=meta_info_subtask)
|
||||
|
||||
|
||||
def get_ASOCA(save_path):
|
||||
meta_info = {
|
||||
"dataset_name": "ASOCA",
|
||||
"dataset_id": "51",
|
||||
"modality": "CT",
|
||||
"home_path": HOME_PATH,
|
||||
"dirpath": "51_ASOCA",
|
||||
"save_txt_path": save_path,
|
||||
}
|
||||
names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "image/*.nii.gz"))
|
||||
names.sort()
|
||||
label_names = [p.replace("/image", "/label") for p in names]
|
||||
# print(os.path.join(meta_info['home_path'], meta_info['dirpath'], "image/*.nii.gz")
|
||||
# print("debug ,", names)
|
||||
organize_by_names(names, label_names, meta_info=meta_info)
|
||||
|
||||
def get_BCV_Cervix(save_path):
|
||||
meta_info = {
|
||||
"dataset_name": "BCV-Cervix",
|
||||
"dataset_id": "52",
|
||||
"modality": "CT",
|
||||
"home_path": HOME_PATH,
|
||||
"dirpath": "52_BCV-Cervix/Training/",
|
||||
"save_txt_path": save_path,
|
||||
}
|
||||
names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "img/*.nii.gz"))
|
||||
names.sort()
|
||||
label_names = [p.replace("/img/", "/label/").replace("-Image", "-Mask") for p in names]
|
||||
organize_by_names(names, label_names, meta_info=meta_info)
|
||||
|
||||
def get_NIHPancrease(save_path):
|
||||
meta_info = {
|
||||
"dataset_name": "NIHPancrease",
|
||||
"dataset_id": "53",
|
||||
"modality": "CT",
|
||||
"home_path": HOME_PATH,
|
||||
"dirpath": "53_NIHPancrease",
|
||||
"save_txt_path": save_path,
|
||||
}
|
||||
names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "data/*.nii.gz") )
|
||||
names.sort()
|
||||
label_names = [p.replace("/data/PANCREAS_", "/label/label") for p in names]
|
||||
organize_by_names(names, label_names, meta_info=meta_info)
|
||||
|
||||
def get_CTPelvic(save_path):
|
||||
meta_info = {
|
||||
"dataset_name": "CTPelvic1K",
|
||||
"dataset_id": "54",
|
||||
"modality": "CT",
|
||||
"home_path": HOME_PATH,
|
||||
"dirpath": "54_CTPelvic1K",
|
||||
"save_txt_path": save_path,
|
||||
}
|
||||
names = []
|
||||
names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset1_data/*.nii.gz"))
|
||||
names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset2_data/*.nii.gz"))
|
||||
names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset3_data/*.nii.gz"))
|
||||
names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset4_data/*.nii.gz"))
|
||||
names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset5_data/*.nii.gz"))
|
||||
names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset7_data/*.nii.gz"))
|
||||
names.sort()
|
||||
# xx_data.nii.gz xx_mask_4label.nii.gz
|
||||
label_names = [p.replace("_data/", "_mask/").replace("_data.nii.gz", "_mask_4label.nii.gz") for p in names]
|
||||
organize_by_names(names, label_names, meta_info=meta_info)
|
||||
|
||||
def get_FLARE(save_path):
|
||||
meta_info = {
|
||||
"dataset_name": "FLARE",
|
||||
"dataset_id": "55",
|
||||
"modality": "CT",
|
||||
"home_path": HOME_PATH,
|
||||
"dirpath": "55_FLARE22Train",
|
||||
"save_txt_path": save_path,
|
||||
"class": ['liver', 'right kidney', 'spleen', 'pancrease', 'aorta','postcava','right adrenal gland','left darenal gland','gallbladder','esophagus','stomach','duodenum','left kidney'],
|
||||
}
|
||||
organize_in_nnunet_style(meta_info=meta_info)
|
||||
|
||||
# def get_HAN(save_path):
|
||||
# meta_info = {
|
||||
# "dataset_name": "Head-and-neck",
|
||||
# "dataset_id": "56",
|
||||
# "modality": "CT",
|
||||
# "home_path": HOME_PATH,
|
||||
# "dirpath": "56_Head-and-Neck-challenge",
|
||||
# "save_txt_path": save_path,
|
||||
# }
|
||||
# names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "data/*.nii.gz"))
|
||||
# names.sort()
|
||||
# label_names = [p.replace("/data/", "/label/") for p in names]
|
||||
# organize_by_names(names, label_names, meta_info=meta_info)
|
||||
|
||||
# def get_StructSeg(save_path):
|
||||
# meta_info = {
|
||||
# "dataset_name": "StructSeg2019",
|
||||
# "dataset_id": "57",
|
||||
# "modality": "CT",
|
||||
# "home_path": HOME_PATH,
|
||||
# "dirpath": "57_StructSeg",
|
||||
# "save_txt_path": save_path,
|
||||
# }
|
||||
# names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "HaN_OAR/data/*"))
|
||||
# names = [f"{name}/data.nii.gz" for name in names]
|
||||
# names.sort()
|
||||
# label_names = [p.replace("/data.nii.gz", "/label.nii.gz") for p in names]
|
||||
# organize_by_names(names, label_names, meta_info=meta_info)
|
||||
|
||||
def get_CHAOS(save_path):
|
||||
meta_info = {
|
||||
"dataset_name": "CHAOS",
|
||||
"dataset_id": "58",
|
||||
"modality": "MRI",
|
||||
"home_path": HOME_PATH,
|
||||
"dirpath": "58_CHAOST2/chaos_MR_T2_normalized/",
|
||||
"save_txt_path": save_path,
|
||||
"class": ["liver", "right kidney", "left kidney", "spleen"],
|
||||
}
|
||||
names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "image*.nii.gz"))
|
||||
names.sort()
|
||||
label_names = [p.replace("/image_", "/label_") for p in names]
|
||||
organize_by_names(names, label_names, meta_info=meta_info)
|
||||
|
||||
def get_SABS(save_path):
|
||||
meta_info = {
|
||||
"dataset_name": "SABS", # BTCV ?
|
||||
"dataset_id": "59",
|
||||
"modality": "CT",
|
||||
"home_path": HOME_PATH,
|
||||
"dirpath": "59_SABS/sabs_CT_normalized/",
|
||||
"save_txt_path": save_path,
|
||||
"class": ["spleen", "right kidney", "left kidney", "gallbladder", "esophagus", "liver", "stomach", "aorta", "postcava", "portal vein and splenic vein", "pancrease", "right adrenal gland", "left adrenal gland"],
|
||||
}
|
||||
names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "image_*.nii.gz"))
|
||||
names.sort()
|
||||
label_names = [p.replace("/image_", "/label_") for p in names]
|
||||
organize_by_names(names, label_names, meta_info=meta_info)
|
||||
|
||||
|
||||
def get_Totalseg(save_path):
|
||||
meta_info = {
|
||||
"dataset_name": "Totalseg",
|
||||
"dataset_id": "60",
|
||||
"modality": "CT",
|
||||
"home_path": HOME_PATH,
|
||||
# "dirpath": "nnUNet_raw/Dataset101_Totalseg",
|
||||
"dirpath": "60_Totalseg",
|
||||
"save_txt_path": save_path,
|
||||
"class": [],
|
||||
}
|
||||
organize_in_nnunet_style(meta_info=meta_info)
|
||||
|
||||
|
||||
def get_WORD(save_path):
|
||||
meta_info = {
|
||||
"dataset_name": "WORDs", # BTCV ?
|
||||
"dataset_id": "07",
|
||||
"modality": "CT",
|
||||
"home_path": HOME_PATH,
|
||||
"dirpath": "07_WORD/WORD-V0.1.0/",
|
||||
"save_txt_path": save_path,
|
||||
}
|
||||
organize_in_style2(meta_info=meta_info)
|
||||
|
||||
def generate_all():
|
||||
save_path="./datasets/dataset_list/all_train.txt"
|
||||
clear_files(save_path)
|
||||
get_BCV_Abdomen(save_path)
|
||||
get_AbdomenCT_1K(save_path)
|
||||
get_AMOS(save_path)
|
||||
get_MSD(save_path)
|
||||
# get_ASOCA()
|
||||
# get_BCV_Cervix()
|
||||
# # get_NIHPancrease() # bug in data ?
|
||||
# get_CTPelvic()
|
||||
# get_FLARE()
|
||||
# get_SABS()
|
||||
|
||||
def generate_their():
|
||||
save_path="./datasets/dataset_list/their_train.txt"
|
||||
clear_files(save_path)
|
||||
save_path="./datasets/dataset_list/their_train.txt"
|
||||
get_BCV_Abdomen(save_path)
|
||||
get_AbdomenCT_1K(save_path)
|
||||
get_AMOS(save_path)
|
||||
get_MSD(save_path)
|
||||
|
||||
def generate_ours():
|
||||
save_path="./datasets/dataset_list/ours_train.txt"
|
||||
get_ASOCA(save_path)
|
||||
get_BCV_Cervix(save_path)
|
||||
# get_NIHPancrease() # bug in data ?
|
||||
get_CTPelvic(save_path)
|
||||
get_FLARE(save_path)
|
||||
get_SABS(save_path)
|
||||
|
||||
def generate_alp_dataset():
|
||||
save_path = "./datasets/dataset_list/alp_train.txt"
|
||||
clear_files(save_path)
|
||||
get_SABS(save_path)
|
||||
get_CHAOS(save_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(__file__)
|
||||
# generate_alp_dataset()
|
||||
save_path ="./datasets/dataset_list/totalseg_train.txt"
|
||||
clear_files(save_path)
|
||||
get_Totalseg(save_path)
|
||||
|
||||
# save_path="./datasets/dataset_list/word_train.txt"
|
||||
print("Over")
|
408
datasets/predict_various_masks.py
Normal file
408
datasets/predict_various_masks.py
Normal file
@ -0,0 +1,408 @@
|
||||
# from utils.predict_automasks import predict
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
from core.automask import SamAutomaticMaskGenerator
|
||||
from core.trainer import SamLearner
|
||||
# from datasets.dataset2d import Dataset2D
|
||||
from einops import rearrange, repeat
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from skimage.segmentation import slic
|
||||
from skimage.util import img_as_float
|
||||
from skimage import io
|
||||
import os
|
||||
from tutils import timer, tdir
|
||||
from scipy.sparse import csr_matrix
|
||||
# from datasets.dataset2d import Dataset2D
|
||||
import cv2
|
||||
import torch.nn.functional as F
|
||||
import time
|
||||
from tutils import tfilename
|
||||
|
||||
import matplotlib.pylab as plt
|
||||
from .utils import load_compressed_data, show_anns, img_to_show
|
||||
import cv2
|
||||
|
||||
def tfunctime(func):
|
||||
def run(*argv, **kargs):
|
||||
t1 = time.time()
|
||||
ret = func(*argv, **kargs)
|
||||
t2 = time.time()
|
||||
print(f"[Function <{func.__name__}>] Running time:{(t2-t1):.6f}s")
|
||||
return ret
|
||||
return run
|
||||
|
||||
def get_predictors():
|
||||
sam_checkpoint = "/quanquan/code/segment-anything/segment_anything/sam_vit_h_4b8939.pth" # for A100
|
||||
# sam_checkpoint = "/home1/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_h_4b8939.pth" # for server 103
|
||||
device = "cuda"
|
||||
model_type = "default"
|
||||
|
||||
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
||||
sam.to(device=device)
|
||||
|
||||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||
predictor = SamLearner(sam_model=sam)
|
||||
return mask_generator, predictor
|
||||
|
||||
def center_clip(img, eta=3):
|
||||
count_values = img[torch.where(img>-199)]
|
||||
mean = count_values.mean()
|
||||
std = count_values.std()
|
||||
img3 = torch.clip(img, mean-eta*std, mean+eta*std)
|
||||
|
||||
return img3
|
||||
|
||||
|
||||
def find_not_exist_masks(masks_repo, masks_new, iou_threshold=0.2):
|
||||
if len(masks_repo) == 0:
|
||||
return masks_new
|
||||
def compute_iou(m1, m2):
|
||||
m1 = m1['segmentation']
|
||||
m2 = m2['segmentation']
|
||||
intersection = m1*m2
|
||||
union = np.float32((m1 + m2) > 0)
|
||||
return intersection.sum() / union.sum()
|
||||
to_append = []
|
||||
for mask_new in masks_new:
|
||||
assert isinstance(mask_new, dict), f"{__file__} Got{type(mask_new)}"
|
||||
intersec_count = 0
|
||||
for mask_in_repo in masks_repo:
|
||||
assert isinstance(mask_in_repo, dict), f"{__file__} Got{type(mask_in_repo)}"
|
||||
iou = compute_iou(mask_in_repo, mask_new)
|
||||
.3
|
||||
if iou > iou_threshold:
|
||||
intersec_count += 1
|
||||
if intersec_count == 0:
|
||||
to_append.append(mask_new)
|
||||
check_keys(to_append)
|
||||
return to_append
|
||||
|
||||
|
||||
def merge_masks(masks_repo, masks_new, iou_threshold=0.2):
|
||||
to_append = find_not_exist_masks(masks_repo, masks_new, iou_threshold)
|
||||
# print(f"DEBUG: {len(masks_new) - len(to_append)} masks are deleted, remaining {len(to_append)}. The total achieves {len(masks_repo) + len(to_append)}")
|
||||
return masks_repo + to_append
|
||||
|
||||
|
||||
def dilate_erode(mask):
|
||||
kernel = np.ones((4, 4), dtype=np.uint8)
|
||||
mask = cv2.morphologyEx(mask.astype(float), cv2.MORPH_CLOSE, kernel, iterations=1)
|
||||
return mask
|
||||
|
||||
|
||||
def get_superpixel(image, hu_mask, hu_threshold=-50):
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.detach().cpu().numpy()
|
||||
if isinstance(hu_mask, torch.Tensor):
|
||||
hu_mask = hu_mask.detach().cpu().numpy()
|
||||
segments = slic(image, n_segments=100, compactness=9)
|
||||
mask_collect = []
|
||||
image2 = image[:,:,0]
|
||||
for i in range(1, segments.max()+1):
|
||||
mask = torch.Tensor(segments==i).detach().cpu().numpy()
|
||||
# assert img
|
||||
if image2[segments==i].mean() <= hu_threshold:
|
||||
continue
|
||||
mask = mask * hu_mask
|
||||
mask = dilate_erode(mask)
|
||||
mask_data = {
|
||||
'segmentation':mask,
|
||||
'area':mask.sum(),
|
||||
'source': "superpixel",
|
||||
'mean_value':(mask * image[:,:,0]).mean(),
|
||||
}
|
||||
mask_collect.append(mask_data)
|
||||
check_keys(mask_collect)
|
||||
return mask_collect
|
||||
|
||||
# def resize_masks(masks):
|
||||
# for m in masks:
|
||||
# m['segmentation'] = F.interpolate(torch.Tensor(m['segmentation'])[None,None,:,:], size=(512,512)).squeeze().numpy()
|
||||
# return masks
|
||||
|
||||
# @tfunctime
|
||||
def get_masks_via_color_changing(img, label, mask_generator, predictor=None):
|
||||
img = repeat(img, " h w -> 1 3 h w")
|
||||
img = torch.Tensor(img)
|
||||
img = F.interpolate(img, size=(1024,1024))
|
||||
|
||||
label = torch.Tensor(label)
|
||||
if label is not None:
|
||||
label_masks = []
|
||||
for i in range(1,label.max().int()+1):
|
||||
labeli = torch.Tensor(label==i).float()
|
||||
labeli = F.interpolate(labeli[None, None, :,:], size=(1024,1024)).squeeze().numpy()
|
||||
area = labeli.sum()
|
||||
if area <=10:
|
||||
continue
|
||||
mask = {
|
||||
"segmentation":labeli,
|
||||
"area":area,
|
||||
"source": "gt",
|
||||
}
|
||||
label_masks.append(mask)
|
||||
else:
|
||||
label_masks = []
|
||||
|
||||
mask_generator.reset_image()
|
||||
predictor.reset_image()
|
||||
|
||||
masks = mask_generator.generate(img.cuda())
|
||||
masks = filter_large_masks(masks)
|
||||
for mask in masks:
|
||||
mask['source'] = "sam_auto_seg"
|
||||
label_masks = merge_masks(label_masks, masks)
|
||||
del masks
|
||||
|
||||
check_keys(label_masks)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# return img, label_masks
|
||||
|
||||
img2 = center_clip(img, 2.5)
|
||||
masks = mask_generator.generate(img.cuda())
|
||||
masks = filter_large_masks(masks)
|
||||
label_masks = merge_masks(label_masks, masks)
|
||||
del masks
|
||||
del img2
|
||||
|
||||
# img2 = center_clip(img, 2)
|
||||
# masks = mask_generator.generate(img2.cuda())
|
||||
# masks = filter_large_masks(masks)
|
||||
# for mask in masks:
|
||||
# mask['source'] = "sam_auto_seg"
|
||||
# label_masks = merge_masks(label_masks, masks)
|
||||
# del masks
|
||||
# del img2
|
||||
|
||||
img2 = center_clip(img, 1)
|
||||
masks = mask_generator.generate(img2.cuda())
|
||||
masks = filter_large_masks(masks)
|
||||
for mask in masks:
|
||||
mask['source'] = "sam_auto_seg"
|
||||
label_masks = merge_masks(label_masks, masks)
|
||||
del masks
|
||||
del img2
|
||||
|
||||
img2 = center_clip(img, 0.5)
|
||||
masks = mask_generator.generate(img2.cuda())
|
||||
masks = filter_large_masks(masks)
|
||||
for mask in masks:
|
||||
mask['source'] = "sam_auto_seg"
|
||||
label_masks = merge_masks(label_masks, masks)
|
||||
del masks
|
||||
del img2
|
||||
|
||||
check_keys(label_masks)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
return img, label_masks
|
||||
|
||||
def check_keys(masks):
|
||||
for m in masks:
|
||||
assert m['segmentation'] is not None
|
||||
assert m['area'] is not None
|
||||
assert m['source'] is not None
|
||||
|
||||
def filter_large_masks(masks):
|
||||
filtered_masks = []
|
||||
for mask in masks:
|
||||
if mask['area'] > 0.25 *1024 * 1024:
|
||||
continue
|
||||
filtered_masks.append(mask)
|
||||
del masks
|
||||
return filtered_masks
|
||||
|
||||
|
||||
def mix_masks(masks):
|
||||
if len(masks) == 0:
|
||||
return None
|
||||
mixed_mask = None
|
||||
for item in masks:
|
||||
m = item['segmentation']
|
||||
mixed_mask = np.zeros_like(m) if mixed_mask is None else mixed_mask
|
||||
mixed_mask += m
|
||||
mixed_mask = np.float32(mixed_mask>0)
|
||||
return mixed_mask
|
||||
|
||||
|
||||
def select_random_point_from_mask(gt_mask):
|
||||
size = gt_mask.shape
|
||||
assert len(size) == 2
|
||||
xy = np.arange(0, size[0] * size[1])
|
||||
gt_mask = np.float32(gt_mask>0)
|
||||
prob = rearrange(gt_mask, "h w -> (h w)")
|
||||
prob = prob / prob.sum()
|
||||
loc = np.random.choice(a=xy, size=1, replace=True, p=prob)[0]
|
||||
x, y = loc % size[1], loc // size[1]
|
||||
return x, y
|
||||
|
||||
|
||||
def select_center_point_from_mask(gt_mask):
|
||||
# get indices of all the foreground pixels
|
||||
indices = np.argwhere(gt_mask > 0)
|
||||
# calculate the center point by taking the mean of the foreground pixel indices
|
||||
center = np.mean(indices, axis=0).astype(int)
|
||||
y, x = center
|
||||
return x, y
|
||||
|
||||
@tfunctime
|
||||
def get_masks_via_points_from_superpixels(img, label_masks, superpixels, predictor, hu_mask=None):
|
||||
if isinstance(hu_mask, torch.Tensor):
|
||||
hu_mask = hu_mask.detach().cpu().numpy()
|
||||
total_mask = mix_masks(label_masks)
|
||||
# superpixels = get_superpixel(rearrange(img, "1 c h w -> h w c"))
|
||||
points = []
|
||||
ex_masks = []
|
||||
for seg in superpixels:
|
||||
mask_collect = []
|
||||
for i in range(5):
|
||||
mm = np.nan_to_num(seg['segmentation'], 0)
|
||||
if mm.sum() <= 0:
|
||||
continue
|
||||
x,y = select_center_point_from_mask(mm) # select_random_point_from_mask
|
||||
# print(x,y)
|
||||
if total_mask[y,x] == 1:
|
||||
continue
|
||||
# else:
|
||||
# print(total_mask[y,x])
|
||||
point = torch.Tensor([x,y])
|
||||
points.append(point)
|
||||
mask = predictor.generate(img.cuda(), point.cuda().unsqueeze(0).unsqueeze(0))[0,0].detach().cpu().numpy()
|
||||
mask = mask * hu_mask
|
||||
mask = dilate_erode(mask)
|
||||
mask = {
|
||||
"segmentation": mask,
|
||||
"area": mask.sum(),
|
||||
"source": "prompt_point_from_superpixel",
|
||||
}
|
||||
mask_collect = merge_masks(mask_collect, [mask])
|
||||
ex_masks += mask_collect
|
||||
check_keys(ex_masks)
|
||||
return ex_masks
|
||||
|
||||
def mask_to_bbox(mask):
|
||||
""" copied from data_engine """
|
||||
# Find the indices of all non-zero elements in the mask
|
||||
coords = np.nonzero(mask)
|
||||
|
||||
# Compute the minimum and maximum values of the row and column indices
|
||||
x_min = np.min(coords[1])
|
||||
y_min = np.min(coords[0])
|
||||
x_max = np.max(coords[1])
|
||||
y_max = np.max(coords[0])
|
||||
|
||||
# Return the coordinates of the bounding box
|
||||
return (x_min, y_min, x_max, y_max)
|
||||
|
||||
@tfunctime
|
||||
def get_masks_via_boxes_from_superpixels(img, label_masks, superpixels, predictor, hu_mask=None):
|
||||
if isinstance(hu_mask, torch.Tensor):
|
||||
hu_mask = hu_mask.detach().cpu().numpy()
|
||||
ex_masks = []
|
||||
for seg in superpixels:
|
||||
if seg['segmentation'].sum() < 100:
|
||||
continue
|
||||
x_min, y_min, x_max, y_max = mask_to_bbox(seg['segmentation'])
|
||||
box = torch.Tensor([x_min, y_min, x_max, y_max])
|
||||
mask = predictor.generate_by_box(img.cuda(), box.cuda().unsqueeze(0).unsqueeze(0))[0,0].detach().cpu().numpy()
|
||||
mask = mask * hu_mask
|
||||
mask = dilate_erode(mask)
|
||||
mask = {
|
||||
"segmentation": mask,
|
||||
"area": mask.sum(),
|
||||
"source": "prompt_box_from_superpixel",
|
||||
}
|
||||
ex_masks = merge_masks(ex_masks, [mask])
|
||||
check_keys(ex_masks)
|
||||
return ex_masks
|
||||
|
||||
|
||||
class VariousPseudoMasksGenerator:
|
||||
def __init__(self, dataset, label_path:str=None) -> None:
|
||||
self.dataset = dataset
|
||||
# assert dirpath.split("/")[-1] == "cache_2d", f"{__file__} Got{dirpath.split('/')[-1]} ; dirpath:{dirpath}"
|
||||
self.label_path = label_path if label_path is not None else dataset.dirpath.replace("cache_2d", "cache_2d_various_pseudo_masks")
|
||||
|
||||
def example(self, mask_generator=None, predictor=None):
|
||||
return self.generate(mask_generator=mask_generator, predictor=predictor, is_example=True)
|
||||
|
||||
def generate(self, mask_generator=None, predictor=None, is_example=False):
|
||||
tt = timer()
|
||||
if mask_generator is None:
|
||||
mask_generator, predictor = get_predictors()
|
||||
self.mask_generator, self.predictor = mask_generator, predictor
|
||||
for img_idx in range(len(self.dataset)):
|
||||
tt()
|
||||
data = self.dataset._get_image(img_idx)
|
||||
masks_all_layers = []
|
||||
words = data['name'].split("/")
|
||||
dataset_name = words[0]
|
||||
filename = words[-1].replace(".nii.gz","")
|
||||
volume_rgb = np.clip(data['img'], -200, 400)
|
||||
volume_rgb = (volume_rgb - volume_rgb.min()) / volume_rgb.max()
|
||||
for slice_idx in range(data['img'].shape[0]):
|
||||
path = tfilename(self.label_path, f'{dataset_name}/{filename}_s{slice_idx}_mask.npz')
|
||||
# if os.path.exists(path):
|
||||
# continue
|
||||
masks = self.get_various_masks(data['img'][slice_idx], data['label'][slice_idx], mask_generator, predictor)
|
||||
self.save_slice_mask(masks, path, slice_idx)
|
||||
img_path = tfilename(self.label_path, f'image_{dataset_name}/{filename}_s{slice_idx}.jpg')
|
||||
self.save_img_rgb(volume_rgb[slice_idx], img_path)
|
||||
# display
|
||||
# plt.figure(figsize=(6,6))
|
||||
# plt.imshow(img_to_show(data['img'][slice_idx]), cmap='gray')
|
||||
# show_anns(masks)
|
||||
# plt.axis('off')
|
||||
# plt.show()
|
||||
print(f"Save to {img_path} and {path}")
|
||||
|
||||
print(f"Processing {img_idx}, {len(masks_all_layers)} saved, time used:{tt()}", end='\r')
|
||||
if is_example:
|
||||
break
|
||||
|
||||
def save_slice_mask(self, masks, path, slice_idx):
|
||||
masks_data = np.array([m['segmentation'] for m in masks]).astype(int)
|
||||
# if len(masks_data) <= 1:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
masks_data = F.interpolate(torch.Tensor(masks_data).unsqueeze(1), size=(512,512)).squeeze().numpy()
|
||||
masks_data = np.int8(masks_data>0)
|
||||
if len(masks_data.shape) == 2:
|
||||
masks_data = masks_data[None,:,:]
|
||||
assert masks_data.shape[1:] == (512,512), f"{__file__} Got{masks_data.shape}"
|
||||
masks_data = rearrange(masks_data, "n h w -> n (h w)")
|
||||
csr = csr_matrix(masks_data)
|
||||
np.savez_compressed(path, data=csr.data, indices=csr.indices, indptr=csr.indptr, shape=csr.shape)
|
||||
|
||||
def save_img_rgb(self, img, path):
|
||||
img = (img * 255).astype(np.uint8)
|
||||
cv2.imwrite(path, img)
|
||||
|
||||
|
||||
@staticmethod
|
||||
@tfunctime
|
||||
def get_various_masks(img_ori, label, mask_generator, predictor):
|
||||
mask_generator.reset_image()
|
||||
predictor.reset_image()
|
||||
img, label_masks = get_masks_via_color_changing(img_ori, label, mask_generator, predictor)
|
||||
|
||||
# # return label_masks
|
||||
# hu_mask = img_ori>img_ori.min() + 10
|
||||
# hu_mask = F.interpolate(torch.Tensor(hu_mask)[None,None,:,:], size=(1024,1024)).squeeze()
|
||||
|
||||
# superpixels = get_superpixel(rearrange(img, "1 c h w -> h w c"), hu_mask=hu_mask)
|
||||
# exclu_superpixels = find_not_exist_masks(label_masks, superpixels)
|
||||
# # point_prompt_masks = get_masks_via_boxes_from_superpixels(img, label_masks, exclu_superpixels, predictor, hu_mask=hu_mask)
|
||||
# box_prompt_masks = get_masks_via_points_from_superpixels(img, label_masks, exclu_superpixels, predictor, hu_mask=hu_mask)
|
||||
|
||||
return label_masks # + box_prompt_masks # + point_prompt_masks #
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# CUDA_VISIBLE_DEVICES=6 python -m datasets.predict_various_masks
|
||||
from datasets.dataset3d import Dataset3D
|
||||
dataset = Dataset3D()
|
||||
gen = VariousPseudoMasksGenerator(dataset=dataset,
|
||||
label_path=tdir("/quanquan/datasets/all_datasets/various_masks_3/"))
|
||||
gen.generate()
|
||||
|
||||
|
0
modeling/__init__.py
Normal file
0
modeling/__init__.py
Normal file
BIN
modeling/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
modeling/__pycache__/build_sam3d2.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/build_sam3d2.cpython-38.pyc
Normal file
Binary file not shown.
BIN
modeling/__pycache__/common.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/common.cpython-38.pyc
Normal file
Binary file not shown.
BIN
modeling/__pycache__/image_encoder.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/image_encoder.cpython-38.pyc
Normal file
Binary file not shown.
BIN
modeling/__pycache__/mask_decoder3d_2.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/mask_decoder3d_2.cpython-38.pyc
Normal file
Binary file not shown.
BIN
modeling/__pycache__/prompt_encoder3d.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/prompt_encoder3d.cpython-38.pyc
Normal file
Binary file not shown.
BIN
modeling/__pycache__/sam3d.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/sam3d.cpython-38.pyc
Normal file
Binary file not shown.
BIN
modeling/__pycache__/transformer.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/transformer.cpython-38.pyc
Normal file
Binary file not shown.
119
modeling/build_sam3d2.py
Normal file
119
modeling/build_sam3d2.py
Normal file
@ -0,0 +1,119 @@
|
||||
"""
|
||||
Differences with build_sam3d.py
|
||||
use mask_decoder3d_2 instead of mask_decoder3d
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from functools import partial
|
||||
|
||||
# from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
|
||||
from .image_encoder import ImageEncoderViT
|
||||
from .mask_decoder3d_2 import MaskDecoder
|
||||
from .prompt_encoder3d import PromptEncoder
|
||||
from .sam3d import Sam
|
||||
from .transformer import TwoWayTransformer
|
||||
|
||||
|
||||
def build_sam_vit_h(checkpoint=None):
|
||||
return _build_sam(
|
||||
encoder_embed_dim=1280,
|
||||
encoder_depth=32,
|
||||
encoder_num_heads=16,
|
||||
encoder_global_attn_indexes=[7, 15, 23, 31],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
build_sam = build_sam_vit_h
|
||||
|
||||
|
||||
def build_sam_vit_l(checkpoint=None):
|
||||
return _build_sam(
|
||||
encoder_embed_dim=1024,
|
||||
encoder_depth=24,
|
||||
encoder_num_heads=16,
|
||||
encoder_global_attn_indexes=[5, 11, 17, 23],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam_vit_b(checkpoint=None):
|
||||
return _build_sam(
|
||||
encoder_embed_dim=768,
|
||||
encoder_depth=12,
|
||||
encoder_num_heads=12,
|
||||
encoder_global_attn_indexes=[2, 5, 8, 11],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
sam_model_registry = {
|
||||
"default": build_sam_vit_h,
|
||||
"vit_h": build_sam_vit_h,
|
||||
"vit_l": build_sam_vit_l,
|
||||
"vit_b": build_sam_vit_b,
|
||||
}
|
||||
|
||||
|
||||
def _build_sam(
|
||||
encoder_embed_dim,
|
||||
encoder_depth,
|
||||
encoder_num_heads,
|
||||
encoder_global_attn_indexes,
|
||||
checkpoint=None,
|
||||
):
|
||||
prompt_embed_dim = 256
|
||||
image_size = 1024
|
||||
vit_patch_size = 16
|
||||
image_embedding_size = image_size // vit_patch_size
|
||||
sam = Sam(
|
||||
image_encoder=ImageEncoderViT(
|
||||
depth=encoder_depth,
|
||||
embed_dim=encoder_embed_dim,
|
||||
img_size=image_size,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
||||
num_heads=encoder_num_heads,
|
||||
patch_size=vit_patch_size,
|
||||
qkv_bias=True,
|
||||
use_rel_pos=True,
|
||||
global_attn_indexes=encoder_global_attn_indexes,
|
||||
window_size=14,
|
||||
out_chans=prompt_embed_dim,
|
||||
),
|
||||
prompt_encoder=PromptEncoder(
|
||||
embed_dim=prompt_embed_dim,
|
||||
image_embedding_size=(image_embedding_size, image_embedding_size),
|
||||
input_image_size=(image_size, image_size),
|
||||
mask_in_chans=16,
|
||||
),
|
||||
mask_decoder=MaskDecoder(
|
||||
num_multimask_outputs=3,
|
||||
transformer=TwoWayTransformer(
|
||||
depth=2,
|
||||
embedding_dim=prompt_embed_dim,
|
||||
mlp_dim=2048,
|
||||
num_heads=8,
|
||||
),
|
||||
transformer_dim=prompt_embed_dim,
|
||||
iou_head_depth=3,
|
||||
iou_head_hidden_dim=256,
|
||||
),
|
||||
pixel_mean=[123.675, 116.28, 103.53],
|
||||
pixel_std=[58.395, 57.12, 57.375],
|
||||
)
|
||||
sam.eval()
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
sam.load_state_dict(state_dict)
|
||||
return sam
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = build_sam_vit_l()
|
||||
data = torch.ones((1,3,1024,1024))
|
||||
out = model(data, multimask_output=True)
|
||||
print(out.shape)
|
43
modeling/common.py
Normal file
43
modeling/common.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from typing import Type
|
||||
|
||||
|
||||
class MLPBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
mlp_dim: int,
|
||||
act: Type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
||||
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
||||
self.act = act()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.lin2(self.act(self.lin1(x)))
|
||||
|
||||
|
||||
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
||||
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
||||
class LayerNorm2d(nn.Module):
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(num_channels))
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
return x
|
402
modeling/image_encoder.py
Normal file
402
modeling/image_encoder.py
Normal file
@ -0,0 +1,402 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
from .common import LayerNorm2d, MLPBlock
|
||||
|
||||
|
||||
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
||||
class ImageEncoderViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int = 1024,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_chans: int = 256,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
use_abs_pos: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
global_attn_indexes: Tuple[int, ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
img_size (int): Input image size.
|
||||
patch_size (int): Patch size.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): Patch embedding dimension.
|
||||
depth (int): Depth of ViT.
|
||||
num_heads (int): Number of attention heads in each ViT block.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
act_layer (nn.Module): Activation layer.
|
||||
use_abs_pos (bool): If True, use absolute positional embeddings.
|
||||
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
window_size (int): Window size for window attention blocks.
|
||||
global_attn_indexes (list): Indexes for blocks using global attention.
|
||||
"""
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
kernel_size=(patch_size, patch_size),
|
||||
stride=(patch_size, patch_size),
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
|
||||
self.pos_embed: Optional[nn.Parameter] = None
|
||||
if use_abs_pos:
|
||||
# Initialize absolute positional embedding with pretrain image size.
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for i in range(depth):
|
||||
block = Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
use_rel_pos=use_rel_pos,
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
window_size=window_size if i not in global_attn_indexes else 0,
|
||||
input_size=(img_size // patch_size, img_size // patch_size),
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
self.neck = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
embed_dim,
|
||||
out_chans,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_chans),
|
||||
nn.Conv2d(
|
||||
out_chans,
|
||||
out_chans,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_chans),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.neck(x.permute(0, 3, 1, 2))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads in each ViT block.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
act_layer (nn.Module): Activation layer.
|
||||
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
window_size (int): Window size for window attention blocks. If it equals 0, then
|
||||
use global attention.
|
||||
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
||||
positional parameter size.
|
||||
"""
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rel_pos=use_rel_pos,
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
input_size=input_size if window_size == 0 else (window_size, window_size),
|
||||
)
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
|
||||
|
||||
self.window_size = window_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
# Window partition
|
||||
if self.window_size > 0:
|
||||
H, W = x.shape[1], x.shape[2]
|
||||
x, pad_hw = window_partition(x, self.window_size)
|
||||
|
||||
x = self.attn(x)
|
||||
# Reverse window partition
|
||||
if self.window_size > 0:
|
||||
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
||||
|
||||
x = shortcut + x
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Multi-head Attention block with relative position embeddings."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
||||
positional parameter size.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.use_rel_pos = use_rel_pos
|
||||
if self.use_rel_pos:
|
||||
assert (
|
||||
input_size is not None
|
||||
), "Input size must be provided if using relative positional encoding."
|
||||
# initialize relative positional embeddings
|
||||
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
||||
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (3, B, nHead, H * W, C)
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
# q, k, v with shape (B * nHead, H * W, C)
|
||||
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
||||
|
||||
attn = (q * self.scale) @ k.transpose(-2, -1)
|
||||
|
||||
if self.use_rel_pos:
|
||||
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
"""
|
||||
Partition into non-overlapping windows with padding if needed.
|
||||
Args:
|
||||
x (tensor): input tokens with [B, H, W, C].
|
||||
window_size (int): window size.
|
||||
|
||||
Returns:
|
||||
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
||||
(Hp, Wp): padded height and width before partition
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
|
||||
pad_h = (window_size - H % window_size) % window_size
|
||||
pad_w = (window_size - W % window_size) % window_size
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
||||
Hp, Wp = H + pad_h, W + pad_w
|
||||
|
||||
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows, (Hp, Wp)
|
||||
|
||||
|
||||
def window_unpartition(
|
||||
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Window unpartition into original sequences and removing padding.
|
||||
Args:
|
||||
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
||||
window_size (int): window size.
|
||||
pad_hw (Tuple): padded height and width (Hp, Wp).
|
||||
hw (Tuple): original height and width (H, W) before padding.
|
||||
|
||||
Returns:
|
||||
x: unpartitioned sequences with [B, H, W, C].
|
||||
"""
|
||||
Hp, Wp = pad_hw
|
||||
H, W = hw
|
||||
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
||||
|
||||
if Hp > H or Wp > W:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get relative positional embeddings according to the relative positions of
|
||||
query and key sizes.
|
||||
Args:
|
||||
q_size (int): size of query q.
|
||||
k_size (int): size of key k.
|
||||
rel_pos (Tensor): relative position embeddings (L, C).
|
||||
|
||||
Returns:
|
||||
Extracted positional embeddings according to relative positions.
|
||||
"""
|
||||
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||
# Interpolate rel pos if needed.
|
||||
if rel_pos.shape[0] != max_rel_dist:
|
||||
# Interpolate rel pos.
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode="linear",
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||
else:
|
||||
rel_pos_resized = rel_pos
|
||||
|
||||
# Scale the coords with short length if shapes for q and k are different.
|
||||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
||||
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
||||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
||||
|
||||
return rel_pos_resized[relative_coords.long()]
|
||||
|
||||
|
||||
def add_decomposed_rel_pos(
|
||||
attn: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
rel_pos_h: torch.Tensor,
|
||||
rel_pos_w: torch.Tensor,
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
||||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
||||
Args:
|
||||
attn (Tensor): attention map.
|
||||
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
||||
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
||||
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
||||
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
||||
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
||||
|
||||
Returns:
|
||||
attn (Tensor): attention map with added relative positional embeddings.
|
||||
"""
|
||||
q_h, q_w = q_size
|
||||
k_h, k_w = k_size
|
||||
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
||||
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
||||
|
||||
B, _, dim = q.shape
|
||||
r_q = q.reshape(B, q_h, q_w, dim)
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
|
||||
attn = (
|
||||
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
||||
).view(B, q_h * q_w, k_h * k_w)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
Image to Patch Embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Tuple[int, int] = (16, 16),
|
||||
stride: Tuple[int, int] = (16, 16),
|
||||
padding: Tuple[int, int] = (0, 0),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
kernel_size (Tuple): kernel size of the projection layer.
|
||||
stride (Tuple): stride of the projection layer.
|
||||
padding (Tuple): padding size of the projection layer.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): Patch embedding dimension.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
# B C H W -> B H W C
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = ImageEncoderViT()
|
||||
data = torch.ones(1,3,1024,1024)
|
||||
print(model)
|
||||
model(data)
|
||||
|
204
modeling/mask_decoder3d_2.py
Normal file
204
modeling/mask_decoder3d_2.py
Normal file
@ -0,0 +1,204 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from einops import rearrange
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
from .common import LayerNorm2d
|
||||
|
||||
|
||||
class MaskDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
transformer_dim: int,
|
||||
transformer: nn.Module,
|
||||
num_multimask_outputs: int = 3,
|
||||
activation: Type[nn.Module] = nn.GELU,
|
||||
iou_head_depth: int = 3,
|
||||
iou_head_hidden_dim: int = 256,
|
||||
num_slices: int = 3,
|
||||
) -> None:
|
||||
"""
|
||||
Predicts masks given an image and prompt embeddings, using a
|
||||
transformer architecture.
|
||||
|
||||
Arguments:
|
||||
transformer_dim (int): the channel dimension of the transformer
|
||||
transformer (nn.Module): the transformer used to predict masks
|
||||
num_multimask_outputs (int): the number of masks to predict
|
||||
when disambiguating masks
|
||||
activation (nn.Module): the type of activation to use when
|
||||
upscaling masks
|
||||
iou_head_depth (int): the depth of the MLP used to predict
|
||||
mask quality
|
||||
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
||||
used to predict mask quality
|
||||
"""
|
||||
super().__init__()
|
||||
self.transformer_dim = transformer_dim
|
||||
self.transformer = transformer
|
||||
|
||||
self.num_multimask_outputs = num_multimask_outputs
|
||||
|
||||
self.iou_token = nn.Embedding(1, transformer_dim)
|
||||
self.num_mask_tokens = num_multimask_outputs + 1
|
||||
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
||||
|
||||
self.output_upscaling = nn.Sequential(
|
||||
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(transformer_dim // 4),
|
||||
activation(),
|
||||
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||
activation(),
|
||||
)
|
||||
self.output_upscaling2 = nn.Sequential(
|
||||
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(transformer_dim // 4),
|
||||
activation(),
|
||||
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||
activation(),
|
||||
)
|
||||
self.output_upscaling3 = nn.Sequential(
|
||||
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(transformer_dim // 4),
|
||||
activation(),
|
||||
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||
activation(),
|
||||
)
|
||||
self.output_hypernetworks_mlps = nn.ModuleList(
|
||||
[
|
||||
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
||||
for i in range(self.num_mask_tokens)
|
||||
]
|
||||
)
|
||||
|
||||
self.iou_prediction_head = MLP(
|
||||
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
multimask_output: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks given image and prompt embeddings.
|
||||
|
||||
Arguments:
|
||||
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
||||
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
||||
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
||||
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
||||
multimask_output (bool): Whether to return multiple masks or a single
|
||||
mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: batched predicted masks
|
||||
torch.Tensor: batched predictions of mask quality
|
||||
"""
|
||||
masks, iou_pred = self.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=image_pe,
|
||||
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
||||
dense_prompt_embeddings=dense_prompt_embeddings,
|
||||
)
|
||||
|
||||
# Select the correct mask or masks for output
|
||||
if multimask_output:
|
||||
# mask_slice = slice(1, None)
|
||||
masks = masks[:, 3:, :, :]
|
||||
iou_pred = iou_pred[:, 1:]
|
||||
else:
|
||||
# mask_slice = slice(0, 1)
|
||||
masks = masks[:, :3, :, :]
|
||||
iou_pred = iou_pred[:, 0:1]
|
||||
|
||||
# Prepare output
|
||||
return masks, iou_pred
|
||||
|
||||
def predict_masks(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Predicts masks. See 'forward' for more details."""
|
||||
# Concatenate output tokens
|
||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) # (1,5,256)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
||||
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # (1,7,256)
|
||||
|
||||
# Expand per-image data in batch direction to be per-mask
|
||||
# TODO: Why this code???
|
||||
# src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
||||
src = image_embeddings
|
||||
src = src + dense_prompt_embeddings
|
||||
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
||||
b, c, h, w = src.shape
|
||||
|
||||
# Run the transformer
|
||||
hs, src = self.transformer(src, pos_src, tokens)
|
||||
iou_token_out = hs[:, 0, :]
|
||||
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
||||
|
||||
# Upscale mask embeddings and predict masks using the mask tokens
|
||||
src = src.transpose(1, 2).view(b, c, h, w)
|
||||
upscaled_embedding = self.output_upscaling(src)
|
||||
upscaled_embedding2 = self.output_upscaling2(src)
|
||||
upscaled_embedding3 = self.output_upscaling3(src)
|
||||
|
||||
hyper_in_list: List[torch.Tensor] = []
|
||||
for i in range(self.num_mask_tokens):
|
||||
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
|
||||
hyper_in = torch.stack(hyper_in_list, dim=1)
|
||||
b, c, h, w = upscaled_embedding.shape
|
||||
masks1 = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
||||
masks2 = (hyper_in @ upscaled_embedding2.view(b, c, h * w)).view(b, -1, h, w)
|
||||
masks3 = (hyper_in @ upscaled_embedding3.view(b, c, h * w)).view(b, -1, h, w)
|
||||
masks = torch.stack([masks1, masks2, masks3], axis=2)
|
||||
assert masks.shape[1] == 4 and masks.shape[2] == 3, f"Got {masks.shape}"
|
||||
masks = rearrange(masks, "b c1 c2 h w -> b (c1 c2) h w")
|
||||
|
||||
# Generate mask quality predictions
|
||||
iou_pred = self.iou_prediction_head(iou_token_out)
|
||||
|
||||
# masks = rearrange(masks, "b (c1 c2) h w -> b c1 c2 h w", c1=4, c2=3)
|
||||
return masks, iou_pred
|
||||
|
||||
|
||||
# Lightly adapted from
|
||||
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
sigmoid_output: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(
|
||||
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
||||
)
|
||||
self.sigmoid_output = sigmoid_output
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
if self.sigmoid_output:
|
||||
x = F.sigmoid(x)
|
||||
return x
|
215
modeling/prompt_encoder3d.py
Normal file
215
modeling/prompt_encoder3d.py
Normal file
@ -0,0 +1,215 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from typing import Any, Optional, Tuple, Type
|
||||
|
||||
from .common import LayerNorm2d
|
||||
|
||||
|
||||
class PromptEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
image_embedding_size: Tuple[int, int],
|
||||
input_image_size: Tuple[int, int],
|
||||
mask_in_chans: int,
|
||||
activation: Type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
"""
|
||||
Encodes prompts for input to SAM's mask decoder.
|
||||
|
||||
Arguments:
|
||||
embed_dim (int): The prompts' embedding dimension
|
||||
image_embedding_size (tuple(int, int)): The spatial size of the
|
||||
image embedding, as (H, W).
|
||||
input_image_size (int): The padded size of the image as input
|
||||
to the image encoder, as (H, W).
|
||||
mask_in_chans (int): The number of hidden channels used for
|
||||
encoding input masks.
|
||||
activation (nn.Module): The activation to use when encoding
|
||||
input masks.
|
||||
"""
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.input_image_size = input_image_size
|
||||
self.image_embedding_size = image_embedding_size
|
||||
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
||||
|
||||
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
||||
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
|
||||
self.point_embeddings = nn.ModuleList(point_embeddings)
|
||||
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
||||
|
||||
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
|
||||
self.mask_downscaling = nn.Sequential(
|
||||
# nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
||||
nn.Conv2d(3, mask_in_chans // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(mask_in_chans // 4),
|
||||
activation(),
|
||||
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
||||
LayerNorm2d(mask_in_chans),
|
||||
activation(),
|
||||
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
||||
)
|
||||
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
||||
|
||||
def get_dense_pe(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns the positional encoding used to encode point prompts,
|
||||
applied to a dense set of points the shape of the image encoding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Positional encoding with shape
|
||||
1x(embed_dim)x(embedding_h)x(embedding_w)
|
||||
"""
|
||||
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
||||
|
||||
def _embed_points(
|
||||
self,
|
||||
points: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
pad: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Embeds point prompts."""
|
||||
points = points + 0.5 # Shift to center of pixel
|
||||
if pad:
|
||||
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
||||
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
||||
points = torch.cat([points, padding_point], dim=1)
|
||||
labels = torch.cat([labels, padding_label], dim=1)
|
||||
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
|
||||
point_embedding[labels == -1] = 0.0
|
||||
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
||||
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
||||
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
||||
return point_embedding
|
||||
|
||||
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
||||
"""Embeds box prompts."""
|
||||
boxes = boxes + 0.5 # Shift to center of pixel
|
||||
coords = boxes.reshape(-1, 2, 2)
|
||||
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
|
||||
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
||||
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
||||
return corner_embedding
|
||||
|
||||
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
||||
"""Embeds mask inputs."""
|
||||
mask_embedding = self.mask_downscaling(masks)
|
||||
return mask_embedding
|
||||
|
||||
def _get_batch_size(
|
||||
self,
|
||||
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
) -> int:
|
||||
"""
|
||||
Gets the batch size of the output given the batch size of the input prompts.
|
||||
"""
|
||||
if points is not None:
|
||||
return points[0].shape[0]
|
||||
elif boxes is not None:
|
||||
return boxes.shape[0]
|
||||
elif masks is not None:
|
||||
return masks.shape[0]
|
||||
else:
|
||||
return 1
|
||||
|
||||
def _get_device(self) -> torch.device:
|
||||
return self.point_embeddings[0].weight.device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Embeds different types of prompts, returning both sparse and dense
|
||||
embeddings.
|
||||
|
||||
Arguments:
|
||||
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
|
||||
and labels to embed.
|
||||
boxes (torch.Tensor or none): boxes to embed
|
||||
masks (torch.Tensor or none): masks to embed
|
||||
|
||||
Returns:
|
||||
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
||||
BxNx(embed_dim), where N is determined by the number of input points
|
||||
and boxes.
|
||||
torch.Tensor: dense embeddings for the masks, in the shape
|
||||
Bx(embed_dim)x(embed_H)x(embed_W)
|
||||
"""
|
||||
bs = self._get_batch_size(points, boxes, masks)
|
||||
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
|
||||
if points is not None:
|
||||
coords, labels = points
|
||||
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
||||
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
||||
if boxes is not None:
|
||||
box_embeddings = self._embed_boxes(boxes)
|
||||
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
||||
|
||||
if masks is not None:
|
||||
dense_embeddings = self._embed_masks(masks)
|
||||
else:
|
||||
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
||||
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
||||
)
|
||||
|
||||
return sparse_embeddings, dense_embeddings
|
||||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""
|
||||
Positional encoding using random spatial frequencies.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
||||
super().__init__()
|
||||
if scale is None or scale <= 0.0:
|
||||
scale = 1.0
|
||||
self.register_buffer(
|
||||
"positional_encoding_gaussian_matrix",
|
||||
scale * torch.randn((2, num_pos_feats)),
|
||||
)
|
||||
|
||||
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
"""Positionally encode points that are normalized to [0,1]."""
|
||||
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
||||
coords = 2 * coords - 1
|
||||
coords = coords @ self.positional_encoding_gaussian_matrix
|
||||
coords = 2 * np.pi * coords
|
||||
# outputs d_1 x ... x d_n x C shape
|
||||
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
||||
|
||||
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Generate positional encoding for a grid of the specified size."""
|
||||
h, w = size
|
||||
device: Any = self.positional_encoding_gaussian_matrix.device
|
||||
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
||||
y_embed = grid.cumsum(dim=0) - 0.5
|
||||
x_embed = grid.cumsum(dim=1) - 0.5
|
||||
y_embed = y_embed / h
|
||||
x_embed = x_embed / w
|
||||
|
||||
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
||||
return pe.permute(2, 0, 1) # C x H x W
|
||||
|
||||
def forward_with_coords(
|
||||
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
||||
) -> torch.Tensor:
|
||||
"""Positionally encode points that are not normalized to [0,1]."""
|
||||
coords = coords_input.clone()
|
||||
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
||||
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
||||
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
174
modeling/sam3d.py
Normal file
174
modeling/sam3d.py
Normal file
@ -0,0 +1,174 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from .image_encoder import ImageEncoderViT
|
||||
from .mask_decoder3d_2 import MaskDecoder
|
||||
from .prompt_encoder3d import PromptEncoder
|
||||
|
||||
|
||||
class Sam(nn.Module):
|
||||
mask_threshold: float = 0.0
|
||||
image_format: str = "RGB"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_encoder: ImageEncoderViT,
|
||||
prompt_encoder: PromptEncoder,
|
||||
mask_decoder: MaskDecoder,
|
||||
pixel_mean: List[float] = [123.675, 116.28, 103.53],
|
||||
pixel_std: List[float] = [58.395, 57.12, 57.375],
|
||||
) -> None:
|
||||
"""
|
||||
SAM predicts object masks from an image and input prompts.
|
||||
|
||||
Arguments:
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the
|
||||
image into image embeddings that allow for efficient mask prediction.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
|
||||
and encoded prompts.
|
||||
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
||||
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
||||
"""
|
||||
super().__init__()
|
||||
self.image_encoder = image_encoder
|
||||
self.prompt_encoder = prompt_encoder
|
||||
self.mask_decoder = mask_decoder
|
||||
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
||||
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
||||
|
||||
@property
|
||||
def device(self) -> Any:
|
||||
return self.pixel_mean.device
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
batched_input: List[Dict[str, Any]],
|
||||
multimask_output: bool,
|
||||
) -> List[Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Predicts masks end-to-end from provided images and prompts.
|
||||
If prompts are not known in advance, using SamPredictor is
|
||||
recommended over calling the model directly.
|
||||
|
||||
Arguments:
|
||||
batched_input (list(dict)): A list over input images, each a
|
||||
dictionary with the following keys. A prompt key can be
|
||||
excluded if it is not present.
|
||||
'image': The image as a torch tensor in 3xHxW format,
|
||||
already transformed for input to the model.
|
||||
'original_size': (tuple(int, int)) The original size of
|
||||
the image before transformation, as (H, W).
|
||||
'point_coords': (torch.Tensor) Batched point prompts for
|
||||
this image, with shape BxNx2. Already transformed to the
|
||||
input frame of the model.
|
||||
'point_labels': (torch.Tensor) Batched labels for point prompts,
|
||||
with shape BxN.
|
||||
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
|
||||
Already transformed to the input frame of the model.
|
||||
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
|
||||
in the form Bx1xHxW.
|
||||
multimask_output (bool): Whether the model should predict multiple
|
||||
disambiguating masks, or return a single mask.
|
||||
|
||||
Returns:
|
||||
(list(dict)): A list over input images, where each element is
|
||||
as dictionary with the following keys.
|
||||
'masks': (torch.Tensor) Batched binary mask predictions,
|
||||
with shape BxCxHxW, where B is the number of input prompts,
|
||||
C is determined by multimask_output, and (H, W) is the
|
||||
original size of the image.
|
||||
'iou_predictions': (torch.Tensor) The model's predictions
|
||||
of mask quality, in shape BxC.
|
||||
'low_res_logits': (torch.Tensor) Low resolution logits with
|
||||
shape BxCxHxW, where H=W=256. Can be passed as mask input
|
||||
to subsequent iterations of prediction.
|
||||
"""
|
||||
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
|
||||
image_embeddings = self.image_encoder(input_images)
|
||||
|
||||
outputs = []
|
||||
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
||||
if "point_coords" in image_record:
|
||||
points = (image_record["point_coords"], image_record["point_labels"])
|
||||
else:
|
||||
points = None
|
||||
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
||||
points=points,
|
||||
boxes=image_record.get("boxes", None),
|
||||
masks=image_record.get("mask_inputs", None),
|
||||
)
|
||||
low_res_masks, iou_predictions = self.mask_decoder(
|
||||
image_embeddings=curr_embedding.unsqueeze(0),
|
||||
image_pe=self.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
masks = self.postprocess_masks(
|
||||
low_res_masks,
|
||||
input_size=image_record["image"].shape[-2:],
|
||||
original_size=image_record["original_size"],
|
||||
)
|
||||
masks = masks > self.mask_threshold
|
||||
outputs.append(
|
||||
{
|
||||
"masks": masks,
|
||||
"iou_predictions": iou_predictions,
|
||||
"low_res_logits": low_res_masks,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
||||
def postprocess_masks(
|
||||
self,
|
||||
masks: torch.Tensor,
|
||||
input_size: Tuple[int, ...],
|
||||
original_size: Tuple[int, ...],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Remove padding and upscale masks to the original image size.
|
||||
|
||||
Arguments:
|
||||
masks (torch.Tensor): Batched masks from the mask_decoder,
|
||||
in BxCxHxW format.
|
||||
input_size (tuple(int, int)): The size of the image input to the
|
||||
model, in (H, W) format. Used to remove padding.
|
||||
original_size (tuple(int, int)): The original size of the image
|
||||
before resizing for input to the model, in (H, W) format.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
|
||||
is given by original_size.
|
||||
"""
|
||||
masks = F.interpolate(
|
||||
masks,
|
||||
(self.image_encoder.img_size, self.image_encoder.img_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
masks = masks[..., : input_size[0], : input_size[1]]
|
||||
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
|
||||
return masks
|
||||
|
||||
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Normalize pixel values and pad to a square input."""
|
||||
# Normalize colors
|
||||
x = (x - self.pixel_mean) / self.pixel_std
|
||||
|
||||
# Pad
|
||||
h, w = x.shape[-2:]
|
||||
padh = self.image_encoder.img_size - h
|
||||
padw = self.image_encoder.img_size - w
|
||||
x = F.pad(x, (0, padw, 0, padh))
|
||||
return x
|
240
modeling/transformer.py
Normal file
240
modeling/transformer.py
Normal file
@ -0,0 +1,240 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
import math
|
||||
from typing import Tuple, Type
|
||||
|
||||
from .common import MLPBlock
|
||||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
depth: int,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer decoder that attends to an input image using
|
||||
queries whose positional embedding is supplied.
|
||||
|
||||
Args:
|
||||
depth (int): number of layers in the transformer
|
||||
embedding_dim (int): the channel dimension for the input embeddings
|
||||
num_heads (int): the number of heads for multihead attention. Must
|
||||
divide embedding_dim
|
||||
mlp_dim (int): the channel dimension internal to the MLP block
|
||||
activation (nn.Module): the activation to use in the MLP block
|
||||
"""
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.mlp_dim = mlp_dim
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
for i in range(depth):
|
||||
self.layers.append(
|
||||
TwoWayAttentionBlock(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_dim=mlp_dim,
|
||||
activation=activation,
|
||||
attention_downsample_rate=attention_downsample_rate,
|
||||
skip_first_layer_pe=(i == 0),
|
||||
)
|
||||
)
|
||||
|
||||
self.final_attn_token_to_image = Attention(
|
||||
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
||||
)
|
||||
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embedding: Tensor,
|
||||
image_pe: Tensor,
|
||||
point_embedding: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
image_embedding (torch.Tensor): image to attend to. Should be shape
|
||||
B x embedding_dim x h x w for any h and w.
|
||||
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
||||
have the same shape as image_embedding.
|
||||
point_embedding (torch.Tensor): the embedding to add to the query points.
|
||||
Must have shape B x N_points x embedding_dim for any N_points.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the processed point_embedding
|
||||
torch.Tensor: the processed image_embedding
|
||||
"""
|
||||
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
||||
bs, c, h, w = image_embedding.shape
|
||||
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
||||
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
||||
|
||||
# Prepare queries
|
||||
queries = point_embedding
|
||||
keys = image_embedding
|
||||
|
||||
# Apply transformer blocks and final layernorm
|
||||
for layer in self.layers:
|
||||
queries, keys = layer(
|
||||
queries=queries,
|
||||
keys=keys,
|
||||
query_pe=point_embedding,
|
||||
key_pe=image_pe,
|
||||
)
|
||||
|
||||
# Apply the final attention layer from the points to the image
|
||||
q = queries + point_embedding
|
||||
k = keys + image_pe
|
||||
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm_final_attn(queries)
|
||||
|
||||
return queries, keys
|
||||
|
||||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int = 2048,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
skip_first_layer_pe: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer block with four layers: (1) self-attention of sparse
|
||||
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
||||
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
||||
inputs.
|
||||
|
||||
Arguments:
|
||||
embedding_dim (int): the channel dimension of the embeddings
|
||||
num_heads (int): the number of heads in the attention layers
|
||||
mlp_dim (int): the hidden dimension of the mlp block
|
||||
activation (nn.Module): the activation of the mlp block
|
||||
skip_first_layer_pe (bool): skip the PE on the first layer
|
||||
"""
|
||||
super().__init__()
|
||||
self.self_attn = Attention(embedding_dim, num_heads)
|
||||
self.norm1 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.cross_attn_token_to_image = Attention(
|
||||
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
||||
self.norm3 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.norm4 = nn.LayerNorm(embedding_dim)
|
||||
self.cross_attn_image_to_token = Attention(
|
||||
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
||||
)
|
||||
|
||||
self.skip_first_layer_pe = skip_first_layer_pe
|
||||
|
||||
def forward(
|
||||
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
# Self attention block
|
||||
if self.skip_first_layer_pe:
|
||||
queries = self.self_attn(q=queries, k=queries, v=queries)
|
||||
else:
|
||||
q = queries + query_pe
|
||||
attn_out = self.self_attn(q=q, k=q, v=queries)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm1(queries)
|
||||
|
||||
# Cross attention block, tokens attending to image embedding
|
||||
q = queries + query_pe
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm2(queries)
|
||||
|
||||
# MLP block
|
||||
mlp_out = self.mlp(queries)
|
||||
queries = queries + mlp_out
|
||||
queries = self.norm3(queries)
|
||||
|
||||
# Cross attention block, image embedding attending to tokens
|
||||
q = queries + query_pe
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
||||
keys = keys + attn_out
|
||||
keys = self.norm4(keys)
|
||||
|
||||
return queries, keys
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
An attention layer that allows for downscaling the size of the embedding
|
||||
after projection to queries, keys, and values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
downsample_rate: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.internal_dim = embedding_dim // downsample_rate
|
||||
self.num_heads = num_heads
|
||||
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
||||
|
||||
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
||||
|
||||
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
||||
b, n, c = x.shape
|
||||
x = x.reshape(b, n, num_heads, c // num_heads)
|
||||
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
||||
|
||||
def _recombine_heads(self, x: Tensor) -> Tensor:
|
||||
b, n_heads, n_tokens, c_per_head = x.shape
|
||||
x = x.transpose(1, 2)
|
||||
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
||||
# Input projections
|
||||
q = self.q_proj(q)
|
||||
k = self.k_proj(k)
|
||||
v = self.v_proj(v)
|
||||
|
||||
# Separate into heads
|
||||
q = self._separate_heads(q, self.num_heads)
|
||||
k = self._separate_heads(k, self.num_heads)
|
||||
v = self._separate_heads(v, self.num_heads)
|
||||
|
||||
# Attention
|
||||
_, _, _, c_per_head = q.shape
|
||||
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
||||
attn = attn / math.sqrt(c_per_head)
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
|
||||
# Get output
|
||||
out = attn @ v
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
19
readme.md
Normal file
19
readme.md
Normal file
@ -0,0 +1,19 @@
|
||||
<!-- # Slide-SAM -->
|
||||
# Slide-SAM: Medical SAM meets sliding window
|
||||
|
||||
|
||||
## Training
|
||||
prepare datasets
|
||||
```
|
||||
python -m datasets.generate_txt
|
||||
```
|
||||
|
||||
cache 3d data into slices
|
||||
```
|
||||
python -m datasets.cache_datasets3d
|
||||
```
|
||||
|
||||
run training
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m core.ddp --tag debug
|
||||
```
|
0
test/__init__.py
Normal file
0
test/__init__.py
Normal file
173
test/volume_eval.py
Normal file
173
test/volume_eval.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""
|
||||
Volume evalutaion
|
||||
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
# from datasets.dataset3d import Dataset3D
|
||||
from tutils.new.manager import ConfigManager
|
||||
from datasets.eval_dataloader.loader_abstract import AbstractLoader
|
||||
|
||||
from core.volume_predictor import VolumePredictor
|
||||
from datasets.data_engine import DataManager, BoxPromptGenerator, PointPromptGenerator
|
||||
|
||||
from tutils import tfilename
|
||||
from tutils.new.trainer.recorder import Recorder
|
||||
from trans_utils.metrics import compute_dice_np
|
||||
from trans_utils.data_utils import Data3dSolver
|
||||
|
||||
|
||||
|
||||
class Evaluater:
|
||||
def __init__(self, config) -> None:
|
||||
self.config = config
|
||||
self.recorder = Recorder()
|
||||
|
||||
def solve(self, model, dataset):
|
||||
# model.eval()
|
||||
self.predictor = model
|
||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
# if i <4:
|
||||
# print
|
||||
# continue
|
||||
# for k, v in data.items():
|
||||
# if isinstance(v, torch.Tensor):
|
||||
# data[k] = v.to(self.rank)
|
||||
if self.config['dataset']['prompt'] == 'box':
|
||||
res = self.eval_step(data, batch_idx=i)
|
||||
if self.config['dataset']['prompt'] == 'point':
|
||||
res = self.eval_step_point(data, batch_idx=i)
|
||||
self.recorder.record(res)
|
||||
res = self.recorder.cal_metrics()
|
||||
print(res)
|
||||
print("prompt:", self.config['dataset']['prompt'], " class_idx:", self.config['dataset']['label_idx'])
|
||||
|
||||
def eval_step(self, data, batch_idx=0):
|
||||
name = data['name']
|
||||
dataset_name = data['dataset_name'][0]
|
||||
label_idx = data['label_idx'][0]
|
||||
template_slice_id = data['template_slice_id'][0]
|
||||
|
||||
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
||||
if template_slice_id == 0:
|
||||
template_slice_id += 1
|
||||
elif template_slice_id == (data['img'].shape[0] - 1):
|
||||
template_slice_id -= 1
|
||||
|
||||
spacing = data['spacing'].numpy().tolist()[0]
|
||||
if data['img'].shape[-1] < 260:
|
||||
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
||||
img = data['img'][0][:,:256,:256]
|
||||
label = data['label'][0][:,:256,:256]
|
||||
else:
|
||||
img = data['img'][0]
|
||||
label = data['label'][0]
|
||||
# img = torch.clip(img, -200, 600)
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
||||
box = np.array([box])
|
||||
pred, stability = self.predictor.predict_volume(
|
||||
x=img,
|
||||
box=box,
|
||||
template_slice_id=template_slice_id,
|
||||
return_stability=True,
|
||||
)
|
||||
prompt_type = 'box'
|
||||
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
||||
Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
||||
Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
|
||||
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
|
||||
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
|
||||
print(dataset_name, name, dice)
|
||||
return {"dice": dice}
|
||||
|
||||
def eval_step_point(self, data, batch_idx=0):
|
||||
name = data['name']
|
||||
dataset_name = data['dataset_name'][0]
|
||||
label_idx = data['label_idx'][0]
|
||||
template_slice_id = data['template_slice_id'][0]
|
||||
spacing = data['spacing'].numpy().tolist()[0]
|
||||
|
||||
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
||||
if template_slice_id == 0:
|
||||
template_slice_id += 1
|
||||
elif template_slice_id == (data['img'].shape[0] - 1):
|
||||
template_slice_id -= 1
|
||||
|
||||
if data['img'].shape[-1] < 260:
|
||||
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
||||
img = data['img'][0][:,:256,:256]
|
||||
label = data['label'][0][:,:256,:256]
|
||||
else:
|
||||
img = data['img'][0]
|
||||
label = data['label'][0]
|
||||
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
||||
point = (box[0]+box[2])*0.5 , (box[1]+box[3])*0.5
|
||||
point = np.array([point]).astype(int)
|
||||
if label[template_slice_id][point[0,1], point[0,0]] == 0:
|
||||
print("Use random point instead !!!")
|
||||
point = PointPromptGenerator().get_prompt_point(label[template_slice_id])
|
||||
point = np.array([point]).astype(int)
|
||||
# box = np.array([box])
|
||||
pred = self.predictor.predict_volume(
|
||||
x=img,
|
||||
point_coords=point,
|
||||
point_labels=np.ones_like(point)[:,:1],
|
||||
template_slice_id=template_slice_id,
|
||||
)
|
||||
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
||||
prompt_type = 'point'
|
||||
Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}.nii.gz"))
|
||||
print(dataset_name, name, dice)
|
||||
return {"dice": dice}
|
||||
|
||||
def to_RGB(img):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
# from core.learner3 import SamLearner
|
||||
# from modeling.build_sam3d import sam_model_registry
|
||||
|
||||
from core.learner3 import SamLearner
|
||||
from modeling.build_sam3d2 import sam_model_registry
|
||||
|
||||
EX_CONFIG = {
|
||||
'dataset':{
|
||||
'prompt': 'box',
|
||||
'dataset_list': ['word'], # ["sabs"], chaos, word
|
||||
'label_idx': 1,
|
||||
}
|
||||
}
|
||||
|
||||
config = ConfigManager()
|
||||
config.add_config("configs/vit_b_103.yaml")
|
||||
config.add_config(EX_CONFIG)
|
||||
|
||||
# Init Model
|
||||
model_type = "vit_b"
|
||||
sam = sam_model_registry[model_type](checkpoint=None)
|
||||
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
||||
learner.use_lora()
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt_v/model_latest.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_20.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_16.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora_small/ckpt/model_epoch_6.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora/ckpt/model_epoch_50.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_8/ckpt_v/model_latest.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_5/ckpt/model_epoch_100.pth"
|
||||
pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_500000.pth"
|
||||
learner.load_well_trained_model(pth)
|
||||
learner.cuda()
|
||||
predictor = VolumePredictor(
|
||||
model=learner.model,
|
||||
use_postprocess=True,
|
||||
use_noise_remove=True,)
|
||||
|
||||
solver = Evaluater(config)
|
||||
dataset = AbstractLoader(config['dataset'], split="test")
|
||||
solver.solve(predictor, dataset)
|
0
trans_utils/__init__.py
Normal file
0
trans_utils/__init__.py
Normal file
BIN
trans_utils/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
trans_utils/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
trans_utils/__pycache__/trainer_ddp.cpython-38.pyc
Normal file
BIN
trans_utils/__pycache__/trainer_ddp.cpython-38.pyc
Normal file
Binary file not shown.
97
trans_utils/data_utils.py
Normal file
97
trans_utils/data_utils.py
Normal file
@ -0,0 +1,97 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from tutils import tfilename
|
||||
from tutils.nn.data import read, itk_to_np, np_to_itk, write
|
||||
from torchvision.utils import save_image
|
||||
import SimpleITK as sitk
|
||||
from tutils.nn.data.tsitk.preprocess import resampleImage
|
||||
|
||||
|
||||
class Data3dSolver:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def simple_write(self, data_np, path="tmp.nii.gz", spacing=None):
|
||||
assert len(data_np.shape) == 3, f"Got {data_np.shape}"
|
||||
data_np = data_np.astype(np.int16)
|
||||
data_itk = np_to_itk(data_np)
|
||||
if spacing is not None:
|
||||
data_itk.SetSpacing(spacing)
|
||||
write(data_itk, path=tfilename(path))
|
||||
print("Save to ", path)
|
||||
|
||||
def write_slices(self, data, path="tmp_masks.jpg"):
|
||||
if isinstance(data, torch.Tensor):
|
||||
pass
|
||||
if isinstance(data, np.ndarray):
|
||||
data = torch.Tensor(data)
|
||||
assert len(data.shape) == 4, f"Shape should be (b c h w) c=1/3, Got {data.shape}"
|
||||
assert data.shape[1] == 1 or data.shape[1] == 3, f"Shape should be (b c h w) c=1/3, Got {data.shape}"
|
||||
assert path.endswith(".jpg") or path.endswith(".png")
|
||||
save_image(torch.Tensor(data).unsqueeze(1), tfilename(path))
|
||||
print("Save to ", path)
|
||||
|
||||
def write_multilabel_nii(self, data, path, meta=None):
|
||||
if isinstance(data, dict):
|
||||
data_all = [v for k,v in data.items()]
|
||||
data = np.stack(data_all, axis=0)
|
||||
assert len(data.shape) == 4, f"Shape should be (b c h w) , Got {data.shape}"
|
||||
# Merge labels to one
|
||||
merged = np.zeros_like(data[0])
|
||||
for i, datai in enumerate(data):
|
||||
merged = np.where(datai > 0, datai * (i+1), merged)
|
||||
|
||||
merged = merged.astype(np.int16)
|
||||
data_itk = np_to_itk(merged)
|
||||
if meta is not None:
|
||||
data_itk = formalize(data_itk, meta)
|
||||
write(data_itk, path=tfilename(path))
|
||||
print("Save to ", path)
|
||||
|
||||
def fwrite(self, data, path, meta):
|
||||
data = data.astype(np.int16)
|
||||
data_itk = np_to_itk(data)
|
||||
data_itk = formalize(data_itk, meta)
|
||||
write(data_itk, path=tfilename(path))
|
||||
|
||||
def read(self, path, spacing_norm=True):
|
||||
data_itk = read(path)
|
||||
if spacing_norm:
|
||||
ori_size = data_itk.GetSize()
|
||||
ori_spacing = data_itk.GetSpacing()
|
||||
data_itk = self.normalize_spacing(data_itk)
|
||||
new_size = data_itk.GetSize()
|
||||
new_spacing = data_itk.GetSpacing()
|
||||
print("Change size from ", ori_size, new_size)
|
||||
print("Change spacing from ", ori_spacing, new_spacing)
|
||||
data_np = itk_to_np(data_itk)
|
||||
print("[data_utils.DEBUG]", data_np.shape)
|
||||
return data_np, data_itk.GetSpacing()
|
||||
|
||||
def normalize_spacing(self, data_itk):
|
||||
spacing = data_itk.GetSpacing()
|
||||
new_spacing = (min(spacing),min(spacing),min(spacing))
|
||||
data_itk = resampleImage(data_itk, NewSpacing=new_spacing)
|
||||
return data_itk
|
||||
|
||||
|
||||
def formalize(img:sitk.SimpleITK.Image, meta:sitk.SimpleITK.Image):
|
||||
# Size = meta.GetSize()
|
||||
Spacing = meta.GetSpacing()
|
||||
Origin = meta.GetOrigin()
|
||||
Direction = meta.GetDirection()
|
||||
|
||||
img.SetSpacing(Spacing)
|
||||
img.SetOrigin(Origin)
|
||||
img.SetDirection(Direction)
|
||||
return img
|
||||
|
||||
|
||||
def write(img:sitk.SimpleITK.Image, path:str, mode:str="nifti"):
|
||||
"""
|
||||
Path: (example) os.path.join(jpg_dir, f"trans_{random_name}.nii.gz")
|
||||
"""
|
||||
mode = mode.lower()
|
||||
writer = sitk.ImageFileWriter()
|
||||
writer.SetFileName(path)
|
||||
writer.Execute(img)
|
24
trans_utils/metrics.py
Normal file
24
trans_utils/metrics.py
Normal file
@ -0,0 +1,24 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
def compute_dice_np(pred_mask, gt_mask):
|
||||
""" numpy values
|
||||
"""
|
||||
assert gt_mask.max() == 1, f"Got gt_mask.max():{gt_mask.max()} Error!!"
|
||||
pred_mask = np.array(pred_mask>0)
|
||||
gt_mask = np.array(gt_mask>0)
|
||||
intersection = np.array(pred_mask * gt_mask).sum()
|
||||
union = pred_mask.sum() + gt_mask.sum()
|
||||
dice = intersection * 2 / union # if union > 0 else 0
|
||||
return dice
|
||||
|
||||
|
||||
def compute_prec_np(pred_mask, gt_mask):
|
||||
true_pos = (np.int32(pred_mask>0) * np.int32(gt_mask>0)).sum()
|
||||
return true_pos / np.int32(pred_mask>0).sum()
|
||||
|
||||
def compute_recall_np(pred_mask, gt_mask):
|
||||
true_pos = (np.int32(pred_mask>0) * np.int32(gt_mask>0)).sum()
|
||||
false_neg = ((gt_mask - pred_mask)>0).sum()
|
||||
return true_pos / (true_pos + false_neg)
|
418
trans_utils/trainer_ddp.py
Normal file
418
trans_utils/trainer_ddp.py
Normal file
@ -0,0 +1,418 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
import tempfile
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
|
||||
from tutils.new.trainer.trainer_abstract import AbstractTrainer
|
||||
from tutils.new.manager.loggers import MultiLogger
|
||||
from tutils.new.manager.csv_recorder import CSVLogger
|
||||
from tutils.new.trainer.recorder import Recorder
|
||||
from tutils.new.utils.core_utils import _get_time_str
|
||||
from tutils.new.utils.public_utils import dict_to_str
|
||||
|
||||
# Waiting for update
|
||||
from tutils import tfilename, tenum
|
||||
|
||||
# export MASTER_ADDR=192.168.1.100
|
||||
# export MASTER_PORT=12345
|
||||
|
||||
|
||||
def setup(rank, world_size):
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
|
||||
# initialize the process group
|
||||
dist.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def ddp(rank, world_size):
|
||||
print(f"Running basic DDP example on rank {rank}.")
|
||||
|
||||
# create model and move it to GPU with id rank
|
||||
model = model.to(rank)
|
||||
ddp_model = DDP(model, device_ids=[rank])
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = ddp_model(torch.randn(20, 10))
|
||||
labels = torch.randn(20, 5).to(rank)
|
||||
loss_fn(outputs, labels).backward()
|
||||
optimizer.step()
|
||||
|
||||
cleanup()
|
||||
|
||||
|
||||
def get_logger(config):
|
||||
config_base = config['base']
|
||||
config_logger = config['logger']
|
||||
logger = MultiLogger(logdir=config_base['runs_dir'],
|
||||
record_mode=config_logger.get('record_mode', None),
|
||||
tag=config_base['tag'],
|
||||
extag=config_base.get('experiment', None),
|
||||
action=config_logger.get('action', 'k')) # backup config.yaml
|
||||
return logger
|
||||
|
||||
class DDPTrainer(AbstractTrainer):
|
||||
def __init__(self, config, tester=None, monitor=None, rank='cuda', world_size=0, logger=None):
|
||||
super().__init__(config, tester, monitor, rank, world_size)
|
||||
self.rank = rank
|
||||
self.logger = logger
|
||||
self.logging_available = (self.rank == 0 or self.rank == 'cuda')
|
||||
print("Running on ", rank)
|
||||
self.global_iteration = 0
|
||||
if self.logging_available:
|
||||
print(f"Logger at Process(rank={rank})")
|
||||
self.recorder = Recorder(reduction=self.recorder_mode)
|
||||
self.recorder_valid = Recorder(reduction=self.recorder_mode)
|
||||
self.recorder_test = Recorder(reduction=self.recorder_mode)
|
||||
self.logger = None
|
||||
self.csvlogger = CSVLogger(tfilename(self.runs_dir, "best_record"))
|
||||
self.csvlogger_all = CSVLogger(tfilename(self.runs_dir, "all_record"))
|
||||
self.monitor = monitor
|
||||
self.tester = tester
|
||||
|
||||
self.logger = get_logger(config)
|
||||
assert self.logger is not None, f"{__file__} Gotrank {self.rank}"
|
||||
|
||||
if self.use_amp:
|
||||
self.scalar = GradScaler()
|
||||
print("Debug settings: use amp=",self.use_amp)
|
||||
|
||||
def init_model(self, model, trainset, validset=None, **kwargs):
|
||||
# Use CacheDataset
|
||||
# trainset = CacheDataset(trainset, num_workers=12, cache_rate=0.5)
|
||||
if trainset is not None:
|
||||
assert len(trainset) > 0 , f"{__file__} Got{len(trainset)}"
|
||||
self.trainloader = DataLoader(dataset=trainset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=True)
|
||||
if validset is not None:
|
||||
self.validloader = DataLoader(dataset=validset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=True)
|
||||
if self.load_pretrain_model:
|
||||
model.module.load()
|
||||
rank = self.rank
|
||||
model = model.to(rank)
|
||||
ddp_model = DDP(model, device_ids=[rank])
|
||||
return ddp_model
|
||||
|
||||
def configure_optim(self, model, **kwargs):
|
||||
# Set optimizer and scheduler
|
||||
optim_configs = model.module.configure_optimizers()
|
||||
assert isinstance(optim_configs, dict)
|
||||
optimizer = optim_configs['optimizer']
|
||||
scheduler = optim_configs['scheduler']
|
||||
|
||||
if self.load_optimizer:
|
||||
start_epoch = model.module.load_optim(optimizer)
|
||||
print(f"[DDPTrainer] Continue training, from epoch {start_epoch}")
|
||||
else:
|
||||
start_epoch = self.start_epoch
|
||||
return optimizer, scheduler, start_epoch
|
||||
|
||||
def fit(self, model, trainset, validset=None):
|
||||
model = self.init_model(model, trainset, validset=validset, rank=self.rank)
|
||||
self.init_timers()
|
||||
optimizer, scheduler, start_epoch = self.configure_optim(model)
|
||||
|
||||
for epoch in range(start_epoch, self.max_epochs):
|
||||
self.on_before_zero_grad()
|
||||
# Training
|
||||
self.timer_epoch()
|
||||
do_training_log = (epoch % self.training_log_interval == 0)
|
||||
if self.validloader is not None and self.validation_interval > 0 and epoch % self.validation_interval == 0:
|
||||
self.valid(model, self.validloader, epoch, do_training_log)
|
||||
|
||||
if self.trainloader is not None:
|
||||
self.train(model, self.trainloader, epoch, optimizer, scheduler, do_training_log)
|
||||
|
||||
if self.logging_available:
|
||||
self.test(model, epoch=epoch)
|
||||
|
||||
if epoch % self.save_interval == 0 and self.logging_available:
|
||||
if 'latest' in self.save_mode:
|
||||
self.save(model, epoch, 'latest', optimizer)
|
||||
if 'all' in self.save_mode:
|
||||
self.save(model, epoch, None, optimizer)
|
||||
# time_save_model = self.timer_epoch()
|
||||
|
||||
print("Training is Over for GPU rank ", self.rank)
|
||||
self.cleanup()
|
||||
|
||||
def test(self, model, epoch):
|
||||
# Evaluation
|
||||
if epoch % self.val_check_interval == 0 and self.logging_available:
|
||||
print("Note: Tester runs on <rank 0> only")
|
||||
if self.tester is not None:
|
||||
out = self.tester.test(model=model, epoch=epoch, rank=self.rank)
|
||||
if self.monitor is not None:
|
||||
best_dict = self.monitor.record(out, epoch)
|
||||
self.recorder_test.record({**best_dict, **out})
|
||||
if best_dict['isbest']:
|
||||
if 'best' in self.save_mode:
|
||||
self.save(model, epoch, type='best')
|
||||
self.csvlogger.record({**best_dict, **out, "time": _get_time_str()})
|
||||
if self.save_all_records:
|
||||
self.csvlogger_all.record({**best_dict, **out, "time": _get_time_str()})
|
||||
self.logger.info(f"\n[*] {dict_to_str(best_dict)}[*] Epoch {epoch}: \n{dict_to_str(out)}")
|
||||
self.logger.add_scalars(out, step=epoch, tag='test')
|
||||
# if ''
|
||||
else:
|
||||
self.logger.info(f"\n[*] Epoch {epoch}: {dict_to_str(out)}")
|
||||
self.on_after_testing(d=out)
|
||||
|
||||
def save(self, model, epoch, type=None, optimizer=None, **kwargs):
|
||||
if self.logging_available:
|
||||
if type is None:
|
||||
# if self.save_interval > 0 and epoch % self.save_interval == 0:
|
||||
save_name = "/ckpt/model_epoch_{}.pth".format(epoch)
|
||||
model.module.save(tfilename(self.runs_dir, save_name), epoch=epoch)
|
||||
self.logger.info(f"Epoch {epoch}: Save model to ``{save_name}``! ")
|
||||
elif type == 'best':
|
||||
# save_name = "/ckpt/best_model_epoch_{}.pth".format(epoch)
|
||||
save_name2 = "/ckpt_v/model_best.pth"
|
||||
# model.save(tfilename(self.runs_dir, save_name), epoch=epoch, is_best=True)
|
||||
model.module.save(tfilename(self.runs_dir, save_name2), epoch=epoch, is_best=True)
|
||||
self.logger.info(f"[Best model] Epoch {epoch}: Save model to ``{save_name2}``! ")
|
||||
elif type == 'latest':
|
||||
if self.save_interval > 0 and epoch % self.save_interval == 0:
|
||||
save_name = "/ckpt_v/model_latest.pth"
|
||||
model.module.save(tfilename(self.runs_dir, save_name), epoch=epoch, is_latest=True)
|
||||
save_optim_name = "/ckpt/optim_latest.pth"
|
||||
model.module.save_optim(tfilename(self.runs_dir, save_optim_name), optimizer=optimizer, epoch=epoch)
|
||||
self.logger.info(f"Epoch {epoch}: Save checkpoint to ``{save_name}``")
|
||||
elif type == "iteration":
|
||||
save_name = "/ckpt/model_iter_{}.pth".format(self.global_iteration)
|
||||
model.module.save(tfilename(self.runs_dir, save_name), epoch=self.global_iteration)
|
||||
self.logger.info(f"Epoch {epoch}: Save model to ``{save_name}``! ")
|
||||
|
||||
|
||||
def train(self, model, trainloader, epoch, optimizer, scheduler=None, do_training_log=True):
|
||||
model.train()
|
||||
out = {}
|
||||
if do_training_log and self.logging_available:
|
||||
self.recorder.clear()
|
||||
time_record = 0.1111
|
||||
self.timer_batch()
|
||||
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
for load_time, batch_idx, data in tenum(trainloader):
|
||||
optimizer.zero_grad()
|
||||
self.timer_data()
|
||||
# training steps
|
||||
for k, v in data.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(self.rank)
|
||||
time_data_cuda = self.timer_data()
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
self.timer_net()
|
||||
out = model.module.training_step(data, batch_idx, epoch=epoch)
|
||||
# try:
|
||||
# out = model.module.training_step(data, batch_idx, epoch=epoch)
|
||||
# except Exception as e:
|
||||
# msg = f"Ignore Error! {e}"
|
||||
# if self.logging_available:
|
||||
# self.logger.info(msg)
|
||||
# else:
|
||||
# print(msg)
|
||||
# continue
|
||||
assert isinstance(out, dict)
|
||||
time_fd = self.timer_net()
|
||||
loss = out['loss']
|
||||
self.scalar.scale(loss).backward()
|
||||
self.scalar.step(optimizer)
|
||||
self.scalar.update()
|
||||
time_bp = self.timer_net()
|
||||
else:
|
||||
self.timer_net()
|
||||
try:
|
||||
out = model.module.training_step(data, batch_idx, epoch=epoch)
|
||||
except Exception as e:
|
||||
msg = f"Ignore Error! {e}"
|
||||
if self.logging_available:
|
||||
self.logger.info(msg)
|
||||
else:
|
||||
print(msg)
|
||||
continue
|
||||
if out['loss'] is None:
|
||||
failed_count += 1
|
||||
continue
|
||||
if torch.isnan(out['loss']):
|
||||
print("Ignore Nan Value: ", out['loss'])
|
||||
failed_count += 1
|
||||
# raise ValueError(f"Get loss: {out['loss']}")
|
||||
assert isinstance(out, dict)
|
||||
time_fd = self.timer_net()
|
||||
loss = out['loss']
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
time_bp = self.timer_net()
|
||||
success_count += 1
|
||||
|
||||
time_batch = self.timer_batch()
|
||||
# batch logger !
|
||||
if self.logging_available and do_training_log:
|
||||
out['time_load'] = load_time
|
||||
out['time_cuda'] = time_data_cuda
|
||||
out['time_forward'] = time_fd
|
||||
out['time_bp'] = time_bp
|
||||
out['time_record'] = time_record
|
||||
out['time_batch'] = time_batch
|
||||
self.timer_data()
|
||||
self.recorder.record(out)
|
||||
time_record = self.timer_data()
|
||||
|
||||
# for debug !
|
||||
if epoch == 0:
|
||||
if self.logging_available:
|
||||
self.logger.info("[*] Debug Checking Pipeline !!!")
|
||||
del out
|
||||
return
|
||||
if self.global_iteration % 100 == 0:
|
||||
print(f"Epoch: {epoch} | batch:{batch_idx}/{len(trainloader)}, Iteration:{self.global_iteration}, results: {to_item(out)}", end='\n')
|
||||
if self.global_iteration % 5000 == 0:
|
||||
self.save(model, epoch, "iteration", optimizer)
|
||||
# print("")
|
||||
self.global_iteration += 1
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
# epoch logger !
|
||||
if self.logging_available:
|
||||
if do_training_log :
|
||||
_dict = self.recorder.cal_metrics()
|
||||
_dict['time_total'] = self.timer_epoch()
|
||||
|
||||
# print(_dict)
|
||||
# assert isinstance(lr, float), f"Got lr={lr}, type: {type(lr)}"
|
||||
loss_str = ""
|
||||
for k, v in _dict.items():
|
||||
loss_str += "{}:{:.4f} ".format(k, v)
|
||||
# lr = optimizer.param_groups[0]['lr']
|
||||
lr = self.get_lr(optimizer)
|
||||
_dict['lr'] = lr
|
||||
loss_str += "{}:{:.6e} ".format('lr', lr)
|
||||
self.logger.info(f"Epoch {epoch}: {loss_str}")
|
||||
# _dict_with_train_tag = {f"train/{k}":v for k,v in _dict.items()}
|
||||
self.logger.add_scalars(_dict, step=epoch, tag='train')
|
||||
time_log_scalars = self.timer_epoch()
|
||||
self.on_after_training(d=_dict)
|
||||
# Clear
|
||||
del out
|
||||
del data
|
||||
|
||||
def valid(self, model, validloader, epoch, do_training_log=True):
|
||||
model.eval()
|
||||
out = {}
|
||||
if do_training_log and self.logging_available:
|
||||
self.recorder_valid.clear()
|
||||
time_record = 0.1
|
||||
self.timer_batch()
|
||||
|
||||
success_count = 1
|
||||
failed_count = 1
|
||||
for load_time, batch_idx, data in tenum(validloader):
|
||||
# model.on_before_zero_grad()
|
||||
self.timer_data()
|
||||
# training steps
|
||||
for k, v in data.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(self.rank)
|
||||
time_data_cuda = self.timer_data()
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
self.timer_net()
|
||||
out = model.module.validation_step(data, batch_idx, epoch=epoch)
|
||||
assert isinstance(out, dict)
|
||||
time_fd = self.timer_net()
|
||||
else:
|
||||
self.timer_net()
|
||||
out = model.module.validation_step(data, batch_idx, epoch=epoch)
|
||||
if out['loss'] is None:
|
||||
failed_count += 1
|
||||
continue
|
||||
if torch.isnan(out['loss']):
|
||||
self.logger.info("Nan Value: ", out['loss'])
|
||||
failed_count += 1
|
||||
raise ValueError(f"Get loss: {out['loss']}")
|
||||
assert isinstance(out, dict)
|
||||
time_fd = self.timer_net()
|
||||
success_count += 1
|
||||
|
||||
time_batch = self.timer_batch()
|
||||
# batch logger !
|
||||
if do_training_log and self.logging_available:
|
||||
out['time_load'] = load_time
|
||||
out['time_cuda'] = time_data_cuda
|
||||
out['time_forward'] = time_fd
|
||||
out['time_record'] = time_record
|
||||
out['time_batch'] = time_batch
|
||||
self.timer_data()
|
||||
self.recorder_valid.record(out)
|
||||
time_record = self.timer_data()
|
||||
|
||||
if batch_idx % 2 == 0:
|
||||
print(f"Valid Epoch: {epoch}. Processing batch_idx:{batch_idx} / {len(validloader)}, time_load: {load_time}, results: {to_item(out)}", end='\r')
|
||||
|
||||
if epoch == 0:
|
||||
if self.logging_available:
|
||||
self.logger.info("[*] Debug Checking validation Pipeline !!!")
|
||||
del out
|
||||
return
|
||||
# model.on_after_zero_grad(d=out)
|
||||
if self.logging_available:
|
||||
self.logger.info(f"Training Success Ratio: {success_count / (success_count + failed_count)}")
|
||||
|
||||
# epoch logger !
|
||||
if self.logging_available:
|
||||
if do_training_log :
|
||||
_dict = self.recorder_valid.cal_metrics()
|
||||
_dict['time_total'] = self.timer_epoch()
|
||||
|
||||
# print(_dict)
|
||||
# assert isinstance(lr, float), f"Got lr={lr}, type: {type(lr)}"
|
||||
loss_str = ""
|
||||
for k, v in _dict.items():
|
||||
loss_str += "{}:{:.4f} ".format(k, v)
|
||||
self.logger.info(f"Epoch {epoch}: {loss_str}")
|
||||
# _dict_with_val_tag = {f"val/{k}":v for k,v in _dict.items()}
|
||||
self.logger.add_scalars(_dict, step=epoch, tag='val')
|
||||
time_log_scalars = self.timer_epoch()
|
||||
self.on_after_training(d=_dict)
|
||||
# Clear
|
||||
del out
|
||||
del data
|
||||
|
||||
|
||||
|
||||
def to_item(tensors):
|
||||
for k,v in tensors.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
tensors[k] = v.detach().cpu().item()
|
||||
return tensors
|
5
utils/__init__.py
Normal file
5
utils/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
BIN
utils/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
utils/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
utils/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/amg.cpython-38.pyc
Normal file
BIN
utils/__pycache__/amg.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/amg3d.cpython-38.pyc
Normal file
BIN
utils/__pycache__/amg3d.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/masks3d_utils.cpython-38.pyc
Normal file
BIN
utils/__pycache__/masks3d_utils.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/transforms.cpython-310.pyc
Normal file
BIN
utils/__pycache__/transforms.cpython-310.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/transforms.cpython-38.pyc
Normal file
BIN
utils/__pycache__/transforms.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/transforms3d.cpython-38.pyc
Normal file
BIN
utils/__pycache__/transforms3d.cpython-38.pyc
Normal file
Binary file not shown.
399
utils/amg.py
Normal file
399
utils/amg.py
Normal file
@ -0,0 +1,399 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from typing import Any, Dict, Generator, ItemsView, List, Tuple
|
||||
|
||||
|
||||
class MaskData:
|
||||
"""
|
||||
A structure for storing masks and their related data in batched format.
|
||||
Implements basic filtering and concatenation.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
for v in kwargs.values():
|
||||
assert isinstance(
|
||||
v, (list, np.ndarray, torch.Tensor)
|
||||
), "MaskData only supports list, numpy arrays, and torch tensors."
|
||||
self._stats = dict(**kwargs)
|
||||
|
||||
def __setitem__(self, key: str, item: Any) -> None:
|
||||
assert isinstance(
|
||||
item, (list, np.ndarray, torch.Tensor)
|
||||
), "MaskData only supports list, numpy arrays, and torch tensors."
|
||||
self._stats[key] = item
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
del self._stats[key]
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._stats[key]
|
||||
|
||||
def items(self) -> ItemsView[str, Any]:
|
||||
return self._stats.items()
|
||||
|
||||
def filter(self, keep: torch.Tensor) -> None:
|
||||
for k, v in self._stats.items():
|
||||
if v is None:
|
||||
self._stats[k] = None
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
|
||||
elif isinstance(v, np.ndarray):
|
||||
self._stats[k] = v[keep.detach().cpu().numpy()]
|
||||
elif isinstance(v, list) and keep.dtype == torch.bool:
|
||||
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
|
||||
elif isinstance(v, list):
|
||||
self._stats[k] = [v[i] for i in keep]
|
||||
else:
|
||||
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
||||
|
||||
def cat(self, new_stats: "MaskData") -> None:
|
||||
for k, v in new_stats.items():
|
||||
if k not in self._stats or self._stats[k] is None:
|
||||
self._stats[k] = deepcopy(v)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
if k == "boxes" and len(v.shape) == 1:
|
||||
v = v[None,:]
|
||||
if k == "boxes" and len(self._stats[k].shape) == 1:
|
||||
self._stats[k] = self._stats[k][None, :]
|
||||
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
|
||||
elif isinstance(v, np.ndarray):
|
||||
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
|
||||
elif isinstance(v, list):
|
||||
self._stats[k] = self._stats[k] + deepcopy(v)
|
||||
else:
|
||||
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
||||
|
||||
def to_numpy(self) -> None:
|
||||
for k, v in self._stats.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
self._stats[k] = v.detach().cpu().numpy()
|
||||
|
||||
|
||||
def is_box_near_crop_edge(
|
||||
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
|
||||
) -> torch.Tensor:
|
||||
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
||||
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
||||
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
||||
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
||||
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
||||
return torch.any(near_crop_edge, dim=1)
|
||||
|
||||
|
||||
def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
|
||||
box_xywh = deepcopy(box_xyxy)
|
||||
box_xywh[2] = box_xywh[2] - box_xywh[0]
|
||||
box_xywh[3] = box_xywh[3] - box_xywh[1]
|
||||
return box_xywh
|
||||
|
||||
|
||||
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
||||
assert len(args) > 0 and all(
|
||||
len(a) == len(args[0]) for a in args
|
||||
), "Batched iteration must have inputs of all the same size."
|
||||
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
||||
for b in range(n_batches):
|
||||
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
||||
|
||||
|
||||
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Encodes masks to an uncompressed RLE, in the format expected by
|
||||
pycoco tools.
|
||||
"""
|
||||
# Put in fortran order and flatten h,w
|
||||
b, h, w = tensor.shape
|
||||
tensor = tensor.permute(0, 2, 1).flatten(1)
|
||||
|
||||
# Compute change indices
|
||||
diff = tensor[:, 1:] ^ tensor[:, :-1]
|
||||
change_indices = diff.nonzero()
|
||||
|
||||
# Encode run length
|
||||
out = []
|
||||
for i in range(b):
|
||||
cur_idxs = change_indices[change_indices[:, 0] == i, 1]
|
||||
cur_idxs = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
|
||||
cur_idxs + 1,
|
||||
torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
|
||||
]
|
||||
)
|
||||
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
||||
counts = [] if tensor[i, 0] == 0 else [0]
|
||||
counts.extend(btw_idxs.detach().cpu().tolist())
|
||||
out.append({"size": [h, w], "counts": counts})
|
||||
return out
|
||||
|
||||
|
||||
def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
||||
"""Compute a binary mask from an uncompressed RLE."""
|
||||
h, w = rle["size"]
|
||||
mask = np.empty(h * w, dtype=bool)
|
||||
idx = 0
|
||||
parity = False
|
||||
for count in rle["counts"]:
|
||||
mask[idx : idx + count] = parity
|
||||
idx += count
|
||||
parity ^= True
|
||||
mask = mask.reshape(w, h)
|
||||
return mask.transpose() # Put in C order
|
||||
|
||||
|
||||
def area_from_rle(rle: Dict[str, Any]) -> int:
|
||||
return sum(rle["counts"][1::2])
|
||||
|
||||
|
||||
def calculate_stability_score(
|
||||
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the stability score for a batch of masks. The stability
|
||||
score is the IoU between the binary masks obtained by thresholding
|
||||
the predicted mask logits at high and low values.
|
||||
"""
|
||||
# One mask is always contained inside the other.
|
||||
# Save memory by preventing unnecessary cast to torch.int64
|
||||
intersections = (
|
||||
(masks > (mask_threshold + threshold_offset))
|
||||
.sum(-1, dtype=torch.int16)
|
||||
.sum(-1, dtype=torch.int32)
|
||||
)
|
||||
unions = (
|
||||
(masks > (mask_threshold - threshold_offset))
|
||||
.sum(-1, dtype=torch.int16)
|
||||
.sum(-1, dtype=torch.int32)
|
||||
)
|
||||
return intersections / unions
|
||||
|
||||
|
||||
def build_point_grid(n_per_side: int) -> np.ndarray:
|
||||
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
||||
offset = 1 / (2 * n_per_side)
|
||||
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
||||
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
||||
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
||||
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
||||
return points
|
||||
|
||||
|
||||
def build_all_layer_point_grids(
|
||||
n_per_side: int, n_layers: int, scale_per_layer: int
|
||||
) -> List[np.ndarray]:
|
||||
"""Generates point grids for all crop layers."""
|
||||
points_by_layer = []
|
||||
for i in range(n_layers + 1):
|
||||
n_points = int(n_per_side / (scale_per_layer**i))
|
||||
points_by_layer.append(build_point_grid(n_points))
|
||||
return points_by_layer
|
||||
|
||||
|
||||
def generate_crop_boxes(
|
||||
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
||||
) -> Tuple[List[List[int]], List[int]]:
|
||||
"""
|
||||
Generates a list of crop boxes of different sizes. Each layer
|
||||
has (2**i)**2 boxes for the ith layer.
|
||||
"""
|
||||
crop_boxes, layer_idxs = [], []
|
||||
im_h, im_w = im_size
|
||||
short_side = min(im_h, im_w)
|
||||
|
||||
# Original image
|
||||
crop_boxes.append([0, 0, im_w, im_h])
|
||||
layer_idxs.append(0)
|
||||
|
||||
def crop_len(orig_len, n_crops, overlap):
|
||||
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
|
||||
|
||||
for i_layer in range(n_layers):
|
||||
n_crops_per_side = 2 ** (i_layer + 1)
|
||||
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
||||
|
||||
crop_w = crop_len(im_w, n_crops_per_side, overlap)
|
||||
crop_h = crop_len(im_h, n_crops_per_side, overlap)
|
||||
|
||||
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
|
||||
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
|
||||
|
||||
# Crops in XYWH format
|
||||
for x0, y0 in product(crop_box_x0, crop_box_y0):
|
||||
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
|
||||
crop_boxes.append(box)
|
||||
layer_idxs.append(i_layer + 1)
|
||||
|
||||
return crop_boxes, layer_idxs
|
||||
|
||||
|
||||
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
||||
# Check if boxes has a channel dimension
|
||||
if len(boxes.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
return boxes + offset
|
||||
|
||||
|
||||
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = torch.tensor([[x0, y0]], device=points.device)
|
||||
# Check if points has a channel dimension
|
||||
if len(points.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
return points + offset
|
||||
|
||||
|
||||
def uncrop_masks(
|
||||
masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
|
||||
) -> torch.Tensor:
|
||||
x0, y0, x1, y1 = crop_box
|
||||
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
||||
return masks
|
||||
# Coordinate transform masks
|
||||
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
|
||||
pad = (x0, pad_x - x0, y0, pad_y - y0)
|
||||
return torch.nn.functional.pad(masks, pad, value=0)
|
||||
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
|
||||
def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
|
||||
from pycocotools import mask as mask_utils # type: ignore
|
||||
|
||||
h, w = uncompressed_rle["size"]
|
||||
rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
|
||||
rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
|
||||
return rle
|
||||
|
||||
|
||||
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
|
||||
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
||||
"""
|
||||
# torch.max below raises an error on empty inputs, just skip in this case
|
||||
if torch.numel(masks) == 0:
|
||||
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
||||
|
||||
# Normalize shape to CxHxW
|
||||
masks = masks.squeeze()
|
||||
assert len(masks.shape) <= 3
|
||||
shape = masks.shape
|
||||
h, w = shape[-2:]
|
||||
if len(shape) > 2:
|
||||
masks = masks.flatten(0, -3)
|
||||
else:
|
||||
masks = masks.unsqueeze(0)
|
||||
|
||||
# Get top and bottom edges
|
||||
in_height, _ = torch.max(masks, dim=-1)
|
||||
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
|
||||
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
||||
in_height_coords = in_height_coords + h * (~in_height)
|
||||
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
||||
|
||||
# Get left and right edges
|
||||
in_width, _ = torch.max(masks, dim=-2)
|
||||
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
|
||||
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
||||
in_width_coords = in_width_coords + w * (~in_width)
|
||||
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
||||
|
||||
# If the mask is empty the right edge will be to the left of the left edge.
|
||||
# Replace these boxes with [0, 0, 0, 0]
|
||||
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
||||
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
||||
out = out * (~empty_filter).unsqueeze(-1)
|
||||
|
||||
# Return to original shape
|
||||
if len(shape) > 2:
|
||||
out = out.reshape(*shape[:-2], 4)
|
||||
else:
|
||||
out = out[0]
|
||||
|
||||
return out
|
||||
|
||||
def batched_mask3d_to_box(masks: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
|
||||
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
||||
"""
|
||||
# torch.max below raises an error on empty inputs, just skip in this case
|
||||
if torch.numel(masks) == 0:
|
||||
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
||||
|
||||
# Normalize shape to CxHxW
|
||||
assert masks.shape[1] == 3
|
||||
masks = masks[:,1,:,:]
|
||||
shape = masks.shape
|
||||
h, w = shape[-2:]
|
||||
if len(shape) > 2:
|
||||
masks = masks.flatten(0, -3)
|
||||
else:
|
||||
masks = masks.unsqueeze(0)
|
||||
|
||||
# Get top and bottom edges
|
||||
in_height, _ = torch.max(masks, dim=-1)
|
||||
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
|
||||
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
||||
in_height_coords = in_height_coords + h * (~in_height)
|
||||
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
||||
|
||||
# Get left and right edges
|
||||
in_width, _ = torch.max(masks, dim=-2)
|
||||
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
|
||||
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
||||
in_width_coords = in_width_coords + w * (~in_width)
|
||||
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
||||
|
||||
# If the mask is empty the right edge will be to the left of the left edge.
|
||||
# Replace these boxes with [0, 0, 0, 0]
|
||||
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
||||
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
||||
out = out * (~empty_filter).unsqueeze(-1)
|
||||
|
||||
# Return to original shape
|
||||
if len(shape) > 2:
|
||||
out = out.reshape(*shape[:-2], 4)
|
||||
else:
|
||||
out = out[0]
|
||||
|
||||
return out
|
204
utils/amg3d.py
Normal file
204
utils/amg3d.py
Normal file
@ -0,0 +1,204 @@
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from typing import Any, Dict, Generator, ItemsView, List, Tuple
|
||||
|
||||
|
||||
class MaskData3d:
|
||||
def __init__(self, size, **kwargs) -> None:
|
||||
self.size = size
|
||||
for v in kwargs.values():
|
||||
assert isinstance(
|
||||
v, (list, np.ndarray, torch.Tensor)
|
||||
), "MaskData only supports list, numpy arrays, and torch tensors."
|
||||
self._stats = dict(**kwargs)
|
||||
|
||||
def __setitem__(self, slice_id, num, item):
|
||||
assert isinstance(
|
||||
item, (list, np.ndarray, torch.Tensor)
|
||||
), "MaskData only supports list, numpy arrays, and torch tensors."
|
||||
key = str(slice_id) + str(num)
|
||||
if self._stats.get(key, None) is None:
|
||||
self._stats[key] = np.zeros(self.size)
|
||||
self._stats[key][slice_id-1:slice_id+2] = item
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
del self._stats[key]
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._stats[key]
|
||||
|
||||
def items(self) -> ItemsView[str, Any]:
|
||||
return self._stats.items()
|
||||
|
||||
def merge(self, slice_ids, num, item):
|
||||
pass
|
||||
|
||||
|
||||
def build_all_layer_point_grids(
|
||||
n_per_side: int = 32, n_layers: int = 0, scale_per_layer: int = 1) -> List[np.ndarray]:
|
||||
"""Generates point grids for all crop layers."""
|
||||
points_by_layer = []
|
||||
for i in range(n_layers + 1):
|
||||
n_points = int(n_per_side / (scale_per_layer**i))
|
||||
points_by_layer.append(build_point_grid(n_points))
|
||||
return points_by_layer
|
||||
|
||||
|
||||
def build_point_grid(n_per_side: int, size) -> np.ndarray:
|
||||
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
||||
offset = 1 / (2 * n_per_side)
|
||||
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
||||
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
||||
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
||||
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
||||
return points * np.array(size)
|
||||
|
||||
# def calculate_stability_score(
|
||||
# masks: torch.Tensor, mask_threshold: float, threshold_offset: float
|
||||
# ) -> torch.Tensor:
|
||||
# """
|
||||
# Computes the stability score for a batch of masks. The stability
|
||||
# score is the IoU between the binary masks obtained by thresholding
|
||||
# the predicted mask logits at high and low values.
|
||||
# """
|
||||
# # One mask is always contained inside the other.
|
||||
# # Save memory by preventing unnecessary cast to torch.int64
|
||||
# intersections = (
|
||||
# (masks > (mask_threshold + threshold_offset))
|
||||
# .sum(-1, dtype=torch.int16)
|
||||
# .sum(-1, dtype=torch.int32)
|
||||
# )
|
||||
# unions = (
|
||||
# (masks > (mask_threshold - threshold_offset))
|
||||
# .sum(-1, dtype=torch.int16)
|
||||
# .sum(-1, dtype=torch.int32)
|
||||
# )
|
||||
# return intersections / unions
|
||||
|
||||
|
||||
def calculate_stability_score_3d(
|
||||
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the stability score for a batch of masks. The stability
|
||||
score is the IoU between the binary masks obtained by thresholding
|
||||
the predicted mask logits at high and low values.
|
||||
"""
|
||||
# One mask is always contained inside the other.
|
||||
# Save memory by preventing unnecessary cast to torch.int64
|
||||
intersections = (
|
||||
(masks > (mask_threshold + threshold_offset))
|
||||
.sum(-1, dtype=torch.int16)
|
||||
.sum(-1, dtype=torch.int32)
|
||||
.sum(-1, dtype=torch.int32)
|
||||
)
|
||||
# intersections = intersections2d.sum(-1, dtype=torch.int32)
|
||||
unions = (
|
||||
(masks > (mask_threshold - threshold_offset))
|
||||
.sum(-1, dtype=torch.int16)
|
||||
.sum(-1, dtype=torch.int32)
|
||||
.sum(-1, dtype=torch.int32)
|
||||
)
|
||||
return (intersections / unions)
|
||||
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
||||
assert len(args) > 0 and all(
|
||||
len(a) == len(args[0]) for a in args
|
||||
), "Batched iteration must have inputs of all the same size."
|
||||
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
||||
for b in range(n_batches):
|
||||
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
||||
|
||||
|
||||
class MaskData:
|
||||
"""
|
||||
A structure for storing masks and their related data in batched format.
|
||||
Implements basic filtering and concatenation.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
for v in kwargs.values():
|
||||
assert isinstance(
|
||||
v, (list, np.ndarray, torch.Tensor)
|
||||
), "MaskData only supports list, numpy arrays, and torch tensors."
|
||||
self._stats = dict(**kwargs)
|
||||
|
||||
def __setitem__(self, key: str, item: Any) -> None:
|
||||
assert isinstance(
|
||||
item, (list, np.ndarray, torch.Tensor)
|
||||
), "MaskData only supports list, numpy arrays, and torch tensors."
|
||||
self._stats[key] = item
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
del self._stats[key]
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._stats[key]
|
||||
|
||||
def items(self) -> ItemsView[str, Any]:
|
||||
return self._stats.items()
|
||||
|
||||
def filter(self, keep: torch.Tensor) -> None:
|
||||
for k, v in self._stats.items():
|
||||
if v is None:
|
||||
self._stats[k] = None
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
|
||||
elif isinstance(v, np.ndarray):
|
||||
self._stats[k] = v[keep.detach().cpu().numpy()]
|
||||
elif isinstance(v, list) and keep.dtype == torch.bool:
|
||||
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
|
||||
elif isinstance(v, list):
|
||||
self._stats[k] = [v[i] for i in keep]
|
||||
else:
|
||||
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
||||
|
||||
def cat(self, new_stats: "MaskData") -> None:
|
||||
for k, v in new_stats.items():
|
||||
if k not in self._stats or self._stats[k] is None:
|
||||
self._stats[k] = deepcopy(v)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
|
||||
elif isinstance(v, np.ndarray):
|
||||
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
|
||||
elif isinstance(v, list):
|
||||
self._stats[k] = self._stats[k] + deepcopy(v)
|
||||
else:
|
||||
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
||||
|
||||
def to_numpy(self) -> None:
|
||||
for k, v in self._stats.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
self._stats[k] = v.detach().cpu().numpy()
|
59
utils/masks3d_utils.py
Normal file
59
utils/masks3d_utils.py
Normal file
@ -0,0 +1,59 @@
|
||||
from tutils.nn.data import write, np_to_itk
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Masks3D:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def from_dict(self, masks):
|
||||
self.masks = masks
|
||||
# self.tags = masks
|
||||
|
||||
def to_2dmasks(self):
|
||||
pass
|
||||
|
||||
def filter_by_bg(self, volume, threshold=None):
|
||||
threshold = (volume.max() - volume.min()) * 0.1 + volume.min()
|
||||
keys = self.masks.keys()
|
||||
for k in keys:
|
||||
v = self.masks[k]
|
||||
assert v.shape == volume.shape, f"Got shape ERROR, {v.shape, volume.shape}"
|
||||
if (v * volume).mean() <= threshold:
|
||||
self.masks.pop(k)
|
||||
|
||||
# def filter_by_area(self,):
|
||||
|
||||
|
||||
def sort_by_logits(self):
|
||||
self.confidences = []
|
||||
self.tags_by_conf = []
|
||||
for k, v in self.masks.items():
|
||||
confidence = v[v>0].mean()
|
||||
self.confidences.append(confidence)
|
||||
self.tags_by_conf.append(k)
|
||||
indices = np.argsort(self.confidences)[::-1]
|
||||
self.tags_by_conf = np.array(self.tags_by_conf)[indices].tolist()
|
||||
self.confidences = np.array(self.confidences)[indices]
|
||||
|
||||
def to_nii(self, path="tmp.nii.gz"):
|
||||
self.sort_by_logits()
|
||||
total = None
|
||||
for i, k in enumerate(self.tags_by_conf):
|
||||
mask = np.int32(self.masks[k]>0)
|
||||
if total is None:
|
||||
total = mask * i
|
||||
else:
|
||||
total = np.where(total>0, total, mask * i)
|
||||
mask_itk = np_to_itk(total)
|
||||
write(mask_itk, path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = "/home1/quanquan/code/projects/finetune_large/segment_anything/tmp.npy"
|
||||
data = np.load(p, allow_pickle=True).tolist()
|
||||
# import ipdb; ipdb.set_trace()
|
||||
print(data.keys())
|
||||
mm = Masks3D()
|
||||
mm.from_dict(data)
|
||||
mm.to_nii()
|
144
utils/onnx.py
Normal file
144
utils/onnx.py
Normal file
@ -0,0 +1,144 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from ..modeling import Sam
|
||||
from .amg import calculate_stability_score
|
||||
|
||||
|
||||
class SamOnnxModel(nn.Module):
|
||||
"""
|
||||
This model should not be called directly, but is used in ONNX export.
|
||||
It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
|
||||
with some functions modified to enable model tracing. Also supports extra
|
||||
options controlling what information. See the ONNX export script for details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Sam,
|
||||
return_single_mask: bool,
|
||||
use_stability_score: bool = False,
|
||||
return_extra_metrics: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.mask_decoder = model.mask_decoder
|
||||
self.model = model
|
||||
self.img_size = model.image_encoder.img_size
|
||||
self.return_single_mask = return_single_mask
|
||||
self.use_stability_score = use_stability_score
|
||||
self.stability_score_offset = 1.0
|
||||
self.return_extra_metrics = return_extra_metrics
|
||||
|
||||
@staticmethod
|
||||
def resize_longest_image_size(
|
||||
input_image_size: torch.Tensor, longest_side: int
|
||||
) -> torch.Tensor:
|
||||
input_image_size = input_image_size.to(torch.float32)
|
||||
scale = longest_side / torch.max(input_image_size)
|
||||
transformed_size = scale * input_image_size
|
||||
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
|
||||
return transformed_size
|
||||
|
||||
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
|
||||
point_coords = point_coords + 0.5
|
||||
point_coords = point_coords / self.img_size
|
||||
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
|
||||
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
|
||||
|
||||
point_embedding = point_embedding * (point_labels != -1)
|
||||
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
|
||||
point_labels == -1
|
||||
)
|
||||
|
||||
for i in range(self.model.prompt_encoder.num_point_embeddings):
|
||||
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
|
||||
i
|
||||
].weight * (point_labels == i)
|
||||
|
||||
return point_embedding
|
||||
|
||||
def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
|
||||
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
|
||||
mask_embedding = mask_embedding + (
|
||||
1 - has_mask_input
|
||||
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
|
||||
return mask_embedding
|
||||
|
||||
def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
|
||||
masks = F.interpolate(
|
||||
masks,
|
||||
size=(self.img_size, self.img_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
|
||||
masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
|
||||
|
||||
orig_im_size = orig_im_size.to(torch.int64)
|
||||
h, w = orig_im_size[0], orig_im_size[1]
|
||||
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
|
||||
return masks
|
||||
|
||||
def select_masks(
|
||||
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Determine if we should return the multiclick mask or not from the number of points.
|
||||
# The reweighting is used to avoid control flow.
|
||||
score_reweight = torch.tensor(
|
||||
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
|
||||
).to(iou_preds.device)
|
||||
score = iou_preds + (num_points - 2.5) * score_reweight
|
||||
best_idx = torch.argmax(score, dim=1)
|
||||
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
|
||||
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
|
||||
|
||||
return masks, iou_preds
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
point_coords: torch.Tensor,
|
||||
point_labels: torch.Tensor,
|
||||
mask_input: torch.Tensor,
|
||||
has_mask_input: torch.Tensor,
|
||||
orig_im_size: torch.Tensor,
|
||||
):
|
||||
sparse_embedding = self._embed_points(point_coords, point_labels)
|
||||
dense_embedding = self._embed_masks(mask_input, has_mask_input)
|
||||
|
||||
masks, scores = self.model.mask_decoder.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embedding,
|
||||
dense_prompt_embeddings=dense_embedding,
|
||||
)
|
||||
|
||||
if self.use_stability_score:
|
||||
scores = calculate_stability_score(
|
||||
masks, self.model.mask_threshold, self.stability_score_offset
|
||||
)
|
||||
|
||||
if self.return_single_mask:
|
||||
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
|
||||
|
||||
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
|
||||
|
||||
if self.return_extra_metrics:
|
||||
stability_scores = calculate_stability_score(
|
||||
upscaled_masks, self.model.mask_threshold, self.stability_score_offset
|
||||
)
|
||||
areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
|
||||
return upscaled_masks, scores, stability_scores, areas, masks
|
||||
|
||||
return upscaled_masks, scores, masks
|
102
utils/transforms.py
Normal file
102
utils/transforms.py
Normal file
@ -0,0 +1,102 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class ResizeLongestSide:
|
||||
"""
|
||||
Resizes images to the longest side 'target_length', as well as provides
|
||||
methods for resizing coordinates and boxes. Provides methods for
|
||||
transforming both numpy array and batched torch tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, target_length: int) -> None:
|
||||
self.target_length = target_length
|
||||
|
||||
def apply_image(self, image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Expects a numpy array with shape HxWxC in uint8 format.
|
||||
"""
|
||||
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
|
||||
return np.array(resize(to_pil_image(image), target_size))
|
||||
|
||||
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
||||
"""
|
||||
Expects a numpy array of length 2 in the final dimension. Requires the
|
||||
original image size in (H, W) format.
|
||||
"""
|
||||
old_h, old_w = original_size
|
||||
new_h, new_w = self.get_preprocess_shape(
|
||||
original_size[0], original_size[1], self.target_length
|
||||
)
|
||||
coords = deepcopy(coords).astype(float)
|
||||
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||||
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||||
return coords
|
||||
|
||||
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
||||
"""
|
||||
Expects a numpy array shape Bx4. Requires the original image size
|
||||
in (H, W) format.
|
||||
"""
|
||||
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
|
||||
return boxes.reshape(-1, 4)
|
||||
|
||||
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Expects batched images with shape BxCxHxW and float format. This
|
||||
transformation may not exactly match apply_image. apply_image is
|
||||
the transformation expected by the model.
|
||||
"""
|
||||
# Expects an image in BCHW format. May not exactly match apply_image.
|
||||
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
|
||||
return F.interpolate(
|
||||
image, target_size, mode="bilinear", align_corners=False, antialias=True
|
||||
)
|
||||
|
||||
def apply_coords_torch(
|
||||
self, coords: torch.Tensor, original_size: Tuple[int, ...]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Expects a torch tensor with length 2 in the last dimension. Requires the
|
||||
original image size in (H, W) format.
|
||||
"""
|
||||
old_h, old_w = original_size
|
||||
new_h, new_w = self.get_preprocess_shape(
|
||||
original_size[0], original_size[1], self.target_length
|
||||
)
|
||||
coords = deepcopy(coords).to(torch.float)
|
||||
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||||
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||||
return coords
|
||||
|
||||
def apply_boxes_torch(
|
||||
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Expects a torch tensor with shape Bx4. Requires the original image
|
||||
size in (H, W) format.
|
||||
"""
|
||||
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
|
||||
return boxes.reshape(-1, 4)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
|
||||
"""
|
||||
Compute the output size given input size and target long side length.
|
||||
"""
|
||||
scale = long_side_length * 1.0 / max(oldh, oldw)
|
||||
newh, neww = oldh * scale, oldw * scale
|
||||
neww = int(neww + 0.5)
|
||||
newh = int(newh + 0.5)
|
||||
return (newh, neww)
|
157
utils/transforms3d.py
Normal file
157
utils/transforms3d.py
Normal file
@ -0,0 +1,157 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torchvision.transforms.functional import resize, to_pil_image # type: ignore,
|
||||
from torch.nn.functional import interpolate
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
|
||||
class SimpleResize:
|
||||
"""
|
||||
Keep the same with training,
|
||||
maybe fixed in future.
|
||||
"""
|
||||
def __init__(self, target_length: int = 1024) -> None:
|
||||
self.target_length = (target_length, target_length)
|
||||
|
||||
def apply_image(self, image: torch.tensor) -> torch.tensor:
|
||||
target_size = self.target_length
|
||||
return interpolate(image.unsqueeze(1), size=(target_size[0], target_size[1])).squeeze()
|
||||
# return np.array(resize(to_pil_image(image), target_size))
|
||||
|
||||
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
||||
old_h, old_w = original_size
|
||||
new_h, new_w = self.target_length
|
||||
coords = deepcopy(coords).astype(float)
|
||||
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||||
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||||
return coords
|
||||
|
||||
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
||||
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
|
||||
return boxes.reshape(-1, 4)
|
||||
|
||||
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
|
||||
# Expects an image in BCHW format. May not exactly match apply_image.
|
||||
target_size = self.target_length
|
||||
return F.interpolate(
|
||||
image, target_size, mode="bilinear", align_corners=False, antialias=True
|
||||
)
|
||||
|
||||
def apply_coords_torch(
|
||||
self, coords: torch.Tensor, original_size: Tuple[int, ...]
|
||||
) -> torch.Tensor:
|
||||
old_h, old_w = original_size
|
||||
new_h, new_w = self.target_length
|
||||
coords = deepcopy(coords).to(torch.float)
|
||||
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||||
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||||
return coords
|
||||
|
||||
def apply_boxes_torch(
|
||||
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
|
||||
) -> torch.Tensor:
|
||||
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
|
||||
return boxes.reshape(-1, 4)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class ResizeLongestSide:
|
||||
"""
|
||||
Resizes images to the longest side 'target_length', as well as provides
|
||||
methods for resizing coordinates and boxes. Provides methods for
|
||||
transforming both numpy array and batched torch tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, target_length: int = 1024) -> None:
|
||||
self.target_length = target_length
|
||||
|
||||
def apply_image(self, image: torch.tensor) -> torch.tensor:
|
||||
"""
|
||||
Expects a numpy array with shape HxWxC in uint8 format.
|
||||
"""
|
||||
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
|
||||
return interpolate(image, size=(target_size[0], target_size[1], image.shape[-1]))
|
||||
# return np.array(resize(to_pil_image(image), target_size))
|
||||
|
||||
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
||||
"""
|
||||
Expects a numpy array of length 2 in the final dimension. Requires the
|
||||
original image size in (H, W) format.
|
||||
"""
|
||||
old_h, old_w = original_size
|
||||
new_h, new_w = self.get_preprocess_shape(
|
||||
original_size[0], original_size[1], self.target_length
|
||||
)
|
||||
coords = deepcopy(coords).astype(float)
|
||||
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||||
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||||
return coords
|
||||
|
||||
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
||||
"""
|
||||
Expects a numpy array shape Bx4. Requires the original image size
|
||||
in (H, W) format.
|
||||
"""
|
||||
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
|
||||
return boxes.reshape(-1, 4)
|
||||
|
||||
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Expects batched images with shape BxCxHxW and float format. This
|
||||
transformation may not exactly match apply_image. apply_image is
|
||||
the transformation expected by the model.
|
||||
"""
|
||||
# Expects an image in BCHW format. May not exactly match apply_image.
|
||||
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
|
||||
return F.interpolate(
|
||||
image, target_size, mode="bilinear", align_corners=False, antialias=True
|
||||
)
|
||||
|
||||
def apply_coords_torch(
|
||||
self, coords: torch.Tensor, original_size: Tuple[int, ...]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Expects a torch tensor with length 2 in the last dimension. Requires the
|
||||
original image size in (H, W) format.
|
||||
"""
|
||||
old_h, old_w = original_size
|
||||
new_h, new_w = self.get_preprocess_shape(
|
||||
original_size[0], original_size[1], self.target_length
|
||||
)
|
||||
coords = deepcopy(coords).to(torch.float)
|
||||
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||||
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||||
return coords
|
||||
|
||||
def apply_boxes_torch(
|
||||
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Expects a torch tensor with shape Bx4. Requires the original image
|
||||
size in (H, W) format.
|
||||
"""
|
||||
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
|
||||
return boxes.reshape(-1, 4)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
|
||||
"""
|
||||
Compute the output size given input size and target long side length.
|
||||
"""
|
||||
scale = long_side_length * 1.0 / max(oldh, oldw)
|
||||
newh, neww = oldh * scale, oldw * scale
|
||||
neww = int(neww + 0.5)
|
||||
newh = int(newh + 0.5)
|
||||
return (newh, neww)
|
Loading…
x
Reference in New Issue
Block a user