Compare commits

...

10 Commits

Author SHA1 Message Date
transcendentsiki
72a73cb4b9 add mbox 2024-04-02 16:57:10 +08:00
Curli Trans
96a2bd15a6 fixed eval_dataloader 2024-04-02 15:48:48 +08:00
transcendentsky
841be2acbe . 2024-03-20 22:20:41 +08:00
transcendentsky
ce31f13be0 . 2024-03-20 22:07:35 +08:00
transcendentsky
88745c6105 . 2024-03-20 21:43:27 +08:00
transcendentsky
635f319708 . 2024-03-20 21:43:17 +08:00
transcendentsky
173762f756 add finetune 2024-03-20 15:58:48 +08:00
transcendentsky
840009725f add fintuning 2024-03-20 15:39:07 +08:00
Fenghe Tang
738d4258c3
Create LICENSE 2024-03-17 15:36:23 +08:00
Curli-quan
6332a3ea31
Update readme.md 2024-03-17 15:07:01 +08:00
51 changed files with 1587 additions and 59 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
*.pth
*.pyc
*.nii.gz

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -25,7 +25,7 @@ training:
load_pretrain_model: false
# optim:
lr: 0.0002
lr: 0.000005
decay_step: 2000
decay_gamma: 0.8
weight_decay: 0.0001
@ -35,11 +35,13 @@ training:
dataset:
types: ['3d'] # ['3d', '2d']
split: 'train'
data_root_path: '/quanquan/datasets/'
data_root_path: '/home1/quanquan/datasets/'
dataset_list: ["alp", "word", "debug"] # ['sam', "their", "ours"]
data_txt_path: './datasets/dataset_list/'
dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/"
cache_data_path: '/home1/quanquan/datasets/cached_dataset2/'
cache_prefix: ['6016'] # '07'
specific_label: [2]
# sam_checkpoint: "/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server
# model_type: "vit_b"

47
configs/vit_sub.yaml Normal file
View File

@ -0,0 +1,47 @@
# ---------------------- Common Configs --------------------------
base:
base_dir: "../runs/sam/"
tag: ''
stage: ''
logger:
mode: ['tb', ]
# mode: ''
recorder_reduction: 'mean'
training:
save_mode: ['all', 'best', 'latest'] # ,
batch_size : 2 # 8 for A100
num_workers : 8
num_epochs : 100 # epochs
use_amp: false
save_interval : 1
val_check_interval: 6
load_pretrain_model: false
# optim:
lr: 0.00002
decay_step: 2000
decay_gamma: 0.8
weight_decay: 0.0001
alpha: 0.99
validation_interval: 100
continue_training: false
load_optimizer: false
breakpoint_path: "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth"
dataset:
types: ['3d'] # ['3d', '2d']
split: 'train'
data_root_path: '/home1/quanquan/datasets/'
dataset_list: ["example"] # for example_train.txt
data_txt_path: './datasets/dataset_list/'
dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/"
cache_data_path: '/home1/quanquan/datasets/cached_dataset2/'
cache_prefix: ['6016'] # '07'
specific_label: [2]
test:
batch_size: 1

Binary file not shown.

Binary file not shown.

122
core/ddp_sub.py Normal file
View File

@ -0,0 +1,122 @@
"""
from ddp_b9.py
Add additional bypass/side-way to finetune on other datasets
"""
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from tutils import tfilename, tdir
from datasets.dataset3d_2dmask import Dataset2D
# from datasets.dataset3d import Dataset3D
from datasets.cache_dataset3d3 import Dataset3D
from datasets.dataset_merged import DatasetMerged, TestsetMerged
from datasets.data_engine import DataEngine
from modeling.build_sam3d2 import sam_model_registry
from .learner_sub1 import SamLearner
# from tutils.new.trainer.trainer_ddp import DDPTrainer
from trans_utils.trainer_ddp import DDPTrainer
# from .lora_sam import LoRA_Sam
import warnings
warnings.filterwarnings("ignore")
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def ddp_train(rank, world_size, config):
setup(rank, world_size)
# sam_checkpoint = "/quanquan/code/segment-anything/segment_anything/sam_vit_b_01ec64.pth" # A800 server
# sam_checkpoint = "/home1/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server
model_type = "vit_b"
device = rank
config_data = config['dataset']
data_type = config_data.get("types", ["3d", "2d"])
data_type = [data_type] if isinstance(data_type, str) else data_type
dataset = Dataset3D(config_data, split='train')
# assert len(validset) > 0
data_engine = DataEngine(dataset=dataset, img_size=(1024,1024))
sam = sam_model_registry[model_type](checkpoint=None)
learner = SamLearner(sam_model=sam, config=config, data_engine=data_engine)
learner.use_lora()
learner.load_well_trained_model(config['training']['breakpoint_path']) # use preset path
learner.use_lora_sub()
ddp_trainer = DDPTrainer(config=config, rank=rank, world_size=world_size)
ddp_trainer.fit(learner, trainset=data_engine, validset=None)
cleanup()
def get_parameter_number(model):
total_num = sum(p.numel() for p in model.parameters())
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
def run_demo(demo_fn, world_size, config):
mp.spawn(demo_fn,
args=(world_size,config),
nprocs=world_size,
join=True)
from collections import OrderedDict
import yaml
import yamlloader
def _ordereddict_to_dict(d):
if not isinstance(d, dict):
return d
for k, v in d.items():
if isinstance(v, OrderedDict):
v = _ordereddict_to_dict(v)
d[k] = dict(v)
elif type(v) == list:
d[k] = _ordereddict_to_dict(v)
elif isinstance(v, dict):
d[k] = _ordereddict_to_dict(v)
return d
# CUDA_VISIBLE_DEVICES=4,5,6,7 python -m core.ddp_b3 --tag lora --config configs/vit_b_103.yaml
if __name__ == "__main__":
import argparse
from tutils.new.manager import trans_args, trans_init, ConfigManager
n_gpus = torch.cuda.device_count()
# assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but {__file__} Got{n_gpus}"
if n_gpus == 1:
print("Warning! Running on only 1 GPU! just for debug")
world_size = n_gpus
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="./configs/vit_sub.yaml")
parser.add_argument("--func", default="train")
parser.add_argument("--reuse", action="store_true")
args = trans_args(parser=parser)
config = ConfigManager()
config.auto_init(file=__file__, args=args, ex_config=None)
# config.save()
path = tfilename(config['base']['runs_dir'], "config.yaml")
with open(path, "w") as f:
yaml.dump(_ordereddict_to_dict(config), f)
print("Save config file to ", path)
if n_gpus < 1: exit(0)
run_demo(ddp_train, world_size, config)

View File

@ -107,8 +107,8 @@ class SamLearner(LearnerModule):
self.model.load_state_dict(state_dict)
# self.lora_module.load_lora_parameters(pth.replace(".pth", "_lora.safetensors"))
def use_lora(self):
lora_r = 8
def use_lora(self, r=8):
lora_r = r
lora_sam = LoRA_Sam(self.model, lora_r, freeze_prompt_encoder=True)
self.lora_module = lora_sam

