first commit

This commit is contained in:
transcendentsky 2023-12-05 14:58:38 +08:00
commit e04459c6fe
73 changed files with 7059 additions and 0 deletions

53
configs/vit_b.yaml Normal file
View File

@ -0,0 +1,53 @@
#### basic configs
# dataset:
# name: 'Cephalometric'
# pth: '/home1/quanquan/datasets/Cephalometric/'
# ---------------------- Common Configs --------------------------
base:
base_dir: "../runs/sam/"
tag: ''
stage: ''
logger:
mode: ['tb', ]
# mode: ''
recorder_reduction: 'mean'
training:
save_mode: ['all','best', 'latest'] #
batch_size : 8 # 20 for A100
num_workers : 16
num_epochs : 500 # epochs
use_amp: true
save_interval : 4
val_check_interval: 6
load_pretrain_model: false
# optim:
lr: 0.0002
decay_step: 2000
decay_gamma: 0.8
weight_decay: 0.0001
alpha: 0.99
validation_interval: 100
dataset:
types: ['3d'] # ['3d', '2d']
split: 'train'
data_root_path: '/quanquan/datasets/'
dataset_list: ["alp", "word", "debug"] # ['sam', "their", "ours"]
data_txt_path: './datasets/dataset_list/'
dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/"
cache_data_path: '/home1/quanquan/datasets/cached_dataset2/'
# sam_checkpoint: "/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server
# model_type: "vit_b"
# Continue training
# continue_training: true
# load_optimizer: true
# breakpoint_path: "/quanquan/code/segment-anything/runs/sam/ddp_b1/lora_3d_2dm"
test:
batch_size: 1

0
core/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

142
core/ddp.py Normal file
View File

@ -0,0 +1,142 @@
"""
from ddp_b9.py
Train the whole ViT
"""
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from tutils import tfilename, tdir
from datasets.dataset3d_2dmask import Dataset2D
# from datasets.dataset3d import Dataset3D
from datasets.cache_dataset3d3 import Dataset3D
from datasets.dataset_merged import DatasetMerged, TestsetMerged
from datasets.data_engine import DataEngine
from modeling.build_sam3d2 import sam_model_registry
from .learner5 import SamLearner
# from tutils.new.trainer.trainer_ddp import DDPTrainer
from trans_utils.trainer_ddp import DDPTrainer
# from .lora_sam import LoRA_Sam
import warnings
warnings.filterwarnings("ignore")
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def load_well_trained_model(model, lora_module, pth=None):
print("Loading from ", pth)
state_dict = torch.load(pth, map_location='cpu')
model.load_state_dict(state_dict)
def ddp_train(rank, world_size, config):
setup(rank, world_size)
sam_checkpoint = "/quanquan/code/segment-anything/segment_anything/sam_vit_b_01ec64.pth" # A800 server
# sam_checkpoint = "/home1/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server
model_type = "vit_b"
device = rank
config_data = config['dataset']
data_type = config_data.get("types", ["3d", "2d"])
data_type = [data_type] if isinstance(data_type, str) else data_type
if '2d' in data_type:
if '3d' in data_type:
dataset = DatasetMerged(config_data, is_train=True)
else:
dataset = Dataset2D(dirpath=config_data['dataset2d_path'], is_train=True)
elif '3d' in data_type:
dataset = Dataset3D(config_data, split='train')
else:
raise NotImplementedError
# assert len(validset) > 0
data_engine = DataEngine(dataset=dataset, img_size=(1024,1024))
sam = sam_model_registry[model_type](checkpoint=None)
learner = SamLearner(sam_model=sam, config=config, data_engine=data_engine)
config_train = config['training']
if config_train.get('continue_training', False):
learner.use_lora()
learner.load_well_trained_model(config['training']['breakpoint_path']) # use preset path
else:
learner.load_pretrained_model(sam_checkpoint)
learner.use_lora()
print('Before Setting', get_parameter_number(learner.model))
for param in learner.model.image_encoder.parameters():
param.requires_grad = True
print('After Setting', get_parameter_number(learner.model))
ddp_trainer = DDPTrainer(config=config, rank=rank, world_size=world_size)
ddp_trainer.fit(learner, trainset=data_engine, validset=None)
cleanup()
def get_parameter_number(model):
total_num = sum(p.numel() for p in model.parameters())
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
def run_demo(demo_fn, world_size, config):
mp.spawn(demo_fn,
args=(world_size,config),
nprocs=world_size,
join=True)
from collections import OrderedDict
import yaml
import yamlloader
def _ordereddict_to_dict(d):
if not isinstance(d, dict):
return d
for k, v in d.items():
if isinstance(v, OrderedDict):
v = _ordereddict_to_dict(v)
d[k] = dict(v)
elif type(v) == list:
d[k] = _ordereddict_to_dict(v)
elif isinstance(v, dict):
d[k] = _ordereddict_to_dict(v)
return d
# CUDA_VISIBLE_DEVICES=4,5,6,7 python -m core.ddp_b3 --tag lora --config configs/vit_b_103.yaml
if __name__ == "__main__":
import argparse
from tutils.new.manager import trans_args, trans_init, ConfigManager
n_gpus = torch.cuda.device_count()
# assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but {__file__} Got{n_gpus}"
if n_gpus == 1:
print("Warning! Running on only 1 GPU! just for debug")
world_size = n_gpus
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="./configs/vit_b.yaml")
parser.add_argument("--func", default="train")
args = trans_args(parser=parser)
config = ConfigManager()
config.auto_init(file=__file__, args=args, ex_config=None)
# config.save()
path = tfilename(config['base']['runs_dir'], "config.yaml")
with open(path, "w") as f:
yaml.dump(_ordereddict_to_dict(config), f)
print("Save config file to ", path)
if n_gpus < 1: exit(0)
run_demo(ddp_train, world_size, config)

522
core/learner2.py Normal file
View File

@ -0,0 +1,522 @@
import torch
import torchvision
import numpy as np
from tutils.trainer import Trainer, LearnerModule
from torch.utils.data import DataLoader
from torch import optim
import torch.optim.lr_scheduler as lr_scheduler
from einops import rearrange, repeat
from torch.nn import functional as F
import os
from typing import Optional, Tuple
import torch.optim.lr_scheduler as lr_scheduler
from modeling.sam3d import Sam
# from segment_anything.utils.transforms import ResizeLongestSide
from utils.transforms import ResizeLongestSide
from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss
from .lora_sam import LoRA_Sam
from safetensors import safe_open
from datasets.data_engine import DataEngine
# def lr_schedule(epoch):
# if epoch < 250:
# return (epoch + 1) / 250 * 0.0008 + 0.00004
# elif epoch < 500:
# return 0.0001
# else:
# return 0.0001
def lr_schedule(epoch):
if epoch < 250:
return (epoch + 1) / 250 * 0.1
elif epoch < 500:
return 0.01
else:
return 0.001
class SamLearner(LearnerModule):
def __init__(
self,
sam_model: Sam,
config=None,
logger=None,
data_engine=DataEngine(None, img_size=(1024,1024)),
lora_module=None,
) -> None:
"""
Uses SAM to calculate the image embedding for an image, and then
allow repeated, efficient mask prediction given prompts.
Arguments:
sam_model (Sam): The model to use for mask prediction.
"""
super().__init__()
self.config = config
self.logger = logger
self.model = sam_model
self.net = self.model
self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
self.reset_image()
self.data_engine = data_engine
self.features = None
self.lora_module = lora_module
def save(self, pth, *args, **kwargs):
# Default: "/model_epoch_{}.pth".format(epoch)
torch.save(self.net.state_dict(), pth)
lora_path = pth.replace(".pth", "_lora.safetensors")
self.lora_module.save_lora_parameters(lora_path)
return True
def load_pretrained_model(self, pth, *args, **kwargs):
"""
Unmatched: prompt_encoder.mask_downscaling.0.weight
their: torch.Size([4, 1, 2, 2])
our: torch.Size([4, 3, 2, 2])
Unmatched: mask_decoder.mask_tokens.weight
their: torch.Size([4, 256])
our: torch.Size([12, 256])
"""
state_dict = torch.load(pth)
model_state_dict = self.model.state_dict()
model_state_dict.update(state_dict)
model_state_dict['prompt_encoder.mask_downscaling.0.weight'] = repeat(state_dict['prompt_encoder.mask_downscaling.0.weight'], "a 1 c d -> a b c d", b=3)
model_state_dict['mask_decoder.mask_tokens.weight'] = repeat(state_dict['mask_decoder.mask_tokens.weight'], "a d -> (a 3) d")
hyper_params_names = [k for k in model_state_dict.keys() if k.startswith("mask_decoder.output_hypernetworks_mlps")]
for name in hyper_params_names:
words = name.split('.')
words[2] = str(int(words[2]) // 3)
name_to_copy = ".".join(words)
model_state_dict[name] = state_dict[name_to_copy]
# for k, v in state_dict.items():
# if model_state_dict[k].shape != state_dict[k].shape:
# print("Unmatched:", k)
self.model.load_state_dict(model_state_dict)
def load_well_trained_model(self, pth=None):
pth = self.config['training']['breakpoint_path'] + "/ckpt_v/model_latest.pth" if pth is None else pth
print("Loading from ", pth)
state_dict = torch.load(pth, map_location="cpu")
# print(state_dict.keys())
# for k in state_dict.keys():
# print(k)
# exit(0)
self.model.load_state_dict(state_dict)
# self.lora_module.load_lora_parameters(pth.replace(".pth", "_lora.safetensors"))
def use_lora(self):
lora_r = 8
lora_sam = LoRA_Sam(self.model, lora_r, freeze_prompt_encoder=True)
self.lora_module = lora_sam
def configure_optimizers(self, **kwargs):
optimizer = optim.AdamW(params=self.model.parameters(), \
lr=self.config['training']['lr'], betas=(0.9, 0.999), eps=1e-08,
weight_decay=self.config['training']['weight_decay'])
# scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule)
scheduler = None
return {'optimizer': optimizer, "scheduler": scheduler}
def load_optim(self, optimizer, pth=None, *args):
pth = self.config['training']['breakpoint_path'] + "/ckpt/optim_latest.pth"
print("Load Optimizer from ", pth)
state_dict = torch.load(pth)
optimizer.load_state_dict(state_dict['optimizer'])
start_epoch = state_dict.get('epoch', 0) + 1
return start_epoch
def training_step(self, data, batch_idx, **kwargs):
img = data['img']
gt_mask = data['label']
prompt_point = data['prompt_point'] # shape: (b, 2)
batch_size = prompt_point.shape[0]
point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device)
prompt_box = data['prompt_box']
prompt_point = rearrange(prompt_point, "b c -> b 1 c")
prompt_box = rearrange(prompt_box, "b c -> b 1 c")
assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}"
assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}"
assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}"
assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}"
self.set_torch_image(img, img.shape[2:])
if np.random.random() > 0.5:
pred_masks, iou_predictions, logits = self.predict_torch(
point_coords=prompt_point,
point_labels=point_label,
multimask_output=True,
return_logits=True,
)
else:
pred_masks, iou_predictions, logits = self.predict_torch(
point_coords=None,
point_labels=None,
boxes=prompt_box,
multimask_output=True,
return_logits=True,
)
# assert pred_masks.shape == gt_mask.shape, f"Got {pred_masks.shape}, {gt_mask.shape}"
loss_1, fl, dl = ranked_combined_loss(pred_mask=pred_masks, gt_mask=gt_mask, iou_pred=iou_predictions)
# print("Debug trainer: 2", prompt_point.shape, point_label.shape, prompt_box.shape)
# Stage 2: based on the above, add more points as prompts
return {"loss": loss_1, "fl": fl.mean(), "dl": dl.mean()}
# @torch.no_grad()
def generate(self, image, prompt_point):
orig_size = image.shape[2:]
assert image.shape[1:] == (3,1024,1024),f"{__file__} Got{image.shape}"
if not self.is_image_set:
self.set_torch_image(image, orig_size)
assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}"
# assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}"
point_label = torch.ones(prompt_point.size()[:-1])
pred_masks, scores, logits = self.predict_torch(
point_coords=prompt_point,
point_labels=point_label,
mask_input=None,
multimask_output=True,
)
return pred_masks
# @torch.no_grad()
def generate_by_box(self, image, prompt_box):
orig_size = image.shape[2:]
assert image.shape[1:] == (3,1024,1024),f"{__file__} Got{image.shape}"
if not self.is_image_set:
self.set_torch_image(image, orig_size)
assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}"
pred_masks, scores, logits = self.predict_torch(
point_coords=None,
point_labels=None,
boxes=prompt_box,
mask_input=None,
multimask_output=True,
)
return pred_masks
@staticmethod
def select_best_mask(predictions, ground_truth):
# Move tensors to the same device (if not already on the same device)
# if predictions.device != ground_truth.device:
# predictions = predictions.to(ground_truth.device)
# Compute IoU between each prediction and ground truth
if predictions.shape[1] == 9:
predictions = rearrange(predictions, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
ground_truth = repeat(ground_truth, "b d h w -> b c d h w", c=3)
else:
predictions = rearrange(predictions, "b d h w -> b 1 d h w")
ground_truth = rearrange(ground_truth, "b d h w -> b 1 d h w")
intersection = torch.sum(predictions * ground_truth, dim=(-3, -2, -1))
union = torch.sum(predictions + ground_truth, dim=(-3, -2, -1)) - intersection
iou = intersection / (union + 1e-6)
# Select the prediction with maximum IoU for each image in the batch
best_indices = torch.argmax(iou, dim=1)
best_masks = torch.gather(predictions, 1, best_indices.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4).repeat(1, 1, predictions.shape[-3], predictions.shape[-2], predictions.shape[-1]))
return best_masks
# ===============================================
def predict_multi_prompt(
self,
point_coords: Optional[torch.Tensor],
point_labels: Optional[torch.Tensor],
mask_logits: Optional[torch.Tensor],
):
if point_coords is not None:
assert (
point_labels is not None
), "point_labels must be supplied if point_coords is supplied."
point_coords = self.transform.apply_coords(point_coords, self.original_size)
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
if point_coords is not None:
points = (coords_torch, point_labels)
else:
points = None
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=points,
boxes=None,
masks=mask_logits,
)
low_res_masks, iou_predictions = self.model.mask_decoder(
image_embeddings=self.features,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
return masks, iou_predictions, low_res_masks
def set_image(
self,
image: np.ndarray,
image_format: str = "RGB",
) -> None:
# Transform the image to the form expected by the model
input_image = self.transform.apply_image(image)
input_image_torch = torch.as_tensor(input_image, device=self.device)
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
self.set_torch_image(input_image_torch, image.shape[:2])
# @torch.no_grad()
def set_torch_image(
self,
transformed_image: torch.Tensor,
original_image_size: Tuple[int, ...],
) -> None:
"""
Calculates the image embeddings for the provided image, allowing
masks to be predicted with the 'predict' method. Expects the input
image to be already transformed to the format expected by the model.
Arguments:
transformed_image (torch.Tensor): The input image, with shape
1x3xHxW, which has been transformed with ResizeLongestSide.
original_image_size (tuple(int, int)): The size of the image
before transformation, in (H, W) format.
"""
assert (
len(transformed_image.shape) == 4
and transformed_image.shape[1] == 3
and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
self.reset_image()
self.original_size = original_image_size
self.input_size = tuple(transformed_image.shape[-2:])
input_image = self.model.preprocess(transformed_image)
self.features = self.model.image_encoder(input_image)
self.is_image_set = True
def predict(
self,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
box: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
# Transform input prompts
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
if point_coords is not None:
assert (
point_labels is not None
), "point_labels must be supplied if point_coords is supplied."
point_coords = self.transform.apply_coords(point_coords, self.original_size)
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
if box is not None:
box = self.transform.apply_boxes(box, self.original_size)
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
box_torch = box_torch[None, :]
if mask_input is not None:
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
mask_input_torch = mask_input_torch[None, :, :, :]
masks, iou_predictions, low_res_masks = self.predict_torch(
coords_torch,
labels_torch,
box_torch,
mask_input_torch,
multimask_output,
return_logits=return_logits,
)
# masks = masks[0].detach().cpu().numpy()
# iou_predictions = iou_predictions[0].detach().cpu().numpy()
# low_res_masks = low_res_masks[0].detach().cpu().numpy()
return masks, iou_predictions, low_res_masks
# @torch.no_grad()
def predict_torch(
self,
point_coords: Optional[torch.Tensor],
point_labels: Optional[torch.Tensor],
boxes: Optional[torch.Tensor] = None,
mask_input: Optional[torch.Tensor] = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
if point_coords is not None:
points = (point_coords, point_labels)
else:
points = None
sparse_embeddings, dense_embeddings = self._get_prompt_embedding(points, boxes, mask_input)
# Predict masks
low_res_masks, iou_predictions = self.model.mask_decoder(
image_embeddings=self.features,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# Upscale the masks to the original image resolution
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
# import ipdb; ipdb.set_trace()
if not return_logits:
masks = masks > self.model.mask_threshold
return masks, iou_predictions, low_res_masks
# @torch.no_grad()
def _get_prompt_embedding(self, points, boxes, mask_input):
# Embed prompts
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=points,
boxes=boxes,
masks=mask_input,
)
return sparse_embeddings, dense_embeddings
def get_image_embedding(self) -> torch.Tensor:
"""
Returns the image embeddings for the currently set image, with
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
the embedding spatial dimension of SAM (typically C=256, H=W=64).
"""
if not self.is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) to generate an embedding."
)
assert self.features is not None, "Features must exist if an image has been set."
return self.features
@property
def device(self) -> torch.device:
return self.model.device
def reset_image(self) -> None:
"""Resets the currently set image."""
self.is_image_set = False
self.features = None
self.orig_h = None
self.orig_w = None
self.input_h = None
self.input_w = None
def validation_step(self, data, batch_idx=0, **kwargs):
img = data['img']
gt_mask = data['label']
prompt_point = data['prompt_point'] # shape: (b, 2)
batch_size = prompt_point.shape[0]
point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device)
prompt_box = data['prompt_box']
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3)
prompt_point = rearrange(prompt_point, "b c -> b 1 c")
prompt_box = rearrange(prompt_box, "b c -> b 1 c")
assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}"
assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}"
assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}"
assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}"
self.set_torch_image(img, img.shape[2:])
# Stage 1: use the 1st prompt, box or point
iou_details = {}
pred_masks1, iou_predictions1, logits1 = self.predict_torch(
point_coords=prompt_point,
point_labels=point_label,
multimask_output=True,
return_logits=True,
)
pred_masks1 = rearrange(pred_masks1, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
loss_point, fl, dl, _, _ = ranked_combined_loss(pred_mask=pred_masks1, gt_mask=gt_mask, iou_pred=iou_predictions1)
iou_details['loss_point'] = loss_point.mean()
iou_details['loss_point_fl'] = fl.mean()
iou_details['loss_point_dl'] = dl.mean()
iou = compute_iou((pred_masks1>0).float(), gt_mask)
iou, _ = torch.max(iou, axis=1)
iou_details['iou_point'] = iou.mean()
pred_masks2, iou_predictions2, logits2 = self.predict_torch(
point_coords=None,
point_labels=None,
boxes=prompt_box,
multimask_output=True,
return_logits=True,
)
pred_masks2 = rearrange(pred_masks2, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
loss_box, fl, dl, _, _ = ranked_combined_loss(pred_mask=pred_masks2, gt_mask=gt_mask, iou_pred=iou_predictions2)
iou_details['loss_box'] = loss_box.mean()
iou_details['loss_box_fl'] = fl.mean()
iou_details['loss_box_dl'] = dl.mean()
iou = compute_iou((pred_masks2>0).float(), gt_mask)
iou, _ = torch.max(iou, axis=1)
iou_details['iou_box'] = iou.mean()
# import ipdb; ipdb.set_trace()
# gt_mask_np = gt_mask.detach().cpu().numpy()
# for step in range(8):
# continue
# # n
# best_pred_masks = self.select_best_mask(pred_masks, gt_mask)
# best_pred_masks_np = best_pred_masks.detach().cpu().numpy()
# # import ipdb; ipdb.set_trace()
# mask_input = logits[0, np.argmax(scores[0].detach().cpu().numpy()), :, :] # Choose the model's best mask
# sub_points, sub_labels = self.data_engine.get_subsequent_prompt_point(best_pred_masks_np, gt_mask_np)
# # sub_points, sub_labels = self.data_engine.point_prompt_generator.select_random_subsequent_point(best_pred_masks_np[0][0], gt_mask_np[0][0])
# y, x = sub_points[0][1], sub_points[0][0]
# assert gt_mask_np[0][0][y,x] + best_pred_masks_np[0][0][y,x] == 1, f"{__file__} Got{gt_mask_np[0][0][y,x], best_pred_masks_np[0][0][y,x]}"
# assert gt_mask_np[0][0][y,x] == sub_labels, f"{__file__} Got{ gt_mask_np[0][0][y,x]}, {sub_labels}"
# assert best_pred_masks_np[0][0][y,x] == (1-sub_labels), f"{__file__} Got{ gt_mask_np[0][0][y,x]}, {1-sub_labels}"
# # import ipdb; ipdb.set_trace()
# # assert sub_points
# # sub_points = np.array(sub_points)[None,...].astype(int)
# # sub_labels = np.array(sub_labels)[None,...]
# prompt_point = np.concatenate([prompt_point, sub_points], axis=0)
# point_label = np.concatenate([point_label, sub_labels], axis=0)
# # import ipdb; ipdb.set_trace()
# pred_masks2, scores, logits = model.predict(
# point_coords=prompt_point,
# point_labels=point_label,
# mask_input=mask_input[None,...],
# multimask_output=False,
# )
# iou = compute_iou(pred_masks2, gt_mask)
# iou, _ = torch.max(iou, axis=1)
# iou_details[f'point_{step+2}'] = iou
return iou_details

154
core/learner3.py Normal file
View File

@ -0,0 +1,154 @@
import torch
import torchvision
import numpy as np
from tutils.trainer import Trainer, LearnerModule
from einops import rearrange, repeat, reduce
import torch.optim.lr_scheduler as lr_scheduler
from core.loss import ranked_combined_loss_with_indicators
from .learner2 import SamLearner as basic_learner
from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss
class SamLearner(basic_learner):
def training_step(self, data, batch_idx, **kwargs):
img = data['img']
gt_mask = data['label']
prompt_point = data['prompt_point'] # shape: (b, 2)
batch_size = prompt_point.shape[0]
point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device)
prompt_box = data['prompt_box']
indicators = data['indicators']
# print(data['name'])
prompt_point = rearrange(prompt_point, "b c -> b 1 c")
prompt_box = rearrange(prompt_box, "b c -> b 1 c")
assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}"
assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}"
assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}"
assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}"
self.set_torch_image(img, img.shape[2:])
# if np.random.random() > 0.5:
pred_masks, iou_predictions, logits = self.predict_torch(
point_coords=prompt_point,
point_labels=point_label,
multimask_output=True,
return_logits=True,
)
loss_1, fl, dl, min_losses, selected_losses, _ = ranked_combined_loss_with_indicators(pred_mask=pred_masks, gt_mask=gt_mask, iou_pred=iou_predictions, indicators=indicators)
# else:
pred_masks, iou_predictions, logits = self.predict_torch(
point_coords=None,
point_labels=None,
boxes=prompt_box,
multimask_output=True,
return_logits=True,
)
# assert pred_masks.shape == gt_mask.shape, f"Got {pred_masks.shape}, {gt_mask.shape}"
loss_2, fl, dl, min_losses, selected_losses, _ = ranked_combined_loss_with_indicators(pred_mask=pred_masks, gt_mask=gt_mask, iou_pred=iou_predictions, indicators=indicators)
loss = loss_1 + loss_2
if loss < -999 or torch.isnan(loss):
print("Warning! Loss Error! ")
print(data['name'])
# print("Debug trainer: 2", prompt_point.shape, point_label.shape, prompt_box.shape)
# Stage 2: based on the above, add more points as prompts
return {"loss": loss_1 + loss_2, "point_loss": loss_1.mean(), "box_loss": loss_2, "fl": fl.mean(), "dl": dl.mean(), "min_losses": min_losses, "selected_losses": selected_losses}
def validation_step(self, data, batch_idx=0, **kwargs):
img = data['img']
gt_mask = data['label']
prompt_point = data['prompt_point'] # shape: (b, 2)
batch_size = prompt_point.shape[0]
point_label = torch.ones((batch_size, 1)) #.to(prompt_point.device)
prompt_box = data['prompt_box']
indicators = data['indicators']
print(data['name'])
prompt_point = rearrange(prompt_point, "b c -> b 1 c")
prompt_box = rearrange(prompt_box, "b c -> b 1 c")
assert img.shape[1:] == (3,1024,1024),f"{__file__} Got{img.shape}"
assert prompt_point.shape[1:] == (1,2), f"{__file__} Got{prompt_point.shape}"
assert point_label.shape[1:] == (1,), f"{__file__} Got{point_label.shape}"
assert prompt_box.shape[1:] == (1,4), f"{__file__} Got{prompt_box.shape}"
self.set_torch_image(img, img.shape[2:])
# Stage 1: use the 1st prompt, box or point
iou_details = {}
pred_masks1, iou_predictions1, logits1 = self.predict_torch(
point_coords=prompt_point,
point_labels=point_label,
multimask_output=True,
return_logits=True,
)
if len(pred_masks1.shape) == 4:
pred_masks1 = rearrange(pred_masks1, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
loss_point, fl, dl, min_losses, selected_losses, min_indices = ranked_combined_loss_with_indicators(pred_mask=pred_masks1, gt_mask=gt_mask, iou_pred=iou_predictions1, indicators=indicators)
iou_details['loss_point'] = loss_point.mean()
iou_details['loss_point_fl'] = fl.mean()
iou_details['loss_point_dl'] = dl.mean()
if len(gt_mask.shape) == 4:
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3)
indices = iou_predictions1.argmax(axis=1)
pred_maxiou = []
for pred, i in zip(pred_masks1, indices):
pred_maxiou.append(pred[i,:,:,:])
pred_maxiou = torch.stack(pred_maxiou, axis=0)
iou = compute_iou2(torch.tensor(pred_maxiou>0, dtype=gt_mask.dtype), gt_mask[:,0,:,:,:]).detach()
iou_details['iou_point'] = iou.mean()
iou = compute_iou(torch.tensor(pred_masks1>0, dtype=gt_mask.dtype), gt_mask).detach()
iou, _ = torch.max(iou, axis=1)
iou_details['iou_point_max'] = iou.mean()
pred_masks2, iou_predictions2, logits2 = self.predict_torch(
point_coords=None,
point_labels=None,
boxes=prompt_box,
multimask_output=True,
return_logits=True,
)
loss_box, fl, dl, min_losses, selected_losses, min_indices = ranked_combined_loss_with_indicators(pred_mask=pred_masks2, gt_mask=gt_mask, iou_pred=iou_predictions2, indicators=indicators)
iou_details['loss_box'] = loss_box.mean()
iou_details['loss_box_fl'] = fl.mean()
iou_details['loss_box_dl'] = dl.mean()
if len(gt_mask.shape) == 4:
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3)
if len(pred_masks2.shape) == 4:
pred_masks2 = rearrange(pred_masks2, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
indices = iou_predictions2.argmax(axis=1)
pred_maxiou = []
for pred, i in zip(pred_masks2, indices):
pred_maxiou.append(pred[i,:,:,:])
pred_maxiou = torch.stack(pred_maxiou, axis=0)
iou = compute_iou2(torch.tensor(pred_maxiou>0, dtype=gt_mask.dtype), gt_mask[:,0,:,:,:]).detach()
iou_details['iou_box'] = iou.mean()
iou = compute_iou(torch.tensor(pred_masks2>0, dtype=gt_mask.dtype), gt_mask).detach()
iou, _ = torch.max(iou, axis=1)
iou_details['iou_box_max'] = iou.mean()
return iou_details
def compute_iou2(pred_mask, gt_mask):
dtype = pred_mask.dtype
intersection = torch.logical_and(pred_mask, gt_mask)
intersection = reduce(intersection, "b d h w -> b", reduction='sum')
union = torch.logical_or(pred_mask, gt_mask)
union = reduce(union, "b d h w -> b", reduction='sum') + 1e-8
iou = intersection / union # if union > 0 else 0
iou = torch.tensor(iou, dtype=dtype)
# print("ranked_combined_loss: compute_iou ", intersection.dtype, union.dtype, iou.dtype)
return iou
# def save(img, mask, mask2):

55
core/learner5.py Normal file
View File

@ -0,0 +1,55 @@
"""
Use mask_decoder3d_2.py
"""
import torch
import torchvision
import numpy as np
from tutils.trainer import Trainer, LearnerModule
from einops import rearrange, repeat, reduce
import torch.optim.lr_scheduler as lr_scheduler
from core.loss import ranked_combined_loss_with_indicators
from .learner3 import SamLearner as basic_learner
from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss
class SamLearner(basic_learner):
def load_pretrained_model(self, pth, *args, **kwargs):
"""
Unmatched: prompt_encoder.mask_downscaling.0.weight
their: torch.Size([4, 1, 2, 2])
our: torch.Size([4, 3, 2, 2])
Unmatched: mask_decoder.mask_tokens.weight
their: torch.Size([4, 256])
our: torch.Size([12, 256])
"""
print("Load pretrained model for mask_decoder3d_2 !!")
state_dict = torch.load(pth)
model_state_dict = self.model.state_dict()
model_state_dict.update(state_dict)
model_state_dict['prompt_encoder.mask_downscaling.0.weight'] = repeat(state_dict['prompt_encoder.mask_downscaling.0.weight'], "a 1 c d -> a b c d", b=3)
# model_state_dict['mask_decoder.mask_tokens.weight'] = repeat(state_dict['mask_decoder.mask_tokens.weight'], "a d -> (a 3) d")
for k, v in model_state_dict.items():
if k.startswith("mask_decoder.output_upscaling2"):
k2 = k.replace("output_upscaling2.", "output_upscaling." )
model_state_dict[k] = model_state_dict[k2]
print("Load weights: ", k)
if k.startswith("mask_decoder.output_upscaling3"):
k2 = k.replace("output_upscaling3.", "output_upscaling." )
model_state_dict[k] = model_state_dict[k2]
print("Load weights: ", k)
hyper_params_names = [k for k in model_state_dict.keys() if k.startswith("mask_decoder.output_hypernetworks_mlps")]
for name in hyper_params_names:
words = name.split('.')
words[2] = str(int(words[2]) // 3)
name_to_copy = ".".join(words)
model_state_dict[name] = state_dict[name_to_copy]
# for k, v in state_dict.items():
# if model_state_dict[k].shape != state_dict[k].shape:
# print("Unmatched:", k)
self.model.load_state_dict(model_state_dict)

196
core/lora_sam.py Normal file
View File

@ -0,0 +1,196 @@
# Modified from https://github.com/JamesQFreeman/Sam_LoRA
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.parameter import Parameter
from safetensors import safe_open
from safetensors.torch import save_file
# from modeling.sam3d import Sam
class _LoRA_qkv(nn.Module):
"""In Sam it is implemented as
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
"""
def __init__(
self,
qkv: nn.Module,
linear_a_q: nn.Module,
linear_b_q: nn.Module,
linear_a_v: nn.Module,
linear_b_v: nn.Module,
):
super().__init__()
self.qkv = qkv
self.linear_a_q = linear_a_q
self.linear_b_q = linear_b_q
self.linear_a_v = linear_a_v
self.linear_b_v = linear_b_v
self.dim = qkv.in_features
self.w_identity = torch.eye(qkv.in_features)
def forward(self, x):
qkv = self.qkv(x) # B,N,N,3*org_C
new_q = self.linear_b_q(self.linear_a_q(x))
new_v = self.linear_b_v(self.linear_a_v(x))
qkv[:, :, :, : self.dim] += new_q
qkv[:, :, :, -self.dim :] += new_v
return qkv
class LoRA_Sam(nn.Module):
"""Applies low-rank adaptation to a Sam model's image encoder.
Args:
sam_model: a vision transformer model, see base_vit.py
r: rank of LoRA
num_classes: how many classes the model output, default to the vit model
lora_layer: which layer we apply LoRA.
freeze_all: freeze whole sam, otherwise only image encoder (VIT)
Examples::
>>> model = ViT('B_16_imagenet1k')
>>> lora_model = LoRA_ViT(model, r=4)
>>> preds = lora_model(img)
>>> print(preds.shape)
torch.Size([1, 1000])
"""
def __init__(self, sam_model, r: int, lora_layer:[int]=None, freeze_all:bool=False, freeze_prompt_encoder=True):
super(LoRA_Sam, self).__init__()
assert r > 0
# base_vit_dim = sam_model.image_encoder.patch_embed.proj.out_channels
# dim = base_vit_dim
if lora_layer:
self.lora_layer = lora_layer
else:
self.lora_layer = list(range(len(sam_model.image_encoder.blocks)))
# create for storage, then we can init them or load weights
self.w_As = [] # These are linear layers
self.w_Bs = []
# lets freeze first
if freeze_all:
for param in sam_model.parameters():
param.requires_grad = False
else:
for param in sam_model.image_encoder.parameters():
param.requires_grad = False
for param in sam_model.image_encoder.patch_embed.parameters():
param.requires_grad = True
if freeze_prompt_encoder:
for param in sam_model.prompt_encoder.parameters():
param.requires_grad = False
# Here, we do the surgery
for t_layer_i, blk in enumerate(sam_model.image_encoder.blocks):
# If we only want few lora layer instead of all
if t_layer_i not in self.lora_layer:
continue
w_qkv_linear = blk.attn.qkv
self.dim = w_qkv_linear.in_features
w_a_linear_q = nn.Linear(self.dim, r, bias=False)
w_b_linear_q = nn.Linear(r, self.dim, bias=False)
w_a_linear_v = nn.Linear(self.dim, r, bias=False)
w_b_linear_v = nn.Linear(r, self.dim, bias=False)
self.w_As.append(w_a_linear_q)
self.w_Bs.append(w_b_linear_q)
self.w_As.append(w_a_linear_v)
self.w_Bs.append(w_b_linear_v)
blk.attn.qkv = _LoRA_qkv(
w_qkv_linear,
w_a_linear_q,
w_b_linear_q,
w_a_linear_v,
w_b_linear_v,
)
self.reset_parameters()
self.sam = sam_model
# with open('vit_named_para.txt', 'w') as f:
# for k, v in sam_model.image_encoder.named_parameters():
# f.write(f'{k} {v.shape}\n')
def save_lora_parameters(self, filename: str) -> None:
r"""Only safetensors is supported now.
pip install safetensor if you do not have one installed yet.
save both lora and fc parameters.
"""
# assert filename.endswith(".safetensors")
num_layer = len(self.w_As) # actually, it is half
a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)}
b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)}
merged_dict = {**a_tensors, **b_tensors}
save_file(merged_dict, filename)
def load_lora_parameters(self, filename: str) -> None:
r"""Only safetensors is supported now.
pip install safetensor if you do not have one installed yet.\
load both lora and fc parameters.
"""
assert filename.endswith(".safetensors")
with safe_open(filename, framework="pt") as f:
for i, w_A_linear in enumerate(self.w_As):
saved_key = f"w_a_{i:03d}"
saved_tensor = f.get_tensor(saved_key)
w_A_linear.weight = Parameter(saved_tensor)
for i, w_B_linear in enumerate(self.w_Bs):
saved_key = f"w_b_{i:03d}"
saved_tensor = f.get_tensor(saved_key)
w_B_linear.weight = Parameter(saved_tensor)
# import ipdb; ipdb.set_trace()
def reset_parameters(self) -> None:
for w_A in self.w_As:
nn.init.kaiming_uniform_(w_A.weight, a=5**0.5)
for w_B in self.w_Bs:
nn.init.zeros_(w_B.weight)
def get_parameter_number(model):
total_num = sum(p.numel() for p in model.parameters())
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
if __name__ == "__main__":
from segment_anything import sam_model_registry
from segment_anything import SamPredictor, SamAutomaticMaskGenerator # prompt and every mode
import numpy as np
lora_r = 8
path = "../checkpoints/sam_vit_b_01ec64.pth"
sam = sam_model_registry["vit_b"](checkpoint=None)
print('before lora', get_parameter_number(sam))
lora_sam = LoRA_Sam(sam, lora_r)
print('after lora', get_parameter_number(sam))
x = torch.rand(size=(3,1024,1024))
path = '../data/cache/data2d_3/0007_s0069_img.npy'
img = np.load(path)
print('img shape', img.shape)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(x)
print('mask num', len(masks))
# loss = np.sum([mask.mean(-1).mean(-1) for mask in masks])
# predictor = SamPredictor(sam)
# predictor.set_image(x)
# masks, _, _ = predictor.predict([50,50])
for f_name in ['save_lora_parameters', 'load_lora_parameters']:
print(f_name)
getattr(lora_sam, f_name)('tmp.safetensors')

