Compare commits
10 Commits
1a352792a8
...
72a73cb4b9
Author | SHA1 | Date | |
---|---|---|---|
![]() |
72a73cb4b9 | ||
![]() |
96a2bd15a6 | ||
![]() |
841be2acbe | ||
![]() |
ce31f13be0 | ||
![]() |
88745c6105 | ||
![]() |
635f319708 | ||
![]() |
173762f756 | ||
![]() |
840009725f | ||
![]() |
738d4258c3 | ||
![]() |
6332a3ea31 |
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
*.pth
|
||||
*.pyc
|
||||
*.nii.gz
|
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
@ -25,7 +25,7 @@ training:
|
||||
load_pretrain_model: false
|
||||
|
||||
# optim:
|
||||
lr: 0.0002
|
||||
lr: 0.000005
|
||||
decay_step: 2000
|
||||
decay_gamma: 0.8
|
||||
weight_decay: 0.0001
|
||||
@ -35,11 +35,13 @@ training:
|
||||
dataset:
|
||||
types: ['3d'] # ['3d', '2d']
|
||||
split: 'train'
|
||||
data_root_path: '/quanquan/datasets/'
|
||||
data_root_path: '/home1/quanquan/datasets/'
|
||||
dataset_list: ["alp", "word", "debug"] # ['sam', "their", "ours"]
|
||||
data_txt_path: './datasets/dataset_list/'
|
||||
dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/"
|
||||
cache_data_path: '/home1/quanquan/datasets/cached_dataset2/'
|
||||
cache_prefix: ['6016'] # '07'
|
||||
specific_label: [2]
|
||||
|
||||
# sam_checkpoint: "/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server
|
||||
# model_type: "vit_b"
|
||||
|
47
configs/vit_sub.yaml
Normal file
47
configs/vit_sub.yaml
Normal file
@ -0,0 +1,47 @@
|
||||
# ---------------------- Common Configs --------------------------
|
||||
base:
|
||||
base_dir: "../runs/sam/"
|
||||
tag: ''
|
||||
stage: ''
|
||||
logger:
|
||||
mode: ['tb', ]
|
||||
# mode: ''
|
||||
recorder_reduction: 'mean'
|
||||
|
||||
training:
|
||||
save_mode: ['all', 'best', 'latest'] # ,
|
||||
batch_size : 2 # 8 for A100
|
||||
num_workers : 8
|
||||
num_epochs : 100 # epochs
|
||||
use_amp: false
|
||||
save_interval : 1
|
||||
val_check_interval: 6
|
||||
load_pretrain_model: false
|
||||
|
||||
# optim:
|
||||
lr: 0.00002
|
||||
decay_step: 2000
|
||||
decay_gamma: 0.8
|
||||
weight_decay: 0.0001
|
||||
alpha: 0.99
|
||||
validation_interval: 100
|
||||
|
||||
continue_training: false
|
||||
load_optimizer: false
|
||||
breakpoint_path: "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth"
|
||||
|
||||
dataset:
|
||||
types: ['3d'] # ['3d', '2d']
|
||||
split: 'train'
|
||||
data_root_path: '/home1/quanquan/datasets/'
|
||||
dataset_list: ["example"] # for example_train.txt
|
||||
data_txt_path: './datasets/dataset_list/'
|
||||
dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/"
|
||||
cache_data_path: '/home1/quanquan/datasets/cached_dataset2/'
|
||||
|
||||
cache_prefix: ['6016'] # '07'
|
||||
specific_label: [2]
|
||||
|
||||
test:
|
||||
batch_size: 1
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
122
core/ddp_sub.py
Normal file
122
core/ddp_sub.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""
|
||||
from ddp_b9.py
|
||||
|
||||
Add additional bypass/side-way to finetune on other datasets
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from tutils import tfilename, tdir
|
||||
|
||||
from datasets.dataset3d_2dmask import Dataset2D
|
||||
# from datasets.dataset3d import Dataset3D
|
||||
from datasets.cache_dataset3d3 import Dataset3D
|
||||
from datasets.dataset_merged import DatasetMerged, TestsetMerged
|
||||
from datasets.data_engine import DataEngine
|
||||
from modeling.build_sam3d2 import sam_model_registry
|
||||
|
||||
from .learner_sub1 import SamLearner
|
||||
# from tutils.new.trainer.trainer_ddp import DDPTrainer
|
||||
from trans_utils.trainer_ddp import DDPTrainer
|
||||
# from .lora_sam import LoRA_Sam
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def setup(rank, world_size):
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12355'
|
||||
|
||||
# initialize the process group
|
||||
dist.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||
|
||||
def cleanup():
|
||||
dist.destroy_process_group()
|
||||
|
||||
def ddp_train(rank, world_size, config):
|
||||
setup(rank, world_size)
|
||||
|
||||
# sam_checkpoint = "/quanquan/code/segment-anything/segment_anything/sam_vit_b_01ec64.pth" # A800 server
|
||||
# sam_checkpoint = "/home1/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server
|
||||
model_type = "vit_b"
|
||||
device = rank
|
||||
|
||||
config_data = config['dataset']
|
||||
data_type = config_data.get("types", ["3d", "2d"])
|
||||
data_type = [data_type] if isinstance(data_type, str) else data_type
|
||||
dataset = Dataset3D(config_data, split='train')
|
||||
|
||||
# assert len(validset) > 0
|
||||
data_engine = DataEngine(dataset=dataset, img_size=(1024,1024))
|
||||
sam = sam_model_registry[model_type](checkpoint=None)
|
||||
|
||||
learner = SamLearner(sam_model=sam, config=config, data_engine=data_engine)
|
||||
learner.use_lora()
|
||||
learner.load_well_trained_model(config['training']['breakpoint_path']) # use preset path
|
||||
learner.use_lora_sub()
|
||||
|
||||
ddp_trainer = DDPTrainer(config=config, rank=rank, world_size=world_size)
|
||||
ddp_trainer.fit(learner, trainset=data_engine, validset=None)
|
||||
|
||||
cleanup()
|
||||
|
||||
|
||||
def get_parameter_number(model):
|
||||
total_num = sum(p.numel() for p in model.parameters())
|
||||
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
return {'Total': total_num, 'Trainable': trainable_num}
|
||||
|
||||
def run_demo(demo_fn, world_size, config):
|
||||
mp.spawn(demo_fn,
|
||||
args=(world_size,config),
|
||||
nprocs=world_size,
|
||||
join=True)
|
||||
|
||||
from collections import OrderedDict
|
||||
import yaml
|
||||
import yamlloader
|
||||
def _ordereddict_to_dict(d):
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
for k, v in d.items():
|
||||
if isinstance(v, OrderedDict):
|
||||
v = _ordereddict_to_dict(v)
|
||||
d[k] = dict(v)
|
||||
elif type(v) == list:
|
||||
d[k] = _ordereddict_to_dict(v)
|
||||
elif isinstance(v, dict):
|
||||
d[k] = _ordereddict_to_dict(v)
|
||||
return d
|
||||
|
||||
# CUDA_VISIBLE_DEVICES=4,5,6,7 python -m core.ddp_b3 --tag lora --config configs/vit_b_103.yaml
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
from tutils.new.manager import trans_args, trans_init, ConfigManager
|
||||
|
||||
n_gpus = torch.cuda.device_count()
|
||||
# assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but {__file__} Got{n_gpus}"
|
||||
if n_gpus == 1:
|
||||
print("Warning! Running on only 1 GPU! just for debug")
|
||||
world_size = n_gpus
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", default="./configs/vit_sub.yaml")
|
||||
parser.add_argument("--func", default="train")
|
||||
parser.add_argument("--reuse", action="store_true")
|
||||
|
||||
args = trans_args(parser=parser)
|
||||
config = ConfigManager()
|
||||
config.auto_init(file=__file__, args=args, ex_config=None)
|
||||
# config.save()
|
||||
path = tfilename(config['base']['runs_dir'], "config.yaml")
|
||||
with open(path, "w") as f:
|
||||
yaml.dump(_ordereddict_to_dict(config), f)
|
||||
print("Save config file to ", path)
|
||||
|
||||
if n_gpus < 1: exit(0)
|
||||
run_demo(ddp_train, world_size, config)
|
@ -107,8 +107,8 @@ class SamLearner(LearnerModule):
|
||||
self.model.load_state_dict(state_dict)
|
||||
# self.lora_module.load_lora_parameters(pth.replace(".pth", "_lora.safetensors"))
|
||||
|
||||
def use_lora(self):
|
||||
lora_r = 8
|
||||
def use_lora(self, r=8):
|
||||
lora_r = r
|
||||
lora_sam = LoRA_Sam(self.model, lora_r, freeze_prompt_encoder=True)
|
||||
self.lora_module = lora_sam
|
||||
|
||||
|
73
core/learner_sub1.py
Normal file
73
core/learner_sub1.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""
|
||||
Use mask_decoder3d_2.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import numpy as np
|
||||
from tutils.trainer import Trainer, LearnerModule
|
||||
from einops import rearrange, repeat, reduce
|
||||
import torch.optim.lr_scheduler as lr_scheduler
|
||||
|
||||
from core.loss import ranked_combined_loss_with_indicators
|
||||
from .learner5 import SamLearner as basic_learner
|
||||
from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss
|
||||
# from torchao.quantization import apply_dynamic_quant
|
||||
# from torch._inductor import config as inductorconfig
|
||||
from .lora_sam import LoRA_Sam
|
||||
|
||||
|
||||
class SamLearner(basic_learner):
|
||||
|
||||
def load_pretrained_model(self, pth, *args, **kwargs):
|
||||
"""
|
||||
Unmatched: prompt_encoder.mask_downscaling.0.weight
|
||||
their: torch.Size([4, 1, 2, 2])
|
||||
our: torch.Size([4, 3, 2, 2])
|
||||
Unmatched: mask_decoder.mask_tokens.weight
|
||||
their: torch.Size([4, 256])
|
||||
our: torch.Size([12, 256])
|
||||
"""
|
||||
print("Load pretrained model for mask_decoder3d_2 !!")
|
||||
|
||||
state_dict = torch.load(pth)
|
||||
model_state_dict = self.model.state_dict()
|
||||
model_state_dict.update(state_dict)
|
||||
model_state_dict['prompt_encoder.mask_downscaling.0.weight'] = repeat(state_dict['prompt_encoder.mask_downscaling.0.weight'], "a 1 c d -> a b c d", b=3)
|
||||
# model_state_dict['mask_decoder.mask_tokens.weight'] = repeat(state_dict['mask_decoder.mask_tokens.weight'], "a d -> (a 3) d")
|
||||
|
||||
for k, v in model_state_dict.items():
|
||||
if k.startswith("mask_decoder.output_upscaling2"):
|
||||
k2 = k.replace("output_upscaling2.", "output_upscaling." )
|
||||
model_state_dict[k] = model_state_dict[k2]
|
||||
print("Load weights: ", k)
|
||||
if k.startswith("mask_decoder.output_upscaling3"):
|
||||
k2 = k.replace("output_upscaling3.", "output_upscaling." )
|
||||
model_state_dict[k] = model_state_dict[k2]
|
||||
print("Load weights: ", k)
|
||||
|
||||
hyper_params_names = [k for k in model_state_dict.keys() if k.startswith("mask_decoder.output_hypernetworks_mlps")]
|
||||
for name in hyper_params_names:
|
||||
words = name.split('.')
|
||||
words[2] = str(int(words[2]) // 3)
|
||||
name_to_copy = ".".join(words)
|
||||
model_state_dict[name] = state_dict[name_to_copy]
|
||||
# for k, v in state_dict.items():
|
||||
# if model_state_dict[k].shape != state_dict[k].shape:
|
||||
# print("Unmatched:", k)
|
||||
self.model.load_state_dict(model_state_dict)
|
||||
|
||||
def quantize(self):
|
||||
# self.model.image_encoder = torch.ao.quantization.quantize_dynamic(
|
||||
# self.model.image_encoder,
|
||||
# dtype=torch.qint8
|
||||
# )
|
||||
apply_dynamic_quant(self.model.image_encoder)
|
||||
inductorconfig.force_fuse_int_mm_with_mul = True
|
||||
print("Quantized !")
|
||||
|
||||
|
||||
def use_lora_sub(self):
|
||||
lora_r = 1
|
||||
lora_sam = LoRA_Sam(self.model, lora_r, freeze_all=True)
|
||||
self.lora_module = lora_sam
|
@ -33,6 +33,7 @@ class _LoRA_qkv(nn.Module):
|
||||
self.linear_a_v = linear_a_v
|
||||
self.linear_b_v = linear_b_v
|
||||
self.dim = qkv.in_features
|
||||
self.in_features = qkv.in_features
|
||||
self.w_identity = torch.eye(qkv.in_features)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -191,7 +191,7 @@ class VolumePredictor:
|
||||
raise NotImplementedError
|
||||
|
||||
# Preprocess prompts
|
||||
self.original_size = x.shape[1:]
|
||||
# self.original_size = x.shape[1:]
|
||||
if point_coords is not None:
|
||||
assert (
|
||||
point_labels is not None
|
||||
@ -217,6 +217,7 @@ class VolumePredictor:
|
||||
center_masks = self._predict_center_slice(center_idx, point_coords, box)
|
||||
return center_masks['masks']
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_volume(
|
||||
self,
|
||||
x,
|
||||
@ -255,8 +256,37 @@ class VolumePredictor:
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# set 3d image
|
||||
self.set_image(x)
|
||||
# Preprocess prompts
|
||||
self.original_size = x.shape[1:]
|
||||
if self.masks3d is None:
|
||||
self.masks3d = np.zeros_like(x)
|
||||
self.slice_count = x.shape[0]
|
||||
return self.predict_with_prompt(
|
||||
point_coords = point_coords,
|
||||
point_labels = point_labels,
|
||||
box = box,
|
||||
mask_input = mask_input,
|
||||
multimask_output = multimask_output,
|
||||
return_logits = return_logits,
|
||||
template_slice_id = template_slice_id,
|
||||
return_stability = return_stability
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_with_prompt(
|
||||
self,
|
||||
point_coords: Optional[np.ndarray] = None,
|
||||
point_labels: Optional[np.ndarray] = None,
|
||||
box: Optional[np.ndarray] = None,
|
||||
mask_input: Optional[np.ndarray] = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
template_slice_id:int = None,
|
||||
return_stability: bool = False,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
|
||||
if point_coords is not None:
|
||||
assert (
|
||||
point_labels is not None
|
||||
@ -273,38 +303,39 @@ class VolumePredictor:
|
||||
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
|
||||
mask_input_torch = mask_input_torch[None, :, :, :]
|
||||
|
||||
# set 3d image
|
||||
self.set_image(x)
|
||||
self.all_prompts = {}
|
||||
|
||||
# predict center slice
|
||||
center_idx = template_slice_id if template_slice_id is not None else x.shape[0] // 2
|
||||
center_idx = template_slice_id if template_slice_id is not None else self.slice_count // 2
|
||||
# print("Processing ", center_idx)
|
||||
center_masks = self._predict_center_slice(center_idx, point_coords, box)
|
||||
if center_masks._stats == {}:
|
||||
print("Ends for no mask.")
|
||||
raise ValueError
|
||||
self.merge_to_mask3d(center_idx, center_masks)
|
||||
center_idx = center_idx.item() if not isinstance(center_idx, int) else center_idx
|
||||
self.all_prompts[center_idx] = box if box is not None else point_coords
|
||||
|
||||
previous_masks = center_masks
|
||||
for i in range(center_idx+1, x.shape[0]-1):
|
||||
for i in range(center_idx+1, self.slice_count-1):
|
||||
# print("Processing downward", i)
|
||||
previous_masks = self._predict_slice(i, previous_masks, orientation="down")
|
||||
previous_masks, scaled_boxes = self._predict_slice(i, previous_masks, orientation="down")
|
||||
if previous_masks._stats == {}:
|
||||
print("Ends for no mask.")
|
||||
break
|
||||
self.merge_to_mask3d(i, previous_masks)
|
||||
self.all_prompts[i] = scaled_boxes
|
||||
|
||||
previous_masks = center_masks
|
||||
for i in np.arange(1, center_idx)[::-1]:
|
||||
# print("Processing upward", i)
|
||||
previous_masks = self._predict_slice(i, previous_masks, orientation="up")
|
||||
previous_masks, scaled_boxes = self._predict_slice(i, previous_masks, orientation="up")
|
||||
if previous_masks._stats == {}:
|
||||
print("Ends for no mask.")
|
||||
break
|
||||
self.merge_to_mask3d(i, previous_masks)
|
||||
self.all_prompts[i] = scaled_boxes
|
||||
|
||||
if self.masks3d is None:
|
||||
self.masks3d = np.zeros_like(x)
|
||||
if return_stability:
|
||||
return self.postprocess_3d(self.masks3d), self.stability_score_2d
|
||||
return self.postprocess_3d(self.masks3d)
|
||||
@ -324,7 +355,7 @@ class VolumePredictor:
|
||||
scaled_boxes, tags = self.generate_prompts_from_previous_masks(previous_masks, orientation)
|
||||
masks = self.genetate_masks_from_boxes(idx, all_boxes=scaled_boxes, tags=tags)
|
||||
masks.to_numpy()
|
||||
return masks
|
||||
return masks, scaled_boxes
|
||||
|
||||
def generate_prompts_from_previous_masks(self, previous_masks: MaskData, orientation):
|
||||
if orientation == "down":
|
||||
@ -486,10 +517,6 @@ class VolumePredictor:
|
||||
data["boxes"] = batched_mask_to_box(data["masks"][:,1,:,:]>0)
|
||||
return data
|
||||
|
||||
# @staticmethod
|
||||
# def calculate_
|
||||
|
||||
|
||||
@staticmethod
|
||||
def batched_remove_noise(masks):
|
||||
ori_shape = masks.shape
|
||||
@ -539,26 +566,80 @@ class VolumePredictor:
|
||||
)
|
||||
|
||||
# Upscale the masks to the original image resolution
|
||||
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size[1:])
|
||||
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
|
||||
|
||||
if not return_logits:
|
||||
masks = masks > self.model.mask_threshold
|
||||
|
||||
return masks, iou_predictions, low_res_masks
|
||||
|
||||
def valid_box(self, data, batch_idx):
|
||||
# valid image with box, or point prompt
|
||||
assert data['img'].shape[0] == 1, f"shape {data['img'].shape}"
|
||||
image = data['img']
|
||||
label = data['label']
|
||||
# def valid_box(self, data, batch_idx):
|
||||
# # valid image with box, or point prompt
|
||||
# assert data['img'].shape[0] == 1, f"shape {data['img'].shape}"
|
||||
# image = data['img']
|
||||
# label = data['label']
|
||||
|
||||
box = BoxPromptGenerator().mask_to_bbox(label)
|
||||
box_mask3d = self.predict_volume(
|
||||
x=image,
|
||||
box=box,
|
||||
)
|
||||
dice = compute_dice_np(box_mask3d, label.detach().cpu().numpy())
|
||||
# box = BoxPromptGenerator().mask_to_bbox(label)
|
||||
# box_mask3d = self.predict_volume(
|
||||
# x=image,
|
||||
# box=box,
|
||||
# )
|
||||
# dice = compute_dice_np(box_mask3d, label.detach().cpu().numpy())
|
||||
|
||||
def get_confidence(self):
|
||||
masks = self.postprocess_3d(self.masks3d)
|
||||
conf_collect = []
|
||||
for i in range(1,self.masks3d.shape[0]-1):
|
||||
prompt_box = self.all_prompts.get(i, None)
|
||||
if prompt_box is not None:
|
||||
mask = masks[i,:,:]
|
||||
if mask.sum() > 0:
|
||||
bbox = BoxPromptGenerator(size=None).mask_to_bbox(mask)
|
||||
bbox = self.transform.apply_boxes(np.array([bbox]), self.original_size)[0]
|
||||
else:
|
||||
bbox = [0,0,0,0]
|
||||
prompt_box = self.all_prompts[i][0]
|
||||
confidence = calculate_iou(bbox, prompt_box)
|
||||
else:
|
||||
confidence = 0
|
||||
if i == 1:
|
||||
conf_collect.append(confidence)
|
||||
conf_collect.append(confidence)
|
||||
assert len(conf_collect) == i+1
|
||||
conf_collect.append(confidence)
|
||||
print(conf_collect)
|
||||
return conf_collect
|
||||
|
||||
def calculate_iou(box1, box2):
|
||||
"""
|
||||
计算两个框的IoU Intersection over Union。
|
||||
|
||||
参数:
|
||||
box1 和 box2 是两个框,每个框表示为四个值 (x1, y1, x2, y2),其中 (x1, y1) 是左上角的坐标,
|
||||
(x2, y2) 是右下角的坐标。
|
||||
|
||||
返回:
|
||||
返回两个框的IoU。
|
||||
"""
|
||||
# 计算交集的左上角和右下角坐标
|
||||
x1_i = max(box1[0], box2[0])
|
||||
y1_i = max(box1[1], box2[1])
|
||||
x2_i = min(box1[2], box2[2])
|
||||
y2_i = min(box1[3], box2[3])
|
||||
|
||||
# 计算交集的面积
|
||||
intersection_area = max(0, x2_i - x1_i) * max(0, y2_i - y1_i)
|
||||
|
||||
# 计算并集的面积
|
||||
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||||
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
||||
union_area = box1_area + box2_area - intersection_area
|
||||
# print(intersection_area, union_area)
|
||||
|
||||
# 计算IoU
|
||||
iou = intersection_area / union_area
|
||||
|
||||
return iou
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -588,7 +669,7 @@ if __name__ == "__main__":
|
||||
volume = itk_to_np(read(img_path)) # test several slices
|
||||
label_itk = read(label_path)
|
||||
spacing = label_itk.GetSpacing()
|
||||
label = itk_to_np(label_itk) == 1
|
||||
label = itk_to_np(label_itk) == 13
|
||||
volume = np.clip(volume, -200, 400)
|
||||
|
||||
# Select the slice with the largest mask
|
||||
@ -598,8 +679,14 @@ if __name__ == "__main__":
|
||||
x_max = np.max(coords[0])
|
||||
template_slice_id = s.argmax()
|
||||
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id])
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id]) # (115, 207, 309, 339)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# box = (125, 210, 300, 310)
|
||||
box = np.array([box])
|
||||
box[0][0] += 10
|
||||
box[0][1] += 10
|
||||
box[0][2] -= 10
|
||||
box[0][3] -= 10
|
||||
|
||||
pred = predictor.predict_volume(
|
||||
x=volume,
|
||||
@ -610,3 +697,8 @@ if __name__ == "__main__":
|
||||
|
||||
Data3dSolver().simple_write(pred, path="mask.nii.gz", spacing=spacing)
|
||||
Data3dSolver().simple_write(label, path="gt.nii.gz", spacing=spacing)
|
||||
|
||||
dice = compute_dice_np(pred, label)
|
||||
print("Dice ", dice, " box: ", box, "slice id", template_slice_id)
|
||||
print(tuple(box))
|
||||
# import ipdb; ipdb.set_trace()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -43,12 +43,14 @@ TEMPLATE={
|
||||
'10_10': [31],
|
||||
'58': [6,2,3,1],
|
||||
'59': [1,2,3,4,5,6,7,8,9,10,11,12,13,14],
|
||||
'60': np.arange(200).tolist(), # for debug
|
||||
'60': (np.ones(200)*(-1)).tolist(), # for debug
|
||||
"65": np.zeros(200).tolist(),
|
||||
}
|
||||
|
||||
|
||||
class Dataset3D(basic_3d_dataset):
|
||||
def __init__(self, config=..., use_cache=True, *args, **kwargs) -> None:
|
||||
super().__init__(config, use_cache=use_cache, *args, **kwargs)
|
||||
def __init__(self, config, use_cache=True, *args, **kwargs) -> None:
|
||||
super().__init__(config=config, use_cache=use_cache, *args, **kwargs)
|
||||
self.basic_dir = config['data_root_path']
|
||||
self.cache_dir = config['cache_data_path']
|
||||
|
||||
@ -93,6 +95,13 @@ class Dataset3D(basic_3d_dataset):
|
||||
|
||||
dataset_name = self.img_names[index]['img_path'].replace(self.basic_dir,"").split("/")[0]
|
||||
assert dataset_name[0] in ['0','1','2','3','4','5','6','7','8','9'], f"Got {dataset_name}"
|
||||
if dataset_name[:2] == "10":
|
||||
subname = self.img_names[index]['img_path'].replace(self.basic_dir,"")[17:19]
|
||||
assert subname in ['10', '03', '06', '07']
|
||||
all_labels = TEMPLATE[dataset_name[:2] + "_" + subname]
|
||||
else:
|
||||
all_labels = TEMPLATE[dataset_name[:2]]
|
||||
|
||||
all_labels = TEMPLATE[dataset_name[:2]]
|
||||
|
||||
num = 0
|
||||
@ -152,7 +161,7 @@ class Dataset3D(basic_3d_dataset):
|
||||
self.save_img_rgb(rearrange(img_rgb, "c h w -> h w c"), save_image_name)
|
||||
|
||||
# Save cache data
|
||||
save_label_name = tfilename(self.cache_dir, dataset_name, f"label/label_{index:04d}_{num:08d}.npz")
|
||||
save_label_name = tfilename(self.cache_dir, dataset_name, f"label_jpg/label_{index:04d}_{num:08d}")
|
||||
self.save_slice_mask(masks_data, save_label_name)
|
||||
print("Save ", save_image_name)
|
||||
|
||||
@ -257,11 +266,26 @@ class Dataset3D(basic_3d_dataset):
|
||||
label_path = name.replace("image/image_", "label/label_").replace(".jpg", ".npz")
|
||||
self._convert_one_mask_from_npz_to_jpg(label_path)
|
||||
|
||||
|
||||
# EX_CONFIG={
|
||||
# "dataset":{
|
||||
# "split": 'train',
|
||||
# "data_root_path": '/home1/quanquan/datasets/',
|
||||
# "dataset_list": ["decathlon_colon"],
|
||||
# "data_txt_path": './datasets/dataset_list/',
|
||||
# "cache_data_path": '/home1/quanquan/datasets/cached_dataset2/',
|
||||
# "cache_prefix": ['10'] # '07'
|
||||
# }
|
||||
# }
|
||||
|
||||
if __name__ == "__main__":
|
||||
# def go_cache():
|
||||
from tutils.new.manager import ConfigManager
|
||||
config = ConfigManager()
|
||||
config.add_config("configs/vit_b.yaml")
|
||||
dataset = Dataset3D(config=config['dataset'], use_cache=True)
|
||||
config.add_config("configs/vit_sub.yaml")
|
||||
# config.add_config(EX_CONFIG)
|
||||
|
||||
# Caching data
|
||||
dataset = Dataset3D(config=config['dataset'], use_cache=False)
|
||||
dataset.caching_data()
|
||||
# dataset.convert_masks_types()
|
||||
|
@ -33,7 +33,7 @@ class Dataset3D(basic_3d_dataset):
|
||||
if not os.path.isdir(dirpath):
|
||||
continue
|
||||
prefix = dirpath.split("/")[-1]
|
||||
if prefix[:2] in config['cache_prefix']:
|
||||
if prefix.split("_")[0] in config['cache_prefix']:
|
||||
data_paths += glob.glob(dirpath + "/label_jpg/*.jpg")
|
||||
print("Load ", dirpath)
|
||||
print('Masks len {}'.format(len(data_paths)))
|
||||
|
2
datasets/dataset_list/example_test.txt
Normal file
2
datasets/dataset_list/example_test.txt
Normal file
@ -0,0 +1,2 @@
|
||||
07_WORD/WORD-V0.1.0/imagesVa/word_0001.nii.gz 07_WORD/WORD-V0.1.0/labelsVa/word_0001.nii.gz
|
||||
07_WORD/WORD-V0.1.0/imagesVa/word_0007.nii.gz 07_WORD/WORD-V0.1.0/labelsVa/word_0007.nii.gz
|
146
datasets/eval_dataloader/loader_abstract.py
Normal file
146
datasets/eval_dataloader/loader_abstract.py
Normal file
@ -0,0 +1,146 @@
|
||||
"""
|
||||
DataLoader only for evaluation
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from tutils.nn.data import read, itk_to_np, np_to_itk
|
||||
from einops import reduce, repeat, rearrange
|
||||
# from tutils.nn.data.tsitk.preprocess import resampleImage
|
||||
from trans_utils.data_utils import Data3dSolver
|
||||
from tutils.nn.data.tsitk.preprocess import resampleImage
|
||||
import SimpleITK as sitk
|
||||
|
||||
|
||||
# Example
|
||||
DATASET_CONFIG={
|
||||
'split': 'test',
|
||||
'data_root_path':'/quanquan/datasets/',
|
||||
'dataset_list': ["ours"],
|
||||
'data_txt_path':'./datasets/dataset_list/',
|
||||
'label_idx': 0,
|
||||
}
|
||||
|
||||
class AbstractLoader(Dataset):
|
||||
def __init__(self, config, split="test") -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.split = split
|
||||
self.img_names = self.prepare_datalist()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
||||
def prepare_datalist(self):
|
||||
config = self.config
|
||||
data_paths = []
|
||||
for item in config['dataset_list']:
|
||||
print("Load datalist from ", item)
|
||||
for line in open(config["data_txt_path"]+ item + f"_{self.split}.txt"):
|
||||
name = line.strip().split()[1].split('.')[0]
|
||||
img_path = config['data_root_path'] + line.strip().split()[0]
|
||||
label_path = config['data_root_path'] + line.strip().split()[1]
|
||||
data_paths.append({'img_path': img_path, 'label_path': label_path, 'name': name})
|
||||
print('train len {}'.format(len(data_paths)))
|
||||
return data_paths
|
||||
|
||||
|
||||
def _get_data(self, index, debug=False):
|
||||
label_idx = self.config['label_idx']
|
||||
name = self.img_names[index]['img_path']
|
||||
img_itk = read(name)
|
||||
spacing = img_itk.GetSpacing()
|
||||
img_ori = itk_to_np(img_itk)
|
||||
scan_orientation = np.argmin(img_ori.shape)
|
||||
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
|
||||
label = label_ori == label_idx
|
||||
|
||||
# img_ori, new_spacing = Data3dSolver().read(self.img_names[index]['img_path'])
|
||||
# label_itk = read(self.img_names[index]['label_path'])
|
||||
# ori_spacing = label_itk.GetSpacing()
|
||||
# label = itk_to_np(label_itk) == label_idx
|
||||
# print("[loader_abstract.DEBUG] size", img_ori.shape, label.shape)
|
||||
# label = self._get_resized_label(label, new_size=img_ori.shape)
|
||||
|
||||
if debug:
|
||||
Data3dSolver().simple_write(label)
|
||||
Data3dSolver().simple_write(img_ori, "tmp_img.nii.gz")
|
||||
|
||||
s = reduce(label, "c h w -> c", reduction="sum")
|
||||
coords = np.nonzero(s)
|
||||
x_min = np.min(coords[0])
|
||||
x_max = np.max(coords[0])
|
||||
template_slice_id = s.argmax() - x_min
|
||||
|
||||
if img_ori.min() < -10:
|
||||
img_ori = np.clip(img_ori, -200, 400)
|
||||
else:
|
||||
img_ori = np.clip(img_ori, 0, 600)
|
||||
|
||||
img_ori = img_ori[x_min:x_max+1,:,:]
|
||||
label = label[x_min:x_max+1,:,:]
|
||||
assert label.shape[0] >= 3
|
||||
|
||||
if template_slice_id <= 1 or template_slice_id >= label.shape[0]-2:
|
||||
template_slice_id == label.shape[0] // 2
|
||||
|
||||
dataset_name = name.replace(self.config['data_root_path'], "").split("/")[0]
|
||||
template_slice = label[template_slice_id,:,:]
|
||||
print("template_slice.area ", template_slice.sum(), template_slice.sum() / (template_slice.shape[0] * template_slice.shape[1]))
|
||||
d = {
|
||||
"name": name,
|
||||
"dataset_name": dataset_name,
|
||||
"img": np.array(img_ori).astype(np.float32),
|
||||
"label_idx": label_idx,
|
||||
"label": np.array(label).astype(np.float32),
|
||||
"template_slice_id": template_slice_id,
|
||||
"template_slice": np.array(label[template_slice_id,:,:]).astype(np.float32),
|
||||
"spacing": np.array(spacing),
|
||||
}
|
||||
return d
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self._get_data(index)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from tutils.new.manager import ConfigManager
|
||||
EX_CONFIG = {
|
||||
'dataset':{
|
||||
'prompt': 'box',
|
||||
'dataset_list': ['guangdong'], # ["sabs"], chaos, word
|
||||
'label_idx': 2,
|
||||
}
|
||||
}
|
||||
|
||||
config = ConfigManager()
|
||||
config.add_config("configs/vit_sub_rectum.yaml")
|
||||
config.add_config(EX_CONFIG)
|
||||
dataset = AbstractLoader(config['dataset'], split="test")
|
||||
for i in range(len(dataset)):
|
||||
dataset._get_data(i, debug=False)
|
||||
|
||||
# label_path = "/home1/quanquan/datasets/01_BCV-Abdomen/Training/label/label0001.nii.gz"
|
||||
# from monai.transforms import SpatialResample
|
||||
# resample = SpatialResample()
|
||||
# label = itk_to_np(read(label_path)) == 1
|
||||
# print(label.shape)
|
||||
# # resampled = resample(label, spatial_size=(label.shape[0]*7, label.shape[1], label.shape[2]))
|
||||
# print(label.shape)
|
||||
|
||||
exit(0)
|
||||
data = itk_to_np(read("tmp_img.nii.gz"))
|
||||
data = torch.Tensor(data)
|
||||
|
||||
maxlen = data.shape[0]
|
||||
slices = []
|
||||
for i in range(1, maxlen-1):
|
||||
slices.append(data[i-1:i+2, :, :])
|
||||
input_slices = torch.stack(slices, axis=0)
|
||||
input_slices = torch.clip(input_slices, -200, 600)
|
||||
|
||||
input_slices
|
||||
|
||||
from torchvision.utils import save_image
|
||||
save_image(input_slices, "tmp.jpg")
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
63
readme.md
63
readme.md
@ -118,9 +118,29 @@ Then run
|
||||
python -m datasets.cache_dataset3d
|
||||
```
|
||||
|
||||
## Configs Settings
|
||||
|
||||
important settings
|
||||
|
||||
```yaml
|
||||
base:
|
||||
base_dir: "../runs/sam/" # logging dir
|
||||
|
||||
dataset:
|
||||
types: ['3d'] # ['3d', '2d']
|
||||
split: 'train'
|
||||
data_root_path: '../datasets/'
|
||||
dataset_list: ["pancreas"]
|
||||
data_txt_path: './datasets/dataset_list/'
|
||||
dataset2d_path: "../08_AbdomenCT-1K/"
|
||||
cache_data_path: '../cached_dataset2/'
|
||||
|
||||
cache_prefix: ['6016'] # cache prefix of cached dataset for training
|
||||
# For example: ['07',] for 07_WORD
|
||||
```
|
||||
|
||||
|
||||
## Start Training
|
||||
## Start Training from scratch (SAM)
|
||||
|
||||
Run training on multi-gpu
|
||||
|
||||
@ -140,6 +160,47 @@ CUDA_VISIBLE_DEVICES=0 python -m core.ddp --tag debug
|
||||
python -m core.volume_predictor
|
||||
```
|
||||
|
||||
## Testset Validation
|
||||
```python
|
||||
EX_CONFIG = {
|
||||
'dataset':{
|
||||
'prompt': 'box', # prompt type: box or point
|
||||
'dataset_list': ['word'], # dataset_list name
|
||||
'label_idx': 2, # label index for inference,
|
||||
},
|
||||
"pth": "./model.pth"
|
||||
}
|
||||
```
|
||||
```
|
||||
python -m test.volume_eval
|
||||
```
|
||||
|
||||
## Finetuning (Recommended)
|
||||
```yaml
|
||||
training:
|
||||
breakpoint_path: "./model.pth" # pretrained weight path
|
||||
```
|
||||
|
||||
```
|
||||
python -m core.ddp_sub --tag run
|
||||
```
|
||||
|
||||
## Validation with Finetuned Weights
|
||||
|
||||
```
|
||||
python -m test.volume_eval_sublora
|
||||
```
|
||||
|
||||
```python
|
||||
EX_CONFIG = {
|
||||
'dataset':{
|
||||
'prompt': 'box', # prompt type: box or point
|
||||
'dataset_list': ['word'], # dataset_list name
|
||||
'label_idx': 2, # label index for inference,
|
||||
},
|
||||
"pth": "./model_finetuned.pth"
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
<p align="center" width="100%">
|
||||
|
@ -137,31 +137,35 @@ if __name__ == "__main__":
|
||||
|
||||
EX_CONFIG = {
|
||||
'dataset':{
|
||||
'prompt': 'box',
|
||||
'dataset_list': ['word'], # ["sabs"], chaos, word
|
||||
'label_idx': 1,
|
||||
'prompt': 'box', # box / point
|
||||
'dataset_list': ['example'], # ["sabs"], chaos, word
|
||||
'label_idx': 5,
|
||||
},
|
||||
'model_type': "vit_b",
|
||||
'pth': "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth",
|
||||
}
|
||||
|
||||
EX_CONFIG = {
|
||||
'dataset':{
|
||||
'prompt': 'box', # box / point
|
||||
'dataset_list': ['example'], # ["sabs"], chaos, word
|
||||
'label_idx': 5,
|
||||
},
|
||||
'model_type': "vit_h",
|
||||
'pth': "/home1/quanquan/code/projects/finetune_large/segment_anything/model_iter_3935000.pth",
|
||||
}
|
||||
|
||||
config = ConfigManager()
|
||||
config.add_config("configs/vit_b_103.yaml")
|
||||
config.add_config("configs/vit_b.yaml")
|
||||
config.add_config(EX_CONFIG)
|
||||
config.print()
|
||||
|
||||
# Init Model
|
||||
model_type = "vit_b"
|
||||
model_type = config['model_type']
|
||||
sam = sam_model_registry[model_type](checkpoint=None)
|
||||
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
||||
learner.use_lora()
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt_v/model_latest.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_20.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_16.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora_small/ckpt/model_epoch_6.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora/ckpt/model_epoch_50.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_8/ckpt_v/model_latest.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_5/ckpt/model_epoch_100.pth"
|
||||
pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_500000.pth"
|
||||
learner.load_well_trained_model(pth)
|
||||
learner.load_well_trained_model(EX_CONFIG['pth'])
|
||||
learner.cuda()
|
||||
predictor = VolumePredictor(
|
||||
model=learner.model,
|
||||
|
216
test/volume_eval_mbox.py
Normal file
216
test/volume_eval_mbox.py
Normal file
@ -0,0 +1,216 @@
|
||||
"""
|
||||
Volume evalutaion
|
||||
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
# from datasets.dataset3d import Dataset3D
|
||||
from tutils.new.manager import ConfigManager
|
||||
from datasets.eval_dataloader.loader_abstract import AbstractLoader
|
||||
|
||||
from core.volume_predictor import VolumePredictor
|
||||
from datasets.data_engine import DataManager, BoxPromptGenerator, PointPromptGenerator
|
||||
|
||||
from tutils import tfilename
|
||||
from tutils.new.trainer.recorder import Recorder
|
||||
from trans_utils.metrics import compute_dice_np
|
||||
from trans_utils.data_utils import Data3dSolver
|
||||
from tutils.tutils.ttimer import timer
|
||||
|
||||
|
||||
class Evaluater:
|
||||
def __init__(self, config) -> None:
|
||||
self.config = config
|
||||
self.recorder = Recorder()
|
||||
|
||||
def solve(self, model, dataset, finetune_number=1):
|
||||
# model.eval()
|
||||
self.predictor = model
|
||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
# if i <4:
|
||||
# print
|
||||
# continue
|
||||
# for k, v in data.items():
|
||||
# if isinstance(v, torch.Tensor):
|
||||
# data[k] = v.to(self.rank)
|
||||
# try:
|
||||
if True:
|
||||
if self.config['dataset']['prompt'] == 'box':
|
||||
dice, pred, label, temp_slice = self.eval_step(data, batch_idx=i)
|
||||
used_slice = [temp_slice]
|
||||
if finetune_number > 1:
|
||||
for i in range(finetune_number - 1):
|
||||
dice, pred, label, temp_slice = self.finetune_with_more_prompt(pred, label, exclude_slide_id=used_slice)
|
||||
used_slice.append(temp_slice)
|
||||
res = {"dice": dice}
|
||||
if self.config['dataset']['prompt'] == 'point':
|
||||
res = self.eval_step_point(data, batch_idx=i)
|
||||
self.recorder.record(res)
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
res = self.recorder.cal_metrics()
|
||||
print(res)
|
||||
print("prompt:", self.config['dataset']['prompt'], " class_idx:", self.config['dataset']['label_idx'])
|
||||
|
||||
def eval_step(self, data, batch_idx=0):
|
||||
name = data['name']
|
||||
dataset_name = data['dataset_name'][0]
|
||||
label_idx = data['label_idx'][0]
|
||||
template_slice_id = data['template_slice_id'][0]
|
||||
|
||||
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
||||
if template_slice_id == 0:
|
||||
template_slice_id += 1
|
||||
elif template_slice_id == (data['img'].shape[0] - 1):
|
||||
template_slice_id -= 1
|
||||
print("Using slice ", template_slice_id, " as template slice")
|
||||
|
||||
spacing = data['spacing'].numpy().tolist()[0]
|
||||
if data['img'].shape[-1] < 260:
|
||||
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
||||
img = data['img'][0][:,:256,:256]
|
||||
label = data['label'][0][:,:256,:256]
|
||||
else:
|
||||
img = data['img'][0]
|
||||
label = data['label'][0]
|
||||
# img = torch.clip(img, -200, 600)
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
||||
box = np.array([box])
|
||||
pred, stability = self.predictor.predict_volume(
|
||||
x=img,
|
||||
box=box,
|
||||
template_slice_id=template_slice_id,
|
||||
return_stability=True,
|
||||
)
|
||||
|
||||
prompt_type = 'box'
|
||||
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
||||
# Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
|
||||
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
|
||||
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
|
||||
print(dataset_name, name, dice)
|
||||
template_slice_id = template_slice_id if isinstance(template_slice_id, int) else template_slice_id.item()
|
||||
return dice, pred, label.detach().cpu().numpy(), template_slice_id
|
||||
|
||||
def eval_step_point(self, data, batch_idx=0):
|
||||
name = data['name']
|
||||
dataset_name = data['dataset_name'][0]
|
||||
label_idx = data['label_idx'][0]
|
||||
template_slice_id = data['template_slice_id'][0]
|
||||
spacing = data['spacing'].numpy().tolist()[0]
|
||||
|
||||
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
||||
if template_slice_id == 0:
|
||||
template_slice_id += 1
|
||||
elif template_slice_id == (data['img'].shape[0] - 1):
|
||||
template_slice_id -= 1
|
||||
|
||||
if data['img'].shape[-1] < 260:
|
||||
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
||||
img = data['img'][0][:,:256,:256]
|
||||
label = data['label'][0][:,:256,:256]
|
||||
else:
|
||||
img = data['img'][0]
|
||||
label = data['label'][0]
|
||||
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
||||
point = (box[0]+box[2])*0.5 , (box[1]+box[3])*0.5
|
||||
point = np.array([point]).astype(int)
|
||||
if label[template_slice_id][point[0,1], point[0,0]] == 0:
|
||||
print("Use random point instead !!!")
|
||||
point = PointPromptGenerator().get_prompt_point(label[template_slice_id])
|
||||
point = np.array([point]).astype(int)
|
||||
# box = np.array([box])
|
||||
pred = self.predictor.predict_volume(
|
||||
x=img,
|
||||
point_coords=point,
|
||||
point_labels=np.ones_like(point)[:,:1],
|
||||
template_slice_id=template_slice_id,
|
||||
)
|
||||
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
||||
prompt_type = 'point'
|
||||
Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}.nii.gz"))
|
||||
print(dataset_name, name, dice)
|
||||
return {"dice": dice}
|
||||
|
||||
def finetune_with_more_prompt(self, pred, label, prompt_type="box", exclude_slide_id=[]):
|
||||
assert pred.shape == label.shape
|
||||
dices = [compute_dice_np(pred[j,:,:], label[j,:,:]) for j in range(pred.shape[0])]
|
||||
rank_list = np.array(dices[1:-1]).argsort() # Ignore the head and tail
|
||||
rank_list += 1 # Ignore the head and tail
|
||||
for i in rank_list:
|
||||
if i in exclude_slide_id:
|
||||
continue
|
||||
template_slice_id = i
|
||||
break
|
||||
# template_slice_id += 1 # Ignore the head and tail
|
||||
print("Using slice ", template_slice_id, " as template slice")
|
||||
old_confidence = self.predictor.get_confidence()
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id])
|
||||
box = np.array([box])
|
||||
new_pred, stability = self.predictor.predict_with_prompt(
|
||||
box=box,
|
||||
template_slice_id=template_slice_id,
|
||||
return_stability=True,
|
||||
)
|
||||
new_confidence = self.predictor.get_confidence()
|
||||
new_confidence[template_slice_id] *= 2
|
||||
all_conf = np.stack([old_confidence, new_confidence], axis=1)
|
||||
preds = [pred, new_pred]
|
||||
merged = np.zeros_like(label)
|
||||
for slice_idx in range(pred.shape[0]):
|
||||
idx = np.argsort(all_conf[slice_idx,:])[-1]
|
||||
merged[slice_idx,:,:] = preds[idx][slice_idx]
|
||||
|
||||
print("old dices", [compute_dice_np(pred, label) for pred in preds])
|
||||
dice = compute_dice_np(merged, label)
|
||||
print("merged dice, idx", dice)
|
||||
return dice, merged, label, template_slice_id
|
||||
|
||||
|
||||
def to_RGB(img):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
from core.learner3 import SamLearner
|
||||
from modeling.build_sam3d2 import sam_model_registry
|
||||
EX_CONFIG = {
|
||||
'dataset':{
|
||||
'prompt': 'box',
|
||||
'prompt_number': 5,
|
||||
'dataset_list': ['example'], # ["sabs"], chaos, word pancreas
|
||||
'label_idx': 5,
|
||||
},
|
||||
"lora_r": 24,
|
||||
'model_type': "vit_h",
|
||||
'ckpt': "/home1/quanquan/code/projects/finetune_large/segment_anything/model_iter_3935000.pth",
|
||||
}
|
||||
config = ConfigManager()
|
||||
config.add_config("configs/vit_b.yaml")
|
||||
config.add_config(EX_CONFIG)
|
||||
config.print()
|
||||
dataset = AbstractLoader(config['dataset'], split="test")
|
||||
print(len(dataset))
|
||||
assert len(dataset) >= 1
|
||||
|
||||
# Init Model
|
||||
model_type = config['model_type'] # "vit_b"
|
||||
sam = sam_model_registry[model_type](checkpoint=None)
|
||||
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
||||
learner.use_lora(r=config['lora_r'])
|
||||
pth = config['ckpt']
|
||||
learner.load_well_trained_model(pth)
|
||||
learner.cuda()
|
||||
predictor = VolumePredictor(
|
||||
model=learner.model,
|
||||
use_postprocess=True,
|
||||
use_noise_remove=True,)
|
||||
|
||||
solver = Evaluater(config)
|
||||
dataset = AbstractLoader(config['dataset'], split="test")
|
||||
solver.solve(predictor, dataset, finetune_number=config['dataset']['prompt_number'])
|
231
test/volume_eval_sublora.py
Normal file
231
test/volume_eval_sublora.py
Normal file
@ -0,0 +1,231 @@
|
||||
"""
|
||||
Volume evalutaion
|
||||
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
# from datasets.dataset3d import Dataset3D
|
||||
from tutils.new.manager import ConfigManager
|
||||
from datasets.eval_dataloader.loader_abstract import AbstractLoader
|
||||
|
||||
from core.volume_predictor import VolumePredictor
|
||||
from datasets.data_engine import DataManager, BoxPromptGenerator, PointPromptGenerator
|
||||
|
||||
from tutils import tfilename
|
||||
from tutils.new.trainer.recorder import Recorder
|
||||
from trans_utils.metrics import compute_dice_np
|
||||
from trans_utils.data_utils import Data3dSolver
|
||||
|
||||
# from monai.metrics import compute_surface_dice
|
||||
import surface_distance as surfdist
|
||||
from tutils.tutils.ttimer import timer
|
||||
|
||||
|
||||
|
||||
class Evaluater:
|
||||
def __init__(self, config) -> None:
|
||||
self.config = config
|
||||
self.recorder = Recorder()
|
||||
|
||||
def solve(self, model, dataset):
|
||||
# model.eval()
|
||||
self.predictor = model
|
||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
# if i <4:
|
||||
# print
|
||||
# continue
|
||||
# for k, v in data.items():
|
||||
# if isinstance(v, torch.Tensor):
|
||||
# data[k] = v.to(self.rank)
|
||||
if self.config['dataset']['prompt'] == 'box':
|
||||
# res = self.eval_step_slice(data, batch_idx=i)
|
||||
res = self.eval_step(data, batch_idx=i)
|
||||
if self.config['dataset']['prompt'] == 'point':
|
||||
res = self.eval_step_point(data, batch_idx=i)
|
||||
self.recorder.record(res)
|
||||
res = self.recorder.cal_metrics()
|
||||
print(res)
|
||||
print("prompt:", self.config['dataset']['prompt'], " class_idx:", self.config['dataset']['label_idx'])
|
||||
|
||||
def eval_step(self, data, batch_idx=0):
|
||||
name = data['name']
|
||||
dataset_name = data['dataset_name'][0]
|
||||
label_idx = data['label_idx'][0]
|
||||
template_slice_id = data['template_slice_id'][0]
|
||||
|
||||
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
||||
if template_slice_id == 0:
|
||||
template_slice_id += 1
|
||||
elif template_slice_id == (data['img'].shape[0] - 1):
|
||||
template_slice_id -= 1
|
||||
|
||||
spacing = data['spacing'].numpy().tolist()[0]
|
||||
if data['img'].shape[-1] < 260:
|
||||
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
||||
img = data['img'][0][:,:256,:256]
|
||||
label = data['label'][0][:,:256,:256]
|
||||
else:
|
||||
img = data['img'][0]
|
||||
label = data['label'][0]
|
||||
# img = torch.clip(img, -200, 600)
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
||||
box = np.array([box])
|
||||
pred, stability = self.predictor.predict_volume(
|
||||
x=img,
|
||||
box=box,
|
||||
template_slice_id=template_slice_id,
|
||||
return_stability=True,
|
||||
)
|
||||
prompt_type = 'box'
|
||||
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
||||
# Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
|
||||
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
|
||||
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
|
||||
|
||||
# nsd = compute_surface_dice(torch.Tensor(pred), label.detach().cpu(), 1)
|
||||
|
||||
# surface_distances = surfdist.compute_surface_distances(
|
||||
# label.detach().cpu().numpy(), pred, spacing_mm=(0.6, 0.6445, 0.6445))
|
||||
# nsd = surfdist.compute_surface_dice_at_tolerance(surface_distances, 1)
|
||||
nsd = 0
|
||||
|
||||
print(dataset_name, name, dice, nsd)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
return {"dice": dice, "nsd": nsd}
|
||||
|
||||
def eval_step_point(self, data, batch_idx=0):
|
||||
name = data['name']
|
||||
dataset_name = data['dataset_name'][0]
|
||||
label_idx = data['label_idx'][0]
|
||||
template_slice_id = data['template_slice_id'][0]
|
||||
spacing = data['spacing'].numpy().tolist()[0]
|
||||
|
||||
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
||||
if template_slice_id == 0:
|
||||
template_slice_id += 1
|
||||
elif template_slice_id == (data['img'].shape[0] - 1):
|
||||
template_slice_id -= 1
|
||||
|
||||
if data['img'].shape[-1] < 260:
|
||||
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
||||
img = data['img'][0][:,:256,:256]
|
||||
label = data['label'][0][:,:256,:256]
|
||||
else:
|
||||
img = data['img'][0]
|
||||
label = data['label'][0]
|
||||
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
||||
point = (box[0]+box[2])*0.5 , (box[1]+box[3])*0.5
|
||||
point = np.array([point]).astype(int)
|
||||
if label[template_slice_id][point[0,1], point[0,0]] == 0:
|
||||
print("Use random point instead !!!")
|
||||
point = PointPromptGenerator().get_prompt_point(label[template_slice_id])
|
||||
point = np.array([point]).astype(int)
|
||||
# box = np.array([box])
|
||||
pred = self.predictor.predict_volume(
|
||||
x=img,
|
||||
point_coords=point,
|
||||
point_labels=np.ones_like(point)[:,:1],
|
||||
template_slice_id=template_slice_id,
|
||||
)
|
||||
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
||||
prompt_type = 'point'
|
||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}.nii.gz"))
|
||||
nsd = compute_surface_dice(pred, label.detach().cpu().numpy())
|
||||
print(dataset_name, name, dice)
|
||||
return {"dice": dice, "nsd": nsd}
|
||||
|
||||
def eval_step_slice(self, data, batch_idx=0):
|
||||
name = data['name']
|
||||
dataset_name = data['dataset_name'][0]
|
||||
label_idx = data['label_idx'][0]
|
||||
template_slice_id = data['template_slice_id'][0]
|
||||
|
||||
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
||||
if template_slice_id == 0:
|
||||
template_slice_id += 1
|
||||
elif template_slice_id == (data['img'].shape[0] - 1):
|
||||
template_slice_id -= 1
|
||||
|
||||
spacing = data['spacing'].numpy().tolist()[0]
|
||||
if data['img'].shape[-1] < 260:
|
||||
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
||||
img = data['img'][0][:,:256,:256]
|
||||
label = data['label'][0][:,:256,:256]
|
||||
else:
|
||||
img = data['img'][0]
|
||||
label = data['label'][0]
|
||||
|
||||
img = img[template_slice_id-1:template_slice_id+2, :,:]
|
||||
label = label[template_slice_id-1:template_slice_id+2, :,:]
|
||||
template_slice_id = 1
|
||||
|
||||
# img = torch.clip(img, -200, 600)
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
||||
box = np.array([box])
|
||||
pred, stability = self.predictor.predict_volume(
|
||||
x=img,
|
||||
box=box,
|
||||
template_slice_id=template_slice_id,
|
||||
return_stability=True,
|
||||
)
|
||||
prompt_type = 'box'
|
||||
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
||||
# Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
|
||||
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
|
||||
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
|
||||
print("Slice evaluation: ", dataset_name, name, dice)
|
||||
return {"dice": dice}
|
||||
|
||||
def to_RGB(img):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
# from core.learner3 import SamLearner
|
||||
# from modeling.build_sam3d import sam_model_registry
|
||||
|
||||
# from core.learner3 import SamLearner
|
||||
from core.learner_sub1 import SamLearner
|
||||
from modeling.build_sam3d2 import sam_model_registry
|
||||
|
||||
EX_CONFIG = {
|
||||
'dataset':{
|
||||
'prompt': 'box',
|
||||
'dataset_list': ['guangdong'], # ["sabs"], chaos, word, decathlon_colon, pancreas
|
||||
'label_idx': 2,
|
||||
},
|
||||
"pth": "./model_latest.pth"
|
||||
}
|
||||
|
||||
config = ConfigManager()
|
||||
# config.add_config("configs/vit_sub.yaml")
|
||||
config.add_config("configs/vit_sub.yaml")
|
||||
config.add_config(EX_CONFIG)
|
||||
|
||||
# Init Model
|
||||
model_type = "vit_b"
|
||||
sam = sam_model_registry[model_type](checkpoint=None)
|
||||
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
||||
learner.use_lora()
|
||||
learner.use_lora_sub()
|
||||
pth = EX_CONFIG['pth']
|
||||
learner.load_well_trained_model(pth)
|
||||
learner.cuda()
|
||||
predictor = VolumePredictor(
|
||||
model=learner.model,
|
||||
use_postprocess=True,
|
||||
use_noise_remove=True,)
|
||||
|
||||
solver = Evaluater(config)
|
||||
dataset = AbstractLoader(config['dataset'], split="test")
|
||||
tt = timer()
|
||||
solver.solve(predictor, dataset)
|
||||
|
||||
print("Time: ", tt())
|
303
tmp.py
Normal file
303
tmp.py
Normal file
@ -0,0 +1,303 @@
|
||||
"""
|
||||
Slow Loading directly
|
||||
|
||||
So we pre-precess data
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
from einops import rearrange, reduce, repeat
|
||||
from tutils.nn.data import read, itk_to_np, np_to_itk, write
|
||||
from tutils import tfilename
|
||||
from .dataset3d import DATASET_CONFIG, Dataset3D as basic_3d_dataset
|
||||
from monai import transforms
|
||||
import torch
|
||||
import cv2
|
||||
from scipy.sparse import csr_matrix
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
from einops import rearrange
|
||||
import glob
|
||||
from torchvision import transforms
|
||||
from monai import transforms as monai_transforms
|
||||
|
||||
# "class": ["spleen", "right kidney", "left kidney", "gallbladder", "esophagus", "liver", "stomach", "aorta", "postcava", "portal vein and splenic vein", "pancrease", "right adrenal gland", "left adrenal gland"],
|
||||
# "class": ["liver", "right kidney", "left kidney", "spleen"],
|
||||
TEMPLATE={
|
||||
'01': [1,2,3,4,5,6,7,8,9,10,11,12,13,14],
|
||||
'02': [1,0,3,4,5,6,7,0,0,0,11,0,0,14],
|
||||
'03': [6],
|
||||
'04': [6,27], # post process
|
||||
'05': [2,26,32], # post process
|
||||
'07': [6,1,3,2,7,4,5,11,14,18,19,12,20,21,23,24],
|
||||
'08': [6, 2, 1, 11],
|
||||
'09': [1,2,3,4,5,6,7,8,9,11,12,13,14,21,22],
|
||||
'12': [6,21,16,2],
|
||||
'13': [6,2,1,11,8,9,7,4,5,12,13,25],
|
||||
'14': [11,11,28,28,28], # Felix data, post process
|
||||
'10_03': [6, 27], # post process
|
||||
'10_06': [30],
|
||||
'10_07': [11, 28], # post process
|
||||
'10_08': [15, 29], # post process
|
||||
'10_09': [1],
|
||||
'10_10': [31],
|
||||
'58': [6,2,3,1],
|
||||
'59': [1,2,3,4,5,6,7,8,9,10,11,12,13,14],
|
||||
'60': (np.ones(200)*(-1)).tolist(), # for debug
|
||||
"65": np.zeros(200).tolist(),
|
||||
}
|
||||
|
||||
EX_CONFIG={
|
||||
"dataset":{
|
||||
"split": 'train',
|
||||
"data_root_path": '/home1/quanquan/datasets/',
|
||||
"dataset_list": ["decathlon_colon"],
|
||||
"data_txt_path": './datasets/dataset_list/',
|
||||
"cache_data_path": '/home1/quanquan/datasets/cached_dataset2/',
|
||||
"cache_prefix": ['10'] # '07'
|
||||
}
|
||||
}
|
||||
|
||||
class Dataset3D(basic_3d_dataset):
|
||||
def __init__(self, config, use_cache=True, *args, **kwargs) -> None:
|
||||
super().__init__(config=config, use_cache=use_cache, *args, **kwargs)
|
||||
self.basic_dir = config['data_root_path']
|
||||
self.cache_dir = config['cache_data_path']
|
||||
|
||||
def prepare_transforms(self):
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Resize((1024,1024)),
|
||||
])
|
||||
self.test_transform = transforms.Compose([
|
||||
monai_transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
|
||||
])
|
||||
|
||||
# @tfunctime
|
||||
# def prepare_datalist(self):
|
||||
def prepare_cached_datalist(self):
|
||||
raise DeprecationWarning("[Warning] Please use cache_dataset3d new version instead!")
|
||||
config = self.config
|
||||
data_paths = []
|
||||
for dirpath in glob.glob(config['cache_data_path'] + "/*"):
|
||||
data_paths += glob.glob(dirpath + "/image/*.jpg")
|
||||
print("Load ", dirpath)
|
||||
print('train len {}'.format(len(data_paths)))
|
||||
print('Examples: ', data_paths[:2])
|
||||
return data_paths
|
||||
|
||||
def caching_data(self):
|
||||
assert self.use_cache == False
|
||||
for index in range(len(self)):
|
||||
self.cache_one_sample(index)
|
||||
|
||||
def cache_one_sample(self, index, debug=False):
|
||||
# LABEL_INDICES
|
||||
name = self.img_names[index]['img_path']
|
||||
img_itk = read(self.img_names[index]['img_path'])
|
||||
img_ori = itk_to_np(img_itk)
|
||||
|
||||
img_ori = np.clip(img_ori, -200,400)
|
||||
|
||||
# spacing = img_itk.GetSpacing()
|
||||
scan_orientation = np.argmin(img_ori.shape)
|
||||
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
|
||||
|
||||
dataset_name = self.img_names[index]['img_path'].replace(self.basic_dir,"").split("/")[0]
|
||||
assert dataset_name[0] in ['0','1','2','3','4','5','6','7','8','9'], f"Got {dataset_name}"
|
||||
if dataset_name[:2] == "10":
|
||||
subname = self.img_names[index]['img_path'].replace(self.basic_dir,"")[17:19]
|
||||
assert subname in ['10', '03', '06', '07']
|
||||
all_labels = TEMPLATE[dataset_name[:2] + "_" + subname]
|
||||
else:
|
||||
all_labels = TEMPLATE[dataset_name[:2]]
|
||||
|
||||
num = 0
|
||||
|
||||
# if min(img_ori.shape) * 1.2 < max(img_ori.shape):
|
||||
# orientation_all = [scan_orientation]
|
||||
# else:
|
||||
# orientation_all = [0,1,2]
|
||||
orientation_all = [scan_orientation]
|
||||
|
||||
for orientation in orientation_all:
|
||||
for slice_idx in range(2, img_ori.shape[orientation]-2):
|
||||
# slice_idx = np.random.randint(2, img_ori.shape[orientation]-2)
|
||||
if orientation == 0:
|
||||
s = img_ori[slice_idx-1:slice_idx+2, :,:]
|
||||
lb = label_ori[slice_idx-1:slice_idx+2, :,:]
|
||||
# spacing = (spacing[1], spacing[2])
|
||||
if orientation == 1:
|
||||
s = img_ori[:,slice_idx-1:slice_idx+2,:]
|
||||
s = rearrange(s, "h c w -> c h w")
|
||||
lb = label_ori[:,slice_idx-1:slice_idx+2,:]
|
||||
lb = rearrange(lb, "h c w -> c h w")
|
||||
# spacing = (spacing[0], spacing[2])
|
||||
if orientation == 2:
|
||||
s = img_ori[:,:,slice_idx-1:slice_idx+2]
|
||||
s = rearrange(s, "h w c -> c h w")
|
||||
lb = label_ori[:,:,slice_idx-1:slice_idx+2]
|
||||
lb = rearrange(lb, "h w c -> c h w")
|
||||
# spacing = (spacing[0], spacing[1])
|
||||
assert s.shape[0] == 3
|
||||
|
||||
# if np.float32(lb[1,:,:]>0).sum() <= 200:
|
||||
# # return self._get_data((index+1)%len(self))
|
||||
# continue
|
||||
# Choose one label
|
||||
label_num = int(lb.max())
|
||||
|
||||
masks_data = []
|
||||
meta = {"img_name": name, "slice": slice_idx, "orientation": orientation, "label_idx": [], "labels": [], "id": f"{num:08d}" }
|
||||
for label_idx in range(1,label_num+1):
|
||||
one_lb = np.float32(lb==label_idx)
|
||||
if one_lb[1,:,:].sum() <= (one_lb.shape[-1] * one_lb.shape[-2] * 0.0014):
|
||||
continue
|
||||
# if one_lb[0,:,:].sum()<=50 or one_lb[2,:,:].sum()<=50:
|
||||
|
||||
masks_data.append(one_lb)
|
||||
meta['label_idx'].append(label_idx)
|
||||
meta['labels'].append(all_labels[label_idx-1])
|
||||
|
||||
if len(masks_data) <= 0:
|
||||
continue
|
||||
|
||||
img_rgb = s
|
||||
img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy()
|
||||
img_rgb = self.to_RGB(img_rgb)
|
||||
save_image_name = tfilename(self.cache_dir, dataset_name, f"image/image_{index:04d}_{num:08d}.jpg")
|
||||
self.save_img_rgb(rearrange(img_rgb, "c h w -> h w c"), save_image_name)
|
||||
|
||||
# Save cache data
|
||||
save_label_name = tfilename(self.cache_dir, dataset_name, f"label_jpg/label_{index:04d}_{num:08d}")
|
||||
self.save_slice_mask(masks_data, save_label_name)
|
||||
print("Save ", save_image_name)
|
||||
|
||||
self.save_meta(meta, tfilename(self.cache_dir, dataset_name, f"meta/meta_{index:04d}_{num:08d}.npy"))
|
||||
|
||||
num += 1
|
||||
|
||||
def save_meta(self, meta, path):
|
||||
assert path.endswith(".npy")
|
||||
np.save(path, meta)
|
||||
|
||||
def save_slice_mask(self, masks_data, prefix):
|
||||
masks_data = F.interpolate(torch.Tensor(masks_data), size=(1024,1024)).numpy()
|
||||
assert masks_data.shape[1:] == (3,1024,1024), f"{__file__} Got{masks_data.shape}"
|
||||
for i in range(masks_data.shape[0]):
|
||||
labeli = masks_data[i].astype(np.uint8) * 255
|
||||
assert labeli.sum() > 0
|
||||
path = tfilename(prefix+f"_{i:04d}.jpg")
|
||||
cv2.imwrite(path, rearrange(labeli, "c h w -> h w c"))
|
||||
print("save to ", path)
|
||||
|
||||
def _old_save_slice_mask(self, masks_data, path):
|
||||
raise DeprecationWarning()
|
||||
exit(0)
|
||||
assert path.endswith(".npz")
|
||||
# masks_data = np.array([m['segmentation'] for m in masks]).astype(int)
|
||||
masks_data = F.interpolate(torch.Tensor(masks_data), size=(1024,1024)).numpy()
|
||||
# masks_data = np.int8(masks_data>0)
|
||||
assert masks_data.shape[1:] == (3,1024,1024), f"{__file__} Got{masks_data.shape}"
|
||||
masks_data = rearrange(masks_data, "n c h w -> n (c h w)")
|
||||
csr = csr_matrix(masks_data)
|
||||
np.savez_compressed(path, data=csr.data, indices=csr.indices, indptr=csr.indptr, shape=csr.shape)
|
||||
|
||||
def save_img_rgb(self, img, path):
|
||||
assert path.endswith(".jpg")
|
||||
assert img.shape == (1024,1024,3)
|
||||
cv2.imwrite(path, img.astype(np.uint8))
|
||||
|
||||
def _get_cached_data(self, index):
|
||||
name = self.img_names[index]
|
||||
# print(name)
|
||||
img = cv2.imread(name)
|
||||
compressed = np.load(name.replace("image/image_", "label/label_").replace(".jpg", ".npz"))
|
||||
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
|
||||
label_ori = csr.toarray()
|
||||
label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024)
|
||||
meta = np.load(name.replace("image/image_", "meta/meta_").replace(".jpg", ".npy"), allow_pickle=True).tolist()
|
||||
# print(meta)
|
||||
pp = reduce(label_ori[:,1,:,:], "n h w -> n", reduction="sum") > 500
|
||||
if pp.sum() == 0:
|
||||
return self._get_cached_data((index+1)%len(self))
|
||||
|
||||
label_idx = np.random.choice(a=np.arange(len(pp)), p=pp/pp.sum())
|
||||
# label_idx = np.random.randint(0, label_ori.shape[0])
|
||||
label_ori = label_ori[label_idx]
|
||||
is_edge = meta.get('is_edge', 0)
|
||||
return rearrange(img, "h w c -> c h w"), label_ori, name, meta['labels'][label_idx], meta['label_idx'][label_idx]
|
||||
|
||||
# @tfunctime
|
||||
def __getitem__(self, index, debug=False):
|
||||
# print("Dataset warning", index, len(self))
|
||||
index = index % len(self)
|
||||
img_rgb, label_ori, name, label_idx, local_idx = self._get_cached_data(index)
|
||||
|
||||
if label_ori.sum() <= 0:
|
||||
print("[Label Error] ", name)
|
||||
return self.__getitem__(index+1)
|
||||
|
||||
# assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}"
|
||||
# img_rgb = self.transform((img_rgb[None,:,:,:]))
|
||||
img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy()
|
||||
|
||||
vector = np.ones(3)
|
||||
ret_dict = {
|
||||
"name": name,
|
||||
"img": img_rgb,
|
||||
"label": label_ori,
|
||||
"indicators": vector,
|
||||
"class": label_idx,
|
||||
"local_idx": local_idx,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
def _convert_one_mask_from_npz_to_jpg(self, path1=None):
|
||||
# path1 = "/home1/quanquan/datasets/cached_dataset2/01_BCV-Abdomen/label/label_0129_00000043.npz" # 32K
|
||||
prefix = path1.replace(".npz", "").replace("/label/", "/label_jpg/")
|
||||
compressed = np.load(path1)
|
||||
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
|
||||
label_ori = csr.toarray()
|
||||
label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024)
|
||||
# print(label_ori.shape)
|
||||
for i in range(label_ori.shape[0]):
|
||||
labeli = label_ori[i]
|
||||
path = tfilename(prefix+f"_{i:04d}.jpg")
|
||||
cv2.imwrite(path, rearrange(labeli, "c h w -> h w c").astype(np.uint8))
|
||||
print("save to ", path)
|
||||
|
||||
def convert_masks_types(self):
|
||||
assert self.use_cache == True
|
||||
for index in range(len(self)):
|
||||
name = self.img_names[index]
|
||||
label_path = name.replace("image/image_", "label/label_").replace(".jpg", ".npz")
|
||||
self._convert_one_mask_from_npz_to_jpg(label_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# def go_cache():
|
||||
from tutils.new.manager import ConfigManager
|
||||
config = ConfigManager()
|
||||
config.add_config("configs/vit_sub.yaml")
|
||||
config.add_config(EX_CONFIG)
|
||||
dataset = Dataset3D(config=config['dataset'], use_cache=False)
|
||||
dataset.caching_data()
|
||||
# config.add_config("configs/vit_b_word_103.yaml")
|
||||
# dataset = Dataset3D(config=config['dataset'], use_cache=True)
|
||||
# dataset.caching_data()
|
||||
# dataset.convert_masks_types()
|
||||
|
||||
# from tutils.new.manager import ConfigManager
|
||||
# config = ConfigManager()
|
||||
# config.add_config("configs/vit_b_103.yaml")
|
||||
# dataset = Dataset3D(config=config['dataset']) # , use_cache=True
|
||||
# data = dataset.__getitem__(0)
|
||||
|
||||
# # import ipdb; ipdb.set_trace()
|
||||
# from torch.utils.data import DataLoader
|
||||
# loader = DataLoader(dataset, batch_size=8)
|
||||
# for batch in loader:
|
||||
# print(batch['img'].shape, batch['label'].shape)
|
||||
# print(data['label'].max())
|
||||
# # import ipdb; ipdb.set_trace()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user