73
core/learner_sub1.py Normal file
View File

@ -0,0 +1,73 @@
"""
Use mask_decoder3d_2.py
"""
import torch
import torchvision
import numpy as np
from tutils.trainer import Trainer, LearnerModule
from einops import rearrange, repeat, reduce
import torch.optim.lr_scheduler as lr_scheduler
from core.loss import ranked_combined_loss_with_indicators
from .learner5 import SamLearner as basic_learner
from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss
# from torchao.quantization import apply_dynamic_quant
# from torch._inductor import config as inductorconfig
from .lora_sam import LoRA_Sam
class SamLearner(basic_learner):
def load_pretrained_model(self, pth, *args, **kwargs):
"""
Unmatched: prompt_encoder.mask_downscaling.0.weight
their: torch.Size([4, 1, 2, 2])
our: torch.Size([4, 3, 2, 2])
Unmatched: mask_decoder.mask_tokens.weight
their: torch.Size([4, 256])
our: torch.Size([12, 256])
"""
print("Load pretrained model for mask_decoder3d_2 !!")
state_dict = torch.load(pth)
model_state_dict = self.model.state_dict()
model_state_dict.update(state_dict)
model_state_dict['prompt_encoder.mask_downscaling.0.weight'] = repeat(state_dict['prompt_encoder.mask_downscaling.0.weight'], "a 1 c d -> a b c d", b=3)
# model_state_dict['mask_decoder.mask_tokens.weight'] = repeat(state_dict['mask_decoder.mask_tokens.weight'], "a d -> (a 3) d")
for k, v in model_state_dict.items():
if k.startswith("mask_decoder.output_upscaling2"):
k2 = k.replace("output_upscaling2.", "output_upscaling." )
model_state_dict[k] = model_state_dict[k2]
print("Load weights: ", k)
if k.startswith("mask_decoder.output_upscaling3"):
k2 = k.replace("output_upscaling3.", "output_upscaling." )
model_state_dict[k] = model_state_dict[k2]
print("Load weights: ", k)
hyper_params_names = [k for k in model_state_dict.keys() if k.startswith("mask_decoder.output_hypernetworks_mlps")]
for name in hyper_params_names:
words = name.split('.')
words[2] = str(int(words[2]) // 3)
name_to_copy = ".".join(words)
model_state_dict[name] = state_dict[name_to_copy]
# for k, v in state_dict.items():
# if model_state_dict[k].shape != state_dict[k].shape:
# print("Unmatched:", k)
self.model.load_state_dict(model_state_dict)
def quantize(self):
# self.model.image_encoder = torch.ao.quantization.quantize_dynamic(
# self.model.image_encoder,
# dtype=torch.qint8
# )
apply_dynamic_quant(self.model.image_encoder)
inductorconfig.force_fuse_int_mm_with_mul = True
print("Quantized !")
def use_lora_sub(self):
lora_r = 1
lora_sam = LoRA_Sam(self.model, lora_r, freeze_all=True)
self.lora_module = lora_sam

View File

@ -33,6 +33,7 @@ class _LoRA_qkv(nn.Module):
self.linear_a_v = linear_a_v
self.linear_b_v = linear_b_v
self.dim = qkv.in_features
self.in_features = qkv.in_features
self.w_identity = torch.eye(qkv.in_features)
def forward(self, x):

View File

@ -191,7 +191,7 @@ class VolumePredictor:
raise NotImplementedError
# Preprocess prompts
self.original_size = x.shape[1:]
# self.original_size = x.shape[1:]
if point_coords is not None:
assert (
point_labels is not None
@ -217,6 +217,7 @@ class VolumePredictor:
center_masks = self._predict_center_slice(center_idx, point_coords, box)
return center_masks['masks']
@torch.no_grad()
def predict_volume(
self,
x,
@ -255,8 +256,37 @@ class VolumePredictor:
else:
raise NotImplementedError
# set 3d image
self.set_image(x)
# Preprocess prompts
self.original_size = x.shape[1:]
if self.masks3d is None:
self.masks3d = np.zeros_like(x)
self.slice_count = x.shape[0]
return self.predict_with_prompt(
point_coords = point_coords,
point_labels = point_labels,
box = box,
mask_input = mask_input,
multimask_output = multimask_output,
return_logits = return_logits,
template_slice_id = template_slice_id,
return_stability = return_stability
)
@torch.no_grad()
def predict_with_prompt(
self,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
box: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
template_slice_id:int = None,
return_stability: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
if point_coords is not None:
assert (
point_labels is not None
@ -273,38 +303,39 @@ class VolumePredictor:
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
mask_input_torch = mask_input_torch[None, :, :, :]
# set 3d image
self.set_image(x)
self.all_prompts = {}
# predict center slice
center_idx = template_slice_id if template_slice_id is not None else x.shape[0] // 2
center_idx = template_slice_id if template_slice_id is not None else self.slice_count // 2
# print("Processing ", center_idx)
center_masks = self._predict_center_slice(center_idx, point_coords, box)
if center_masks._stats == {}:
print("Ends for no mask.")
raise ValueError
self.merge_to_mask3d(center_idx, center_masks)
center_idx = center_idx.item() if not isinstance(center_idx, int) else center_idx
self.all_prompts[center_idx] = box if box is not None else point_coords
previous_masks = center_masks
for i in range(center_idx+1, x.shape[0]-1):
for i in range(center_idx+1, self.slice_count-1):
# print("Processing downward", i)
previous_masks = self._predict_slice(i, previous_masks, orientation="down")
previous_masks, scaled_boxes = self._predict_slice(i, previous_masks, orientation="down")
if previous_masks._stats == {}:
print("Ends for no mask.")
break
self.merge_to_mask3d(i, previous_masks)
self.all_prompts[i] = scaled_boxes
previous_masks = center_masks
for i in np.arange(1, center_idx)[::-1]:
# print("Processing upward", i)
previous_masks = self._predict_slice(i, previous_masks, orientation="up")
previous_masks, scaled_boxes = self._predict_slice(i, previous_masks, orientation="up")
if previous_masks._stats == {}:
print("Ends for no mask.")
break
self.merge_to_mask3d(i, previous_masks)
self.all_prompts[i] = scaled_boxes
if self.masks3d is None:
self.masks3d = np.zeros_like(x)
if return_stability:
return self.postprocess_3d(self.masks3d), self.stability_score_2d
return self.postprocess_3d(self.masks3d)
@ -324,7 +355,7 @@ class VolumePredictor:
scaled_boxes, tags = self.generate_prompts_from_previous_masks(previous_masks, orientation)
masks = self.genetate_masks_from_boxes(idx, all_boxes=scaled_boxes, tags=tags)
masks.to_numpy()
return masks
return masks, scaled_boxes
def generate_prompts_from_previous_masks(self, previous_masks: MaskData, orientation):
if orientation == "down":
@ -486,10 +517,6 @@ class VolumePredictor:
data["boxes"] = batched_mask_to_box(data["masks"][:,1,:,:]>0)
return data
# @staticmethod
# def calculate_
@staticmethod
def batched_remove_noise(masks):
ori_shape = masks.shape
@ -539,26 +566,80 @@ class VolumePredictor:
)
# Upscale the masks to the original image resolution
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size[1:])
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
if not return_logits:
masks = masks > self.model.mask_threshold
return masks, iou_predictions, low_res_masks
def valid_box(self, data, batch_idx):
# valid image with box, or point prompt
assert data['img'].shape[0] == 1, f"shape {data['img'].shape}"
image = data['img']
label = data['label']
# def valid_box(self, data, batch_idx):
# # valid image with box, or point prompt
# assert data['img'].shape[0] == 1, f"shape {data['img'].shape}"
# image = data['img']
# label = data['label']
box = BoxPromptGenerator().mask_to_bbox(label)
box_mask3d = self.predict_volume(
x=image,
box=box,
)
dice = compute_dice_np(box_mask3d, label.detach().cpu().numpy())
# box = BoxPromptGenerator().mask_to_bbox(label)
# box_mask3d = self.predict_volume(
# x=image,
# box=box,
# )
# dice = compute_dice_np(box_mask3d, label.detach().cpu().numpy())
def get_confidence(self):
masks = self.postprocess_3d(self.masks3d)
conf_collect = []
for i in range(1,self.masks3d.shape[0]-1):
prompt_box = self.all_prompts.get(i, None)
if prompt_box is not None:
mask = masks[i,:,:]
if mask.sum() > 0:
bbox = BoxPromptGenerator(size=None).mask_to_bbox(mask)
bbox = self.transform.apply_boxes(np.array([bbox]), self.original_size)[0]
else:
bbox = [0,0,0,0]
prompt_box = self.all_prompts[i][0]
confidence = calculate_iou(bbox, prompt_box)
else:
confidence = 0
if i == 1:
conf_collect.append(confidence)
conf_collect.append(confidence)
assert len(conf_collect) == i+1
conf_collect.append(confidence)
print(conf_collect)
return conf_collect
def calculate_iou(box1, box2):
"""
计算两个框的IoU Intersection over Union
参数
box1 box2 是两个框每个框表示为四个值 (x1, y1, x2, y2)其中 (x1, y1) 是左上角的坐标
(x2, y2) 是右下角的坐标
返回
返回两个框的IoU
"""
# 计算交集的左上角和右下角坐标
x1_i = max(box1[0], box2[0])
y1_i = max(box1[1], box2[1])
x2_i = min(box1[2], box2[2])
y2_i = min(box1[3], box2[3])
# 计算交集的面积
intersection_area = max(0, x2_i - x1_i) * max(0, y2_i - y1_i)
# 计算并集的面积
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
union_area = box1_area + box2_area - intersection_area
# print(intersection_area, union_area)
# 计算IoU
iou = intersection_area / union_area
return iou
if __name__ == "__main__":
@ -588,7 +669,7 @@ if __name__ == "__main__":
volume = itk_to_np(read(img_path)) # test several slices
label_itk = read(label_path)
spacing = label_itk.GetSpacing()
label = itk_to_np(label_itk) == 1
label = itk_to_np(label_itk) == 13
volume = np.clip(volume, -200, 400)
# Select the slice with the largest mask
@ -598,8 +679,14 @@ if __name__ == "__main__":
x_max = np.max(coords[0])
template_slice_id = s.argmax()
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id])
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id]) # (115, 207, 309, 339)
# import ipdb; ipdb.set_trace()
# box = (125, 210, 300, 310)
box = np.array([box])
box[0][0] += 10
box[0][1] += 10
box[0][2] -= 10
box[0][3] -= 10
pred = predictor.predict_volume(
x=volume,
@ -610,3 +697,8 @@ if __name__ == "__main__":
Data3dSolver().simple_write(pred, path="mask.nii.gz", spacing=spacing)
Data3dSolver().simple_write(label, path="gt.nii.gz", spacing=spacing)
dice = compute_dice_np(pred, label)
print("Dice ", dice, " box: ", box, "slice id", template_slice_id)
print(tuple(box))
# import ipdb; ipdb.set_trace()

View File

@ -43,12 +43,14 @@ TEMPLATE={
'10_10': [31],
'58': [6,2,3,1],
'59': [1,2,3,4,5,6,7,8,9,10,11,12,13,14],
'60': np.arange(200).tolist(), # for debug
'60': (np.ones(200)*(-1)).tolist(), # for debug
"65": np.zeros(200).tolist(),
}
class Dataset3D(basic_3d_dataset):
def __init__(self, config=..., use_cache=True, *args, **kwargs) -> None:
super().__init__(config, use_cache=use_cache, *args, **kwargs)
def __init__(self, config, use_cache=True, *args, **kwargs) -> None:
super().__init__(config=config, use_cache=use_cache, *args, **kwargs)
self.basic_dir = config['data_root_path']
self.cache_dir = config['cache_data_path']
@ -93,6 +95,13 @@ class Dataset3D(basic_3d_dataset):
dataset_name = self.img_names[index]['img_path'].replace(self.basic_dir,"").split("/")[0]
assert dataset_name[0] in ['0','1','2','3','4','5','6','7','8','9'], f"Got {dataset_name}"
if dataset_name[:2] == "10":
subname = self.img_names[index]['img_path'].replace(self.basic_dir,"")[17:19]
assert subname in ['10', '03', '06', '07']
all_labels = TEMPLATE[dataset_name[:2] + "_" + subname]
else:
all_labels = TEMPLATE[dataset_name[:2]]
all_labels = TEMPLATE[dataset_name[:2]]
num = 0
@ -152,7 +161,7 @@ class Dataset3D(basic_3d_dataset):
self.save_img_rgb(rearrange(img_rgb, "c h w -> h w c"), save_image_name)
# Save cache data
save_label_name = tfilename(self.cache_dir, dataset_name, f"label/label_{index:04d}_{num:08d}.npz")
save_label_name = tfilename(self.cache_dir, dataset_name, f"label_jpg/label_{index:04d}_{num:08d}")
self.save_slice_mask(masks_data, save_label_name)
print("Save ", save_image_name)
@ -257,11 +266,26 @@ class Dataset3D(basic_3d_dataset):
label_path = name.replace("image/image_", "label/label_").replace(".jpg", ".npz")
self._convert_one_mask_from_npz_to_jpg(label_path)
# EX_CONFIG={
# "dataset":{
# "split": 'train',
# "data_root_path": '/home1/quanquan/datasets/',
# "dataset_list": ["decathlon_colon"],
# "data_txt_path": './datasets/dataset_list/',
# "cache_data_path": '/home1/quanquan/datasets/cached_dataset2/',
# "cache_prefix": ['10'] # '07'
# }
# }
if __name__ == "__main__":
# def go_cache():
from tutils.new.manager import ConfigManager
config = ConfigManager()
config.add_config("configs/vit_b.yaml")
dataset = Dataset3D(config=config['dataset'], use_cache=True)
config.add_config("configs/vit_sub.yaml")
# config.add_config(EX_CONFIG)
# Caching data
dataset = Dataset3D(config=config['dataset'], use_cache=False)
dataset.caching_data()
# dataset.convert_masks_types()

View File

@ -33,7 +33,7 @@ class Dataset3D(basic_3d_dataset):
if not os.path.isdir(dirpath):
continue
prefix = dirpath.split("/")[-1]
if prefix[:2] in config['cache_prefix']:
if prefix.split("_")[0] in config['cache_prefix']:
data_paths += glob.glob(dirpath + "/label_jpg/*.jpg")
print("Load ", dirpath)
print('Masks len {}'.format(len(data_paths)))

View File

@ -0,0 +1,2 @@
07_WORD/WORD-V0.1.0/imagesVa/word_0001.nii.gz 07_WORD/WORD-V0.1.0/labelsVa/word_0001.nii.gz
07_WORD/WORD-V0.1.0/imagesVa/word_0007.nii.gz 07_WORD/WORD-V0.1.0/labelsVa/word_0007.nii.gz

View File

@ -0,0 +1,146 @@
"""
DataLoader only for evaluation
"""
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from tutils.nn.data import read, itk_to_np, np_to_itk
from einops import reduce, repeat, rearrange
# from tutils.nn.data.tsitk.preprocess import resampleImage
from trans_utils.data_utils import Data3dSolver
from tutils.nn.data.tsitk.preprocess import resampleImage
import SimpleITK as sitk
# Example
DATASET_CONFIG={
'split': 'test',
'data_root_path':'/quanquan/datasets/',
'dataset_list': ["ours"],
'data_txt_path':'./datasets/dataset_list/',
'label_idx': 0,
}
class AbstractLoader(Dataset):
def __init__(self, config, split="test") -> None:
super().__init__()
self.config = config
self.split = split
self.img_names = self.prepare_datalist()
def __len__(self):
return len(self.img_names)
def prepare_datalist(self):
config = self.config
data_paths = []
for item in config['dataset_list']:
print("Load datalist from ", item)
for line in open(config["data_txt_path"]+ item + f"_{self.split}.txt"):
name = line.strip().split()[1].split('.')[0]
img_path = config['data_root_path'] + line.strip().split()[0]
label_path = config['data_root_path'] + line.strip().split()[1]
data_paths.append({'img_path': img_path, 'label_path': label_path, 'name': name})
print('train len {}'.format(len(data_paths)))
return data_paths
def _get_data(self, index, debug=False):
label_idx = self.config['label_idx']
name = self.img_names[index]['img_path']
img_itk = read(name)
spacing = img_itk.GetSpacing()
img_ori = itk_to_np(img_itk)
scan_orientation = np.argmin(img_ori.shape)
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
label = label_ori == label_idx
# img_ori, new_spacing = Data3dSolver().read(self.img_names[index]['img_path'])
# label_itk = read(self.img_names[index]['label_path'])
# ori_spacing = label_itk.GetSpacing()
# label = itk_to_np(label_itk) == label_idx
# print("[loader_abstract.DEBUG] size", img_ori.shape, label.shape)
# label = self._get_resized_label(label, new_size=img_ori.shape)
if debug:
Data3dSolver().simple_write(label)
Data3dSolver().simple_write(img_ori, "tmp_img.nii.gz")
s = reduce(label, "c h w -> c", reduction="sum")
coords = np.nonzero(s)
x_min = np.min(coords[0])
x_max = np.max(coords[0])
template_slice_id = s.argmax() - x_min
if img_ori.min() < -10:
img_ori = np.clip(img_ori, -200, 400)
else:
img_ori = np.clip(img_ori, 0, 600)
img_ori = img_ori[x_min:x_max+1,:,:]
label = label[x_min:x_max+1,:,:]
assert label.shape[0] >= 3
if template_slice_id <= 1 or template_slice_id >= label.shape[0]-2:
template_slice_id == label.shape[0] // 2
dataset_name = name.replace(self.config['data_root_path'], "").split("/")[0]
template_slice = label[template_slice_id,:,:]
print("template_slice.area ", template_slice.sum(), template_slice.sum() / (template_slice.shape[0] * template_slice.shape[1]))
d = {
"name": name,
"dataset_name": dataset_name,
"img": np.array(img_ori).astype(np.float32),
"label_idx": label_idx,
"label": np.array(label).astype(np.float32),
"template_slice_id": template_slice_id,
"template_slice": np.array(label[template_slice_id,:,:]).astype(np.float32),
"spacing": np.array(spacing),
}
return d
def __getitem__(self, index):
return self._get_data(index)
if __name__ == "__main__":
from tutils.new.manager import ConfigManager
EX_CONFIG = {
'dataset':{
'prompt': 'box',
'dataset_list': ['guangdong'], # ["sabs"], chaos, word
'label_idx': 2,
}
}
config = ConfigManager()
config.add_config("configs/vit_sub_rectum.yaml")
config.add_config(EX_CONFIG)
dataset = AbstractLoader(config['dataset'], split="test")
for i in range(len(dataset)):
dataset._get_data(i, debug=False)
# label_path = "/home1/quanquan/datasets/01_BCV-Abdomen/Training/label/label0001.nii.gz"
# from monai.transforms import SpatialResample
# resample = SpatialResample()
# label = itk_to_np(read(label_path)) == 1
# print(label.shape)
# # resampled = resample(label, spatial_size=(label.shape[0]*7, label.shape[1], label.shape[2]))
# print(label.shape)
exit(0)
data = itk_to_np(read("tmp_img.nii.gz"))
data = torch.Tensor(data)
maxlen = data.shape[0]
slices = []
for i in range(1, maxlen-1):
slices.append(data[i-1:i+2, :, :])
input_slices = torch.stack(slices, axis=0)
input_slices = torch.clip(input_slices, -200, 600)
input_slices
from torchvision.utils import save_image
save_image(input_slices, "tmp.jpg")

View File

@ -118,9 +118,29 @@ Then run
python -m datasets.cache_dataset3d
```
## Configs Settings
important settings
```yaml
base:
base_dir: "../runs/sam/" # logging dir
dataset:
types: ['3d'] # ['3d', '2d']
split: 'train'
data_root_path: '../datasets/'
dataset_list: ["pancreas"]
data_txt_path: './datasets/dataset_list/'
dataset2d_path: "../08_AbdomenCT-1K/"
cache_data_path: '../cached_dataset2/'
cache_prefix: ['6016'] # cache prefix of cached dataset for training
# For example: ['07',] for 07_WORD
```
## Start Training
## Start Training from scratch (SAM)
Run training on multi-gpu
@ -140,6 +160,47 @@ CUDA_VISIBLE_DEVICES=0 python -m core.ddp --tag debug
python -m core.volume_predictor
```
## Testset Validation
```python
EX_CONFIG = {
'dataset':{
'prompt': 'box', # prompt type: box or point
'dataset_list': ['word'], # dataset_list name
'label_idx': 2, # label index for inference,
},
"pth": "./model.pth"
}
```
```
python -m test.volume_eval
```
## Finetuning (Recommended)
```yaml
training:
breakpoint_path: "./model.pth" # pretrained weight path
```
```
python -m core.ddp_sub --tag run
```
## Validation with Finetuned Weights
```
python -m test.volume_eval_sublora
```
```python
EX_CONFIG = {
'dataset':{
'prompt': 'box', # prompt type: box or point
'dataset_list': ['word'], # dataset_list name
'label_idx': 2, # label index for inference,
},
"pth": "./model_finetuned.pth"
}
```
<p align="center" width="100%">

View File

@ -137,31 +137,35 @@ if __name__ == "__main__":
EX_CONFIG = {
'dataset':{
'prompt': 'box',
'dataset_list': ['word'], # ["sabs"], chaos, word
'label_idx': 1,
}
'prompt': 'box', # box / point
'dataset_list': ['example'], # ["sabs"], chaos, word
'label_idx': 5,
},
'model_type': "vit_b",
'pth': "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth",
}
EX_CONFIG = {
'dataset':{
'prompt': 'box', # box / point
'dataset_list': ['example'], # ["sabs"], chaos, word
'label_idx': 5,
},
'model_type': "vit_h",
'pth': "/home1/quanquan/code/projects/finetune_large/segment_anything/model_iter_3935000.pth",
}
config = ConfigManager()
config.add_config("configs/vit_b_103.yaml")
config.add_config("configs/vit_b.yaml")
config.add_config(EX_CONFIG)
config.print()
# Init Model
model_type = "vit_b"
model_type = config['model_type']
sam = sam_model_registry[model_type](checkpoint=None)
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
learner.use_lora()
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt_v/model_latest.pth"
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_20.pth"
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_16.pth"
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora_small/ckpt/model_epoch_6.pth"
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora/ckpt/model_epoch_50.pth"
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_8/ckpt_v/model_latest.pth"
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_5/ckpt/model_epoch_100.pth"
pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth"
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_500000.pth"
learner.load_well_trained_model(pth)
learner.load_well_trained_model(EX_CONFIG['pth'])
learner.cuda()
predictor = VolumePredictor(
model=learner.model,

216
test/volume_eval_mbox.py Normal file
View File

@ -0,0 +1,216 @@
"""
Volume evalutaion
"""
import torch
import numpy as np
from torch.utils.data import DataLoader
# from datasets.dataset3d import Dataset3D
from tutils.new.manager import ConfigManager
from datasets.eval_dataloader.loader_abstract import AbstractLoader
from core.volume_predictor import VolumePredictor
from datasets.data_engine import DataManager, BoxPromptGenerator, PointPromptGenerator
from tutils import tfilename
from tutils.new.trainer.recorder import Recorder
from trans_utils.metrics import compute_dice_np
from trans_utils.data_utils import Data3dSolver
from tutils.tutils.ttimer import timer
class Evaluater:
def __init__(self, config) -> None:
self.config = config
self.recorder = Recorder()
def solve(self, model, dataset, finetune_number=1):
# model.eval()
self.predictor = model
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
for i, data in enumerate(dataloader):
# if i <4:
# print
# continue
# for k, v in data.items():
# if isinstance(v, torch.Tensor):
# data[k] = v.to(self.rank)
# try:
if True:
if self.config['dataset']['prompt'] == 'box':
dice, pred, label, temp_slice = self.eval_step(data, batch_idx=i)
used_slice = [temp_slice]
if finetune_number > 1:
for i in range(finetune_number - 1):
dice, pred, label, temp_slice = self.finetune_with_more_prompt(pred, label, exclude_slide_id=used_slice)
used_slice.append(temp_slice)
res = {"dice": dice}
if self.config['dataset']['prompt'] == 'point':
res = self.eval_step_point(data, batch_idx=i)
self.recorder.record(res)
# except Exception as e:
# print(e)
res = self.recorder.cal_metrics()
print(res)
print("prompt:", self.config['dataset']['prompt'], " class_idx:", self.config['dataset']['label_idx'])
def eval_step(self, data, batch_idx=0):
name = data['name']
dataset_name = data['dataset_name'][0]
label_idx = data['label_idx'][0]
template_slice_id = data['template_slice_id'][0]
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
if template_slice_id == 0:
template_slice_id += 1
elif template_slice_id == (data['img'].shape[0] - 1):
template_slice_id -= 1
print("Using slice ", template_slice_id, " as template slice")
spacing = data['spacing'].numpy().tolist()[0]
if data['img'].shape[-1] < 260:
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
img = data['img'][0][:,:256,:256]
label = data['label'][0][:,:256,:256]
else:
img = data['img'][0]
label = data['label'][0]
# img = torch.clip(img, -200, 600)
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
box = np.array([box])
pred, stability = self.predictor.predict_volume(
x=img,
box=box,
template_slice_id=template_slice_id,
return_stability=True,
)
prompt_type = 'box'
dice = compute_dice_np(pred, label.detach().cpu().numpy())
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
# Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
print(dataset_name, name, dice)
template_slice_id = template_slice_id if isinstance(template_slice_id, int) else template_slice_id.item()
return dice, pred, label.detach().cpu().numpy(), template_slice_id
def eval_step_point(self, data, batch_idx=0):
name = data['name']
dataset_name = data['dataset_name'][0]
label_idx = data['label_idx'][0]
template_slice_id = data['template_slice_id'][0]
spacing = data['spacing'].numpy().tolist()[0]
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
if template_slice_id == 0:
template_slice_id += 1
elif template_slice_id == (data['img'].shape[0] - 1):
template_slice_id -= 1
if data['img'].shape[-1] < 260:
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
img = data['img'][0][:,:256,:256]
label = data['label'][0][:,:256,:256]
else:
img = data['img'][0]
label = data['label'][0]
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
point = (box[0]+box[2])*0.5 , (box[1]+box[3])*0.5
point = np.array([point]).astype(int)
if label[template_slice_id][point[0,1], point[0,0]] == 0:
print("Use random point instead !!!")
point = PointPromptGenerator().get_prompt_point(label[template_slice_id])
point = np.array([point]).astype(int)
# box = np.array([box])
pred = self.predictor.predict_volume(
x=img,
point_coords=point,
point_labels=np.ones_like(point)[:,:1],
template_slice_id=template_slice_id,
)
dice = compute_dice_np(pred, label.detach().cpu().numpy())
prompt_type = 'point'
Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}.nii.gz"))
print(dataset_name, name, dice)
return {"dice": dice}
def finetune_with_more_prompt(self, pred, label, prompt_type="box", exclude_slide_id=[]):
assert pred.shape == label.shape
dices = [compute_dice_np(pred[j,:,:], label[j,:,:]) for j in range(pred.shape[0])]
rank_list = np.array(dices[1:-1]).argsort() # Ignore the head and tail
rank_list += 1 # Ignore the head and tail
for i in rank_list:
if i in exclude_slide_id:
continue
template_slice_id = i
break
# template_slice_id += 1 # Ignore the head and tail
print("Using slice ", template_slice_id, " as template slice")
old_confidence = self.predictor.get_confidence()
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id])
box = np.array([box])
new_pred, stability = self.predictor.predict_with_prompt(
box=box,
template_slice_id=template_slice_id,
return_stability=True,
)
new_confidence = self.predictor.get_confidence()
new_confidence[template_slice_id] *= 2
all_conf = np.stack([old_confidence, new_confidence], axis=1)
preds = [pred, new_pred]
merged = np.zeros_like(label)
for slice_idx in range(pred.shape[0]):
idx = np.argsort(all_conf[slice_idx,:])[-1]
merged[slice_idx,:,:] = preds[idx][slice_idx]
print("old dices", [compute_dice_np(pred, label) for pred in preds])
dice = compute_dice_np(merged, label)
print("merged dice, idx", dice)
return dice, merged, label, template_slice_id
def to_RGB(img):
pass
if __name__ == "__main__":
from core.learner3 import SamLearner
from modeling.build_sam3d2 import sam_model_registry
EX_CONFIG = {
'dataset':{
'prompt': 'box',
'prompt_number': 5,
'dataset_list': ['example'], # ["sabs"], chaos, word pancreas
'label_idx': 5,
},
"lora_r": 24,
'model_type': "vit_h",
'ckpt': "/home1/quanquan/code/projects/finetune_large/segment_anything/model_iter_3935000.pth",
}
config = ConfigManager()
config.add_config("configs/vit_b.yaml")
config.add_config(EX_CONFIG)
config.print()
dataset = AbstractLoader(config['dataset'], split="test")
print(len(dataset))
assert len(dataset) >= 1
# Init Model
model_type = config['model_type'] # "vit_b"
sam = sam_model_registry[model_type](checkpoint=None)
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
learner.use_lora(r=config['lora_r'])
pth = config['ckpt']
learner.load_well_trained_model(pth)
learner.cuda()
predictor = VolumePredictor(
model=learner.model,
use_postprocess=True,
use_noise_remove=True,)
solver = Evaluater(config)
dataset = AbstractLoader(config['dataset'], split="test")
solver.solve(predictor, dataset, finetune_number=config['dataset']['prompt_number'])