155
core/loss.py Normal file
View File

@ -0,0 +1,155 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from einops import repeat, rearrange, reduce
import numpy as np
def compute_dice_np(pred_mask, gt_mask):
""" numpy values
"""
pred_mask = np.array(pred_mask>0)
gt_mask = np.array(gt_mask>0)
intersection = np.array(pred_mask * gt_mask).sum()
union = pred_mask.sum() + gt_mask.sum()
dice = intersection * 2 / union # if union > 0 else 0
return dice
def combined_loss(logits, targets, alpha=0.2, gamma=2.0, smooth=1e-5, reduction='mean'):
# Calculate the focal loss
fl = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
pt = torch.exp(-fl)
focal_loss = alpha * (1 - pt) ** gamma * fl
if reduction == 'mean':
fl = torch.mean(focal_loss)
elif reduction == 'sum':
fl = torch.sum(focal_loss)
# Calculate the Dice loss
prob = torch.sigmoid(logits)
intersection = torch.sum(prob * targets, dim=(-2, -1))
union = torch.sum(prob + targets, dim=(-2, -1))
dice_loss = 1 - (2 * intersection + smooth) / (union + smooth)
return focal_loss, dice_loss
if reduction == 'mean':
dl = torch.mean(dice_loss)
elif reduction == 'sum':
dl = torch.sum(dice_loss)
# Combine the losses using the specified ratio
loss = 20 * fl + dl
return loss
# Assuming your prediction and ground truth tensors are named `pred` and `gt`, respectively
def mse_loss(pred, gt):
mse_loss = nn.MSELoss(reduction='none')
loss = mse_loss(pred, gt)
return loss
def compute_iou(pred_mask, gt_mask):
dtype = pred_mask.dtype
intersection = torch.logical_and(pred_mask, gt_mask)
intersection = reduce(intersection, "b c d h w -> b c", reduction='sum')
union = torch.logical_or(pred_mask, gt_mask)
union = reduce(union, "b c d h w -> b c", reduction='sum') + 1e-8
iou = intersection / union # if union > 0 else 0
iou = torch.tensor(iou, dtype=dtype)
# print("ranked_combined_loss: compute_iou ", intersection.dtype, union.dtype, iou.dtype)
return iou
def ranked_combined_loss(pred_mask, gt_mask, iou_pred):
# (b c1 c2 h w), c1: num_prediction; c2: num_slices
if len(gt_mask.shape) == 4:
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3)
if len(pred_mask.shape) == 4:
pred_mask = rearrange(pred_mask, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
fl, dl = combined_loss(pred_mask, gt_mask)
fl = reduce(fl, "b c d h w -> b c", reduction="mean")
dl = reduce(dl, "b c d-> b c", reduction="mean")
segment_loss = 20*fl + dl
min_losses, min_loss_indices = torch.min(segment_loss, dim=1)
iou = compute_iou(torch.tensor(torch.tensor(pred_mask>0, dtype=gt_mask.dtype)>0, dtype=gt_mask.dtype), gt_mask).detach().detach()
# print("ranked_combined_loss ", iou.dtype)
iou_loss = mse_loss(iou_pred, iou)
selected_losses = torch.gather(iou_loss, 1, min_loss_indices.unsqueeze(1))
selected_fl = torch.gather(fl, 1, min_loss_indices.unsqueeze(1))
selected_dl = torch.gather(dl, 1, min_loss_indices.unsqueeze(1))
# print(min_losses.shape, selected_losses.shape)
total_loss = min_losses.mean() + selected_losses.mean()
# return total_loss, min_losses, selected_losses
return total_loss, selected_fl, selected_dl, min_losses.mean(), selected_losses.mean(), min_loss_indices
def ranked_combined_loss_one_slice(pred_mask, gt_mask, iou_pred, mask_loc):
if len(gt_mask.shape) == 4:
# assert gt_mask.shape[1] == 1, f"Got {gt_mask.shape}"
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3)
if len(pred_mask.shape) == 4:
pred_mask = rearrange(pred_mask, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
gt_mask = gt_mask[:,:,mask_loc,:,:]
pred_mask = pred_mask[:,:,mask_loc,:,:]
assert len(pred_mask.shape) == 5
return ranked_combined_loss(pred_mask, gt_mask, iou_pred)
def ranked_combined_loss_with_indicators(pred_mask, gt_mask, iou_pred, indicators):
# indicators: indicate which slice are with the mask
# (b c1 c2 h w), c1: num_prediction; c2: num_slices
if len(gt_mask.shape) == 4:
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=3)
if len(pred_mask.shape) == 4:
pred_mask = rearrange(pred_mask, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
b, c1, c2, h, w = pred_mask.shape
indicators = torch.tensor(indicators, dtype=pred_mask.dtype)
indicators = repeat(indicators, "b d -> b c d h w", c=3, h=h, w=w)
pred_mask = pred_mask * indicators
gt_mask = gt_mask * indicators
# Same as "ranked_combined_loss"
return ranked_combined_loss(pred_mask, gt_mask, iou_pred)
def compute_all_loss_with_indicators(pred_mask, gt_mask, iou_pred, indicators):
# indicators: indicate which slice are with the mask
# (b c1 c2 h w), c1: num_prediction; c2: num_slices
if len(gt_mask.shape) == 4:
gt_mask = repeat(gt_mask, "b d h w -> b c d h w", c=1)
if len(pred_mask.shape) == 4:
pred_mask = rearrange(pred_mask, "b (c1 c2) h w -> b c1 c2 h w", c1=1, c2=3)
b, c1, c2, h, w = pred_mask.shape
indicators = torch.tensor(indicators, dtype=pred_mask.dtype)
indicators = repeat(indicators, "b d -> b c d h w", c=1, h=h, w=w)
pred_mask = pred_mask * indicators
gt_mask = gt_mask * indicators
# Same as "compute_all_loss"
return compute_all_loss(pred_mask, gt_mask, iou_pred)
def compute_all_loss(pred_mask, gt_mask, iou_pred):
if len(pred_mask.shape) == 4:
pred_mask = pred_mask.unsqueeze(1)
if len(gt_mask.shape) == 4:
gt_mask = gt_mask.unsqueeze(1)
# import ipdb; ipdb.set_trace()
fl, dl = combined_loss(pred_mask, gt_mask)
segment_loss = 20*fl.mean() + dl.mean()
iou_loss = mse_loss(iou_pred, compute_iou(torch.tensor(pred_mask>0, dtype=gt_mask.dtype), gt_mask))
total_loss = segment_loss.mean() + iou_loss.mean()
return total_loss, fl, dl, iou_loss
# def compute_
if __name__ == "__main__":
pred_mask = torch.ones((1,9,1024,1024))*9
pred_mask[:,:,:200,:] = -1
gt_mask = torch.ones((1,3,1024,1024))
loss = ranked_combined_loss(pred_mask, gt_mask, iou_pred=torch.ones(gt_mask.shape[:1]))
print(loss)

612
core/volume_predictor.py Normal file
View File

@ -0,0 +1,612 @@
import torch
import numpy as np
from typing import Any, List, Dict, Tuple, Optional
from einops import rearrange
from utils.amg import MaskData, batched_mask_to_box, batched_mask3d_to_box
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
from utils.amg3d import build_point_grid, calculate_stability_score_3d, MaskData3d, batch_iterator
from utils.amg import calculate_stability_score
from utils.transforms3d import ResizeLongestSide, SimpleResize
from einops import rearrange, repeat
# from datasets.data_engine import ValidEngine, BoxPromptGenerator
from datasets.data_engine import DataEngine, DataManager, BoxPromptGenerator, PointPromptGenerator
import cv2
import torch.nn.functional as F
from tutils.new.manager import ConfigManager
from tutils.nn.data import read, itk_to_np, write, np_to_itk
from torchvision.utils import save_image
from einops import rearrange, reduce, repeat
from core.loss import compute_dice_np
class PseudoPredictor:
# def __init__(self) -> None:
# self.image_encoder =
def predict(self, *args, **kwargs):
mask = np.zeros((1024,1024))
mask[:500,:500] = 1
return mask
class VolumePredictor:
def __init__(
self,
model,
slice_per_batch: int = 4,
points_per_side: Optional[int] = 32,
points_per_batch: int = 16,
pred_iou_thresh: float = 0.5, # debug, standard value is 0.88,
stability_score_thresh: float = 0.6, # debug, standard value is 0.95,
stability_score_offset: float = 1.0,
box_nms_thresh: float = 0.7,
use_postprocess = True,
use_noise_remove = True,
) -> None:
self.model = model
self.im_size = (model.image_encoder.img_size, model.image_encoder.img_size)
self.slice_per_batch = slice_per_batch
self.point_grids = build_point_grid(points_per_side, self.im_size)
self.features = None
self.is_image_set = False
self.transform = SimpleResize(model.image_encoder.img_size)
self.points_per_batch = points_per_batch
self.pred_iou_thresh = pred_iou_thresh
self.stability_score_thresh = stability_score_thresh
self.stability_score_offset = stability_score_offset
self.box_nms_thresh = box_nms_thresh
self.masks3d = dict()
self.box_prompt_generator = BoxPromptGenerator(size=(1024,1024))
self.masks3d = None
self.stability_score_2d = None
self.input_size = model.image_encoder.img_size
self.use_postprocess = use_postprocess
self.use_noise_remove = use_noise_remove
if not use_postprocess:
print("Warning! No postprocess")
if not use_noise_remove:
print("Warning! No use_noise_remove")
# self.original_size = (1024,1024)
@property
def device(self) -> torch.device:
return self.model.device
def reset_image(self) -> None:
"""Resets the currently set image."""
self.is_image_set = False
self.features = None
self.orig_h = None
self.orig_w = None
self.input_h = None
self.input_w = None
self.masks3d = None
self.stability_score_2d = None
def set_image(
self,
image: np.ndarray,
image_format: str = "nifti",
) -> None:
# Transform the image to the form expected by the model
self.original_size = image.shape
input_image_torch = torch.as_tensor(image, device=self.device)
input_image_torch = self.transform.apply_image(input_image_torch.float())
assert np.argmin(input_image_torch.shape) == 0, f"Got image.shape: {input_image_torch.shape}"
maxlen = input_image_torch.shape[0]
slices = []
for i in range(1, maxlen-1):
slices.append(input_image_torch[i-1:i+2, :, :])
input_slices = torch.stack(slices, axis=0)
self.set_torch_image(input_slices)
def batched_to_RGB(self, images):
for i in range(images.shape[0]):
images[i] = (images[i] - images[i].min()) / (images[i].max() - images[i].min()) * 255
return images
@torch.no_grad()
def set_torch_image(
self,
transformed_image: torch.Tensor,
) -> None:
assert (
len(transformed_image.shape) == 4
and transformed_image.shape[1] == 3
and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
self.reset_image()
self.input_size = tuple(transformed_image.shape[-2:])
transformed_image = self.batched_to_RGB(transformed_image)
input_image = self.model.preprocess(transformed_image)
features = []
for input_image_batch in batch_iterator(self.slice_per_batch, input_image):
# print(input_image_batch[0].shape)
features_batch = self.model.image_encoder(input_image_batch[0]).cpu()
features.append(features_batch)
self.features = torch.cat(features, axis=0)
self.is_image_set = True
def __call__(self, x, *args: Any, **kwds: Any) -> Any:
return self.predict_volume(x)
def merge_to_mask3d(self, idx, masks:MaskData):
if masks._stats == {} or len(masks['masks']) == 0:
print("No masks")
return
if self.masks3d is None:
self.masks3d = np.zeros(self.original_size)
if self.stability_score_2d is None:
self.stability_score_2d = np.zeros(self.original_size[0])
masks_values = masks['masks']
for mask_value in zip(masks_values):
old_mask = self.masks3d[idx-1:idx+2]
# self.masks3d[idx-1:idx+2] = np.where(mask_value > old_mask, mask_value, old_mask)
self.masks3d[idx-1:idx+2] = mask_value + old_mask
self.stability_score_2d[idx] = masks['stability_score_2d'][0,0]
def postprocess_3d(self, masks3d):
# add removing noise ?
return masks3d > 0
def _debug_predict_slice(
self,
x,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
box: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
template_slice_id:int = None,
return_stability: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
main entrence
predict 3d volume
x: volume: (c h w)
box: [[x,y,x,y]]
"""
# Check Input
assert len(x.shape) == 3
assert box is None or len(box.shape) == 2
# preprocess
# x = np.clip(x, -200,400)
print(f"Checking Data range: [{x.min()}, {x.max()}]" )
# Adjust direction
indicator = np.argmin(x.shape)
if indicator == 0:
pass
elif indicator == 1:
x = rearrange(x, "h c w -> c h w")
elif indicator == 2:
x = rearrange(x, "h w c -> c h w")
else:
raise NotImplementedError
# Preprocess prompts
self.original_size = x.shape[1:]
if point_coords is not None:
assert (
point_labels is not None
), "point_labels must be supplied if point_coords is supplied."
point_coords = self.transform.apply_coords(point_coords, self.original_size)
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
if box is not None:
box = self.transform.apply_boxes(box, self.original_size)
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
box_torch = box_torch[None, :]
if mask_input is not None:
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
mask_input_torch = mask_input_torch[None, :, :, :]
# set 3d image
self.set_image(x)
# predict center slice
center_idx = template_slice_id if template_slice_id is not None else x.shape[0] // 2
# print("Processing ", center_idx)
center_masks = self._predict_center_slice(center_idx, point_coords, box)
return center_masks['masks']
def predict_volume(
self,
x,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
box: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
template_slice_id:int = None,
return_stability: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
main entrence
predict 3d volume
x: volume: (c h w)
box: [[x,y,x,y]]
"""
# Check Input
assert len(x.shape) == 3
assert box is None or len(box.shape) == 2
# preprocess
# x = np.clip(x, -200,400)
print(f"Checking Data range: [{x.min()}, {x.max()}]" )
# Adjust direction
indicator = np.argmin(x.shape)
if indicator == 0:
pass
elif indicator == 1:
x = rearrange(x, "h c w -> c h w")
elif indicator == 2:
x = rearrange(x, "h w c -> c h w")
else:
raise NotImplementedError
# Preprocess prompts
self.original_size = x.shape[1:]
if point_coords is not None:
assert (
point_labels is not None
), "point_labels must be supplied if point_coords is supplied."
point_coords = self.transform.apply_coords(point_coords, self.original_size)
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
if box is not None:
box = self.transform.apply_boxes(box, self.original_size)
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
box_torch = box_torch[None, :]
if mask_input is not None:
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
mask_input_torch = mask_input_torch[None, :, :, :]
# set 3d image
self.set_image(x)
# predict center slice
center_idx = template_slice_id if template_slice_id is not None else x.shape[0] // 2
# print("Processing ", center_idx)
center_masks = self._predict_center_slice(center_idx, point_coords, box)
if center_masks._stats == {}:
print("Ends for no mask.")
raise ValueError
self.merge_to_mask3d(center_idx, center_masks)
previous_masks = center_masks
for i in range(center_idx+1, x.shape[0]-1):
# print("Processing downward", i)
previous_masks = self._predict_slice(i, previous_masks, orientation="down")
if previous_masks._stats == {}:
print("Ends for no mask.")
break
self.merge_to_mask3d(i, previous_masks)
previous_masks = center_masks
for i in np.arange(1, center_idx)[::-1]:
# print("Processing upward", i)
previous_masks = self._predict_slice(i, previous_masks, orientation="up")
if previous_masks._stats == {}:
print("Ends for no mask.")
break
self.merge_to_mask3d(i, previous_masks)
if self.masks3d is None:
self.masks3d = np.zeros_like(x)
if return_stability:
return self.postprocess_3d(self.masks3d), self.stability_score_2d
return self.postprocess_3d(self.masks3d)
def _predict_center_slice(self, idx, point_prompt=None, box_prompt=None):
if box_prompt is not None:
masks = self.genetate_masks_from_boxes(idx, all_boxes=box_prompt, tags=["center_slice"])
masks.to_numpy()
return masks
if point_prompt is not None:
masks = self.genetate_masks_from_point_grids(idx, point_prompt)
masks.to_numpy()
return masks
raise ValueError("No prompts! ?")
def _predict_slice(self, idx, previous_masks, orientation):
scaled_boxes, tags = self.generate_prompts_from_previous_masks(previous_masks, orientation)
masks = self.genetate_masks_from_boxes(idx, all_boxes=scaled_boxes, tags=tags)
masks.to_numpy()
return masks
def generate_prompts_from_previous_masks(self, previous_masks: MaskData, orientation):
if orientation == "down":
masks = previous_masks['masks'][:,2,:,:]
elif orientation == "up":
masks = previous_masks['masks'][:,0,:,:]
else:
raise ValueError
raw_tags = previous_masks['tags']
scaled_boxes = []
tags = []
for mask, tag in zip(masks, raw_tags):
if mask.sum() <= 50:
continue
# mask = self.remove_mask_noise(mask)
# if mask.sum() <= 50:
# continue
mask = F.interpolate(torch.Tensor(mask).float()[None,None,:,:], self.input_size).squeeze().numpy()
# scaled_boxes.append(self.box_prompt_generator.mask_to_bbox(mask))
scaled_boxes.append(self.box_prompt_generator.enlarge(self.box_prompt_generator.mask_to_bbox(mask)))
tags.append(tag)
scaled_boxes = np.array(scaled_boxes)
return scaled_boxes, tags
def genetate_masks_from_point_grids(self, idx, points_for_image):
idx = idx - 1 # ignore the head and tail slices
# Get points for this crop
data = MaskData()
tags = [f"s{idx}_p{p}" for p in range(points_for_image.shape[0])]
for (points, batched_tags) in batch_iterator(self.points_per_batch, points_for_image, tags):
batch_data = self._process_batch(idx, points=points, tags=batched_tags)
data.cat(batch_data)
del batch_data
# Remove duplicates within this crop.
# keep_by_nms = batched_nms(
# data["boxes"].float(),
# data["iou_preds"],
# torch.zeros_like(data["boxes"][:, 0]), # categories
# iou_threshold=self.box_nms_thresh,
# )
# data.filter(keep_by_nms)
return data
def genetate_masks_from_boxes(self, idx, all_boxes, tags):
idx = idx - 1
data = MaskData()
for (batched_boxes, batched_tags) in batch_iterator(self.points_per_batch, all_boxes, tags):
batch_data = self._process_batch(idx, boxes=batched_boxes, tags=batched_tags)
data.cat(batch_data)
del batch_data
return data
def _process_batch(
self,
fea_slice_idx,
points=None,
boxes=None,
multimask_output=True,
tags=None
) -> MaskData:
"""
Process with a subset of points. (bacause so many points can not be feed in one time)
"""
if points is not None:
in_points = torch.as_tensor(points, device=self.model.device)
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
masks, iou_preds, _ = self.predict_torch_by_sliceidx(
fea_slice_idx=fea_slice_idx,
point_coords=in_points[:, None, :],
point_labels=in_labels[:, None],
multimask_output=multimask_output,
return_logits=True,
)
masks = rearrange(masks, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
elif boxes is not None:
# in_points = torch.as_tensor(points, device=self.model.device)
# in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
boxes = torch.as_tensor(boxes, device=self.model.device)
masks, iou_preds, _ = self.predict_torch_by_sliceidx(
fea_slice_idx=fea_slice_idx,
boxes=boxes,
multimask_output=multimask_output,
return_logits=True,
)
masks = rearrange(masks, "b (c1 c2) h w -> b c1 c2 h w", c1=3, c2=3)
else:
raise ValueError(f"No points or boxes")
indices = iou_preds.argmax(axis=1)
# indices = torch.tensor([2], device=iou_preds.device)
pred_maxiou = []
for pred, i in zip(masks, indices):
pred_maxiou.append(pred[i,:,:,:])
masks = torch.stack(pred_maxiou, axis=0)
iou_maxiou = []
for iou, i in zip(iou_preds, indices):
iou_maxiou.append(iou[i])
iou_preds = torch.stack(iou_maxiou)
# Serialize predictions and store in MaskData
data = MaskData(
masks=masks.detach().cpu(),
iou_preds=iou_preds.detach().cpu(),
# points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
# tags=repeat(np.array(tags), "c -> (c 3)")
tags=np.array(tags),
)
del masks
# Filter by area
# if True:
# keep_mask = (data['masks']>0).sum(-1).sum(-1).sum(-1) < (data['masks']<=0).sum(-1).sum(-1).sum(-1) * 0.4
# data.filter(keep_mask)
# print("keep mask / pred", keep_mask.sum())
# Filter Background
# if True:
# keep_mask =
# Filter by predicted IoU
if self.use_postprocess:
if self.pred_iou_thresh > -0.0:
keep_mask = data["iou_preds"] > self.pred_iou_thresh
# print("pred_iou", data["iou_preds"], (data["masks"]>0).sum(-1).sum(-1))
data.filter(keep_mask)
# print("keep mask / pred", keep_mask.sum())
# Calculate stability score
data["stability_score"] = calculate_stability_score_3d(
data["masks"], self.model.mask_threshold, self.stability_score_offset
) # .mean(axis=-1)
if self.stability_score_thresh > 0.0:
# print("stability", data["stability_score"], (data["masks"]>0).sum(-1).sum(-1))
keep_mask = data["stability_score"] >= self.stability_score_thresh
data.filter(keep_mask)
# print("keep mask / stable", keep_mask.sum())
data["stability_score_2d"] = calculate_stability_score(
data["masks"][:,1:2,:,:], self.model.mask_threshold, self.stability_score_offset
)
# Threshold masks and calculate boxes
data['logits'] = data['masks']
data["noisy_masks"] = data["logits"] > self.model.mask_threshold
# data['masks'] = torch.zeros_like(data['noisy_masks'], dtype=data['noisy_masks'].dtype, device=data['noisy_masks'].device)
b, c,_,_ = data["noisy_masks"].shape
data['masks'] = data["noisy_masks"].float()
if self.use_noise_remove:
for i in range(b):
for j in range(c):
data['masks'][i,j,:,:] = torch.Tensor(self.remove_mask_noise(data['noisy_masks'][i,j,:,:]))
# data["boxes"] = batched_mask_to_box(reduce(data["masks"], "b c h w -> b h w", reduction="sum")>0)
data["boxes"] = batched_mask_to_box(data["masks"][:,1,:,:]>0)
return data
# @staticmethod
# def calculate_
@staticmethod
def batched_remove_noise(masks):
ori_shape = masks.shape
def remove_mask_noise(self, mask):
# mask_sum = mask.sum()
kerner_size = min(mask.sum() // 20, 8)
if isinstance(mask, torch.Tensor):
mask = mask.detach().cpu().numpy()
kernel = np.ones((kerner_size,kerner_size), dtype=np.uint8)
opening = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel, 1)
return opening
@torch.no_grad()
def predict_torch_by_sliceidx(
self,
fea_slice_idx: int,
point_coords: Optional[torch.Tensor] = None,
point_labels: Optional[torch.Tensor] = None,
boxes: Optional[torch.Tensor] = None,
mask_input: Optional[torch.Tensor] = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if not self.is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
if point_coords is not None:
points = (point_coords, point_labels)
else:
points = None
# Embed prompts
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=points,
boxes=boxes,
masks=mask_input,
)
# Predict masks
low_res_masks, iou_predictions = self.model.mask_decoder(
image_embeddings=self.features[fea_slice_idx].to(self.model.device),
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# Upscale the masks to the original image resolution
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size[1:])
if not return_logits:
masks = masks > self.model.mask_threshold
return masks, iou_predictions, low_res_masks
def valid_box(self, data, batch_idx):
# valid image with box, or point prompt
assert data['img'].shape[0] == 1, f"shape {data['img'].shape}"
image = data['img']
label = data['label']
box = BoxPromptGenerator().mask_to_bbox(label)
box_mask3d = self.predict_volume(
x=image,
box=box,
)
dice = compute_dice_np(box_mask3d, label.detach().cpu().numpy())
if __name__ == "__main__":
from core.learner3 import SamLearner
from modeling.build_sam3d2 import sam_model_registry
from trans_utils.data_utils import Data3dSolver
config = ConfigManager()
config.add_config("configs/vit_b_103.yaml")
model_type = "vit_b"
sam = sam_model_registry[model_type](checkpoint=None)
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
learner.use_lora()
pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth"
learner.load_well_trained_model(pth)
learner.cuda()
predictor = VolumePredictor(
model=learner.model,
use_postprocess=True,
use_noise_remove=True,)
# Load data
img_path = "/home1/quanquan/datasets/07_WORD/WORD-V0.1.0/imagesVa/word_0001.nii.gz" # "/home1/quanquan/datasets/59_SABS/sabs_CT_normalized/image_5.nii.gz"
label_path = "/home1/quanquan/datasets/07_WORD/WORD-V0.1.0/labelsVa/word_0001.nii.gz"
volume = itk_to_np(read(img_path)) # test several slices
label_itk = read(label_path)
spacing = label_itk.GetSpacing()
label = itk_to_np(label_itk) == 1
volume = np.clip(volume, -200, 400)
# Select the slice with the largest mask
s = reduce(label, "c h w -> c", reduction="sum")
coords = np.nonzero(s)
x_min = np.min(coords[0])
x_max = np.max(coords[0])
template_slice_id = s.argmax()
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id])
box = np.array([box])
pred = predictor.predict_volume(
x=volume,
box=box,
template_slice_id=template_slice_id,
return_stability=False,
)
Data3dSolver().simple_write(pred, path="mask.nii.gz", spacing=spacing)
Data3dSolver().simple_write(label, path="gt.nii.gz", spacing=spacing)

0
datasets/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

267
datasets/cache_dataset3d.py Normal file
View File

@ -0,0 +1,267 @@
"""
Slow Loading directly
So we pre-precess data
"""
import numpy as np
import os
from einops import rearrange, reduce, repeat
from tutils.nn.data import read, itk_to_np, np_to_itk, write
from tutils import tfilename
from .dataset3d import DATASET_CONFIG, Dataset3D as basic_3d_dataset
from monai import transforms
import torch
import cv2
from scipy.sparse import csr_matrix
import torch.nn.functional as F
from torchvision import transforms
from einops import rearrange
import glob
from torchvision import transforms
from monai import transforms as monai_transforms
# "class": ["spleen", "right kidney", "left kidney", "gallbladder", "esophagus", "liver", "stomach", "aorta", "postcava", "portal vein and splenic vein", "pancrease", "right adrenal gland", "left adrenal gland"],
# "class": ["liver", "right kidney", "left kidney", "spleen"],
TEMPLATE={
'01': [1,2,3,4,5,6,7,8,9,10,11,12,13,14],
'02': [1,0,3,4,5,6,7,0,0,0,11,0,0,14],
'03': [6],
'04': [6,27], # post process
'05': [2,26,32], # post process
'07': [6,1,3,2,7,4,5,11,14,18,19,12,20,21,23,24],
'08': [6, 2, 1, 11],
'09': [1,2,3,4,5,6,7,8,9,11,12,13,14,21,22],
'12': [6,21,16,2],
'13': [6,2,1,11,8,9,7,4,5,12,13,25],
'14': [11,11,28,28,28], # Felix data, post process
'10_03': [6, 27], # post process
'10_06': [30],
'10_07': [11, 28], # post process
'10_08': [15, 29], # post process
'10_09': [1],
'10_10': [31],
'58': [6,2,3,1],
'59': [1,2,3,4,5,6,7,8,9,10,11,12,13,14],
'60': np.arange(200).tolist(), # for debug
}
class Dataset3D(basic_3d_dataset):
def __init__(self, config=..., use_cache=True, *args, **kwargs) -> None:
super().__init__(config, use_cache=use_cache, *args, **kwargs)
self.basic_dir = config['data_root_path']
self.cache_dir = config['cache_data_path']
def prepare_transforms(self):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((1024,1024)),
])
self.test_transform = transforms.Compose([
monai_transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
])
# @tfunctime
# def prepare_datalist(self):
def prepare_cached_datalist(self):
raise DeprecationWarning("[Warning] Please use cache_dataset3d new version instead!")
config = self.config
data_paths = []
for dirpath in glob.glob(config['cache_data_path'] + "/*"):
data_paths += glob.glob(dirpath + "/image/*.jpg")
print("Load ", dirpath)
print('train len {}'.format(len(data_paths)))
print('Examples: ', data_paths[:2])
return data_paths
def caching_data(self):
assert self.use_cache == False
for index in range(len(self)):
self.cache_one_sample(index)
def cache_one_sample(self, index, debug=False):
# LABEL_INDICES
name = self.img_names[index]['img_path']
img_itk = read(self.img_names[index]['img_path'])
img_ori = itk_to_np(img_itk)
img_ori = np.clip(img_ori, -200,400)
# spacing = img_itk.GetSpacing()
scan_orientation = np.argmin(img_ori.shape)
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
dataset_name = self.img_names[index]['img_path'].replace(self.basic_dir,"").split("/")[0]
assert dataset_name[0] in ['0','1','2','3','4','5','6','7','8','9'], f"Got {dataset_name}"
all_labels = TEMPLATE[dataset_name[:2]]
num = 0
# if min(img_ori.shape) * 1.2 < max(img_ori.shape):
# orientation_all = [scan_orientation]
# else:
# orientation_all = [0,1,2]
orientation_all = [scan_orientation]
for orientation in orientation_all:
for slice_idx in range(2, img_ori.shape[orientation]-2):
# slice_idx = np.random.randint(2, img_ori.shape[orientation]-2)
if orientation == 0:
s = img_ori[slice_idx-1:slice_idx+2, :,:]
lb = label_ori[slice_idx-1:slice_idx+2, :,:]
# spacing = (spacing[1], spacing[2])
if orientation == 1:
s = img_ori[:,slice_idx-1:slice_idx+2,:]
s = rearrange(s, "h c w -> c h w")
lb = label_ori[:,slice_idx-1:slice_idx+2,:]
lb = rearrange(lb, "h c w -> c h w")
# spacing = (spacing[0], spacing[2])
if orientation == 2:
s = img_ori[:,:,slice_idx-1:slice_idx+2]
s = rearrange(s, "h w c -> c h w")
lb = label_ori[:,:,slice_idx-1:slice_idx+2]
lb = rearrange(lb, "h w c -> c h w")
# spacing = (spacing[0], spacing[1])
assert s.shape[0] == 3
# if np.float32(lb[1,:,:]>0).sum() <= 200:
# # return self._get_data((index+1)%len(self))
# continue
# Choose one label
label_num = int(lb.max())
masks_data = []
meta = {"img_name": name, "slice": slice_idx, "orientation": orientation, "label_idx": [], "labels": [], "id": f"{num:08d}" }
for label_idx in range(1,label_num+1):
one_lb = np.float32(lb==label_idx)
if one_lb[1,:,:].sum() <= (one_lb.shape[-1] * one_lb.shape[-2] * 0.0014):
continue
# if one_lb[0,:,:].sum()<=50 or one_lb[2,:,:].sum()<=50:
masks_data.append(one_lb)
meta['label_idx'].append(label_idx)
meta['labels'].append(all_labels[label_idx-1])
if len(masks_data) <= 0:
continue
img_rgb = s
img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy()
img_rgb = self.to_RGB(img_rgb)
save_image_name = tfilename(self.cache_dir, dataset_name, f"image/image_{index:04d}_{num:08d}.jpg")
self.save_img_rgb(rearrange(img_rgb, "c h w -> h w c"), save_image_name)
# Save cache data
save_label_name = tfilename(self.cache_dir, dataset_name, f"label/label_{index:04d}_{num:08d}.npz")
self.save_slice_mask(masks_data, save_label_name)
print("Save ", save_image_name)
self.save_meta(meta, tfilename(self.cache_dir, dataset_name, f"meta/meta_{index:04d}_{num:08d}.npy"))
num += 1
def save_meta(self, meta, path):
assert path.endswith(".npy")
np.save(path, meta)
def save_slice_mask(self, masks_data, prefix):
masks_data = F.interpolate(torch.Tensor(masks_data), size=(1024,1024)).numpy()
assert masks_data.shape[1:] == (3,1024,1024), f"{__file__} Got{masks_data.shape}"
for i in range(masks_data.shape[0]):
labeli = masks_data[i].astype(np.uint8) * 255
assert labeli.sum() > 0
path = tfilename(prefix+f"_{i:04d}.jpg")
cv2.imwrite(path, rearrange(labeli, "c h w -> h w c"))
print("save to ", path)
def _old_save_slice_mask(self, masks_data, path):
raise DeprecationWarning()
exit(0)
assert path.endswith(".npz")
# masks_data = np.array([m['segmentation'] for m in masks]).astype(int)
masks_data = F.interpolate(torch.Tensor(masks_data), size=(1024,1024)).numpy()
# masks_data = np.int8(masks_data>0)
assert masks_data.shape[1:] == (3,1024,1024), f"{__file__} Got{masks_data.shape}"
masks_data = rearrange(masks_data, "n c h w -> n (c h w)")
csr = csr_matrix(masks_data)
np.savez_compressed(path, data=csr.data, indices=csr.indices, indptr=csr.indptr, shape=csr.shape)
def save_img_rgb(self, img, path):
assert path.endswith(".jpg")
assert img.shape == (1024,1024,3)
cv2.imwrite(path, img.astype(np.uint8))
def _get_cached_data(self, index):
name = self.img_names[index]
# print(name)
img = cv2.imread(name)
compressed = np.load(name.replace("image/image_", "label/label_").replace(".jpg", ".npz"))
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
label_ori = csr.toarray()
label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024)
meta = np.load(name.replace("image/image_", "meta/meta_").replace(".jpg", ".npy"), allow_pickle=True).tolist()
# print(meta)
pp = reduce(label_ori[:,1,:,:], "n h w -> n", reduction="sum") > 500
if pp.sum() == 0:
return self._get_cached_data((index+1)%len(self))
label_idx = np.random.choice(a=np.arange(len(pp)), p=pp/pp.sum())
# label_idx = np.random.randint(0, label_ori.shape[0])
label_ori = label_ori[label_idx]
is_edge = meta.get('is_edge', 0)
return rearrange(img, "h w c -> c h w"), label_ori, name, meta['labels'][label_idx], meta['label_idx'][label_idx]
# @tfunctime
def __getitem__(self, index, debug=False):
# print("Dataset warning", index, len(self))
index = index % len(self)
img_rgb, label_ori, name, label_idx, local_idx = self._get_cached_data(index)
if label_ori.sum() <= 0:
print("[Label Error] ", name)
return self.__getitem__(index+1)
# assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}"
# img_rgb = self.transform((img_rgb[None,:,:,:]))
img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy()
vector = np.ones(3)
ret_dict = {
"name": name,
"img": img_rgb,
"label": label_ori,
"indicators": vector,
"class": label_idx,
"local_idx": local_idx,
}
return ret_dict
def _convert_one_mask_from_npz_to_jpg(self, path1=None):
# path1 = "/home1/quanquan/datasets/cached_dataset2/01_BCV-Abdomen/label/label_0129_00000043.npz" # 32K
prefix = path1.replace(".npz", "").replace("/label/", "/label_jpg/")
compressed = np.load(path1)
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
label_ori = csr.toarray()
label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024)
# print(label_ori.shape)
for i in range(label_ori.shape[0]):
labeli = label_ori[i]
path = tfilename(prefix+f"_{i:04d}.jpg")
cv2.imwrite(path, rearrange(labeli, "c h w -> h w c").astype(np.uint8))
print("save to ", path)
def convert_masks_types(self):
assert self.use_cache == True
for index in range(len(self)):
name = self.img_names[index]
label_path = name.replace("image/image_", "label/label_").replace(".jpg", ".npz")
self._convert_one_mask_from_npz_to_jpg(label_path)
if __name__ == "__main__":
# def go_cache():
from tutils.new.manager import ConfigManager
config = ConfigManager()
config.add_config("configs/vit_b.yaml")
dataset = Dataset3D(config=config['dataset'], use_cache=True)
dataset.caching_data()
# dataset.convert_masks_types()

View File

@ -0,0 +1,97 @@
"""
re-index by masks! not images
"""
import numpy as np
import os
from einops import rearrange, reduce, repeat
from tutils.nn.data import read, itk_to_np, np_to_itk, write
from tutils import tfilename
from .cache_dataset3d import Dataset3D as basic_3d_dataset
from monai import transforms
import torch
import cv2
from scipy.sparse import csr_matrix
import torch.nn.functional as F
from torchvision import transforms
from einops import rearrange
import glob
from torchvision import transforms
from monai import transforms as monai_transforms
class Dataset3D(basic_3d_dataset):
def __init__(self, config=..., use_cache=True, *args, **kwargs) -> None:
super().__init__(config, use_cache=use_cache, *args, **kwargs)
self.basic_dir = config['data_root_path']
self.cache_dir = config['cache_data_path']
def prepare_cached_datalist(self):
config = self.config
data_paths = []
for dirpath in glob.glob(config['cache_data_path'] + "/*"):
if not os.path.isdir(dirpath):
continue
prefix = dirpath.split("/")[-1]
if prefix[:2] in config['cache_prefix']:
data_paths += glob.glob(dirpath + "/label_jpg/*.jpg")
print("Load ", dirpath)
print('Masks len {}'.format(len(data_paths)))
print('Examples: ', data_paths[:2])
return data_paths
def _get_cached_data(self, index):
mask_path = self.img_names[index]
# print(name)
mask = np.int32(cv2.imread(mask_path) > 0)
prefix = mask_path[:-9]
img_path = prefix.replace("/label_jpg/label_", "/image/image_") + ".jpg"
img = cv2.imread(img_path)
number = int(mask_path[-8:-4])
meta = np.load(prefix.replace("/label_jpg/label_", "/meta/meta_")+".npy", allow_pickle=True).tolist()
# label_idx = np.random.randint(0, label_ori.shape[0])
return rearrange(img, "h w c -> c h w"), rearrange(mask, "h w c -> c h w"), mask_path, meta['labels'][number], meta['label_idx'][number]
# @tfunctime
def __getitem__(self, index, debug=False):
index = index % len(self)
img_rgb, label_ori, name, label_idx, local_idx = self._get_cached_data(index)
# assert label_ori.sum() > 0
if label_ori.sum() <= 0:
print("[Label Error] ", name)
return self.__getitem__(index+1)
# assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}"
# img_rgb = self.transform((img_rgb[None,:,:,:]))
img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy()
vector = np.ones(3)
ret_dict = {
"name": name,
"img": img_rgb,
"label": label_ori,
"indicators": vector,
"class": label_idx,
"local_idx": local_idx,
"is_problem": label_ori.sum() <= 30,
}
return ret_dict
if __name__ == "__main__":
# def go_cache():
from tutils.new.manager import ConfigManager
from tqdm import tqdm
config = ConfigManager()
config.add_config("configs/vit_b.yaml")
config.print()
dataset = Dataset3D(config=config['dataset'], use_cache=True)
# dataset.caching_data()
# dataset.convert_masks_types()
for i in tqdm(range(len(dataset))):
data = dataset.__getitem__(i)
# assert