231
test/volume_eval_sublora.py Normal file
View File

@ -0,0 +1,231 @@
"""
Volume evalutaion
"""
import torch
import numpy as np
from torch.utils.data import DataLoader
# from datasets.dataset3d import Dataset3D
from tutils.new.manager import ConfigManager
from datasets.eval_dataloader.loader_abstract import AbstractLoader
from core.volume_predictor import VolumePredictor
from datasets.data_engine import DataManager, BoxPromptGenerator, PointPromptGenerator
from tutils import tfilename
from tutils.new.trainer.recorder import Recorder
from trans_utils.metrics import compute_dice_np
from trans_utils.data_utils import Data3dSolver
# from monai.metrics import compute_surface_dice
import surface_distance as surfdist
from tutils.tutils.ttimer import timer
class Evaluater:
def __init__(self, config) -> None:
self.config = config
self.recorder = Recorder()
def solve(self, model, dataset):
# model.eval()
self.predictor = model
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
for i, data in enumerate(dataloader):
# if i <4:
# print
# continue
# for k, v in data.items():
# if isinstance(v, torch.Tensor):
# data[k] = v.to(self.rank)
if self.config['dataset']['prompt'] == 'box':
# res = self.eval_step_slice(data, batch_idx=i)
res = self.eval_step(data, batch_idx=i)
if self.config['dataset']['prompt'] == 'point':
res = self.eval_step_point(data, batch_idx=i)
self.recorder.record(res)
res = self.recorder.cal_metrics()
print(res)
print("prompt:", self.config['dataset']['prompt'], " class_idx:", self.config['dataset']['label_idx'])
def eval_step(self, data, batch_idx=0):
name = data['name']
dataset_name = data['dataset_name'][0]
label_idx = data['label_idx'][0]
template_slice_id = data['template_slice_id'][0]
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
if template_slice_id == 0:
template_slice_id += 1
elif template_slice_id == (data['img'].shape[0] - 1):
template_slice_id -= 1
spacing = data['spacing'].numpy().tolist()[0]
if data['img'].shape[-1] < 260:
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
img = data['img'][0][:,:256,:256]
label = data['label'][0][:,:256,:256]
else:
img = data['img'][0]
label = data['label'][0]
# img = torch.clip(img, -200, 600)
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
box = np.array([box])
pred, stability = self.predictor.predict_volume(
x=img,
box=box,
template_slice_id=template_slice_id,
return_stability=True,
)
prompt_type = 'box'
dice = compute_dice_np(pred, label.detach().cpu().numpy())
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
# Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
# nsd = compute_surface_dice(torch.Tensor(pred), label.detach().cpu(), 1)
# surface_distances = surfdist.compute_surface_distances(
# label.detach().cpu().numpy(), pred, spacing_mm=(0.6, 0.6445, 0.6445))
# nsd = surfdist.compute_surface_dice_at_tolerance(surface_distances, 1)
nsd = 0
print(dataset_name, name, dice, nsd)
# import ipdb; ipdb.set_trace()
return {"dice": dice, "nsd": nsd}
def eval_step_point(self, data, batch_idx=0):
name = data['name']
dataset_name = data['dataset_name'][0]
label_idx = data['label_idx'][0]
template_slice_id = data['template_slice_id'][0]
spacing = data['spacing'].numpy().tolist()[0]
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
if template_slice_id == 0:
template_slice_id += 1
elif template_slice_id == (data['img'].shape[0] - 1):
template_slice_id -= 1
if data['img'].shape[-1] < 260:
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
img = data['img'][0][:,:256,:256]
label = data['label'][0][:,:256,:256]
else:
img = data['img'][0]
label = data['label'][0]
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
point = (box[0]+box[2])*0.5 , (box[1]+box[3])*0.5
point = np.array([point]).astype(int)
if label[template_slice_id][point[0,1], point[0,0]] == 0:
print("Use random point instead !!!")
point = PointPromptGenerator().get_prompt_point(label[template_slice_id])
point = np.array([point]).astype(int)
# box = np.array([box])
pred = self.predictor.predict_volume(
x=img,
point_coords=point,
point_labels=np.ones_like(point)[:,:1],
template_slice_id=template_slice_id,
)
dice = compute_dice_np(pred, label.detach().cpu().numpy())
prompt_type = 'point'
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}.nii.gz"))
nsd = compute_surface_dice(pred, label.detach().cpu().numpy())
print(dataset_name, name, dice)
return {"dice": dice, "nsd": nsd}
def eval_step_slice(self, data, batch_idx=0):
name = data['name']
dataset_name = data['dataset_name'][0]
label_idx = data['label_idx'][0]
template_slice_id = data['template_slice_id'][0]
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
if template_slice_id == 0:
template_slice_id += 1
elif template_slice_id == (data['img'].shape[0] - 1):
template_slice_id -= 1
spacing = data['spacing'].numpy().tolist()[0]
if data['img'].shape[-1] < 260:
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
img = data['img'][0][:,:256,:256]
label = data['label'][0][:,:256,:256]
else:
img = data['img'][0]
label = data['label'][0]
img = img[template_slice_id-1:template_slice_id+2, :,:]
label = label[template_slice_id-1:template_slice_id+2, :,:]
template_slice_id = 1
# img = torch.clip(img, -200, 600)
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
box = np.array([box])
pred, stability = self.predictor.predict_volume(
x=img,
box=box,
template_slice_id=template_slice_id,
return_stability=True,
)
prompt_type = 'box'
dice = compute_dice_np(pred, label.detach().cpu().numpy())
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
# Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
print("Slice evaluation: ", dataset_name, name, dice)
return {"dice": dice}
def to_RGB(img):
pass
if __name__ == "__main__":
# from core.learner3 import SamLearner
# from modeling.build_sam3d import sam_model_registry
# from core.learner3 import SamLearner
from core.learner_sub1 import SamLearner
from modeling.build_sam3d2 import sam_model_registry
EX_CONFIG = {
'dataset':{
'prompt': 'box',
'dataset_list': ['guangdong'], # ["sabs"], chaos, word, decathlon_colon, pancreas
'label_idx': 2,
},
"pth": "./model_latest.pth"
}
config = ConfigManager()
# config.add_config("configs/vit_sub.yaml")
config.add_config("configs/vit_sub.yaml")
config.add_config(EX_CONFIG)
# Init Model
model_type = "vit_b"
sam = sam_model_registry[model_type](checkpoint=None)
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
learner.use_lora()
learner.use_lora_sub()
pth = EX_CONFIG['pth']
learner.load_well_trained_model(pth)
learner.cuda()
predictor = VolumePredictor(
model=learner.model,
use_postprocess=True,
use_noise_remove=True,)
solver = Evaluater(config)
dataset = AbstractLoader(config['dataset'], split="test")
tt = timer()
solver.solve(predictor, dataset)
print("Time: ", tt())