274
datasets/data_engine.py Normal file
View File

@ -0,0 +1,274 @@
import numpy as np
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from einops import repeat, rearrange
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
class PointPromptGenerator(object):
def __init__(self, size=None) -> None:
pass
def get_prompt_point(self, gt_mask):
# assert gt_mask.shape == (1024,1024) or gt_mask.shape == (512,512), f"[data_engine] {__file__} Got{gt_mask.shape}"
if not (gt_mask.shape == (1024,1024) or gt_mask.shape == (512,512) or gt_mask.shape == (256,256) ):
print(f"[Warning] [data_engine] {__file__} Got{gt_mask.shape}")
assert gt_mask.sum() > 0
self.size = gt_mask.shape
self.xy = np.arange(0, self.size[0] * self.size[1])
gt_mask = np.float32(gt_mask>0)
prob = rearrange(gt_mask, "h w -> (h w)")
prob = prob / prob.sum()
loc = np.random.choice(a=self.xy, size=1, replace=True, p=prob)[0]
x, y = loc % self.size[1], loc // self.size[1]
return x, y
@staticmethod
def select_random_subsequent_point(pred_mask, gt_mask):
# union = np.float32((pred_mask + gt_mask)>0)
# diff = union - intersect
assert len(pred_mask.shape) == 2
assert len(gt_mask.shape) == 2
assert gt_mask.sum() > 0, f"[data_engine] Got {gt_mask.sum()}==0 "
diff = np.float32(np.abs(pred_mask - gt_mask)>0)
diff = np.nan_to_num(diff, nan=0)
# print(diff.shape)
xy = np.arange(0, diff.shape[0] * diff.shape[1])
if diff.sum() == 0:
prob = rearrange(gt_mask, "h w -> (h w)")
prob = prob / prob.sum()
loc = np.random.choice(a=xy, size=1, replace=True, p=prob)[0]
x, y = loc % diff.shape[1], loc // diff.shape[1]
return (x,y), 1
# Get_prompt_point
prob = rearrange(diff, "h w -> (h w)")
prob = prob / prob.sum()
loc = np.random.choice(a=xy, size=1, replace=True, p=prob)[0]
x, y = loc % diff.shape[1], loc // diff.shape[1]
if gt_mask[y, x] == 1 and pred_mask[y, x] == 0:
classification = 1
else:
classification = 0
# raise ValueError
return (x, y), classification
class BoxPromptGenerator(object):
def __init__(self, size) -> None:
self.size = size
@staticmethod
def mask_to_bbox(mask):
# Find the indices of all non-zero elements in the mask
coords = np.nonzero(mask)
# Compute the minimum and maximum values of the row and column indices
x_min = np.min(coords[1])
y_min = np.min(coords[0])
x_max = np.max(coords[1])
y_max = np.max(coords[0])
# Return the coordinates of the bounding box
return (x_min, y_min, x_max, y_max)
# return (y_min, x_min, y_max, x_max)
def add_random_noise_to_bbox(self, bbox):
bbox = list(bbox)
# Calculate the side lengths of the box in the x and y directions
x_side_length = bbox[2] - bbox[0]
y_side_length = bbox[3] - bbox[1]
# Calculate the standard deviation of the noise
std_dev = 0.01 * (x_side_length + y_side_length) / 2
# Generate random noise for each coordinate
x_noise = np.random.normal(scale=std_dev)
y_noise = np.random.normal(scale=std_dev)
# Add the random noise to each coordinate, but make sure it is not larger than 20 pixels
bbox[0] += min(int(round(x_noise)), 20)
bbox[1] += min(int(round(y_noise)), 20)
bbox[2] += min(int(round(x_noise)), 20)
bbox[3] += min(int(round(y_noise)), 20)
# Make sure the modified coordinates do not exceed the maximum possible values
bbox[0] = max(bbox[0], 0)
bbox[1] = max(bbox[1], 0)
bbox[2] = min(bbox[2], self.size[0])
bbox[3] = min(bbox[3], self.size[1])
# Return the modified bounding box
return bbox
def get_prompt_box(self, gt_mask):
""" return (x_min, y_min, x_max, y_max) """
assert gt_mask.shape == (1024,1024) or gt_mask.shape == (512,512) or gt_mask.shape == (256,256), f"[data_engine] {__file__} Got{gt_mask.shape}"
box = self.mask_to_bbox(gt_mask)
box_w_noise = self.add_random_noise_to_bbox(box)
return box_w_noise
def enlarge(self, bbox, margin=0):
x0, y0, x1, y1 = bbox[0], bbox[1], bbox[2], bbox[3]
margin_x = int((x1 - x0)*0.05)
margin_y = int((y1 - y0)*0.05)
x0 = max(x0 - margin_x, 0)
y0 = max(y0 - margin_x, 0)
x1 = min(x1 - margin_y, self.size[0]-1)
y1 = min(y1 - margin_y, self.size[1]-1)
# print("[DEBUG] , enlarge size: ", margin_x, margin_y)
# print("[DEBUG] from", bbox, "to", (x0,y0,x1,y1))
return (x0,y0,x1,y1)
class DataEngine(Dataset):
def __init__(self, dataset=None, img_size=None) -> None:
# CACHE_DISK_DIR="/home1/quanquan/code/projects/medical-guangdong/cache/data2d_3/"
super().__init__()
self.point_prompt_generator = PointPromptGenerator(img_size)
self.box_prompt_generator = BoxPromptGenerator(img_size)
# self._get_dataset(dirpath=dirpath)
self.dataset = dataset
# def _get_dataset(self, dirpath):
# self.dataset = Dataset2D(dirpath=dirpath, is_train=True)
def __len__(self):
return len(self.dataset)
def _get_true_index(self, idx):
return idx
def __getitem__(self, idx):
return self.get_prompt(idx)
def get_prompt_point(self, gt_mask):
return self.point_prompt_generator.get_prompt_point(gt_mask)
def get_prompt_box(self, gt_mask):
return self.box_prompt_generator.get_prompt_box(gt_mask)
# def _get_data_from_dataset(self, idx):
def get_prompt(self, idx):
idx = self._get_true_index(idx)
data = self.dataset.__getitem__(idx)
img = data['img'] # (3,h,w) d=3
mask = data['label'] # (3,h,w) d=3
try:
gt_mask = mask[1,:,:]
except Exception as e:
import ipdb; ipdb.set_trace()
gt_mask = mask[1,:,:]
gt_mask = gt_mask.numpy() if isinstance(gt_mask, torch.Tensor) else gt_mask
# if np.random.rand() > 0.5:
prompt_point = self.get_prompt_point(gt_mask)
# else:
prompt_box = self.get_prompt_box(gt_mask)
data['prompt_point'] = np.array(prompt_point).astype(np.float32)
data['prompt_box'] = np.array(prompt_box).astype(np.float32)
data['point_label'] = np.ones((1,)).astype(np.float32)
return data
def get_subsequent_prompt_point(self, pred_mask, gt_mask):
# return self.point_prompt_generator.select_random_subsequent_point_torch(pred_mask, gt_mask)
# return self.point_prompt_generator.select_random_subsequent_point(pred_mask=pred_mask, gt_mask=gt_mask)
coord_collect = []
label_collect = []
for i in range(pred_mask.shape[0]):
coords, label = self.point_prompt_generator.select_random_subsequent_point(pred_mask[i][0], gt_mask[i][0])
if label == -1:
return None, None
coord_collect.append(coords)
label_collect.append(label)
coord_collect = np.stack(coord_collect, axis=0)
label_collect = np.stack(label_collect, axis=0)
return coord_collect, label_collect
def get_noisy_box_from_box(self, box):
# Get noisy box from labeled box
return self.box_prompt_generator.add_random_noise_to_bbox(box)
# def get_prompt_mask(self, )
# class ValidEngine(DataEngine):
# def __init__(self, dataset=None, img_size=(1024,1024), is_train=False) -> None:
# # assert dataset is not None
# self.dataset = dataset
# self.is_train = is_train
# super().__init__(dataset=dataset, img_size=img_size)
# self.expand_dataset_ratio = 1
# # def _get_dataset(self, dirpath):
# # self.dataset = Dataset3D(dirpath=dirpath, is_train=self.is_train)
# def __len__(self):
# return len(self.dataset)
# def _get_true_index(self, idx):
# return idx
class DataManager:
def __init__(self, img_size=None) -> None:
self.point_prompt_generator = PointPromptGenerator(img_size)
self.box_prompt_generator = BoxPromptGenerator(img_size)
def get_prompt_point(self, gt_mask):
return self.point_prompt_generator.get_prompt_point(gt_mask)
def get_prompt_box(self, gt_mask):
return self.box_prompt_generator.get_prompt_box(gt_mask)
def get_subsequent_prompt_point(self, pred_mask, gt_mask):
# return self.point_prompt_generator.select_random_subsequent_point_torch(pred_mask, gt_mask)
# return self.point_prompt_generator.select_random_subsequent_point(pred_mask=pred_mask, gt_mask=gt_mask)
coord_collect = []
label_collect = []
for i in range(pred_mask.shape[0]):
coords, label = self.point_prompt_generator.select_random_subsequent_point(pred_mask[i][0], gt_mask[i][0])
if label == -1:
return None, None
coord_collect.append(coords)
label_collect.append(label)
coord_collect = np.stack(coord_collect, axis=0)
label_collect = np.stack(label_collect, axis=0)
return coord_collect, label_collect
def get_noisy_box_from_box(self, box):
# Get noisy box from labeled box
return self.box_prompt_generator.add_random_noise_to_bbox(box)
if __name__ == "__main__":
dataset = DataEngine()
data = dataset.__getitem__(0)
import ipdb; ipdb.set_trace()

239
datasets/dataset3d.py Normal file
View File

@ -0,0 +1,239 @@
from torchvision import transforms
from monai import transforms as monai_transforms
import numpy as np
import SimpleITK as sitk
import torch
from torch.utils.data import Dataset as dataset
import torch.nn.functional as F
import glob
import os
from einops import rearrange
from tutils.nn.data import read, itk_to_np
from tqdm import tqdm
# from monai.transforms import SpatialPadd, CenterSpatialCropd, Resized, NormalizeIntensityd
# from monai.transforms import RandAdjustContrastd, RandShiftIntensityd, Rotated, RandAffined
# from datasets.common_2d_aug import RandomRotation, RandomResizedCrop, RandomHorizontalFlip, ColorJitter, ToTensor, Normalize
from tutils import tfilename, tdir
import random
import time
import cv2
from scipy.sparse import csr_matrix
def tfunctime(func):
def run(*argv, **kargs):
t1 = time.time()
ret = func(*argv, **kargs)
t2 = time.time()
print(f"[Function {func.__name__}] Running time:{(t2-t1):.6f}s")
return ret
return run
# DEFAULT_PATH='/home1/quanquan/datasets/KiTS/'
DEFAULT_PATH="/home1/quanquan/datasets/BCV-Abdomen/Training/"
LABEL_INDICES={
"t2sag": ["bg","kidney", "label 2", "label 3", "rectum", "tumor", "other"],
}
# CACHE_DISK_DIR="/home1/quanquan/code/projects/medical-guangdong/cache/data2d_3/"
CACHE_DISK_DIR=None
# DEFAULT_CONFIG={
# "pad": (512,512),
# "crop": (384,384),
# "resize": (512,512),
# }
DATASET_CONFIG={
'split': 'train',
'data_root_path':'/quanquan/datasets/',
'dataset_list': ['sam', "their", "ours"],
'data_txt_path':'./datasets/dataset_list/',
}
DATASET_METAINFO={
"WORD": {0:"Background", 1:"Liver", 2:"Spleen", 3:"Left Kidney", 4:"Right Kidney", 5:"Stomach", 6:"Gallbladder", 7:"Esophagus", 8:"Pancreas", 9:"Duodenum", 10:"Colon", 11:"Intestine", 12:"Adrenal", 13:"Rectum", 14:"Bladder", 15:"left head of femur", 16:"right head of femur"}
}
class Dataset3D(dataset):
def __init__(self, config=DATASET_CONFIG, is_train=True, split='train', getting_multi_mask=False, use_cache=False) -> None:
super().__init__()
self.config = config
self.is_train = is_train
self.split = split
self.getting_multi_mask = getting_multi_mask
self.use_cache = use_cache
self.img_names = self.prepare_cached_datalist() if use_cache else self.prepare_datalist()
# self.img_names = self.prepare_datalist()
self.prepare_transforms()
def prepare_cached_datalist(self):
raise NotImplementedError
def prepare_transforms(self):
self.transform = monai_transforms.Compose([
# transforms.Resized(keys=['img', 'label'], spatial_size=(3,512,512)),
# transforms.RandSpatialCropd(keys=["img", 'label'], roi_size=(3,448,448)),
# transforms.RandAffined(keys=['img', 'label'], prob=0.5, shear_range=(0.2,0.2)),
# transforms.RandCropByPosNegLabeld(keys=['img', 'label'], spatial_size=(3,448,448), label_key='label', neg=0),
# transforms.RandSmoothFieldAdjustContrastd(keys=['img', 'label'], )
monai_transforms.RandAdjustContrastd(keys=['img'], ),
# transforms.RandShiftIntensityd(keys=['img'], prob=0.8, offsets=(0, 20)),
monai_transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
])
self.test_transform = transforms.Compose([
monai_transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
])
def _get_image(self, index):
name = self.img_names[index]['img_path']
if not os.path.exists(name):
print("Path not exists!", name)
return self._get_image(index+1%len(self))
img_itk = read(self.img_names[index]['img_path'])
img_ori = itk_to_np(img_itk)
img = np.clip(img_ori, -200, 400).astype(np.float32)
img = (img - img.min()) / img.max() * 255
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
return {"img":img, "name":name.replace(self.config['data_root_path'], ""), "label":label_ori}
def __len__(self):
return len(self.img_names)
# @tfunctime
def prepare_datalist(self):
config = self.config
data_paths = []
for item in config['dataset_list']:
print("Load datalist from ", item)
for line in open(config["data_txt_path"]+ item + f"_{self.split}.txt"):
name = line.strip().split()[1].split('.')[0]
img_path = config['data_root_path'] + line.strip().split()[0]
label_path = config['data_root_path'] + line.strip().split()[1]
data_paths.append({'img_path': img_path, 'label_path': label_path, 'name': name})
print('train len {}'.format(len(data_paths)))
return data_paths
# @tfunctime
def _get_data(self, index, debug=False):
# LABEL_INDICES
name = self.img_names[index]['img_path']
img_itk = read(self.img_names[index]['img_path'])
img_ori = itk_to_np(img_itk)
# spacing = img_itk.GetSpacing()
scan_orientation = np.argmin(img_ori.shape)
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
if min(img_ori.shape) * 2 < max(img_ori.shape):
orientation = scan_orientation
else:
orientation = np.random.randint(3)
slice_idx = np.random.randint(2, img_ori.shape[orientation]-2)
if orientation == 0:
s = img_ori[slice_idx-1:slice_idx+2, :,:]
lb = label_ori[slice_idx-1:slice_idx+2, :,:]
# spacing = (spacing[1], spacing[2])
if orientation == 1:
s = img_ori[:,slice_idx-1:slice_idx+2,:]
s = rearrange(s, "h c w -> c h w")
lb = label_ori[:,slice_idx-1:slice_idx+2,:]
lb = rearrange(lb, "h c w -> c h w")
# spacing = (spacing[0], spacing[2])
if orientation == 2:
s = img_ori[:,:,slice_idx-1:slice_idx+2]
s = rearrange(s, "h w c -> c h w")
lb = label_ori[:,:,slice_idx-1:slice_idx+2]
lb = rearrange(lb, "h w c -> c h w")
# spacing = (spacing[0], spacing[1])
assert s.shape[0] == 3
if np.float32(lb[1,:,:]>0).sum() <= 200:
return self._get_data((index+1)%len(self))
# Choose one label
label_num = int(lb.max())
is_good_mask = []
for label_idx in range(1,label_num+1):
one_lb = np.float32(lb==label_idx)
is_good_mask.append(one_lb.sum()>=50)
label_idx = np.random.choice(range(1,label_num+1), p=np.array(is_good_mask)/np.sum(is_good_mask))
lb = np.float32(lb==label_idx)
return s, lb, name, label_idx
# @tfunctime
def _get_cached_data(self, index):
name = self.img_names[index]
img = cv2.imread(name)
compressed = np.load(name.replace("image/image_", "label/label_").replace(".jpg", ".npz"))
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
label_ori = csr.toarray()
label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024)
label_idx = np.random.randint(0, label_ori.shape[0])
label_ori = label_ori[label_idx]
return rearrange(img, "h w c -> c h w"), label_ori, name, -1
def to_RGB(self, img):
# transform images to RGB style
img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(int)
return img
# @tfunctime
def __getitem__(self, index, debug=False):
img_ori, label_ori, name, label_idx = self._get_data(index)
img_ori = np.clip(img_ori, -200,400)
img_rgb = self.to_RGB(img_ori)
assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}"
bundle_ori = {"img":torch.Tensor(img_rgb).unsqueeze(0), "label":torch.Tensor(label_ori).unsqueeze(0)}
# import ipdb; ipdb.set_trace()
if self.is_train:
# bundle = self.transform(bundle_ori)[0] # use with transforms.RandCropByPosNegLabeld
bundle = self.transform(bundle_ori)
else:
bundle = self.test_transform(bundle_ori)
if not self.use_cache:
bundle['label'] = (bundle['label']>0.5).float()
vector = np.ones(3)
if debug:
ret_dict = {
"name": name,
"img": bundle['img'],
"label": bundle['label'],
"img_ori":img_ori,
"label_ori":label_ori,
"label_idx": label_idx,
"indicators": vector,
# "label_name":
}
return ret_dict
ret_dict = {
"name": name,
"img": bundle['img'][0].float(),
"label": bundle['label'][0].float(),
"indicators": vector,
}
if bundle['label'][0][1,:,:].sum() <= 0:
return self.__getitem__(index+1 % len(self))
return ret_dict
if __name__ == "__main__":
from tutils.new.manager import ConfigManager
config = ConfigManager()
config.add_config("configs/vit_b_103.yaml")
dataset = Dataset3D(config=config['dataset']) # , use_cache=True
data = dataset.__getitem__(0)
import ipdb; ipdb.set_trace()
from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=8)
for batch in loader:
print(batch['img'].shape, batch['label'].shape)
# print(data['label'].max())
# import ipdb; ipdb.set_trace()

View File