303
tmp.py Normal file
View File

@ -0,0 +1,303 @@
"""
Slow Loading directly
So we pre-precess data
"""
import numpy as np
import os
from einops import rearrange, reduce, repeat
from tutils.nn.data import read, itk_to_np, np_to_itk, write
from tutils import tfilename
from .dataset3d import DATASET_CONFIG, Dataset3D as basic_3d_dataset
from monai import transforms
import torch
import cv2
from scipy.sparse import csr_matrix
import torch.nn.functional as F
from torchvision import transforms
from einops import rearrange
import glob
from torchvision import transforms
from monai import transforms as monai_transforms
# "class": ["spleen", "right kidney", "left kidney", "gallbladder", "esophagus", "liver", "stomach", "aorta", "postcava", "portal vein and splenic vein", "pancrease", "right adrenal gland", "left adrenal gland"],
# "class": ["liver", "right kidney", "left kidney", "spleen"],
TEMPLATE={
'01': [1,2,3,4,5,6,7,8,9,10,11,12,13,14],
'02': [1,0,3,4,5,6,7,0,0,0,11,0,0,14],
'03': [6],
'04': [6,27], # post process
'05': [2,26,32], # post process
'07': [6,1,3,2,7,4,5,11,14,18,19,12,20,21,23,24],
'08': [6, 2, 1, 11],
'09': [1,2,3,4,5,6,7,8,9,11,12,13,14,21,22],
'12': [6,21,16,2],
'13': [6,2,1,11,8,9,7,4,5,12,13,25],
'14': [11,11,28,28,28], # Felix data, post process
'10_03': [6, 27], # post process
'10_06': [30],
'10_07': [11, 28], # post process
'10_08': [15, 29], # post process
'10_09': [1],
'10_10': [31],
'58': [6,2,3,1],
'59': [1,2,3,4,5,6,7,8,9,10,11,12,13,14],
'60': (np.ones(200)*(-1)).tolist(), # for debug
"65": np.zeros(200).tolist(),
}
EX_CONFIG={
"dataset":{
"split": 'train',
"data_root_path": '/home1/quanquan/datasets/',
"dataset_list": ["decathlon_colon"],
"data_txt_path": './datasets/dataset_list/',
"cache_data_path": '/home1/quanquan/datasets/cached_dataset2/',
"cache_prefix": ['10'] # '07'
}
}
class Dataset3D(basic_3d_dataset):
def __init__(self, config, use_cache=True, *args, **kwargs) -> None:
super().__init__(config=config, use_cache=use_cache, *args, **kwargs)
self.basic_dir = config['data_root_path']
self.cache_dir = config['cache_data_path']
def prepare_transforms(self):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((1024,1024)),
])
self.test_transform = transforms.Compose([
monai_transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
])
# @tfunctime
# def prepare_datalist(self):
def prepare_cached_datalist(self):
raise DeprecationWarning("[Warning] Please use cache_dataset3d new version instead!")
config = self.config
data_paths = []
for dirpath in glob.glob(config['cache_data_path'] + "/*"):
data_paths += glob.glob(dirpath + "/image/*.jpg")
print("Load ", dirpath)
print('train len {}'.format(len(data_paths)))
print('Examples: ', data_paths[:2])
return data_paths
def caching_data(self):
assert self.use_cache == False
for index in range(len(self)):
self.cache_one_sample(index)
def cache_one_sample(self, index, debug=False):
# LABEL_INDICES
name = self.img_names[index]['img_path']
img_itk = read(self.img_names[index]['img_path'])
img_ori = itk_to_np(img_itk)
img_ori = np.clip(img_ori, -200,400)
# spacing = img_itk.GetSpacing()
scan_orientation = np.argmin(img_ori.shape)
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
dataset_name = self.img_names[index]['img_path'].replace(self.basic_dir,"").split("/")[0]
assert dataset_name[0] in ['0','1','2','3','4','5','6','7','8','9'], f"Got {dataset_name}"
if dataset_name[:2] == "10":
subname = self.img_names[index]['img_path'].replace(self.basic_dir,"")[17:19]
assert subname in ['10', '03', '06', '07']
all_labels = TEMPLATE[dataset_name[:2] + "_" + subname]
else:
all_labels = TEMPLATE[dataset_name[:2]]
num = 0
# if min(img_ori.shape) * 1.2 < max(img_ori.shape):
# orientation_all = [scan_orientation]
# else:
# orientation_all = [0,1,2]
orientation_all = [scan_orientation]
for orientation in orientation_all:
for slice_idx in range(2, img_ori.shape[orientation]-2):
# slice_idx = np.random.randint(2, img_ori.shape[orientation]-2)
if orientation == 0:
s = img_ori[slice_idx-1:slice_idx+2, :,:]
lb = label_ori[slice_idx-1:slice_idx+2, :,:]
# spacing = (spacing[1], spacing[2])
if orientation == 1:
s = img_ori[:,slice_idx-1:slice_idx+2,:]
s = rearrange(s, "h c w -> c h w")
lb = label_ori[:,slice_idx-1:slice_idx+2,:]
lb = rearrange(lb, "h c w -> c h w")
# spacing = (spacing[0], spacing[2])
if orientation == 2:
s = img_ori[:,:,slice_idx-1:slice_idx+2]
s = rearrange(s, "h w c -> c h w")
lb = label_ori[:,:,slice_idx-1:slice_idx+2]
lb = rearrange(lb, "h w c -> c h w")
# spacing = (spacing[0], spacing[1])
assert s.shape[0] == 3
# if np.float32(lb[1,:,:]>0).sum() <= 200:
# # return self._get_data((index+1)%len(self))
# continue
# Choose one label
label_num = int(lb.max())
masks_data = []
meta = {"img_name": name, "slice": slice_idx, "orientation": orientation, "label_idx": [], "labels": [], "id": f"{num:08d}" }
for label_idx in range(1,label_num+1):
one_lb = np.float32(lb==label_idx)
if one_lb[1,:,:].sum() <= (one_lb.shape[-1] * one_lb.shape[-2] * 0.0014):
continue
# if one_lb[0,:,:].sum()<=50 or one_lb[2,:,:].sum()<=50:
masks_data.append(one_lb)
meta['label_idx'].append(label_idx)
meta['labels'].append(all_labels[label_idx-1])
if len(masks_data) <= 0:
continue
img_rgb = s
img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy()
img_rgb = self.to_RGB(img_rgb)
save_image_name = tfilename(self.cache_dir, dataset_name, f"image/image_{index:04d}_{num:08d}.jpg")
self.save_img_rgb(rearrange(img_rgb, "c h w -> h w c"), save_image_name)
# Save cache data
save_label_name = tfilename(self.cache_dir, dataset_name, f"label_jpg/label_{index:04d}_{num:08d}")
self.save_slice_mask(masks_data, save_label_name)
print("Save ", save_image_name)
self.save_meta(meta, tfilename(self.cache_dir, dataset_name, f"meta/meta_{index:04d}_{num:08d}.npy"))
num += 1
def save_meta(self, meta, path):
assert path.endswith(".npy")
np.save(path, meta)
def save_slice_mask(self, masks_data, prefix):
masks_data = F.interpolate(torch.Tensor(masks_data), size=(1024,1024)).numpy()
assert masks_data.shape[1:] == (3,1024,1024), f"{__file__} Got{masks_data.shape}"
for i in range(masks_data.shape[0]):
labeli = masks_data[i].astype(np.uint8) * 255
assert labeli.sum() > 0
path = tfilename(prefix+f"_{i:04d}.jpg")
cv2.imwrite(path, rearrange(labeli, "c h w -> h w c"))
print("save to ", path)
def _old_save_slice_mask(self, masks_data, path):
raise DeprecationWarning()
exit(0)
assert path.endswith(".npz")
# masks_data = np.array([m['segmentation'] for m in masks]).astype(int)
masks_data = F.interpolate(torch.Tensor(masks_data), size=(1024,1024)).numpy()
# masks_data = np.int8(masks_data>0)
assert masks_data.shape[1:] == (3,1024,1024), f"{__file__} Got{masks_data.shape}"
masks_data = rearrange(masks_data, "n c h w -> n (c h w)")
csr = csr_matrix(masks_data)
np.savez_compressed(path, data=csr.data, indices=csr.indices, indptr=csr.indptr, shape=csr.shape)
def save_img_rgb(self, img, path):
assert path.endswith(".jpg")
assert img.shape == (1024,1024,3)
cv2.imwrite(path, img.astype(np.uint8))
def _get_cached_data(self, index):
name = self.img_names[index]
# print(name)
img = cv2.imread(name)
compressed = np.load(name.replace("image/image_", "label/label_").replace(".jpg", ".npz"))
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
label_ori = csr.toarray()
label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024)
meta = np.load(name.replace("image/image_", "meta/meta_").replace(".jpg", ".npy"), allow_pickle=True).tolist()
# print(meta)
pp = reduce(label_ori[:,1,:,:], "n h w -> n", reduction="sum") > 500
if pp.sum() == 0:
return self._get_cached_data((index+1)%len(self))
label_idx = np.random.choice(a=np.arange(len(pp)), p=pp/pp.sum())
# label_idx = np.random.randint(0, label_ori.shape[0])
label_ori = label_ori[label_idx]
is_edge = meta.get('is_edge', 0)
return rearrange(img, "h w c -> c h w"), label_ori, name, meta['labels'][label_idx], meta['label_idx'][label_idx]
# @tfunctime
def __getitem__(self, index, debug=False):
# print("Dataset warning", index, len(self))
index = index % len(self)
img_rgb, label_ori, name, label_idx, local_idx = self._get_cached_data(index)
if label_ori.sum() <= 0:
print("[Label Error] ", name)
return self.__getitem__(index+1)
# assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}"
# img_rgb = self.transform((img_rgb[None,:,:,:]))
img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy()
vector = np.ones(3)
ret_dict = {
"name": name,
"img": img_rgb,
"label": label_ori,
"indicators": vector,
"class": label_idx,
"local_idx": local_idx,
}
return ret_dict
def _convert_one_mask_from_npz_to_jpg(self, path1=None):
# path1 = "/home1/quanquan/datasets/cached_dataset2/01_BCV-Abdomen/label/label_0129_00000043.npz" # 32K
prefix = path1.replace(".npz", "").replace("/label/", "/label_jpg/")
compressed = np.load(path1)
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
label_ori = csr.toarray()
label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024)
# print(label_ori.shape)
for i in range(label_ori.shape[0]):
labeli = label_ori[i]
path = tfilename(prefix+f"_{i:04d}.jpg")
cv2.imwrite(path, rearrange(labeli, "c h w -> h w c").astype(np.uint8))
print("save to ", path)
def convert_masks_types(self):
assert self.use_cache == True
for index in range(len(self)):
name = self.img_names[index]
label_path = name.replace("image/image_", "label/label_").replace(".jpg", ".npz")
self._convert_one_mask_from_npz_to_jpg(label_path)
if __name__ == "__main__":
# def go_cache():
from tutils.new.manager import ConfigManager
config = ConfigManager()
config.add_config("configs/vit_sub.yaml")
config.add_config(EX_CONFIG)
dataset = Dataset3D(config=config['dataset'], use_cache=False)
dataset.caching_data()
# config.add_config("configs/vit_b_word_103.yaml")
# dataset = Dataset3D(config=config['dataset'], use_cache=True)
# dataset.caching_data()
# dataset.convert_masks_types()
# from tutils.new.manager import ConfigManager
# config = ConfigManager()
# config.add_config("configs/vit_b_103.yaml")
# dataset = Dataset3D(config=config['dataset']) # , use_cache=True
# data = dataset.__getitem__(0)
# # import ipdb; ipdb.set_trace()
# from torch.utils.data import DataLoader
# loader = DataLoader(dataset, batch_size=8)
# for batch in loader:
# print(batch['img'].shape, batch['label'].shape)
# print(data['label'].max())
# # import ipdb; ipdb.set_trace()

Binary file not shown.