@ -0,0 +1,178 @@
# from torchvision import transforms
from monai import transforms
import numpy as np
import SimpleITK as sitk
import torch
from torch.utils.data import Dataset as dataset
import glob
import os
from einops import rearrange, repeat
from tutils.nn.data.tsitk import read
from tqdm import tqdm
# from monai.transforms import SpatialPadd, CenterSpatialCropd, Resized, NormalizeIntensityd
# from monai.transforms import RandAdjustContrastd, RandShiftIntensityd, Rotated, RandAffined
# from datasets.common_2d_aug import RandomRotation, RandomResizedCrop, RandomHorizontalFlip, ColorJitter, ToTensor, Normalize
from tutils import tfilename, tdir
import random
from tutils.nn.data import itk_to_np
from scipy.sparse import csr_matrix
import cv2
DEFAULT_PATH="/quanquan/datasets/08_AbdomenCT-1K/"
class Dataset2D(dataset):
def __init__(self, dirpath=None, is_train=True, getting_multi_mask=False) -> None:
super().__init__()
self.dirpath = dirpath
self.is_train = is_train
self.getting_multi_mask = getting_multi_mask
self.img_names = self.prepare_datalist()
self.prepare_transforms()
self.weights_dict = {"gt":2, "sam_auto_seg":2, "prompt_point_from_superpixel":1, "prompt_box_from_superpixel":1, "superpixel":0}
def prepare_transforms(self):
self.transform = transforms.Compose([
transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
# transforms.RandSpatialCropd(keys=["img"], roi_size=(448,448,1)),
transforms.RandAffined(keys=['img', 'label'], prob=0.5, shear_range=(0.2,0.2)),
transforms.RandCropByPosNegLabeld(keys=['img', 'label'], spatial_size=(3,960,960), label_key='label', neg=0),
# transforms.RandSmoothFieldAdjustContrastd(keys=['img', 'label'], )
transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
transforms.RandAdjustContrastd(keys=['img'], ),
transforms.RandShiftIntensityd(keys=['img'], prob=0.8, offsets=(-5, 5)),
])
self.test_transform = transforms.Compose([
transforms.Resized(keys=['img'], spatial_size=(3,1024,1024)),
])
def __len__(self):
return len(self.img_names)
def to_RGB(self, img):
# transform images to RGB style
img = ((img - img.min()) / img.max() * 255).astype(int)
return img
def prepare_datalist(self):
dirpath_img = os.path.join(self.dirpath, 'preprocessed', 'cache_2d_various_pseudo_masks')
names = glob.glob(os.path.join(dirpath_img, "*_mask.npz"))
names = [os.path.split(name)[-1].replace("_mask.npz", "") for name in names]
names.sort()
# names = names[:15000]
print(f"[Dataset2d] Load {len(names)} paths.")
assert len(names) > 0, f"{__file__} Gotdirpath: {self.dirpath}"
return names
def _get_data(self, index, debug=False, iternum=0):
img_names = self.img_names
img_info = os.path.split(img_names[index])[-1].split('_s')
filename, slice_idx = img_info[0], int(img_info[-1][:4])
mask_loc = np.random.randint(0,3)
if mask_loc == 0:
slices_indices = [slice_idx, slice_idx+1, slice_idx+2]
elif mask_loc == 1:
slices_indices = [slice_idx-1, slice_idx, slice_idx+1]
elif mask_loc == 2:
slices_indices = [slice_idx-2, slice_idx-1, slice_idx]
# Load .npy data
filenames = [os.path.join(self.dirpath, "preprocessed", "cache_jpg", f"{filename}_s{i:04}_img.jpg") for i in slices_indices]
for name in filenames:
if not os.path.exists(name):
return self._get_data(index+1 % len(self))
imgs = [cv2.imread(name, cv2.IMREAD_GRAYSCALE) for name in filenames]
img_rgb = np.stack(imgs, axis=0)
# Load RGB data
compressed = np.load(os.path.join(self.dirpath, "preprocessed", "cache_2d_various_pseudo_masks", img_names[index]+"_mask.npz"))
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
label_ori = csr.toarray()
label_ori = rearrange(label_ori, "c (h w) -> c h w", h=1024, w=1024)
metadata = np.load(os.path.join(self.dirpath, "preprocessed", "cache_2d_various_pseudo_masks", img_names[index]+"_metadata.npy"), allow_pickle=True)
label_prob = np.array([self.weights_dict[item['source']] for item in metadata]).astype(float)
label_prob = label_prob / label_prob.sum()
label_idx = np.random.choice(a=np.arange(len(metadata)), p=label_prob)
label_ori = label_ori[label_idx]
metadata = metadata[label_idx]
assert metadata['source'] != 'superpixel'
assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}"
bundle_ori = {"img":torch.Tensor(rearrange(img_rgb, "c h w -> 1 c h w")), "label":torch.Tensor(repeat(label_ori, "h w -> 1 3 h w"))}
# import ipdb; ipdb.set_trace()
if self.is_train:
bundle = self.transform(bundle_ori)[0]
else:
bundle = self.test_transform(bundle_ori)
bundle['label'] = (bundle['label']>0.5).float()
if bundle['label'][0].sum() < 100:
return self._get_data((index+1)%len(self), iternum=iternum+1)
vector = np.zeros(3)
vector[mask_loc] = 1
if debug:
ret_dict = {
"name": img_names[index],
"img": bundle['img'][0],
"label": bundle['label'][0],
"img_ori":img_rgb,
"label_ori":label_ori,
"weight": self.weights_dict[metadata['source']],
"iternum": iternum,
"mask_loc": mask_loc,
"indicators": vector,
}
return ret_dict
ret_dict = {
"name": img_names[index],
"img": bundle['img'][0],
"label": bundle['label'][0],
"mask_loc": mask_loc,
"indicators": vector,
}
return ret_dict
def __getitem__(self, index):
return self._get_data(index)
class Testset2d(Dataset2D):
def __init__(self, dirpath=None, is_train=False, getting_multi_mask=False) -> None:
super().__init__(dirpath, is_train, getting_multi_mask)
self.test_names = self.prepare_datalist()
def prepare_datalist(self):
dirpath_img = os.path.join(self.dirpath, 'cache_2d_various_pseudo_masks')
names = glob.glob(os.path.join(dirpath_img, "*_mask.npz"))
names = [os.path.split(name)[-1].replace("_mask.npz", "") for name in names]
names.sort()
names = names[15000:]
print(f"[Dataset2d] Load {len(names)} paths.")
assert len(names) > 0, f"{__file__} Gotdirpath: {self.dirpath}"
return names
if __name__ == "__main__":
from torch.utils.data import DataLoader
dataset = Dataset2D(dirpath=DEFAULT_PATH)
loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)
iternums = 0
for i, data in enumerate(loader):
# iternums += data['iternum'].item()
print(i, iternums / (i+1), data['img'].shape, data['label'].shape)
assert data['label'].sum() >= 100, f"{__file__} Got{data['label'].sum()}"
assert torch.Tensor(data['label']==1).sum() >= 100, f"{__file__} Got {torch.Tensor(data['label']==1).sum().sum()}"
import ipdb; ipdb.set_trace()

View File

@ -0,0 +1,74 @@
# from torchvision import transforms
from monai import transforms
import numpy as np
import SimpleITK as sitk
import torch
from torch.utils.data import Dataset as dataset
import torch.nn.functional as F
import glob
import os
from einops import rearrange
from tutils.nn.data import read, itk_to_np
from tqdm import tqdm
from tutils import tfilename, tdir
import random
# from .dataset2d import Dataset2D
from .dataset3d_2dmask import Dataset2D
from .dataset3d import Dataset3D
class DatasetMerged(dataset):
def __init__(self, config=None, is_train=True, getting_multi_mask=False) -> None:
super().__init__()
self.dataset2d = Dataset2D(dirpath="/quanquan/datasets/08_AbdomenCT-1K/", is_train=True)
self.dataset3d = Dataset3D(config=config, is_train=True)
self.len_2d = len(self.dataset2d)
self.len_3d = len(self.dataset3d)
def __getitem__(self, index, debug=False):
index = index % len(self)
# print("DEBUG! is_2d:", index < self.len_2d)
if index < self.len_2d:
return self.dataset2d.__getitem__(index)
else:
index = (index - self.len_2d) % self.len_3d
return self.dataset3d.__getitem__(index)
def __len__(self):
return len(self.dataset2d) + len(self.dataset3d) * 200
class TestsetMerged(dataset):
def __init__(self, config=None, is_train=False) -> None:
super().__init__()
self.dataset2d = Dataset2D(dirpath="/quanquan/datasets/08_AbdomenCT-1K/preprocessed/", is_train=False)
self.dataset3d = Dataset3D(config=config, is_train=False, split='val')
self.len_2d = len(self.dataset2d)
self.len_3d = len(self.dataset3d)
def __getitem__(self, index, debug=False):
index = index % len(self)
if index < self.len_2d:
return self.dataset2d.__getitem__(index)
else:
index = (index - self.len_2d) % self.len_3d
return self.dataset3d.__getitem__(index)
def __len__(self):
return len(self.dataset2d) + len(self.dataset3d) * 2
if __name__ == "__main__":
from tutils import timer
from tutils.new.manager import trans_args, trans_init, ConfigManager
config = ConfigManager()
config.add_basic_config()
config.add_config("configs/vit_b.yaml")
dataset = DatasetMerged(config['dataset'])
tt = timer()
for i in range(20000,len(dataset)):
data = dataset.__getitem__(i)
print("time: ", tt())

435
datasets/generate_txt.py Normal file
View File

@ -0,0 +1,435 @@
import os
import glob
import shutil
from tutils import tfilename
HOME_PATH="/quanquan/datasets/"
def check_existing(img_path, label_path):
if os.path.exists(img_path) and os.path.exists(label_path):
return True
else:
if not os.path.exists(img_path):
print("IMAGE Not exist: ", img_path)
if not os.path.exists(label_path):
print("LABEL Not exist: ", label_path)
return False
def get_availabel_files(names, label_names):
llist = [[n,n2] for n,n2 in zip(names, label_names) if check_existing(n, n2)]
names = [n[0] for n in llist]
label_names = [n[1] for n in llist]
return names, label_names
def write_txt(img_paths, label_paths, meta_info, split, writing_mode='a+'):
dataset_id = meta_info['dataset_id']
assert split in ['train', 'val', 'test'], f" split in ['train', 'val', 'test'] , but Got {split}"
save_path = meta_info["save_txt_path"].replace("_train.txt", f"_{split}.txt")
count = 0
with open(save_path, writing_mode) as f:
for p1, p2 in zip(img_paths, label_paths):
p1 = p1.replace(meta_info['home_path'], "")
p2 = p2.replace(meta_info['home_path'], "")
line = f"{p1}\t{p2}\n"
# print(line, end=" ")
f.write(line)
count += 1
if count <= 0:
raise ValueError(f"ID: {meta_info['dataset_id']}, \tTask: {meta_info['dataset_name']}\t, {count} files are writen.")
print(f"ID: {meta_info['dataset_id']}, \tTask: {meta_info['dataset_name']}\t, {count} files are writen.\t Writing Over! write into ", save_path)
def organize_in_nnunet_style(meta_info):
dirpath = os.path.join(meta_info['home_path'], meta_info['dirpath'])
if os.path.exists(os.path.join(dirpath, "imagesTr")) and os.path.exists(os.path.join(dirpath, "labelsTr")):
img_paths = glob.glob(os.path.join(dirpath, "imagesTr", "*.nii.gz"))
img_paths.sort()
label_paths = [p.replace("imagesTr", "labelsTr")[:-12]+".nii.gz" for p in img_paths]
img_paths, label_paths = get_availabel_files(img_paths, label_paths)
write_txt(img_paths, label_paths, meta_info=meta_info, split='train', writing_mode="a+")
if os.path.exists(os.path.join(dirpath, "imagesVa")) and os.path.exists(os.path.join(dirpath, "labelsVa")):
img_paths = glob.glob(os.path.join(dirpath, "imagesVa", "*.nii.gz"))
img_paths.sort()
label_paths = [p.replace("imagesVa", "labelsVa")[:-12]+".nii.gz" for p in img_paths]
img_paths, label_paths = get_availabel_files(img_paths, label_paths)
write_txt(img_paths, label_paths, meta_info=meta_info, split='val', writing_mode="a+")
if os.path.exists(os.path.join(dirpath, "imagesTs")) and os.path.exists(os.path.join(dirpath, "labelsTs")):
img_paths = glob.glob(os.path.join(dirpath, "imagesTs", "*.nii.gz"))
img_paths.sort()
label_paths = [p.replace("imagesTs", "labelsTs")[:-12]+".nii.gz" for p in img_paths]
img_paths, label_paths = get_availabel_files(img_paths, label_paths)
write_txt(img_paths, label_paths, meta_info=meta_info, split='test', writing_mode="a+")
def organize_in_style2(meta_info):
dirpath = os.path.join(meta_info['home_path'], meta_info['dirpath'])
if os.path.exists(os.path.join(dirpath, "imagesTr")) and os.path.exists(os.path.join(dirpath, "labelsTr")):
img_paths = glob.glob(os.path.join(dirpath, "imagesTr", "*.nii.gz"))
img_paths.sort()
label_paths = [p.replace("imagesTr", "labelsTr") for p in img_paths]
img_paths, label_paths = get_availabel_files(img_paths, label_paths)
write_txt(img_paths, label_paths, meta_info=meta_info, split='train', writing_mode="a+")
if os.path.exists(os.path.join(dirpath, "imagesVa")) and os.path.exists(os.path.join(dirpath, "labelsVa")):
img_paths = glob.glob(os.path.join(dirpath, "imagesVa", "*.nii.gz"))
img_paths.sort()
label_paths = [p.replace("imagesVa", "labelsVa") for p in img_paths]
img_paths, label_paths = get_availabel_files(img_paths, label_paths)
write_txt(img_paths, label_paths, meta_info=meta_info, split='val', writing_mode="a+")
if os.path.exists(os.path.join(dirpath, "imagesTs")) and os.path.exists(os.path.join(dirpath, "labelsTs")):
img_paths = glob.glob(os.path.join(dirpath, "imagesTs", "*.nii.gz"))
img_paths.sort()
label_paths = [p.replace("imagesTs", "labelsTs") for p in img_paths]
img_paths, label_paths = get_availabel_files(img_paths, label_paths)
write_txt(img_paths, label_paths, meta_info=meta_info, split='test', writing_mode="a+")
def organize_by_names(names_in, label_names_in, meta_info):
assert len(names_in) > 0, f"Meta info: {meta_info}"
names, label_names = get_availabel_files(names_in, label_names_in)
assert len(names) > 0, f"Meta info: {meta_info}, \n {names_in[:2]} \n {label_names_in[:2]}"
assert len(label_names) > 0, f"Meta info: {meta_info}, \n {names_in[:2]} \n {label_names_in[:2]}"
# print("debug files", len(names))
if len(names) > 10:
num_valid = min(int(len(names) // 10), 10)
# print("num valid", num_valid)
train_names = names[:-num_valid*2]
valid_names = names[-num_valid*2:-num_valid]
test_names = names[-num_valid:]
train_labels = label_names[:-num_valid*2]
valid_labels = label_names[-num_valid*2:-num_valid]
test_labels = label_names[-num_valid:]
write_txt(train_names, train_labels, meta_info=meta_info, split="train")
write_txt(valid_names, valid_labels, meta_info=meta_info, split="val")
write_txt(test_names, test_labels, meta_info=meta_info, split="test")
else:
write_txt(names, label_names, meta_info=meta_info, split="train")
def clear_files(train_path):
if os.path.exists(train_path):
parent, name = os.path.split(train_path)
shutil.move(train_path, tfilename(parent, "misc", name))
val_path = train_path.replace("_train.txt", "_val.txt")
if os.path.exists(val_path):
parent, name = os.path.split(val_path)
shutil.move(val_path, os.path.join(parent, "misc", name))
test_path = train_path.replace("_train.txt", "_test.txt")
if os.path.exists(test_path):
parent, name = os.path.split(test_path)
shutil.move(test_path, os.path.join(parent, "misc", name))
print("Files cleared!")
# from tutils.nn.data import read
# def convert_to_nii(paths):
###################################################################################
###################################################################################
def get_BCV_Abdomen(save_path=None):
meta_info = {
"dataset_name": "BTCV",
"dataset_id": "01",
"modality": "CT",
"home_path": HOME_PATH,
"dirpath": "01_BCV-Abdomen/Training/",
"save_txt_path": save_path,
}
names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "img/*.nii.gz"))
names.sort()
label_names = [p.replace("img", "label") for p in names]
organize_by_names(names, label_names, meta_info=meta_info)
def get_AbdomenCT_1K(save_path):
meta_info = {
"dataset_name": "AbdomenCT-1K",
"dataset_id": "08",
"modality": "CT",
"home_path": HOME_PATH,
"dirpath": "08_AbdomenCT-1K",
"save_txt_path": save_path,
}
# print(names)
organize_in_nnunet_style(meta_info=meta_info)
def get_AMOS(save_path):
meta_info = {
"dataset_name": "AMOS",
"dataset_id": "09",
"modality": "CT",
"home_path": HOME_PATH,
"dirpath": "09_AMOS",
"save_txt_path": save_path,
}
organize_in_style2(meta_info)
def get_MSD(save_path):
meta_info = {
"dataset_name": "MSD", # Decathlon
"dataset_id": "10",
"modality": "CT",
"home_path": HOME_PATH,
"parent_dirpath": "10_Decathlon",
"dirpath": "",
"save_txt_path": save_path,
}
subtasks = ["Task06_Lung", "Task08_HepaticVessel", "Task09_Spleen", "Task10_Colon"]
for task in subtasks:
# print("Processing ", task)
meta_info_subtask = {
"dataset_name":task,
"dataset_id": f"{meta_info['dataset_id']}_{task[4:6]}",
"home_path":HOME_PATH,
"dirpath": f"{meta_info['parent_dirpath']}/{task}",
"save_txt_path": save_path,
}
# print(meta_info_subtask)
organize_in_style2(meta_info=meta_info_subtask)
def get_MSD_MRI(save_path):
meta_info = {
"dataset_name": "MSD", # Decathlon
"dataset_id": "10",
"modality": "MRI",
"home_path": HOME_PATH,
"parent_dirpath": "10_Decathlon",
"dirpath": "",
"save_txt_path": save_path,
}
subtasks = ["Task02_Heart", "Task05_Prostate"]
for task in subtasks:
# print("Processing ", task)
meta_info_subtask = {
"dataset_name":task,
"dataset_id": f"{meta_info['dataset_id']}_{task[4:6]}",
"home_path":HOME_PATH,
"dirpath": f"{meta_info['parent_dirpath']}/{task}",
"save_txt_path": save_path,
}
# print(meta_info_subtask)
organize_in_style2(meta_info=meta_info_subtask)
def get_ASOCA(save_path):
meta_info = {
"dataset_name": "ASOCA",
"dataset_id": "51",
"modality": "CT",
"home_path": HOME_PATH,
"dirpath": "51_ASOCA",
"save_txt_path": save_path,
}
names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "image/*.nii.gz"))
names.sort()
label_names = [p.replace("/image", "/label") for p in names]
# print(os.path.join(meta_info['home_path'], meta_info['dirpath'], "image/*.nii.gz")
# print("debug ,", names)
organize_by_names(names, label_names, meta_info=meta_info)
def get_BCV_Cervix(save_path):
meta_info = {
"dataset_name": "BCV-Cervix",
"dataset_id": "52",
"modality": "CT",
"home_path": HOME_PATH,
"dirpath": "52_BCV-Cervix/Training/",
"save_txt_path": save_path,
}
names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "img/*.nii.gz"))
names.sort()
label_names = [p.replace("/img/", "/label/").replace("-Image", "-Mask") for p in names]
organize_by_names(names, label_names, meta_info=meta_info)
def get_NIHPancrease(save_path):
meta_info = {
"dataset_name": "NIHPancrease",
"dataset_id": "53",
"modality": "CT",
"home_path": HOME_PATH,
"dirpath": "53_NIHPancrease",
"save_txt_path": save_path,
}
names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "data/*.nii.gz") )
names.sort()
label_names = [p.replace("/data/PANCREAS_", "/label/label") for p in names]
organize_by_names(names, label_names, meta_info=meta_info)
def get_CTPelvic(save_path):
meta_info = {
"dataset_name": "CTPelvic1K",
"dataset_id": "54",
"modality": "CT",
"home_path": HOME_PATH,
"dirpath": "54_CTPelvic1K",
"save_txt_path": save_path,
}
names = []
names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset1_data/*.nii.gz"))
names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset2_data/*.nii.gz"))
names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset3_data/*.nii.gz"))
names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset4_data/*.nii.gz"))
names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset5_data/*.nii.gz"))
names += glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "CTPelvic1K_dataset7_data/*.nii.gz"))
names.sort()
# xx_data.nii.gz xx_mask_4label.nii.gz
label_names = [p.replace("_data/", "_mask/").replace("_data.nii.gz", "_mask_4label.nii.gz") for p in names]
organize_by_names(names, label_names, meta_info=meta_info)
def get_FLARE(save_path):
meta_info = {
"dataset_name": "FLARE",
"dataset_id": "55",
"modality": "CT",
"home_path": HOME_PATH,
"dirpath": "55_FLARE22Train",
"save_txt_path": save_path,
"class": ['liver', 'right kidney', 'spleen', 'pancrease', 'aorta','postcava','right adrenal gland','left darenal gland','gallbladder','esophagus','stomach','duodenum','left kidney'],
}
organize_in_nnunet_style(meta_info=meta_info)
# def get_HAN(save_path):
# meta_info = {
# "dataset_name": "Head-and-neck",
# "dataset_id": "56",
# "modality": "CT",
# "home_path": HOME_PATH,
# "dirpath": "56_Head-and-Neck-challenge",
# "save_txt_path": save_path,
# }
# names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "data/*.nii.gz"))
# names.sort()
# label_names = [p.replace("/data/", "/label/") for p in names]
# organize_by_names(names, label_names, meta_info=meta_info)
# def get_StructSeg(save_path):
# meta_info = {
# "dataset_name": "StructSeg2019",
# "dataset_id": "57",
# "modality": "CT",
# "home_path": HOME_PATH,
# "dirpath": "57_StructSeg",
# "save_txt_path": save_path,
# }
# names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "HaN_OAR/data/*"))
# names = [f"{name}/data.nii.gz" for name in names]
# names.sort()
# label_names = [p.replace("/data.nii.gz", "/label.nii.gz") for p in names]
# organize_by_names(names, label_names, meta_info=meta_info)
def get_CHAOS(save_path):
meta_info = {
"dataset_name": "CHAOS",
"dataset_id": "58",
"modality": "MRI",
"home_path": HOME_PATH,
"dirpath": "58_CHAOST2/chaos_MR_T2_normalized/",
"save_txt_path": save_path,
"class": ["liver", "right kidney", "left kidney", "spleen"],
}
names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "image*.nii.gz"))
names.sort()
label_names = [p.replace("/image_", "/label_") for p in names]
organize_by_names(names, label_names, meta_info=meta_info)
def get_SABS(save_path):
meta_info = {
"dataset_name": "SABS", # BTCV ?
"dataset_id": "59",
"modality": "CT",
"home_path": HOME_PATH,
"dirpath": "59_SABS/sabs_CT_normalized/",
"save_txt_path": save_path,
"class": ["spleen", "right kidney", "left kidney", "gallbladder", "esophagus", "liver", "stomach", "aorta", "postcava", "portal vein and splenic vein", "pancrease", "right adrenal gland", "left adrenal gland"],
}
names = glob.glob(os.path.join(meta_info['home_path'], meta_info['dirpath'], "image_*.nii.gz"))
names.sort()
label_names = [p.replace("/image_", "/label_") for p in names]
organize_by_names(names, label_names, meta_info=meta_info)
def get_Totalseg(save_path):
meta_info = {
"dataset_name": "Totalseg",
"dataset_id": "60",
"modality": "CT",
"home_path": HOME_PATH,
# "dirpath": "nnUNet_raw/Dataset101_Totalseg",
"dirpath": "60_Totalseg",
"save_txt_path": save_path,
"class": [],
}
organize_in_nnunet_style(meta_info=meta_info)
def get_WORD(save_path):
meta_info = {
"dataset_name": "WORDs", # BTCV ?
"dataset_id": "07",
"modality": "CT",
"home_path": HOME_PATH,
"dirpath": "07_WORD/WORD-V0.1.0/",
"save_txt_path": save_path,
}
organize_in_style2(meta_info=meta_info)
def generate_all():
save_path="./datasets/dataset_list/all_train.txt"
clear_files(save_path)
get_BCV_Abdomen(save_path)
get_AbdomenCT_1K(save_path)
get_AMOS(save_path)
get_MSD(save_path)
# get_ASOCA()
# get_BCV_Cervix()
# # get_NIHPancrease() # bug in data ?
# get_CTPelvic()
# get_FLARE()
# get_SABS()
def generate_their():
save_path="./datasets/dataset_list/their_train.txt"
clear_files(save_path)
save_path="./datasets/dataset_list/their_train.txt"
get_BCV_Abdomen(save_path)
get_AbdomenCT_1K(save_path)
get_AMOS(save_path)
get_MSD(save_path)
def generate_ours():
save_path="./datasets/dataset_list/ours_train.txt"
get_ASOCA(save_path)
get_BCV_Cervix(save_path)
# get_NIHPancrease() # bug in data ?
get_CTPelvic(save_path)
get_FLARE(save_path)
get_SABS(save_path)
def generate_alp_dataset():
save_path = "./datasets/dataset_list/alp_train.txt"
clear_files(save_path)
get_SABS(save_path)
get_CHAOS(save_path)
if __name__ == "__main__":
print(__file__)
# generate_alp_dataset()
save_path ="./datasets/dataset_list/totalseg_train.txt"
clear_files(save_path)
get_Totalseg(save_path)
# save_path="./datasets/dataset_list/word_train.txt"
print("Over")

View File

@ -0,0 +1,408 @@
# from utils.predict_automasks import predict
from segment_anything import sam_model_registry, SamPredictor
from core.automask import SamAutomaticMaskGenerator
from core.trainer import SamLearner
# from datasets.dataset2d import Dataset2D
from einops import rearrange, repeat
import torch
import numpy as np
from skimage.segmentation import slic
from skimage.util import img_as_float
from skimage import io
import os
from tutils import timer, tdir
from scipy.sparse import csr_matrix
# from datasets.dataset2d import Dataset2D
import cv2
import torch.nn.functional as F
import time
from tutils import tfilename
import matplotlib.pylab as plt
from .utils import load_compressed_data, show_anns, img_to_show
import cv2
def tfunctime(func):
def run(*argv, **kargs):
t1 = time.time()
ret = func(*argv, **kargs)
t2 = time.time()
print(f"[Function <{func.__name__}>] Running time:{(t2-t1):.6f}s")
return ret
return run
def get_predictors():
sam_checkpoint = "/quanquan/code/segment-anything/segment_anything/sam_vit_h_4b8939.pth" # for A100
# sam_checkpoint = "/home1/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_h_4b8939.pth" # for server 103
device = "cuda"
model_type = "default"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
predictor = SamLearner(sam_model=sam)
return mask_generator, predictor
def center_clip(img, eta=3):
count_values = img[torch.where(img>-199)]
mean = count_values.mean()
std = count_values.std()
img3 = torch.clip(img, mean-eta*std, mean+eta*std)
return img3
def find_not_exist_masks(masks_repo, masks_new, iou_threshold=0.2):
if len(masks_repo) == 0:
return masks_new
def compute_iou(m1, m2):
m1 = m1['segmentation']
m2 = m2['segmentation']
intersection = m1*m2
union = np.float32((m1 + m2) > 0)
return intersection.sum() / union.sum()
to_append = []
for mask_new in masks_new:
assert isinstance(mask_new, dict), f"{__file__} Got{type(mask_new)}"
intersec_count = 0
for mask_in_repo in masks_repo:
assert isinstance(mask_in_repo, dict), f"{__file__} Got{type(mask_in_repo)}"
iou = compute_iou(mask_in_repo, mask_new)
.3
if iou > iou_threshold:
intersec_count += 1
if intersec_count == 0:
to_append.append(mask_new)
check_keys(to_append)
return to_append
def merge_masks(masks_repo, masks_new, iou_threshold=0.2):
to_append = find_not_exist_masks(masks_repo, masks_new, iou_threshold)
# print(f"DEBUG: {len(masks_new) - len(to_append)} masks are deleted, remaining {len(to_append)}. The total achieves {len(masks_repo) + len(to_append)}")
return masks_repo + to_append
def dilate_erode(mask):
kernel = np.ones((4, 4), dtype=np.uint8)
mask = cv2.morphologyEx(mask.astype(float), cv2.MORPH_CLOSE, kernel, iterations=1)
return mask
def get_superpixel(image, hu_mask, hu_threshold=-50):
if isinstance(image, torch.Tensor):
image = image.detach().cpu().numpy()
if isinstance(hu_mask, torch.Tensor):
hu_mask = hu_mask.detach().cpu().numpy()
segments = slic(image, n_segments=100, compactness=9)
mask_collect = []
image2 = image[:,:,0]
for i in range(1, segments.max()+1):
mask = torch.Tensor(segments==i).detach().cpu().numpy()
# assert img
if image2[segments==i].mean() <= hu_threshold:
continue
mask = mask * hu_mask
mask = dilate_erode(mask)
mask_data = {
'segmentation':mask,
'area':mask.sum(),
'source': "superpixel",
'mean_value':(mask * image[:,:,0]).mean(),
}
mask_collect.append(mask_data)
check_keys(mask_collect)
return mask_collect
# def resize_masks(masks):
# for m in masks:
# m['segmentation'] = F.interpolate(torch.Tensor(m['segmentation'])[None,None,:,:], size=(512,512)).squeeze().numpy()
# return masks
# @tfunctime
def get_masks_via_color_changing(img, label, mask_generator, predictor=None):
img = repeat(img, " h w -> 1 3 h w")
img = torch.Tensor(img)
img = F.interpolate(img, size=(1024,1024))
label = torch.Tensor(label)
if label is not None:
label_masks = []
for i in range(1,label.max().int()+1):
labeli = torch.Tensor(label==i).float()
labeli = F.interpolate(labeli[None, None, :,:], size=(1024,1024)).squeeze().numpy()
area = labeli.sum()
if area <=10:
continue
mask = {
"segmentation":labeli,
"area":area,
"source": "gt",
}
label_masks.append(mask)
else:
label_masks = []
mask_generator.reset_image()
predictor.reset_image()
masks = mask_generator.generate(img.cuda())
masks = filter_large_masks(masks)
for mask in masks:
mask['source'] = "sam_auto_seg"
label_masks = merge_masks(label_masks, masks)
del masks
check_keys(label_masks)
# import ipdb; ipdb.set_trace()
# return img, label_masks
img2 = center_clip(img, 2.5)
masks = mask_generator.generate(img.cuda())
masks = filter_large_masks(masks)
label_masks = merge_masks(label_masks, masks)
del masks
del img2
# img2 = center_clip(img, 2)
# masks = mask_generator.generate(img2.cuda())
# masks = filter_large_masks(masks)
# for mask in masks:
# mask['source'] = "sam_auto_seg"
# label_masks = merge_masks(label_masks, masks)
# del masks
# del img2
img2 = center_clip(img, 1)
masks = mask_generator.generate(img2.cuda())
masks = filter_large_masks(masks)
for mask in masks:
mask['source'] = "sam_auto_seg"
label_masks = merge_masks(label_masks, masks)
del masks
del img2
img2 = center_clip(img, 0.5)
masks = mask_generator.generate(img2.cuda())
masks = filter_large_masks(masks)
for mask in masks:
mask['source'] = "sam_auto_seg"
label_masks = merge_masks(label_masks, masks)
del masks
del img2
check_keys(label_masks)
# import ipdb; ipdb.set_trace()
return img, label_masks
def check_keys(masks):
for m in masks:
assert m['segmentation'] is not None
assert m['area'] is not None
assert m['source'] is not None
def filter_large_masks(masks):
filtered_masks = []
for mask in masks:
if mask['area'] > 0.25 *1024 * 1024:
continue
filtered_masks.append(mask)
del masks
return filtered_masks
def mix_masks(masks):
if len(masks) == 0:
return None
mixed_mask = None
for item in masks:
m = item['segmentation']
mixed_mask = np.zeros_like(m) if mixed_mask is None else mixed_mask
mixed_mask += m
mixed_mask = np.float32(mixed_mask>0)
return mixed_mask
def select_random_point_from_mask(gt_mask):
size = gt_mask.shape
assert len(size) == 2
xy = np.arange(0, size[0] * size[1])
gt_mask = np.float32(gt_mask>0)
prob = rearrange(gt_mask, "h w -> (h w)")
prob = prob / prob.sum()
loc = np.random.choice(a=xy, size=1, replace=True, p=prob)[0]
x, y = loc % size[1], loc // size[1]
return x, y
def select_center_point_from_mask(gt_mask):
# get indices of all the foreground pixels
indices = np.argwhere(gt_mask > 0)
# calculate the center point by taking the mean of the foreground pixel indices
center = np.mean(indices, axis=0).astype(int)
y, x = center
return x, y
@tfunctime
def get_masks_via_points_from_superpixels(img, label_masks, superpixels, predictor, hu_mask=None):
if isinstance(hu_mask, torch.Tensor):
hu_mask = hu_mask.detach().cpu().numpy()
total_mask = mix_masks(label_masks)
# superpixels = get_superpixel(rearrange(img, "1 c h w -> h w c"))
points = []
ex_masks = []
for seg in superpixels:
mask_collect = []
for i in range(5):
mm = np.nan_to_num(seg['segmentation'], 0)
if mm.sum() <= 0:
continue
x,y = select_center_point_from_mask(mm) # select_random_point_from_mask
# print(x,y)
if total_mask[y,x] == 1:
continue
# else:
# print(total_mask[y,x])
point = torch.Tensor([x,y])
points.append(point)
mask = predictor.generate(img.cuda(), point.cuda().unsqueeze(0).unsqueeze(0))[0,0].detach().cpu().numpy()
mask = mask * hu_mask
mask = dilate_erode(mask)
mask = {
"segmentation": mask,
"area": mask.sum(),
"source": "prompt_point_from_superpixel",
}
mask_collect = merge_masks(mask_collect, [mask])
ex_masks += mask_collect
check_keys(ex_masks)
return ex_masks
def mask_to_bbox(mask):
""" copied from data_engine """
# Find the indices of all non-zero elements in the mask
coords = np.nonzero(mask)
# Compute the minimum and maximum values of the row and column indices
x_min = np.min(coords[1])
y_min = np.min(coords[0])
x_max = np.max(coords[1])
y_max = np.max(coords[0])
# Return the coordinates of the bounding box
return (x_min, y_min, x_max, y_max)
@tfunctime
def get_masks_via_boxes_from_superpixels(img, label_masks, superpixels, predictor, hu_mask=None):
if isinstance(hu_mask, torch.Tensor):
hu_mask = hu_mask.detach().cpu().numpy()
ex_masks = []
for seg in superpixels:
if seg['segmentation'].sum() < 100:
continue
x_min, y_min, x_max, y_max = mask_to_bbox(seg['segmentation'])
box = torch.Tensor([x_min, y_min, x_max, y_max])
mask = predictor.generate_by_box(img.cuda(), box.cuda().unsqueeze(0).unsqueeze(0))[0,0].detach().cpu().numpy()
mask = mask * hu_mask
mask = dilate_erode(mask)
mask = {
"segmentation": mask,
"area": mask.sum(),
"source": "prompt_box_from_superpixel",
}
ex_masks = merge_masks(ex_masks, [mask])
check_keys(ex_masks)
return ex_masks
class VariousPseudoMasksGenerator:
def __init__(self, dataset, label_path:str=None) -> None:
self.dataset = dataset
# assert dirpath.split("/")[-1] == "cache_2d", f"{__file__} Got{dirpath.split('/')[-1]} ; dirpath:{dirpath}"
self.label_path = label_path if label_path is not None else dataset.dirpath.replace("cache_2d", "cache_2d_various_pseudo_masks")
def example(self, mask_generator=None, predictor=None):
return self.generate(mask_generator=mask_generator, predictor=predictor, is_example=True)
def generate(self, mask_generator=None, predictor=None, is_example=False):
tt = timer()
if mask_generator is None:
mask_generator, predictor = get_predictors()
self.mask_generator, self.predictor = mask_generator, predictor
for img_idx in range(len(self.dataset)):
tt()
data = self.dataset._get_image(img_idx)
masks_all_layers = []
words = data['name'].split("/")
dataset_name = words[0]
filename = words[-1].replace(".nii.gz","")
volume_rgb = np.clip(data['img'], -200, 400)
volume_rgb = (volume_rgb - volume_rgb.min()) / volume_rgb.max()
for slice_idx in range(data['img'].shape[0]):
path = tfilename(self.label_path, f'{dataset_name}/{filename}_s{slice_idx}_mask.npz')
# if os.path.exists(path):
# continue
masks = self.get_various_masks(data['img'][slice_idx], data['label'][slice_idx], mask_generator, predictor)
self.save_slice_mask(masks, path, slice_idx)
img_path = tfilename(self.label_path, f'image_{dataset_name}/{filename}_s{slice_idx}.jpg')
self.save_img_rgb(volume_rgb[slice_idx], img_path)
# display
# plt.figure(figsize=(6,6))
# plt.imshow(img_to_show(data['img'][slice_idx]), cmap='gray')
# show_anns(masks)
# plt.axis('off')
# plt.show()
print(f"Save to {img_path} and {path}")
print(f"Processing {img_idx}, {len(masks_all_layers)} saved, time used:{tt()}", end='\r')
if is_example:
break
def save_slice_mask(self, masks, path, slice_idx):
masks_data = np.array([m['segmentation'] for m in masks]).astype(int)
# if len(masks_data) <= 1:
# import ipdb; ipdb.set_trace()
masks_data = F.interpolate(torch.Tensor(masks_data).unsqueeze(1), size=(512,512)).squeeze().numpy()
masks_data = np.int8(masks_data>0)
if len(masks_data.shape) == 2:
masks_data = masks_data[None,:,:]
assert masks_data.shape[1:] == (512,512), f"{__file__} Got{masks_data.shape}"
masks_data = rearrange(masks_data, "n h w -> n (h w)")
csr = csr_matrix(masks_data)
np.savez_compressed(path, data=csr.data, indices=csr.indices, indptr=csr.indptr, shape=csr.shape)
def save_img_rgb(self, img, path):
img = (img * 255).astype(np.uint8)
cv2.imwrite(path, img)
@staticmethod
@tfunctime
def get_various_masks(img_ori, label, mask_generator, predictor):
mask_generator.reset_image()
predictor.reset_image()
img, label_masks = get_masks_via_color_changing(img_ori, label, mask_generator, predictor)
# # return label_masks
# hu_mask = img_ori>img_ori.min() + 10
# hu_mask = F.interpolate(torch.Tensor(hu_mask)[None,None,:,:], size=(1024,1024)).squeeze()
# superpixels = get_superpixel(rearrange(img, "1 c h w -> h w c"), hu_mask=hu_mask)
# exclu_superpixels = find_not_exist_masks(label_masks, superpixels)
# # point_prompt_masks = get_masks_via_boxes_from_superpixels(img, label_masks, exclu_superpixels, predictor, hu_mask=hu_mask)
# box_prompt_masks = get_masks_via_points_from_superpixels(img, label_masks, exclu_superpixels, predictor, hu_mask=hu_mask)
return label_masks # + box_prompt_masks # + point_prompt_masks #
if __name__ == "__main__":
# CUDA_VISIBLE_DEVICES=6 python -m datasets.predict_various_masks
from datasets.dataset3d import Dataset3D
dataset = Dataset3D()
gen = VariousPseudoMasksGenerator(dataset=dataset,
label_path=tdir("/quanquan/datasets/all_datasets/various_masks_3/"))
gen.generate()

0
modeling/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

119
modeling/build_sam3d2.py Normal file
View File

@ -0,0 +1,119 @@
"""
Differences with build_sam3d.py
use mask_decoder3d_2 instead of mask_decoder3d
"""
import torch
from functools import partial
# from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
from .image_encoder import ImageEncoderViT
from .mask_decoder3d_2 import MaskDecoder
from .prompt_encoder3d import PromptEncoder
from .sam3d import Sam
from .transformer import TwoWayTransformer
def build_sam_vit_h(checkpoint=None):
return _build_sam(
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_global_attn_indexes=[7, 15, 23, 31],
checkpoint=checkpoint,
)
build_sam = build_sam_vit_h
def build_sam_vit_l(checkpoint=None):
return _build_sam(
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_global_attn_indexes=[5, 11, 17, 23],
checkpoint=checkpoint,
)
def build_sam_vit_b(checkpoint=None):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
sam_model_registry = {
"default": build_sam_vit_h,
"vit_h": build_sam_vit_h,
"vit_l": build_sam_vit_l,
"vit_b": build_sam_vit_b,
}
def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
sam = Sam(
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
return sam
if __name__ == "__main__":
model = build_sam_vit_l()
data = torch.ones((1,3,1024,1024))
out = model(data, multimask_output=True)
print(out.shape)

43
modeling/common.py Normal file
View File

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from typing import Type
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x

402
modeling/image_encoder.py Normal file
View File

@ -0,0 +1,402 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Type
from .common import LayerNorm2d, MLPBlock
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
return x
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
if __name__ == "__main__":
model = ImageEncoderViT()
data = torch.ones(1,3,1024,1024)
print(model)
model(data)

View File

@ -0,0 +1,204 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from typing import List, Tuple, Type
from .common import LayerNorm2d
class MaskDecoder(nn.Module):
def __init__(
self,
*,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
activation: Type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
num_slices: int = 3,
) -> None:
"""
Predicts masks given an image and prompt embeddings, using a
transformer architecture.
Arguments:
transformer_dim (int): the channel dimension of the transformer
transformer (nn.Module): the transformer used to predict masks
num_multimask_outputs (int): the number of masks to predict
when disambiguating masks
activation (nn.Module): the type of activation to use when
upscaling masks
iou_head_depth (int): the depth of the MLP used to predict
mask quality
iou_head_hidden_dim (int): the hidden dimension of the MLP
used to predict mask quality
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.num_multimask_outputs = num_multimask_outputs
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
self.output_upscaling2 = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
self.output_upscaling3 = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
self.output_hypernetworks_mlps = nn.ModuleList(
[
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
for i in range(self.num_mask_tokens)
]
)
self.iou_prediction_head = MLP(
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
)
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
"""
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# Select the correct mask or masks for output
if multimask_output:
# mask_slice = slice(1, None)
masks = masks[:, 3:, :, :]
iou_pred = iou_pred[:, 1:]
else:
# mask_slice = slice(0, 1)
masks = masks[:, :3, :, :]
iou_pred = iou_pred[:, 0:1]
# Prepare output
return masks, iou_pred
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) # (1,5,256)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # (1,7,256)
# Expand per-image data in batch direction to be per-mask
# TODO: Why this code???
# src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = image_embeddings
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
upscaled_embedding2 = self.output_upscaling2(src)
upscaled_embedding3 = self.output_upscaling3(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks1 = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
masks2 = (hyper_in @ upscaled_embedding2.view(b, c, h * w)).view(b, -1, h, w)
masks3 = (hyper_in @ upscaled_embedding3.view(b, c, h * w)).view(b, -1, h, w)
masks = torch.stack([masks1, masks2, masks3], axis=2)
assert masks.shape[1] == 4 and masks.shape[2] == 3, f"Got {masks.shape}"
masks = rearrange(masks, "b c1 c2 h w -> b (c1 c2) h w")
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
# masks = rearrange(masks, "b (c1 c2) h w -> b c1 c2 h w", c1=4, c2=3)
return masks, iou_pred
# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
sigmoid_output: bool = False,
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
self.sigmoid_output = sigmoid_output
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = F.sigmoid(x)
return x

View File

@ -0,0 +1,215 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from torch import nn
from typing import Any, Optional, Tuple, Type
from .common import LayerNorm2d
class PromptEncoder(nn.Module):
def __init__(
self,
embed_dim: int,
image_embedding_size: Tuple[int, int],
input_image_size: Tuple[int, int],
mask_in_chans: int,
activation: Type[nn.Module] = nn.GELU,
) -> None:
"""
Encodes prompts for input to SAM's mask decoder.
Arguments:
embed_dim (int): The prompts' embedding dimension
image_embedding_size (tuple(int, int)): The spatial size of the
image embedding, as (H, W).
input_image_size (int): The padded size of the image as input
to the image encoder, as (H, W).
mask_in_chans (int): The number of hidden channels used for
encoding input masks.
activation (nn.Module): The activation to use when encoding
input masks.
"""
super().__init__()
self.embed_dim = embed_dim
self.input_image_size = input_image_size
self.image_embedding_size = image_embedding_size
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
self.point_embeddings = nn.ModuleList(point_embeddings)
self.not_a_point_embed = nn.Embedding(1, embed_dim)
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
self.mask_downscaling = nn.Sequential(
# nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
nn.Conv2d(3, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
activation(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans),
activation(),
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
self.no_mask_embed = nn.Embedding(1, embed_dim)
def get_dense_pe(self) -> torch.Tensor:
"""
Returns the positional encoding used to encode point prompts,
applied to a dense set of points the shape of the image encoding.
Returns:
torch.Tensor: Positional encoding with shape
1x(embed_dim)x(embedding_h)x(embedding_w)
"""
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
def _embed_points(
self,
points: torch.Tensor,
labels: torch.Tensor,
pad: bool,
) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
"""Embeds mask inputs."""
mask_embedding = self.mask_downscaling(masks)
return mask_embedding
def _get_batch_size(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> int:
"""
Gets the batch size of the output given the batch size of the input prompts.
"""
if points is not None:
return points[0].shape[0]
elif boxes is not None:
return boxes.shape[0]
elif masks is not None:
return masks.shape[0]
else:
return 1
def _get_device(self) -> torch.device:
return self.point_embeddings[0].weight.device
def forward(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Arguments:
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
and labels to embed.
boxes (torch.Tensor or none): boxes to embed
masks (torch.Tensor or none): masks to embed
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
if points is not None:
coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
if boxes is not None:
box_embeddings = self._embed_boxes(boxes)
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
return sparse_embeddings, dense_embeddings
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.register_buffer(
"positional_encoding_gaussian_matrix",
scale * torch.randn((2, num_pos_feats)),
)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
h, w = size
device: Any = self.positional_encoding_gaussian_matrix.device
grid = torch.ones((h, w), device=device, dtype=torch.float32)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / h
x_embed = x_embed / w
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
return pe.permute(2, 0, 1) # C x H x W
def forward_with_coords(
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
) -> torch.Tensor:
"""Positionally encode points that are not normalized to [0,1]."""
coords = coords_input.clone()
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
return self._pe_encoding(coords.to(torch.float)) # B x N x C

174
modeling/sam3d.py Normal file
View File

@ -0,0 +1,174 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from torch.nn import functional as F
from typing import Any, Dict, List, Tuple
from .image_encoder import ImageEncoderViT
from .mask_decoder3d_2 import MaskDecoder
from .prompt_encoder3d import PromptEncoder
class Sam(nn.Module):
mask_threshold: float = 0.0
image_format: str = "RGB"
def __init__(
self,
image_encoder: ImageEncoderViT,
prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder,
pixel_mean: List[float] = [123.675, 116.28, 103.53],
pixel_std: List[float] = [58.395, 57.12, 57.375],
) -> None:
"""
SAM predicts object masks from an image and input prompts.
Arguments:
image_encoder (ImageEncoderViT): The backbone used to encode the
image into image embeddings that allow for efficient mask prediction.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
and encoded prompts.
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
pixel_std (list(float)): Std values for normalizing pixels in the input image.
"""
super().__init__()
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
@property
def device(self) -> Any:
return self.pixel_mean.device
@torch.no_grad()
def forward(
self,
batched_input: List[Dict[str, Any]],
multimask_output: bool,
) -> List[Dict[str, torch.Tensor]]:
"""
Predicts masks end-to-end from provided images and prompts.
If prompts are not known in advance, using SamPredictor is
recommended over calling the model directly.
Arguments:
batched_input (list(dict)): A list over input images, each a
dictionary with the following keys. A prompt key can be
excluded if it is not present.
'image': The image as a torch tensor in 3xHxW format,
already transformed for input to the model.
'original_size': (tuple(int, int)) The original size of
the image before transformation, as (H, W).
'point_coords': (torch.Tensor) Batched point prompts for
this image, with shape BxNx2. Already transformed to the
input frame of the model.
'point_labels': (torch.Tensor) Batched labels for point prompts,
with shape BxN.
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
Already transformed to the input frame of the model.
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
in the form Bx1xHxW.
multimask_output (bool): Whether the model should predict multiple
disambiguating masks, or return a single mask.
Returns:
(list(dict)): A list over input images, where each element is
as dictionary with the following keys.
'masks': (torch.Tensor) Batched binary mask predictions,
with shape BxCxHxW, where B is the number of input prompts,
C is determined by multimask_output, and (H, W) is the
original size of the image.
'iou_predictions': (torch.Tensor) The model's predictions
of mask quality, in shape BxC.
'low_res_logits': (torch.Tensor) Low resolution logits with
shape BxCxHxW, where H=W=256. Can be passed as mask input
to subsequent iterations of prediction.
"""
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
image_embeddings = self.image_encoder(input_images)
outputs = []
for image_record, curr_embedding in zip(batched_input, image_embeddings):
if "point_coords" in image_record:
points = (image_record["point_coords"], image_record["point_labels"])
else:
points = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=image_record.get("boxes", None),
masks=image_record.get("mask_inputs", None),
)
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
masks = self.postprocess_masks(
low_res_masks,
input_size=image_record["image"].shape[-2:],
original_size=image_record["original_size"],
)
masks = masks > self.mask_threshold
outputs.append(
{
"masks": masks,
"iou_predictions": iou_predictions,
"low_res_logits": low_res_masks,
}
)
return outputs
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
"""
Remove padding and upscale masks to the original image size.
Arguments:
masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
input_size (tuple(int, int)): The size of the image input to the
model, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original size of the image
before resizing for input to the model, in (H, W) format.
Returns:
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
is given by original_size.
"""
masks = F.interpolate(
masks,
(self.image_encoder.img_size, self.image_encoder.img_size),
mode="bilinear",
align_corners=False,
)
masks = masks[..., : input_size[0], : input_size[1]]
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
return masks
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
h, w = x.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x

240
modeling/transformer.py Normal file
View File

@ -0,0 +1,240 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import Tensor, nn
import math
from typing import Tuple, Type
from .common import MLPBlock
class TwoWayTransformer(nn.Module):
def __init__(
self,
depth: int,
embedding_dim: int,
num_heads: int,
mlp_dim: int,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
) -> None:
"""
A transformer decoder that attends to an input image using
queries whose positional embedding is supplied.
Args:
depth (int): number of layers in the transformer
embedding_dim (int): the channel dimension for the input embeddings
num_heads (int): the number of heads for multihead attention. Must
divide embedding_dim
mlp_dim (int): the channel dimension internal to the MLP block
activation (nn.Module): the activation to use in the MLP block
"""
super().__init__()
self.depth = depth
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.layers = nn.ModuleList()
for i in range(depth):
self.layers.append(
TwoWayAttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
activation=activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
)
)
self.final_attn_token_to_image = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
def forward(
self,
image_embedding: Tensor,
image_pe: Tensor,
point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape
B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must
have the same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
torch.Tensor: the processed point_embedding
torch.Tensor: the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
# Prepare queries
queries = point_embedding
keys = image_embedding
# Apply transformer blocks and final layernorm
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
# Apply the final attention layer from the points to the image
q = queries + point_embedding
k = keys + image_pe
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries)
return queries, keys
class TwoWayAttentionBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
num_heads: int,
mlp_dim: int = 2048,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""
A transformer block with four layers: (1) self-attention of sparse
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
block on sparse inputs, and (4) cross attention of dense inputs to sparse
inputs.
Arguments:
embedding_dim (int): the channel dimension of the embeddings
num_heads (int): the number of heads in the attention layers
mlp_dim (int): the hidden dimension of the mlp block
activation (nn.Module): the activation of the mlp block
skip_first_layer_pe (bool): skip the PE on the first layer
"""
super().__init__()
self.self_attn = Attention(embedding_dim, num_heads)
self.norm1 = nn.LayerNorm(embedding_dim)
self.cross_attn_token_to_image = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.norm2 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
) -> Tuple[Tensor, Tensor]:
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)
return queries, keys
class Attention(nn.Module):
"""
An attention layer that allows for downscaling the size of the embedding
after projection to queries, keys, and values.
"""
def __init__(
self,
embedding_dim: int,
num_heads: int,
downsample_rate: int = 1,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
def _recombine_heads(self, x: Tensor) -> Tensor:
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
# Input projections
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# Separate into heads
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)
# Attention
_, _, _, c_per_head = q.shape
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)
# Get output
out = attn @ v
out = self._recombine_heads(out)
out = self.out_proj(out)
return out

19
readme.md Normal file
View File

@ -0,0 +1,19 @@
<!-- # Slide-SAM -->
# Slide-SAM: Medical SAM meets sliding window
## Training
prepare datasets
```
python -m datasets.generate_txt
```
cache 3d data into slices
```
python -m datasets.cache_datasets3d
```
run training
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m core.ddp --tag debug
```

0
test/__init__.py Normal file
View File

173
test/volume_eval.py Normal file
View File

@ -0,0 +1,173 @@
"""
Volume evalutaion
"""
import torch
import numpy as np
from torch.utils.data import DataLoader
# from datasets.dataset3d import Dataset3D
from tutils.new.manager import ConfigManager
from datasets.eval_dataloader.loader_abstract import AbstractLoader
from core.volume_predictor import VolumePredictor
from datasets.data_engine import DataManager, BoxPromptGenerator, PointPromptGenerator
from tutils import tfilename
from tutils.new.trainer.recorder import Recorder
from trans_utils.metrics import compute_dice_np
from trans_utils.data_utils import Data3dSolver
class Evaluater:
def __init__(self, config) -> None:
self.config = config
self.recorder = Recorder()
def solve(self, model, dataset):
# model.eval()
self.predictor = model
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
for i, data in enumerate(dataloader):
# if i <4:
# print
# continue
# for k, v in data.items():
# if isinstance(v, torch.Tensor):
# data[k] = v.to(self.rank)
if self.config['dataset']['prompt'] == 'box':
res = self.eval_step(data, batch_idx=i)
if self.config['dataset']['prompt'] == 'point':
res = self.eval_step_point(data, batch_idx=i)
self.recorder.record(res)
res = self.recorder.cal_metrics()
print(res)
print("prompt:", self.config['dataset']['prompt'], " class_idx:", self.config['dataset']['label_idx'])
def eval_step(self, data, batch_idx=0):
name = data['name']
dataset_name = data['dataset_name'][0]
label_idx = data['label_idx'][0]
template_slice_id = data['template_slice_id'][0]
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
if template_slice_id == 0:
template_slice_id += 1
elif template_slice_id == (data['img'].shape[0] - 1):
template_slice_id -= 1
spacing = data['spacing'].numpy().tolist()[0]
if data['img'].shape[-1] < 260:
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
img = data['img'][0][:,:256,:256]
label = data['label'][0][:,:256,:256]
else:
img = data['img'][0]
label = data['label'][0]
# img = torch.clip(img, -200, 600)
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
box = np.array([box])
pred, stability = self.predictor.predict_volume(
x=img,
box=box,
template_slice_id=template_slice_id,
return_stability=True,
)
prompt_type = 'box'
dice = compute_dice_np(pred, label.detach().cpu().numpy())
Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
print(dataset_name, name, dice)
return {"dice": dice}
def eval_step_point(self, data, batch_idx=0):
name = data['name']
dataset_name = data['dataset_name'][0]
label_idx = data['label_idx'][0]
template_slice_id = data['template_slice_id'][0]
spacing = data['spacing'].numpy().tolist()[0]
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
if template_slice_id == 0:
template_slice_id += 1
elif template_slice_id == (data['img'].shape[0] - 1):
template_slice_id -= 1
if data['img'].shape[-1] < 260:
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
img = data['img'][0][:,:256,:256]
label = data['label'][0][:,:256,:256]
else:
img = data['img'][0]
label = data['label'][0]
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
point = (box[0]+box[2])*0.5 , (box[1]+box[3])*0.5
point = np.array([point]).astype(int)
if label[template_slice_id][point[0,1], point[0,0]] == 0:
print("Use random point instead !!!")
point = PointPromptGenerator().get_prompt_point(label[template_slice_id])
point = np.array([point]).astype(int)
# box = np.array([box])
pred = self.predictor.predict_volume(
x=img,
point_coords=point,
point_labels=np.ones_like(point)[:,:1],
template_slice_id=template_slice_id,
)
dice = compute_dice_np(pred, label.detach().cpu().numpy())
prompt_type = 'point'
Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}.nii.gz"))
print(dataset_name, name, dice)
return {"dice": dice}
def to_RGB(img):
pass
if __name__ == "__main__":
# from core.learner3 import SamLearner
# from modeling.build_sam3d import sam_model_registry
from core.learner3 import SamLearner
from modeling.build_sam3d2 import sam_model_registry
EX_CONFIG = {
'dataset':{
'prompt': 'box',
'dataset_list': ['word'], # ["sabs"], chaos, word
'label_idx': 1,
}
}
config = ConfigManager()
config.add_config("configs/vit_b_103.yaml")
config.add_config(EX_CONFIG)
# Init Model
model_type = "vit_b"
sam = sam_model_registry[model_type](checkpoint=None)
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
learner.use_lora()
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt_v/model_latest.pth"
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_20.pth"
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_16.pth"
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora_small/ckpt/model_epoch_6.pth"
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora/ckpt/model_epoch_50.pth"
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_8/ckpt_v/model_latest.pth"
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_5/ckpt/model_epoch_100.pth"
pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth"
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_500000.pth"
learner.load_well_trained_model(pth)
learner.cuda()
predictor = VolumePredictor(
model=learner.model,
use_postprocess=True,
use_noise_remove=True,)
solver = Evaluater(config)
dataset = AbstractLoader(config['dataset'], split="test")
solver.solve(predictor, dataset)

0
trans_utils/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

97
trans_utils/data_utils.py Normal file
View File

@ -0,0 +1,97 @@
import torch
import numpy as np
from tutils import tfilename
from tutils.nn.data import read, itk_to_np, np_to_itk, write
from torchvision.utils import save_image
import SimpleITK as sitk
from tutils.nn.data.tsitk.preprocess import resampleImage
class Data3dSolver:
def __init__(self) -> None:
pass
def simple_write(self, data_np, path="tmp.nii.gz", spacing=None):
assert len(data_np.shape) == 3, f"Got {data_np.shape}"
data_np = data_np.astype(np.int16)
data_itk = np_to_itk(data_np)
if spacing is not None:
data_itk.SetSpacing(spacing)
write(data_itk, path=tfilename(path))
print("Save to ", path)
def write_slices(self, data, path="tmp_masks.jpg"):
if isinstance(data, torch.Tensor):
pass
if isinstance(data, np.ndarray):
data = torch.Tensor(data)
assert len(data.shape) == 4, f"Shape should be (b c h w) c=1/3, Got {data.shape}"
assert data.shape[1] == 1 or data.shape[1] == 3, f"Shape should be (b c h w) c=1/3, Got {data.shape}"
assert path.endswith(".jpg") or path.endswith(".png")
save_image(torch.Tensor(data).unsqueeze(1), tfilename(path))
print("Save to ", path)
def write_multilabel_nii(self, data, path, meta=None):
if isinstance(data, dict):
data_all = [v for k,v in data.items()]
data = np.stack(data_all, axis=0)
assert len(data.shape) == 4, f"Shape should be (b c h w) , Got {data.shape}"
# Merge labels to one
merged = np.zeros_like(data[0])
for i, datai in enumerate(data):
merged = np.where(datai > 0, datai * (i+1), merged)
merged = merged.astype(np.int16)
data_itk = np_to_itk(merged)
if meta is not None:
data_itk = formalize(data_itk, meta)
write(data_itk, path=tfilename(path))
print("Save to ", path)
def fwrite(self, data, path, meta):
data = data.astype(np.int16)
data_itk = np_to_itk(data)
data_itk = formalize(data_itk, meta)
write(data_itk, path=tfilename(path))
def read(self, path, spacing_norm=True):
data_itk = read(path)
if spacing_norm:
ori_size = data_itk.GetSize()
ori_spacing = data_itk.GetSpacing()
data_itk = self.normalize_spacing(data_itk)
new_size = data_itk.GetSize()
new_spacing = data_itk.GetSpacing()
print("Change size from ", ori_size, new_size)
print("Change spacing from ", ori_spacing, new_spacing)
data_np = itk_to_np(data_itk)
print("[data_utils.DEBUG]", data_np.shape)
return data_np, data_itk.GetSpacing()
def normalize_spacing(self, data_itk):
spacing = data_itk.GetSpacing()
new_spacing = (min(spacing),min(spacing),min(spacing))
data_itk = resampleImage(data_itk, NewSpacing=new_spacing)
return data_itk
def formalize(img:sitk.SimpleITK.Image, meta:sitk.SimpleITK.Image):
# Size = meta.GetSize()
Spacing = meta.GetSpacing()
Origin = meta.GetOrigin()
Direction = meta.GetDirection()
img.SetSpacing(Spacing)
img.SetOrigin(Origin)
img.SetDirection(Direction)
return img
def write(img:sitk.SimpleITK.Image, path:str, mode:str="nifti"):
"""
Path: (example) os.path.join(jpg_dir, f"trans_{random_name}.nii.gz")
"""
mode = mode.lower()
writer = sitk.ImageFileWriter()
writer.SetFileName(path)
writer.Execute(img)

24
trans_utils/metrics.py Normal file
View File

@ -0,0 +1,24 @@
import numpy as np
def compute_dice_np(pred_mask, gt_mask):
""" numpy values
"""
assert gt_mask.max() == 1, f"Got gt_mask.max():{gt_mask.max()} Error!!"
pred_mask = np.array(pred_mask>0)
gt_mask = np.array(gt_mask>0)
intersection = np.array(pred_mask * gt_mask).sum()
union = pred_mask.sum() + gt_mask.sum()
dice = intersection * 2 / union # if union > 0 else 0
return dice
def compute_prec_np(pred_mask, gt_mask):
true_pos = (np.int32(pred_mask>0) * np.int32(gt_mask>0)).sum()
return true_pos / np.int32(pred_mask>0).sum()
def compute_recall_np(pred_mask, gt_mask):
true_pos = (np.int32(pred_mask>0) * np.int32(gt_mask>0)).sum()
false_neg = ((gt_mask - pred_mask)>0).sum()
return true_pos / (true_pos + false_neg)

418
trans_utils/trainer_ddp.py Normal file
View File

@ -0,0 +1,418 @@
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader, Dataset
import torch.multiprocessing as mp
import torch.distributed as dist
import tempfile
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tutils.new.trainer.trainer_abstract import AbstractTrainer
from tutils.new.manager.loggers import MultiLogger
from tutils.new.manager.csv_recorder import CSVLogger
from tutils.new.trainer.recorder import Recorder
from tutils.new.utils.core_utils import _get_time_str
from tutils.new.utils.public_utils import dict_to_str
# Waiting for update
from tutils import tfilename, tenum
# export MASTER_ADDR=192.168.1.100
# export MASTER_PORT=12345
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def ddp(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
# create model and move it to GPU with id rank
model = model.to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
def get_logger(config):
config_base = config['base']
config_logger = config['logger']
logger = MultiLogger(logdir=config_base['runs_dir'],
record_mode=config_logger.get('record_mode', None),
tag=config_base['tag'],
extag=config_base.get('experiment', None),
action=config_logger.get('action', 'k')) # backup config.yaml
return logger
class DDPTrainer(AbstractTrainer):
def __init__(self, config, tester=None, monitor=None, rank='cuda', world_size=0, logger=None):
super().__init__(config, tester, monitor, rank, world_size)
self.rank = rank
self.logger = logger
self.logging_available = (self.rank == 0 or self.rank == 'cuda')
print("Running on ", rank)
self.global_iteration = 0
if self.logging_available:
print(f"Logger at Process(rank={rank})")
self.recorder = Recorder(reduction=self.recorder_mode)
self.recorder_valid = Recorder(reduction=self.recorder_mode)
self.recorder_test = Recorder(reduction=self.recorder_mode)
self.logger = None
self.csvlogger = CSVLogger(tfilename(self.runs_dir, "best_record"))
self.csvlogger_all = CSVLogger(tfilename(self.runs_dir, "all_record"))
self.monitor = monitor
self.tester = tester
self.logger = get_logger(config)
assert self.logger is not None, f"{__file__} Gotrank {self.rank}"
if self.use_amp:
self.scalar = GradScaler()
print("Debug settings: use amp=",self.use_amp)
def init_model(self, model, trainset, validset=None, **kwargs):
# Use CacheDataset
# trainset = CacheDataset(trainset, num_workers=12, cache_rate=0.5)
if trainset is not None:
assert len(trainset) > 0 , f"{__file__} Got{len(trainset)}"
self.trainloader = DataLoader(dataset=trainset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
drop_last=True,
pin_memory=True)
if validset is not None:
self.validloader = DataLoader(dataset=validset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
drop_last=True,
pin_memory=True)
if self.load_pretrain_model:
model.module.load()
rank = self.rank
model = model.to(rank)
ddp_model = DDP(model, device_ids=[rank])
return ddp_model
def configure_optim(self, model, **kwargs):
# Set optimizer and scheduler
optim_configs = model.module.configure_optimizers()
assert isinstance(optim_configs, dict)
optimizer = optim_configs['optimizer']
scheduler = optim_configs['scheduler']
if self.load_optimizer:
start_epoch = model.module.load_optim(optimizer)
print(f"[DDPTrainer] Continue training, from epoch {start_epoch}")
else:
start_epoch = self.start_epoch
return optimizer, scheduler, start_epoch
def fit(self, model, trainset, validset=None):
model = self.init_model(model, trainset, validset=validset, rank=self.rank)
self.init_timers()
optimizer, scheduler, start_epoch = self.configure_optim(model)
for epoch in range(start_epoch, self.max_epochs):
self.on_before_zero_grad()
# Training
self.timer_epoch()
do_training_log = (epoch % self.training_log_interval == 0)
if self.validloader is not None and self.validation_interval > 0 and epoch % self.validation_interval == 0:
self.valid(model, self.validloader, epoch, do_training_log)
if self.trainloader is not None:
self.train(model, self.trainloader, epoch, optimizer, scheduler, do_training_log)
if self.logging_available:
self.test(model, epoch=epoch)
if epoch % self.save_interval == 0 and self.logging_available:
if 'latest' in self.save_mode:
self.save(model, epoch, 'latest', optimizer)
if 'all' in self.save_mode:
self.save(model, epoch, None, optimizer)
# time_save_model = self.timer_epoch()
print("Training is Over for GPU rank ", self.rank)
self.cleanup()
def test(self, model, epoch):
# Evaluation
if epoch % self.val_check_interval == 0 and self.logging_available:
print("Note: Tester runs on <rank 0> only")
if self.tester is not None:
out = self.tester.test(model=model, epoch=epoch, rank=self.rank)
if self.monitor is not None:
best_dict = self.monitor.record(out, epoch)
self.recorder_test.record({**best_dict, **out})
if best_dict['isbest']:
if 'best' in self.save_mode:
self.save(model, epoch, type='best')
self.csvlogger.record({**best_dict, **out, "time": _get_time_str()})
if self.save_all_records:
self.csvlogger_all.record({**best_dict, **out, "time": _get_time_str()})
self.logger.info(f"\n[*] {dict_to_str(best_dict)}[*] Epoch {epoch}: \n{dict_to_str(out)}")
self.logger.add_scalars(out, step=epoch, tag='test')
# if ''
else:
self.logger.info(f"\n[*] Epoch {epoch}: {dict_to_str(out)}")
self.on_after_testing(d=out)
def save(self, model, epoch, type=None, optimizer=None, **kwargs):
if self.logging_available:
if type is None:
# if self.save_interval > 0 and epoch % self.save_interval == 0:
save_name = "/ckpt/model_epoch_{}.pth".format(epoch)
model.module.save(tfilename(self.runs_dir, save_name), epoch=epoch)
self.logger.info(f"Epoch {epoch}: Save model to ``{save_name}``! ")
elif type == 'best':
# save_name = "/ckpt/best_model_epoch_{}.pth".format(epoch)
save_name2 = "/ckpt_v/model_best.pth"
# model.save(tfilename(self.runs_dir, save_name), epoch=epoch, is_best=True)
model.module.save(tfilename(self.runs_dir, save_name2), epoch=epoch, is_best=True)
self.logger.info(f"[Best model] Epoch {epoch}: Save model to ``{save_name2}``! ")
elif type == 'latest':
if self.save_interval > 0 and epoch % self.save_interval == 0:
save_name = "/ckpt_v/model_latest.pth"
model.module.save(tfilename(self.runs_dir, save_name), epoch=epoch, is_latest=True)
save_optim_name = "/ckpt/optim_latest.pth"
model.module.save_optim(tfilename(self.runs_dir, save_optim_name), optimizer=optimizer, epoch=epoch)
self.logger.info(f"Epoch {epoch}: Save checkpoint to ``{save_name}``")
elif type == "iteration":
save_name = "/ckpt/model_iter_{}.pth".format(self.global_iteration)
model.module.save(tfilename(self.runs_dir, save_name), epoch=self.global_iteration)
self.logger.info(f"Epoch {epoch}: Save model to ``{save_name}``! ")
def train(self, model, trainloader, epoch, optimizer, scheduler=None, do_training_log=True):
model.train()
out = {}
if do_training_log and self.logging_available:
self.recorder.clear()
time_record = 0.1111
self.timer_batch()
success_count = 0
failed_count = 0
for load_time, batch_idx, data in tenum(trainloader):
optimizer.zero_grad()
self.timer_data()
# training steps
for k, v in data.items():
if isinstance(v, torch.Tensor):
data[k] = v.to(self.rank)
time_data_cuda = self.timer_data()
if self.use_amp:
with autocast():
self.timer_net()
out = model.module.training_step(data, batch_idx, epoch=epoch)
# try:
# out = model.module.training_step(data, batch_idx, epoch=epoch)
# except Exception as e:
# msg = f"Ignore Error! {e}"
# if self.logging_available:
# self.logger.info(msg)
# else:
# print(msg)
# continue
assert isinstance(out, dict)
time_fd = self.timer_net()
loss = out['loss']
self.scalar.scale(loss).backward()
self.scalar.step(optimizer)
self.scalar.update()
time_bp = self.timer_net()
else:
self.timer_net()
try:
out = model.module.training_step(data, batch_idx, epoch=epoch)
except Exception as e:
msg = f"Ignore Error! {e}"
if self.logging_available:
self.logger.info(msg)
else:
print(msg)
continue
if out['loss'] is None:
failed_count += 1
continue
if torch.isnan(out['loss']):
print("Ignore Nan Value: ", out['loss'])
failed_count += 1
# raise ValueError(f"Get loss: {out['loss']}")
assert isinstance(out, dict)
time_fd = self.timer_net()
loss = out['loss']
loss.backward()
optimizer.step()
time_bp = self.timer_net()
success_count += 1
time_batch = self.timer_batch()
# batch logger !
if self.logging_available and do_training_log:
out['time_load'] = load_time
out['time_cuda'] = time_data_cuda
out['time_forward'] = time_fd
out['time_bp'] = time_bp
out['time_record'] = time_record
out['time_batch'] = time_batch
self.timer_data()
self.recorder.record(out)
time_record = self.timer_data()
# for debug !
if epoch == 0:
if self.logging_available:
self.logger.info("[*] Debug Checking Pipeline !!!")
del out
return
if self.global_iteration % 100 == 0:
print(f"Epoch: {epoch} | batch:{batch_idx}/{len(trainloader)}, Iteration:{self.global_iteration}, results: {to_item(out)}", end='\n')
if self.global_iteration % 5000 == 0:
self.save(model, epoch, "iteration", optimizer)
# print("")
self.global_iteration += 1
if scheduler is not None:
scheduler.step()
# epoch logger !
if self.logging_available:
if do_training_log :
_dict = self.recorder.cal_metrics()
_dict['time_total'] = self.timer_epoch()
# print(_dict)
# assert isinstance(lr, float), f"Got lr={lr}, type: {type(lr)}"
loss_str = ""
for k, v in _dict.items():
loss_str += "{}:{:.4f} ".format(k, v)
# lr = optimizer.param_groups[0]['lr']
lr = self.get_lr(optimizer)
_dict['lr'] = lr
loss_str += "{}:{:.6e} ".format('lr', lr)
self.logger.info(f"Epoch {epoch}: {loss_str}")
# _dict_with_train_tag = {f"train/{k}":v for k,v in _dict.items()}
self.logger.add_scalars(_dict, step=epoch, tag='train')
time_log_scalars = self.timer_epoch()
self.on_after_training(d=_dict)
# Clear
del out
del data
def valid(self, model, validloader, epoch, do_training_log=True):
model.eval()
out = {}
if do_training_log and self.logging_available:
self.recorder_valid.clear()
time_record = 0.1
self.timer_batch()
success_count = 1
failed_count = 1
for load_time, batch_idx, data in tenum(validloader):
# model.on_before_zero_grad()
self.timer_data()
# training steps
for k, v in data.items():
if isinstance(v, torch.Tensor):
data[k] = v.to(self.rank)
time_data_cuda = self.timer_data()
if self.use_amp:
with autocast():
self.timer_net()
out = model.module.validation_step(data, batch_idx, epoch=epoch)
assert isinstance(out, dict)
time_fd = self.timer_net()
else:
self.timer_net()
out = model.module.validation_step(data, batch_idx, epoch=epoch)
if out['loss'] is None:
failed_count += 1
continue
if torch.isnan(out['loss']):
self.logger.info("Nan Value: ", out['loss'])
failed_count += 1
raise ValueError(f"Get loss: {out['loss']}")
assert isinstance(out, dict)
time_fd = self.timer_net()
success_count += 1
time_batch = self.timer_batch()
# batch logger !
if do_training_log and self.logging_available:
out['time_load'] = load_time
out['time_cuda'] = time_data_cuda
out['time_forward'] = time_fd
out['time_record'] = time_record
out['time_batch'] = time_batch
self.timer_data()
self.recorder_valid.record(out)
time_record = self.timer_data()
if batch_idx % 2 == 0:
print(f"Valid Epoch: {epoch}. Processing batch_idx:{batch_idx} / {len(validloader)}, time_load: {load_time}, results: {to_item(out)}", end='\r')
if epoch == 0:
if self.logging_available:
self.logger.info("[*] Debug Checking validation Pipeline !!!")
del out
return
# model.on_after_zero_grad(d=out)
if self.logging_available:
self.logger.info(f"Training Success Ratio: {success_count / (success_count + failed_count)}")
# epoch logger !
if self.logging_available:
if do_training_log :
_dict = self.recorder_valid.cal_metrics()
_dict['time_total'] = self.timer_epoch()
# print(_dict)
# assert isinstance(lr, float), f"Got lr={lr}, type: {type(lr)}"
loss_str = ""
for k, v in _dict.items():
loss_str += "{}:{:.4f} ".format(k, v)
self.logger.info(f"Epoch {epoch}: {loss_str}")
# _dict_with_val_tag = {f"val/{k}":v for k,v in _dict.items()}
self.logger.add_scalars(_dict, step=epoch, tag='val')
time_log_scalars = self.timer_epoch()
self.on_after_training(d=_dict)
# Clear
del out
del data
def to_item(tensors):
for k,v in tensors.items():
if isinstance(v, torch.Tensor):
tensors[k] = v.detach().cpu().item()
return tensors

5
utils/__init__.py Normal file
View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

399
utils/amg.py Normal file
View File

@ -0,0 +1,399 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import math
from copy import deepcopy
from itertools import product
from typing import Any, Dict, Generator, ItemsView, List, Tuple
class MaskData:
"""
A structure for storing masks and their related data in batched format.
Implements basic filtering and concatenation.
"""
def __init__(self, **kwargs) -> None:
for v in kwargs.values():
assert isinstance(
v, (list, np.ndarray, torch.Tensor)
), "MaskData only supports list, numpy arrays, and torch tensors."
self._stats = dict(**kwargs)
def __setitem__(self, key: str, item: Any) -> None:
assert isinstance(
item, (list, np.ndarray, torch.Tensor)
), "MaskData only supports list, numpy arrays, and torch tensors."
self._stats[key] = item
def __delitem__(self, key: str) -> None:
del self._stats[key]
def __getitem__(self, key: str) -> Any:
return self._stats[key]
def items(self) -> ItemsView[str, Any]:
return self._stats.items()
def filter(self, keep: torch.Tensor) -> None:
for k, v in self._stats.items():
if v is None:
self._stats[k] = None
elif isinstance(v, torch.Tensor):
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
elif isinstance(v, np.ndarray):
self._stats[k] = v[keep.detach().cpu().numpy()]
elif isinstance(v, list) and keep.dtype == torch.bool:
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
elif isinstance(v, list):
self._stats[k] = [v[i] for i in keep]
else:
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
def cat(self, new_stats: "MaskData") -> None:
for k, v in new_stats.items():
if k not in self._stats or self._stats[k] is None:
self._stats[k] = deepcopy(v)
elif isinstance(v, torch.Tensor):
if k == "boxes" and len(v.shape) == 1:
v = v[None,:]
if k == "boxes" and len(self._stats[k].shape) == 1:
self._stats[k] = self._stats[k][None, :]
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
elif isinstance(v, np.ndarray):
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
elif isinstance(v, list):
self._stats[k] = self._stats[k] + deepcopy(v)
else:
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
def to_numpy(self) -> None:
for k, v in self._stats.items():
if isinstance(v, torch.Tensor):
self._stats[k] = v.detach().cpu().numpy()
def is_box_near_crop_edge(
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
) -> torch.Tensor:
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
return torch.any(near_crop_edge, dim=1)
def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
box_xywh = deepcopy(box_xyxy)
box_xywh[2] = box_xywh[2] - box_xywh[0]
box_xywh[3] = box_xywh[3] - box_xywh[1]
return box_xywh
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
assert len(args) > 0 and all(
len(a) == len(args[0]) for a in args
), "Batched iteration must have inputs of all the same size."
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
for b in range(n_batches):
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
"""
Encodes masks to an uncompressed RLE, in the format expected by
pycoco tools.
"""
# Put in fortran order and flatten h,w
b, h, w = tensor.shape
tensor = tensor.permute(0, 2, 1).flatten(1)
# Compute change indices
diff = tensor[:, 1:] ^ tensor[:, :-1]
change_indices = diff.nonzero()
# Encode run length
out = []
for i in range(b):
cur_idxs = change_indices[change_indices[:, 0] == i, 1]
cur_idxs = torch.cat(
[
torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
cur_idxs + 1,
torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
]
)
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
counts = [] if tensor[i, 0] == 0 else [0]
counts.extend(btw_idxs.detach().cpu().tolist())
out.append({"size": [h, w], "counts": counts})
return out
def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
"""Compute a binary mask from an uncompressed RLE."""
h, w = rle["size"]
mask = np.empty(h * w, dtype=bool)
idx = 0
parity = False
for count in rle["counts"]:
mask[idx : idx + count] = parity
idx += count
parity ^= True
mask = mask.reshape(w, h)
return mask.transpose() # Put in C order
def area_from_rle(rle: Dict[str, Any]) -> int:
return sum(rle["counts"][1::2])
def calculate_stability_score(
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
) -> torch.Tensor:
"""
Computes the stability score for a batch of masks. The stability
score is the IoU between the binary masks obtained by thresholding
the predicted mask logits at high and low values.
"""
# One mask is always contained inside the other.
# Save memory by preventing unnecessary cast to torch.int64
intersections = (
(masks > (mask_threshold + threshold_offset))
.sum(-1, dtype=torch.int16)
.sum(-1, dtype=torch.int32)
)
unions = (
(masks > (mask_threshold - threshold_offset))
.sum(-1, dtype=torch.int16)
.sum(-1, dtype=torch.int32)
)
return intersections / unions
def build_point_grid(n_per_side: int) -> np.ndarray:
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
offset = 1 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
return points
def build_all_layer_point_grids(
n_per_side: int, n_layers: int, scale_per_layer: int
) -> List[np.ndarray]:
"""Generates point grids for all crop layers."""
points_by_layer = []
for i in range(n_layers + 1):
n_points = int(n_per_side / (scale_per_layer**i))
points_by_layer.append(build_point_grid(n_points))
return points_by_layer
def generate_crop_boxes(
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
) -> Tuple[List[List[int]], List[int]]:
"""
Generates a list of crop boxes of different sizes. Each layer
has (2**i)**2 boxes for the ith layer.
"""
crop_boxes, layer_idxs = [], []
im_h, im_w = im_size
short_side = min(im_h, im_w)
# Original image
crop_boxes.append([0, 0, im_w, im_h])
layer_idxs.append(0)
def crop_len(orig_len, n_crops, overlap):
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
for i_layer in range(n_layers):
n_crops_per_side = 2 ** (i_layer + 1)
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
crop_w = crop_len(im_w, n_crops_per_side, overlap)
crop_h = crop_len(im_h, n_crops_per_side, overlap)
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
# Crops in XYWH format
for x0, y0 in product(crop_box_x0, crop_box_y0):
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
crop_boxes.append(box)
layer_idxs.append(i_layer + 1)
return crop_boxes, layer_idxs
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
# Check if boxes has a channel dimension
if len(boxes.shape) == 3:
offset = offset.unsqueeze(1)
return boxes + offset
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0]], device=points.device)
# Check if points has a channel dimension
if len(points.shape) == 3:
offset = offset.unsqueeze(1)
return points + offset
def uncrop_masks(
masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
) -> torch.Tensor:
x0, y0, x1, y1 = crop_box
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
return masks
# Coordinate transform masks
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
pad = (x0, pad_x - x0, y0, pad_y - y0)
return torch.nn.functional.pad(masks, pad, value=0)
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
from pycocotools import mask as mask_utils # type: ignore
h, w = uncompressed_rle["size"]
rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
return rle
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
"""
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
"""
# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
# Normalize shape to CxHxW
masks = masks.squeeze()
assert len(masks.shape) <= 3
shape = masks.shape
h, w = shape[-2:]
if len(shape) > 2:
masks = masks.flatten(0, -3)
else:
masks = masks.unsqueeze(0)
# Get top and bottom edges
in_height, _ = torch.max(masks, dim=-1)
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
in_height_coords = in_height_coords + h * (~in_height)
top_edges, _ = torch.min(in_height_coords, dim=-1)
# Get left and right edges
in_width, _ = torch.max(masks, dim=-2)
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
right_edges, _ = torch.max(in_width_coords, dim=-1)
in_width_coords = in_width_coords + w * (~in_width)
left_edges, _ = torch.min(in_width_coords, dim=-1)
# If the mask is empty the right edge will be to the left of the left edge.
# Replace these boxes with [0, 0, 0, 0]
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
out = out * (~empty_filter).unsqueeze(-1)
# Return to original shape
if len(shape) > 2:
out = out.reshape(*shape[:-2], 4)
else:
out = out[0]
return out
def batched_mask3d_to_box(masks: torch.Tensor) -> torch.Tensor:
"""
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
"""
# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
# Normalize shape to CxHxW
assert masks.shape[1] == 3
masks = masks[:,1,:,:]
shape = masks.shape
h, w = shape[-2:]
if len(shape) > 2:
masks = masks.flatten(0, -3)
else:
masks = masks.unsqueeze(0)
# Get top and bottom edges
in_height, _ = torch.max(masks, dim=-1)
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
in_height_coords = in_height_coords + h * (~in_height)
top_edges, _ = torch.min(in_height_coords, dim=-1)
# Get left and right edges
in_width, _ = torch.max(masks, dim=-2)
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
right_edges, _ = torch.max(in_width_coords, dim=-1)
in_width_coords = in_width_coords + w * (~in_width)
left_edges, _ = torch.min(in_width_coords, dim=-1)
# If the mask is empty the right edge will be to the left of the left edge.
# Replace these boxes with [0, 0, 0, 0]
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
out = out * (~empty_filter).unsqueeze(-1)
# Return to original shape
if len(shape) > 2:
out = out.reshape(*shape[:-2], 4)
else:
out = out[0]
return out

204
utils/amg3d.py Normal file
View File

@ -0,0 +1,204 @@
import numpy as np
import torch
import math
from copy import deepcopy
from itertools import product
from typing import Any, Dict, Generator, ItemsView, List, Tuple
class MaskData3d:
def __init__(self, size, **kwargs) -> None:
self.size = size
for v in kwargs.values():
assert isinstance(
v, (list, np.ndarray, torch.Tensor)
), "MaskData only supports list, numpy arrays, and torch tensors."
self._stats = dict(**kwargs)
def __setitem__(self, slice_id, num, item):
assert isinstance(
item, (list, np.ndarray, torch.Tensor)
), "MaskData only supports list, numpy arrays, and torch tensors."
key = str(slice_id) + str(num)
if self._stats.get(key, None) is None:
self._stats[key] = np.zeros(self.size)
self._stats[key][slice_id-1:slice_id+2] = item
def __delitem__(self, key: str) -> None:
del self._stats[key]
def __getitem__(self, key: str) -> Any:
return self._stats[key]
def items(self) -> ItemsView[str, Any]:
return self._stats.items()
def merge(self, slice_ids, num, item):
pass
def build_all_layer_point_grids(
n_per_side: int = 32, n_layers: int = 0, scale_per_layer: int = 1) -> List[np.ndarray]:
"""Generates point grids for all crop layers."""
points_by_layer = []
for i in range(n_layers + 1):
n_points = int(n_per_side / (scale_per_layer**i))
points_by_layer.append(build_point_grid(n_points))
return points_by_layer
def build_point_grid(n_per_side: int, size) -> np.ndarray:
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
offset = 1 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
return points * np.array(size)
# def calculate_stability_score(
# masks: torch.Tensor, mask_threshold: float, threshold_offset: float
# ) -> torch.Tensor:
# """
# Computes the stability score for a batch of masks. The stability
# score is the IoU between the binary masks obtained by thresholding
# the predicted mask logits at high and low values.
# """
# # One mask is always contained inside the other.
# # Save memory by preventing unnecessary cast to torch.int64
# intersections = (
# (masks > (mask_threshold + threshold_offset))
# .sum(-1, dtype=torch.int16)
# .sum(-1, dtype=torch.int32)
# )
# unions = (
# (masks > (mask_threshold - threshold_offset))
# .sum(-1, dtype=torch.int16)
# .sum(-1, dtype=torch.int32)
# )
# return intersections / unions
def calculate_stability_score_3d(
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
) -> torch.Tensor:
"""
Computes the stability score for a batch of masks. The stability
score is the IoU between the binary masks obtained by thresholding
the predicted mask logits at high and low values.
"""
# One mask is always contained inside the other.
# Save memory by preventing unnecessary cast to torch.int64
intersections = (
(masks > (mask_threshold + threshold_offset))
.sum(-1, dtype=torch.int16)
.sum(-1, dtype=torch.int32)
.sum(-1, dtype=torch.int32)
)
# intersections = intersections2d.sum(-1, dtype=torch.int32)
unions = (
(masks > (mask_threshold - threshold_offset))
.sum(-1, dtype=torch.int16)
.sum(-1, dtype=torch.int32)
.sum(-1, dtype=torch.int32)
)
return (intersections / unions)
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
assert len(args) > 0 and all(
len(a) == len(args[0]) for a in args
), "Batched iteration must have inputs of all the same size."
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
for b in range(n_batches):
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
class MaskData:
"""
A structure for storing masks and their related data in batched format.
Implements basic filtering and concatenation.
"""
def __init__(self, **kwargs) -> None:
for v in kwargs.values():
assert isinstance(
v, (list, np.ndarray, torch.Tensor)
), "MaskData only supports list, numpy arrays, and torch tensors."
self._stats = dict(**kwargs)
def __setitem__(self, key: str, item: Any) -> None:
assert isinstance(
item, (list, np.ndarray, torch.Tensor)
), "MaskData only supports list, numpy arrays, and torch tensors."
self._stats[key] = item
def __delitem__(self, key: str) -> None:
del self._stats[key]
def __getitem__(self, key: str) -> Any:
return self._stats[key]
def items(self) -> ItemsView[str, Any]:
return self._stats.items()
def filter(self, keep: torch.Tensor) -> None:
for k, v in self._stats.items():
if v is None:
self._stats[k] = None
elif isinstance(v, torch.Tensor):
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
elif isinstance(v, np.ndarray):
self._stats[k] = v[keep.detach().cpu().numpy()]
elif isinstance(v, list) and keep.dtype == torch.bool:
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
elif isinstance(v, list):
self._stats[k] = [v[i] for i in keep]
else:
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
def cat(self, new_stats: "MaskData") -> None:
for k, v in new_stats.items():
if k not in self._stats or self._stats[k] is None:
self._stats[k] = deepcopy(v)
elif isinstance(v, torch.Tensor):
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
elif isinstance(v, np.ndarray):
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
elif isinstance(v, list):
self._stats[k] = self._stats[k] + deepcopy(v)
else:
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
def to_numpy(self) -> None:
for k, v in self._stats.items():
if isinstance(v, torch.Tensor):
self._stats[k] = v.detach().cpu().numpy()

59
utils/masks3d_utils.py Normal file
View File

@ -0,0 +1,59 @@
from tutils.nn.data import write, np_to_itk
import numpy as np
class Masks3D:
def __init__(self) -> None:
pass
def from_dict(self, masks):
self.masks = masks
# self.tags = masks
def to_2dmasks(self):
pass
def filter_by_bg(self, volume, threshold=None):
threshold = (volume.max() - volume.min()) * 0.1 + volume.min()
keys = self.masks.keys()
for k in keys:
v = self.masks[k]
assert v.shape == volume.shape, f"Got shape ERROR, {v.shape, volume.shape}"
if (v * volume).mean() <= threshold:
self.masks.pop(k)
# def filter_by_area(self,):
def sort_by_logits(self):
self.confidences = []
self.tags_by_conf = []
for k, v in self.masks.items():
confidence = v[v>0].mean()
self.confidences.append(confidence)
self.tags_by_conf.append(k)
indices = np.argsort(self.confidences)[::-1]
self.tags_by_conf = np.array(self.tags_by_conf)[indices].tolist()
self.confidences = np.array(self.confidences)[indices]
def to_nii(self, path="tmp.nii.gz"):
self.sort_by_logits()
total = None
for i, k in enumerate(self.tags_by_conf):
mask = np.int32(self.masks[k]>0)
if total is None:
total = mask * i
else:
total = np.where(total>0, total, mask * i)
mask_itk = np_to_itk(total)
write(mask_itk, path)
if __name__ == "__main__":
p = "/home1/quanquan/code/projects/finetune_large/segment_anything/tmp.npy"
data = np.load(p, allow_pickle=True).tolist()
# import ipdb; ipdb.set_trace()
print(data.keys())
mm = Masks3D()
mm.from_dict(data)
mm.to_nii()

144
utils/onnx.py Normal file
View File

@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Tuple
from ..modeling import Sam
from .amg import calculate_stability_score
class SamOnnxModel(nn.Module):
"""
This model should not be called directly, but is used in ONNX export.
It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
with some functions modified to enable model tracing. Also supports extra
options controlling what information. See the ONNX export script for details.
"""
def __init__(
self,
model: Sam,
return_single_mask: bool,
use_stability_score: bool = False,
return_extra_metrics: bool = False,
) -> None:
super().__init__()
self.mask_decoder = model.mask_decoder
self.model = model
self.img_size = model.image_encoder.img_size
self.return_single_mask = return_single_mask
self.use_stability_score = use_stability_score
self.stability_score_offset = 1.0
self.return_extra_metrics = return_extra_metrics
@staticmethod
def resize_longest_image_size(
input_image_size: torch.Tensor, longest_side: int
) -> torch.Tensor:
input_image_size = input_image_size.to(torch.float32)
scale = longest_side / torch.max(input_image_size)
transformed_size = scale * input_image_size
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
return transformed_size
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
point_coords = point_coords + 0.5
point_coords = point_coords / self.img_size
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
point_embedding = point_embedding * (point_labels != -1)
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
point_labels == -1
)
for i in range(self.model.prompt_encoder.num_point_embeddings):
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
i
].weight * (point_labels == i)
return point_embedding
def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
mask_embedding = mask_embedding + (
1 - has_mask_input
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
return mask_embedding
def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
masks = F.interpolate(
masks,
size=(self.img_size, self.img_size),
mode="bilinear",
align_corners=False,
)
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
orig_im_size = orig_im_size.to(torch.int64)
h, w = orig_im_size[0], orig_im_size[1]
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
return masks
def select_masks(
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
) -> Tuple[torch.Tensor, torch.Tensor]:
# Determine if we should return the multiclick mask or not from the number of points.
# The reweighting is used to avoid control flow.
score_reweight = torch.tensor(
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
).to(iou_preds.device)
score = iou_preds + (num_points - 2.5) * score_reweight
best_idx = torch.argmax(score, dim=1)
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
return masks, iou_preds
@torch.no_grad()
def forward(
self,
image_embeddings: torch.Tensor,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
mask_input: torch.Tensor,
has_mask_input: torch.Tensor,
orig_im_size: torch.Tensor,
):
sparse_embedding = self._embed_points(point_coords, point_labels)
dense_embedding = self._embed_masks(mask_input, has_mask_input)
masks, scores = self.model.mask_decoder.predict_masks(
image_embeddings=image_embeddings,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embedding,
dense_prompt_embeddings=dense_embedding,
)
if self.use_stability_score:
scores = calculate_stability_score(
masks, self.model.mask_threshold, self.stability_score_offset
)
if self.return_single_mask:
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
if self.return_extra_metrics:
stability_scores = calculate_stability_score(
upscaled_masks, self.model.mask_threshold, self.stability_score_offset
)
areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
return upscaled_masks, scores, stability_scores, areas, masks
return upscaled_masks, scores, masks

102
utils/transforms.py Normal file
View File

@ -0,0 +1,102 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from torch.nn import functional as F
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
from copy import deepcopy
from typing import Tuple
class ResizeLongestSide:
"""
Resizes images to the longest side 'target_length', as well as provides
methods for resizing coordinates and boxes. Provides methods for
transforming both numpy array and batched torch tensors.
"""
def __init__(self, target_length: int) -> None:
self.target_length = target_length
def apply_image(self, image: np.ndarray) -> np.ndarray:
"""
Expects a numpy array with shape HxWxC in uint8 format.
"""
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
return np.array(resize(to_pil_image(image), target_size))
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
"""
Expects a numpy array of length 2 in the final dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], self.target_length
)
coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
"""
Expects a numpy array shape Bx4. Requires the original image size
in (H, W) format.
"""
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
"""
Expects batched images with shape BxCxHxW and float format. This
transformation may not exactly match apply_image. apply_image is
the transformation expected by the model.
"""
# Expects an image in BCHW format. May not exactly match apply_image.
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
return F.interpolate(
image, target_size, mode="bilinear", align_corners=False, antialias=True
)
def apply_coords_torch(
self, coords: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
"""
Expects a torch tensor with length 2 in the last dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], self.target_length
)
coords = deepcopy(coords).to(torch.float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def apply_boxes_torch(
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
"""
Expects a torch tensor with shape Bx4. Requires the original image
size in (H, W) format.
"""
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
"""
Compute the output size given input size and target long side length.
"""
scale = long_side_length * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)

157
utils/transforms3d.py Normal file
View File

@ -0,0 +1,157 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from torch.nn import functional as F
from torchvision.transforms.functional import resize, to_pil_image # type: ignore,
from torch.nn.functional import interpolate
from copy import deepcopy
from typing import Tuple
class SimpleResize:
"""
Keep the same with training,
maybe fixed in future.
"""
def __init__(self, target_length: int = 1024) -> None:
self.target_length = (target_length, target_length)
def apply_image(self, image: torch.tensor) -> torch.tensor:
target_size = self.target_length
return interpolate(image.unsqueeze(1), size=(target_size[0], target_size[1])).squeeze()
# return np.array(resize(to_pil_image(image), target_size))
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
old_h, old_w = original_size
new_h, new_w = self.target_length
coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
# Expects an image in BCHW format. May not exactly match apply_image.
target_size = self.target_length
return F.interpolate(
image, target_size, mode="bilinear", align_corners=False, antialias=True
)
def apply_coords_torch(
self, coords: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
old_h, old_w = original_size
new_h, new_w = self.target_length
coords = deepcopy(coords).to(torch.float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def apply_boxes_torch(
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
class ResizeLongestSide:
"""
Resizes images to the longest side 'target_length', as well as provides
methods for resizing coordinates and boxes. Provides methods for
transforming both numpy array and batched torch tensors.
"""
def __init__(self, target_length: int = 1024) -> None:
self.target_length = target_length
def apply_image(self, image: torch.tensor) -> torch.tensor:
"""
Expects a numpy array with shape HxWxC in uint8 format.
"""
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
return interpolate(image, size=(target_size[0], target_size[1], image.shape[-1]))
# return np.array(resize(to_pil_image(image), target_size))
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
"""
Expects a numpy array of length 2 in the final dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], self.target_length
)
coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
"""
Expects a numpy array shape Bx4. Requires the original image size
in (H, W) format.
"""
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
"""
Expects batched images with shape BxCxHxW and float format. This
transformation may not exactly match apply_image. apply_image is
the transformation expected by the model.
"""
# Expects an image in BCHW format. May not exactly match apply_image.
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
return F.interpolate(
image, target_size, mode="bilinear", align_corners=False, antialias=True
)
def apply_coords_torch(
self, coords: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
"""
Expects a torch tensor with length 2 in the last dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], self.target_length
)
coords = deepcopy(coords).to(torch.float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def apply_boxes_torch(
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
"""
Expects a torch tensor with shape Bx4. Requires the original image
size in (H, W) format.
"""
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
"""
Compute the output size given input size and target long side length.
"""
scale = long_side_length * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)