Compare commits
No commits in common. "72a73cb4b973d4fb4a506ee2f73da52c445c99a9" and "1a352792a8cc2fe9bb9930dce1d8fef2733e8616" have entirely different histories.
72a73cb4b9
...
1a352792a8
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +0,0 @@
|
|||||||
*.pth
|
|
||||||
*.pyc
|
|
||||||
*.nii.gz
|
|
201
LICENSE
201
LICENSE
@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
@ -25,7 +25,7 @@ training:
|
|||||||
load_pretrain_model: false
|
load_pretrain_model: false
|
||||||
|
|
||||||
# optim:
|
# optim:
|
||||||
lr: 0.000005
|
lr: 0.0002
|
||||||
decay_step: 2000
|
decay_step: 2000
|
||||||
decay_gamma: 0.8
|
decay_gamma: 0.8
|
||||||
weight_decay: 0.0001
|
weight_decay: 0.0001
|
||||||
@ -35,13 +35,11 @@ training:
|
|||||||
dataset:
|
dataset:
|
||||||
types: ['3d'] # ['3d', '2d']
|
types: ['3d'] # ['3d', '2d']
|
||||||
split: 'train'
|
split: 'train'
|
||||||
data_root_path: '/home1/quanquan/datasets/'
|
data_root_path: '/quanquan/datasets/'
|
||||||
dataset_list: ["alp", "word", "debug"] # ['sam', "their", "ours"]
|
dataset_list: ["alp", "word", "debug"] # ['sam', "their", "ours"]
|
||||||
data_txt_path: './datasets/dataset_list/'
|
data_txt_path: './datasets/dataset_list/'
|
||||||
dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/"
|
dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/"
|
||||||
cache_data_path: '/home1/quanquan/datasets/cached_dataset2/'
|
cache_data_path: '/home1/quanquan/datasets/cached_dataset2/'
|
||||||
cache_prefix: ['6016'] # '07'
|
|
||||||
specific_label: [2]
|
|
||||||
|
|
||||||
# sam_checkpoint: "/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server
|
# sam_checkpoint: "/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server
|
||||||
# model_type: "vit_b"
|
# model_type: "vit_b"
|
||||||
|
@ -1,47 +0,0 @@
|
|||||||
# ---------------------- Common Configs --------------------------
|
|
||||||
base:
|
|
||||||
base_dir: "../runs/sam/"
|
|
||||||
tag: ''
|
|
||||||
stage: ''
|
|
||||||
logger:
|
|
||||||
mode: ['tb', ]
|
|
||||||
# mode: ''
|
|
||||||
recorder_reduction: 'mean'
|
|
||||||
|
|
||||||
training:
|
|
||||||
save_mode: ['all', 'best', 'latest'] # ,
|
|
||||||
batch_size : 2 # 8 for A100
|
|
||||||
num_workers : 8
|
|
||||||
num_epochs : 100 # epochs
|
|
||||||
use_amp: false
|
|
||||||
save_interval : 1
|
|
||||||
val_check_interval: 6
|
|
||||||
load_pretrain_model: false
|
|
||||||
|
|
||||||
# optim:
|
|
||||||
lr: 0.00002
|
|
||||||
decay_step: 2000
|
|
||||||
decay_gamma: 0.8
|
|
||||||
weight_decay: 0.0001
|
|
||||||
alpha: 0.99
|
|
||||||
validation_interval: 100
|
|
||||||
|
|
||||||
continue_training: false
|
|
||||||
load_optimizer: false
|
|
||||||
breakpoint_path: "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth"
|
|
||||||
|
|
||||||
dataset:
|
|
||||||
types: ['3d'] # ['3d', '2d']
|
|
||||||
split: 'train'
|
|
||||||
data_root_path: '/home1/quanquan/datasets/'
|
|
||||||
dataset_list: ["example"] # for example_train.txt
|
|
||||||
data_txt_path: './datasets/dataset_list/'
|
|
||||||
dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/"
|
|
||||||
cache_data_path: '/home1/quanquan/datasets/cached_dataset2/'
|
|
||||||
|
|
||||||
cache_prefix: ['6016'] # '07'
|
|
||||||
specific_label: [2]
|
|
||||||
|
|
||||||
test:
|
|
||||||
batch_size: 1
|
|
||||||
|
|
BIN
core/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
core/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
core/__pycache__/ddp.cpython-38.pyc
Normal file
BIN
core/__pycache__/ddp.cpython-38.pyc
Normal file
Binary file not shown.
BIN
core/__pycache__/ddp_b10.cpython-38.pyc
Normal file
BIN
core/__pycache__/ddp_b10.cpython-38.pyc
Normal file
Binary file not shown.
BIN
core/__pycache__/learner2.cpython-38.pyc
Normal file
BIN
core/__pycache__/learner2.cpython-38.pyc
Normal file
Binary file not shown.
BIN
core/__pycache__/learner3.cpython-38.pyc
Normal file
BIN
core/__pycache__/learner3.cpython-38.pyc
Normal file
Binary file not shown.
BIN
core/__pycache__/learner5.cpython-38.pyc
Normal file
BIN
core/__pycache__/learner5.cpython-38.pyc
Normal file
Binary file not shown.
BIN
core/__pycache__/lora_sam.cpython-38.pyc
Normal file
BIN
core/__pycache__/lora_sam.cpython-38.pyc
Normal file
Binary file not shown.
BIN
core/__pycache__/loss.cpython-38.pyc
Normal file
BIN
core/__pycache__/loss.cpython-38.pyc
Normal file
Binary file not shown.
122
core/ddp_sub.py
122
core/ddp_sub.py
@ -1,122 +0,0 @@
|
|||||||
"""
|
|
||||||
from ddp_b9.py
|
|
||||||
|
|
||||||
Add additional bypass/side-way to finetune on other datasets
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
|
|
||||||
from tutils import tfilename, tdir
|
|
||||||
|
|
||||||
from datasets.dataset3d_2dmask import Dataset2D
|
|
||||||
# from datasets.dataset3d import Dataset3D
|
|
||||||
from datasets.cache_dataset3d3 import Dataset3D
|
|
||||||
from datasets.dataset_merged import DatasetMerged, TestsetMerged
|
|
||||||
from datasets.data_engine import DataEngine
|
|
||||||
from modeling.build_sam3d2 import sam_model_registry
|
|
||||||
|
|
||||||
from .learner_sub1 import SamLearner
|
|
||||||
# from tutils.new.trainer.trainer_ddp import DDPTrainer
|
|
||||||
from trans_utils.trainer_ddp import DDPTrainer
|
|
||||||
# from .lora_sam import LoRA_Sam
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
warnings.filterwarnings("ignore")
|
|
||||||
|
|
||||||
|
|
||||||
def setup(rank, world_size):
|
|
||||||
os.environ['MASTER_ADDR'] = 'localhost'
|
|
||||||
os.environ['MASTER_PORT'] = '12355'
|
|
||||||
|
|
||||||
# initialize the process group
|
|
||||||
dist.init_process_group("gloo", rank=rank, world_size=world_size)
|
|
||||||
|
|
||||||
def cleanup():
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
def ddp_train(rank, world_size, config):
|
|
||||||
setup(rank, world_size)
|
|
||||||
|
|
||||||
# sam_checkpoint = "/quanquan/code/segment-anything/segment_anything/sam_vit_b_01ec64.pth" # A800 server
|
|
||||||
# sam_checkpoint = "/home1/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server
|
|
||||||
model_type = "vit_b"
|
|
||||||
device = rank
|
|
||||||
|
|
||||||
config_data = config['dataset']
|
|
||||||
data_type = config_data.get("types", ["3d", "2d"])
|
|
||||||
data_type = [data_type] if isinstance(data_type, str) else data_type
|
|
||||||
dataset = Dataset3D(config_data, split='train')
|
|
||||||
|
|
||||||
# assert len(validset) > 0
|
|
||||||
data_engine = DataEngine(dataset=dataset, img_size=(1024,1024))
|
|
||||||
sam = sam_model_registry[model_type](checkpoint=None)
|
|
||||||
|
|
||||||
learner = SamLearner(sam_model=sam, config=config, data_engine=data_engine)
|
|
||||||
learner.use_lora()
|
|
||||||
learner.load_well_trained_model(config['training']['breakpoint_path']) # use preset path
|
|
||||||
learner.use_lora_sub()
|
|
||||||
|
|
||||||
ddp_trainer = DDPTrainer(config=config, rank=rank, world_size=world_size)
|
|
||||||
ddp_trainer.fit(learner, trainset=data_engine, validset=None)
|
|
||||||
|
|
||||||
cleanup()
|
|
||||||
|
|
||||||
|
|
||||||
def get_parameter_number(model):
|
|
||||||
total_num = sum(p.numel() for p in model.parameters())
|
|
||||||
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
return {'Total': total_num, 'Trainable': trainable_num}
|
|
||||||
|
|
||||||
def run_demo(demo_fn, world_size, config):
|
|
||||||
mp.spawn(demo_fn,
|
|
||||||
args=(world_size,config),
|
|
||||||
nprocs=world_size,
|
|
||||||
join=True)
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
import yaml
|
|
||||||
import yamlloader
|
|
||||||
def _ordereddict_to_dict(d):
|
|
||||||
if not isinstance(d, dict):
|
|
||||||
return d
|
|
||||||
for k, v in d.items():
|
|
||||||
if isinstance(v, OrderedDict):
|
|
||||||
v = _ordereddict_to_dict(v)
|
|
||||||
d[k] = dict(v)
|
|
||||||
elif type(v) == list:
|
|
||||||
d[k] = _ordereddict_to_dict(v)
|
|
||||||
elif isinstance(v, dict):
|
|
||||||
d[k] = _ordereddict_to_dict(v)
|
|
||||||
return d
|
|
||||||
|
|
||||||
# CUDA_VISIBLE_DEVICES=4,5,6,7 python -m core.ddp_b3 --tag lora --config configs/vit_b_103.yaml
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import argparse
|
|
||||||
from tutils.new.manager import trans_args, trans_init, ConfigManager
|
|
||||||
|
|
||||||
n_gpus = torch.cuda.device_count()
|
|
||||||
# assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but {__file__} Got{n_gpus}"
|
|
||||||
if n_gpus == 1:
|
|
||||||
print("Warning! Running on only 1 GPU! just for debug")
|
|
||||||
world_size = n_gpus
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--config", default="./configs/vit_sub.yaml")
|
|
||||||
parser.add_argument("--func", default="train")
|
|
||||||
parser.add_argument("--reuse", action="store_true")
|
|
||||||
|
|
||||||
args = trans_args(parser=parser)
|
|
||||||
config = ConfigManager()
|
|
||||||
config.auto_init(file=__file__, args=args, ex_config=None)
|
|
||||||
# config.save()
|
|
||||||
path = tfilename(config['base']['runs_dir'], "config.yaml")
|
|
||||||
with open(path, "w") as f:
|
|
||||||
yaml.dump(_ordereddict_to_dict(config), f)
|
|
||||||
print("Save config file to ", path)
|
|
||||||
|
|
||||||
if n_gpus < 1: exit(0)
|
|
||||||
run_demo(ddp_train, world_size, config)
|
|
@ -107,8 +107,8 @@ class SamLearner(LearnerModule):
|
|||||||
self.model.load_state_dict(state_dict)
|
self.model.load_state_dict(state_dict)
|
||||||
# self.lora_module.load_lora_parameters(pth.replace(".pth", "_lora.safetensors"))
|
# self.lora_module.load_lora_parameters(pth.replace(".pth", "_lora.safetensors"))
|
||||||
|
|
||||||
def use_lora(self, r=8):
|
def use_lora(self):
|
||||||
lora_r = r
|
lora_r = 8
|
||||||
lora_sam = LoRA_Sam(self.model, lora_r, freeze_prompt_encoder=True)
|
lora_sam = LoRA_Sam(self.model, lora_r, freeze_prompt_encoder=True)
|
||||||
self.lora_module = lora_sam
|
self.lora_module = lora_sam
|
||||||
|
|
||||||
|
@ -1,73 +0,0 @@
|
|||||||
"""
|
|
||||||
Use mask_decoder3d_2.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torchvision
|
|
||||||
import numpy as np
|
|
||||||
from tutils.trainer import Trainer, LearnerModule
|
|
||||||
from einops import rearrange, repeat, reduce
|
|
||||||
import torch.optim.lr_scheduler as lr_scheduler
|
|
||||||
|
|
||||||
from core.loss import ranked_combined_loss_with_indicators
|
|
||||||
from .learner5 import SamLearner as basic_learner
|
|
||||||
from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss
|
|
||||||
# from torchao.quantization import apply_dynamic_quant
|
|
||||||
# from torch._inductor import config as inductorconfig
|
|
||||||
from .lora_sam import LoRA_Sam
|
|
||||||
|
|
||||||
|
|
||||||
class SamLearner(basic_learner):
|
|
||||||
|
|
||||||
def load_pretrained_model(self, pth, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Unmatched: prompt_encoder.mask_downscaling.0.weight
|
|
||||||
their: torch.Size([4, 1, 2, 2])
|
|
||||||
our: torch.Size([4, 3, 2, 2])
|
|
||||||
Unmatched: mask_decoder.mask_tokens.weight
|
|
||||||
their: torch.Size([4, 256])
|
|
||||||
our: torch.Size([12, 256])
|
|
||||||
"""
|
|
||||||
print("Load pretrained model for mask_decoder3d_2 !!")
|
|
||||||
|
|
||||||
state_dict = torch.load(pth)
|
|
||||||
model_state_dict = self.model.state_dict()
|
|
||||||
model_state_dict.update(state_dict)
|
|
||||||
model_state_dict['prompt_encoder.mask_downscaling.0.weight'] = repeat(state_dict['prompt_encoder.mask_downscaling.0.weight'], "a 1 c d -> a b c d", b=3)
|
|
||||||
# model_state_dict['mask_decoder.mask_tokens.weight'] = repeat(state_dict['mask_decoder.mask_tokens.weight'], "a d -> (a 3) d")
|
|
||||||
|
|
||||||
for k, v in model_state_dict.items():
|
|
||||||
if k.startswith("mask_decoder.output_upscaling2"):
|
|
||||||
k2 = k.replace("output_upscaling2.", "output_upscaling." )
|
|
||||||
model_state_dict[k] = model_state_dict[k2]
|
|
||||||
print("Load weights: ", k)
|
|
||||||
if k.startswith("mask_decoder.output_upscaling3"):
|
|
||||||
k2 = k.replace("output_upscaling3.", "output_upscaling." )
|
|
||||||
model_state_dict[k] = model_state_dict[k2]
|
|
||||||
print("Load weights: ", k)
|
|
||||||
|
|
||||||
hyper_params_names = [k for k in model_state_dict.keys() if k.startswith("mask_decoder.output_hypernetworks_mlps")]
|
|
||||||
for name in hyper_params_names:
|
|
||||||
words = name.split('.')
|
|
||||||
words[2] = str(int(words[2]) // 3)
|
|
||||||
name_to_copy = ".".join(words)
|
|
||||||
model_state_dict[name] = state_dict[name_to_copy]
|
|
||||||
# for k, v in state_dict.items():
|
|
||||||
# if model_state_dict[k].shape != state_dict[k].shape:
|
|
||||||
# print("Unmatched:", k)
|
|
||||||
self.model.load_state_dict(model_state_dict)
|
|
||||||
|
|
||||||
def quantize(self):
|
|
||||||
# self.model.image_encoder = torch.ao.quantization.quantize_dynamic(
|
|
||||||
# self.model.image_encoder,
|
|
||||||
# dtype=torch.qint8
|
|
||||||
# )
|
|
||||||
apply_dynamic_quant(self.model.image_encoder)
|
|
||||||
inductorconfig.force_fuse_int_mm_with_mul = True
|
|
||||||
print("Quantized !")
|
|
||||||
|
|
||||||
|
|
||||||
def use_lora_sub(self):
|
|
||||||
lora_r = 1
|
|
||||||
lora_sam = LoRA_Sam(self.model, lora_r, freeze_all=True)
|
|
||||||
self.lora_module = lora_sam
|
|
@ -33,7 +33,6 @@ class _LoRA_qkv(nn.Module):
|
|||||||
self.linear_a_v = linear_a_v
|
self.linear_a_v = linear_a_v
|
||||||
self.linear_b_v = linear_b_v
|
self.linear_b_v = linear_b_v
|
||||||
self.dim = qkv.in_features
|
self.dim = qkv.in_features
|
||||||
self.in_features = qkv.in_features
|
|
||||||
self.w_identity = torch.eye(qkv.in_features)
|
self.w_identity = torch.eye(qkv.in_features)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -191,7 +191,7 @@ class VolumePredictor:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# Preprocess prompts
|
# Preprocess prompts
|
||||||
# self.original_size = x.shape[1:]
|
self.original_size = x.shape[1:]
|
||||||
if point_coords is not None:
|
if point_coords is not None:
|
||||||
assert (
|
assert (
|
||||||
point_labels is not None
|
point_labels is not None
|
||||||
@ -217,7 +217,6 @@ class VolumePredictor:
|
|||||||
center_masks = self._predict_center_slice(center_idx, point_coords, box)
|
center_masks = self._predict_center_slice(center_idx, point_coords, box)
|
||||||
return center_masks['masks']
|
return center_masks['masks']
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def predict_volume(
|
def predict_volume(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@ -256,37 +255,8 @@ class VolumePredictor:
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# set 3d image
|
|
||||||
self.set_image(x)
|
|
||||||
# Preprocess prompts
|
# Preprocess prompts
|
||||||
self.original_size = x.shape[1:]
|
self.original_size = x.shape[1:]
|
||||||
if self.masks3d is None:
|
|
||||||
self.masks3d = np.zeros_like(x)
|
|
||||||
self.slice_count = x.shape[0]
|
|
||||||
return self.predict_with_prompt(
|
|
||||||
point_coords = point_coords,
|
|
||||||
point_labels = point_labels,
|
|
||||||
box = box,
|
|
||||||
mask_input = mask_input,
|
|
||||||
multimask_output = multimask_output,
|
|
||||||
return_logits = return_logits,
|
|
||||||
template_slice_id = template_slice_id,
|
|
||||||
return_stability = return_stability
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def predict_with_prompt(
|
|
||||||
self,
|
|
||||||
point_coords: Optional[np.ndarray] = None,
|
|
||||||
point_labels: Optional[np.ndarray] = None,
|
|
||||||
box: Optional[np.ndarray] = None,
|
|
||||||
mask_input: Optional[np.ndarray] = None,
|
|
||||||
multimask_output: bool = True,
|
|
||||||
return_logits: bool = False,
|
|
||||||
template_slice_id:int = None,
|
|
||||||
return_stability: bool = False,
|
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
|
|
||||||
if point_coords is not None:
|
if point_coords is not None:
|
||||||
assert (
|
assert (
|
||||||
point_labels is not None
|
point_labels is not None
|
||||||
@ -303,39 +273,38 @@ class VolumePredictor:
|
|||||||
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
|
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
|
||||||
mask_input_torch = mask_input_torch[None, :, :, :]
|
mask_input_torch = mask_input_torch[None, :, :, :]
|
||||||
|
|
||||||
self.all_prompts = {}
|
# set 3d image
|
||||||
|
self.set_image(x)
|
||||||
|
|
||||||
# predict center slice
|
# predict center slice
|
||||||
center_idx = template_slice_id if template_slice_id is not None else self.slice_count // 2
|
center_idx = template_slice_id if template_slice_id is not None else x.shape[0] // 2
|
||||||
# print("Processing ", center_idx)
|
# print("Processing ", center_idx)
|
||||||
center_masks = self._predict_center_slice(center_idx, point_coords, box)
|
center_masks = self._predict_center_slice(center_idx, point_coords, box)
|
||||||
if center_masks._stats == {}:
|
if center_masks._stats == {}:
|
||||||
print("Ends for no mask.")
|
print("Ends for no mask.")
|
||||||
raise ValueError
|
raise ValueError
|
||||||
self.merge_to_mask3d(center_idx, center_masks)
|
self.merge_to_mask3d(center_idx, center_masks)
|
||||||
center_idx = center_idx.item() if not isinstance(center_idx, int) else center_idx
|
|
||||||
self.all_prompts[center_idx] = box if box is not None else point_coords
|
|
||||||
|
|
||||||
previous_masks = center_masks
|
previous_masks = center_masks
|
||||||
for i in range(center_idx+1, self.slice_count-1):
|
for i in range(center_idx+1, x.shape[0]-1):
|
||||||
# print("Processing downward", i)
|
# print("Processing downward", i)
|
||||||
previous_masks, scaled_boxes = self._predict_slice(i, previous_masks, orientation="down")
|
previous_masks = self._predict_slice(i, previous_masks, orientation="down")
|
||||||
if previous_masks._stats == {}:
|
if previous_masks._stats == {}:
|
||||||
print("Ends for no mask.")
|
print("Ends for no mask.")
|
||||||
break
|
break
|
||||||
self.merge_to_mask3d(i, previous_masks)
|
self.merge_to_mask3d(i, previous_masks)
|
||||||
self.all_prompts[i] = scaled_boxes
|
|
||||||
|
|
||||||
previous_masks = center_masks
|
previous_masks = center_masks
|
||||||
for i in np.arange(1, center_idx)[::-1]:
|
for i in np.arange(1, center_idx)[::-1]:
|
||||||
# print("Processing upward", i)
|
# print("Processing upward", i)
|
||||||
previous_masks, scaled_boxes = self._predict_slice(i, previous_masks, orientation="up")
|
previous_masks = self._predict_slice(i, previous_masks, orientation="up")
|
||||||
if previous_masks._stats == {}:
|
if previous_masks._stats == {}:
|
||||||
print("Ends for no mask.")
|
print("Ends for no mask.")
|
||||||
break
|
break
|
||||||
self.merge_to_mask3d(i, previous_masks)
|
self.merge_to_mask3d(i, previous_masks)
|
||||||
self.all_prompts[i] = scaled_boxes
|
|
||||||
|
|
||||||
|
if self.masks3d is None:
|
||||||
|
self.masks3d = np.zeros_like(x)
|
||||||
if return_stability:
|
if return_stability:
|
||||||
return self.postprocess_3d(self.masks3d), self.stability_score_2d
|
return self.postprocess_3d(self.masks3d), self.stability_score_2d
|
||||||
return self.postprocess_3d(self.masks3d)
|
return self.postprocess_3d(self.masks3d)
|
||||||
@ -355,7 +324,7 @@ class VolumePredictor:
|
|||||||
scaled_boxes, tags = self.generate_prompts_from_previous_masks(previous_masks, orientation)
|
scaled_boxes, tags = self.generate_prompts_from_previous_masks(previous_masks, orientation)
|
||||||
masks = self.genetate_masks_from_boxes(idx, all_boxes=scaled_boxes, tags=tags)
|
masks = self.genetate_masks_from_boxes(idx, all_boxes=scaled_boxes, tags=tags)
|
||||||
masks.to_numpy()
|
masks.to_numpy()
|
||||||
return masks, scaled_boxes
|
return masks
|
||||||
|
|
||||||
def generate_prompts_from_previous_masks(self, previous_masks: MaskData, orientation):
|
def generate_prompts_from_previous_masks(self, previous_masks: MaskData, orientation):
|
||||||
if orientation == "down":
|
if orientation == "down":
|
||||||
@ -517,6 +486,10 @@ class VolumePredictor:
|
|||||||
data["boxes"] = batched_mask_to_box(data["masks"][:,1,:,:]>0)
|
data["boxes"] = batched_mask_to_box(data["masks"][:,1,:,:]>0)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
# @staticmethod
|
||||||
|
# def calculate_
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def batched_remove_noise(masks):
|
def batched_remove_noise(masks):
|
||||||
ori_shape = masks.shape
|
ori_shape = masks.shape
|
||||||
@ -566,80 +539,26 @@ class VolumePredictor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Upscale the masks to the original image resolution
|
# Upscale the masks to the original image resolution
|
||||||
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
|
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size[1:])
|
||||||
|
|
||||||
if not return_logits:
|
if not return_logits:
|
||||||
masks = masks > self.model.mask_threshold
|
masks = masks > self.model.mask_threshold
|
||||||
|
|
||||||
return masks, iou_predictions, low_res_masks
|
return masks, iou_predictions, low_res_masks
|
||||||
|
|
||||||
# def valid_box(self, data, batch_idx):
|
def valid_box(self, data, batch_idx):
|
||||||
# # valid image with box, or point prompt
|
# valid image with box, or point prompt
|
||||||
# assert data['img'].shape[0] == 1, f"shape {data['img'].shape}"
|
assert data['img'].shape[0] == 1, f"shape {data['img'].shape}"
|
||||||
# image = data['img']
|
image = data['img']
|
||||||
# label = data['label']
|
label = data['label']
|
||||||
|
|
||||||
# box = BoxPromptGenerator().mask_to_bbox(label)
|
box = BoxPromptGenerator().mask_to_bbox(label)
|
||||||
# box_mask3d = self.predict_volume(
|
box_mask3d = self.predict_volume(
|
||||||
# x=image,
|
x=image,
|
||||||
# box=box,
|
box=box,
|
||||||
# )
|
)
|
||||||
# dice = compute_dice_np(box_mask3d, label.detach().cpu().numpy())
|
dice = compute_dice_np(box_mask3d, label.detach().cpu().numpy())
|
||||||
|
|
||||||
def get_confidence(self):
|
|
||||||
masks = self.postprocess_3d(self.masks3d)
|
|
||||||
conf_collect = []
|
|
||||||
for i in range(1,self.masks3d.shape[0]-1):
|
|
||||||
prompt_box = self.all_prompts.get(i, None)
|
|
||||||
if prompt_box is not None:
|
|
||||||
mask = masks[i,:,:]
|
|
||||||
if mask.sum() > 0:
|
|
||||||
bbox = BoxPromptGenerator(size=None).mask_to_bbox(mask)
|
|
||||||
bbox = self.transform.apply_boxes(np.array([bbox]), self.original_size)[0]
|
|
||||||
else:
|
|
||||||
bbox = [0,0,0,0]
|
|
||||||
prompt_box = self.all_prompts[i][0]
|
|
||||||
confidence = calculate_iou(bbox, prompt_box)
|
|
||||||
else:
|
|
||||||
confidence = 0
|
|
||||||
if i == 1:
|
|
||||||
conf_collect.append(confidence)
|
|
||||||
conf_collect.append(confidence)
|
|
||||||
assert len(conf_collect) == i+1
|
|
||||||
conf_collect.append(confidence)
|
|
||||||
print(conf_collect)
|
|
||||||
return conf_collect
|
|
||||||
|
|
||||||
def calculate_iou(box1, box2):
|
|
||||||
"""
|
|
||||||
计算两个框的IoU Intersection over Union。
|
|
||||||
|
|
||||||
参数:
|
|
||||||
box1 和 box2 是两个框,每个框表示为四个值 (x1, y1, x2, y2),其中 (x1, y1) 是左上角的坐标,
|
|
||||||
(x2, y2) 是右下角的坐标。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
返回两个框的IoU。
|
|
||||||
"""
|
|
||||||
# 计算交集的左上角和右下角坐标
|
|
||||||
x1_i = max(box1[0], box2[0])
|
|
||||||
y1_i = max(box1[1], box2[1])
|
|
||||||
x2_i = min(box1[2], box2[2])
|
|
||||||
y2_i = min(box1[3], box2[3])
|
|
||||||
|
|
||||||
# 计算交集的面积
|
|
||||||
intersection_area = max(0, x2_i - x1_i) * max(0, y2_i - y1_i)
|
|
||||||
|
|
||||||
# 计算并集的面积
|
|
||||||
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
|
||||||
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
|
||||||
union_area = box1_area + box2_area - intersection_area
|
|
||||||
# print(intersection_area, union_area)
|
|
||||||
|
|
||||||
# 计算IoU
|
|
||||||
iou = intersection_area / union_area
|
|
||||||
|
|
||||||
return iou
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -669,7 +588,7 @@ if __name__ == "__main__":
|
|||||||
volume = itk_to_np(read(img_path)) # test several slices
|
volume = itk_to_np(read(img_path)) # test several slices
|
||||||
label_itk = read(label_path)
|
label_itk = read(label_path)
|
||||||
spacing = label_itk.GetSpacing()
|
spacing = label_itk.GetSpacing()
|
||||||
label = itk_to_np(label_itk) == 13
|
label = itk_to_np(label_itk) == 1
|
||||||
volume = np.clip(volume, -200, 400)
|
volume = np.clip(volume, -200, 400)
|
||||||
|
|
||||||
# Select the slice with the largest mask
|
# Select the slice with the largest mask
|
||||||
@ -679,14 +598,8 @@ if __name__ == "__main__":
|
|||||||
x_max = np.max(coords[0])
|
x_max = np.max(coords[0])
|
||||||
template_slice_id = s.argmax()
|
template_slice_id = s.argmax()
|
||||||
|
|
||||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id]) # (115, 207, 309, 339)
|
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id])
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
# box = (125, 210, 300, 310)
|
|
||||||
box = np.array([box])
|
box = np.array([box])
|
||||||
box[0][0] += 10
|
|
||||||
box[0][1] += 10
|
|
||||||
box[0][2] -= 10
|
|
||||||
box[0][3] -= 10
|
|
||||||
|
|
||||||
pred = predictor.predict_volume(
|
pred = predictor.predict_volume(
|
||||||
x=volume,
|
x=volume,
|
||||||
@ -697,8 +610,3 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
Data3dSolver().simple_write(pred, path="mask.nii.gz", spacing=spacing)
|
Data3dSolver().simple_write(pred, path="mask.nii.gz", spacing=spacing)
|
||||||
Data3dSolver().simple_write(label, path="gt.nii.gz", spacing=spacing)
|
Data3dSolver().simple_write(label, path="gt.nii.gz", spacing=spacing)
|
||||||
|
|
||||||
dice = compute_dice_np(pred, label)
|
|
||||||
print("Dice ", dice, " box: ", box, "slice id", template_slice_id)
|
|
||||||
print(tuple(box))
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
BIN
datasets/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
datasets/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
datasets/__pycache__/cache_dataset3d.cpython-38.pyc
Normal file
BIN
datasets/__pycache__/cache_dataset3d.cpython-38.pyc
Normal file
Binary file not shown.
BIN
datasets/__pycache__/cache_dataset3d3.cpython-38.pyc
Normal file
BIN
datasets/__pycache__/cache_dataset3d3.cpython-38.pyc
Normal file
Binary file not shown.
BIN
datasets/__pycache__/data_engine.cpython-38.pyc
Normal file
BIN
datasets/__pycache__/data_engine.cpython-38.pyc
Normal file
Binary file not shown.
BIN
datasets/__pycache__/dataset3d.cpython-38.pyc
Normal file
BIN
datasets/__pycache__/dataset3d.cpython-38.pyc
Normal file
Binary file not shown.
BIN
datasets/__pycache__/dataset3d_2dmask.cpython-38.pyc
Normal file
BIN
datasets/__pycache__/dataset3d_2dmask.cpython-38.pyc
Normal file
Binary file not shown.
BIN
datasets/__pycache__/dataset_merged.cpython-38.pyc
Normal file
BIN
datasets/__pycache__/dataset_merged.cpython-38.pyc
Normal file
Binary file not shown.
@ -43,14 +43,12 @@ TEMPLATE={
|
|||||||
'10_10': [31],
|
'10_10': [31],
|
||||||
'58': [6,2,3,1],
|
'58': [6,2,3,1],
|
||||||
'59': [1,2,3,4,5,6,7,8,9,10,11,12,13,14],
|
'59': [1,2,3,4,5,6,7,8,9,10,11,12,13,14],
|
||||||
'60': (np.ones(200)*(-1)).tolist(), # for debug
|
'60': np.arange(200).tolist(), # for debug
|
||||||
"65": np.zeros(200).tolist(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Dataset3D(basic_3d_dataset):
|
class Dataset3D(basic_3d_dataset):
|
||||||
def __init__(self, config, use_cache=True, *args, **kwargs) -> None:
|
def __init__(self, config=..., use_cache=True, *args, **kwargs) -> None:
|
||||||
super().__init__(config=config, use_cache=use_cache, *args, **kwargs)
|
super().__init__(config, use_cache=use_cache, *args, **kwargs)
|
||||||
self.basic_dir = config['data_root_path']
|
self.basic_dir = config['data_root_path']
|
||||||
self.cache_dir = config['cache_data_path']
|
self.cache_dir = config['cache_data_path']
|
||||||
|
|
||||||
@ -95,13 +93,6 @@ class Dataset3D(basic_3d_dataset):
|
|||||||
|
|
||||||
dataset_name = self.img_names[index]['img_path'].replace(self.basic_dir,"").split("/")[0]
|
dataset_name = self.img_names[index]['img_path'].replace(self.basic_dir,"").split("/")[0]
|
||||||
assert dataset_name[0] in ['0','1','2','3','4','5','6','7','8','9'], f"Got {dataset_name}"
|
assert dataset_name[0] in ['0','1','2','3','4','5','6','7','8','9'], f"Got {dataset_name}"
|
||||||
if dataset_name[:2] == "10":
|
|
||||||
subname = self.img_names[index]['img_path'].replace(self.basic_dir,"")[17:19]
|
|
||||||
assert subname in ['10', '03', '06', '07']
|
|
||||||
all_labels = TEMPLATE[dataset_name[:2] + "_" + subname]
|
|
||||||
else:
|
|
||||||
all_labels = TEMPLATE[dataset_name[:2]]
|
|
||||||
|
|
||||||
all_labels = TEMPLATE[dataset_name[:2]]
|
all_labels = TEMPLATE[dataset_name[:2]]
|
||||||
|
|
||||||
num = 0
|
num = 0
|
||||||
@ -161,7 +152,7 @@ class Dataset3D(basic_3d_dataset):
|
|||||||
self.save_img_rgb(rearrange(img_rgb, "c h w -> h w c"), save_image_name)
|
self.save_img_rgb(rearrange(img_rgb, "c h w -> h w c"), save_image_name)
|
||||||
|
|
||||||
# Save cache data
|
# Save cache data
|
||||||
save_label_name = tfilename(self.cache_dir, dataset_name, f"label_jpg/label_{index:04d}_{num:08d}")
|
save_label_name = tfilename(self.cache_dir, dataset_name, f"label/label_{index:04d}_{num:08d}.npz")
|
||||||
self.save_slice_mask(masks_data, save_label_name)
|
self.save_slice_mask(masks_data, save_label_name)
|
||||||
print("Save ", save_image_name)
|
print("Save ", save_image_name)
|
||||||
|
|
||||||
@ -266,26 +257,11 @@ class Dataset3D(basic_3d_dataset):
|
|||||||
label_path = name.replace("image/image_", "label/label_").replace(".jpg", ".npz")
|
label_path = name.replace("image/image_", "label/label_").replace(".jpg", ".npz")
|
||||||
self._convert_one_mask_from_npz_to_jpg(label_path)
|
self._convert_one_mask_from_npz_to_jpg(label_path)
|
||||||
|
|
||||||
|
|
||||||
# EX_CONFIG={
|
|
||||||
# "dataset":{
|
|
||||||
# "split": 'train',
|
|
||||||
# "data_root_path": '/home1/quanquan/datasets/',
|
|
||||||
# "dataset_list": ["decathlon_colon"],
|
|
||||||
# "data_txt_path": './datasets/dataset_list/',
|
|
||||||
# "cache_data_path": '/home1/quanquan/datasets/cached_dataset2/',
|
|
||||||
# "cache_prefix": ['10'] # '07'
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# def go_cache():
|
# def go_cache():
|
||||||
from tutils.new.manager import ConfigManager
|
from tutils.new.manager import ConfigManager
|
||||||
config = ConfigManager()
|
config = ConfigManager()
|
||||||
config.add_config("configs/vit_sub.yaml")
|
config.add_config("configs/vit_b.yaml")
|
||||||
# config.add_config(EX_CONFIG)
|
dataset = Dataset3D(config=config['dataset'], use_cache=True)
|
||||||
|
|
||||||
# Caching data
|
|
||||||
dataset = Dataset3D(config=config['dataset'], use_cache=False)
|
|
||||||
dataset.caching_data()
|
dataset.caching_data()
|
||||||
# dataset.convert_masks_types()
|
# dataset.convert_masks_types()
|
||||||
|
@ -33,7 +33,7 @@ class Dataset3D(basic_3d_dataset):
|
|||||||
if not os.path.isdir(dirpath):
|
if not os.path.isdir(dirpath):
|
||||||
continue
|
continue
|
||||||
prefix = dirpath.split("/")[-1]
|
prefix = dirpath.split("/")[-1]
|
||||||
if prefix.split("_")[0] in config['cache_prefix']:
|
if prefix[:2] in config['cache_prefix']:
|
||||||
data_paths += glob.glob(dirpath + "/label_jpg/*.jpg")
|
data_paths += glob.glob(dirpath + "/label_jpg/*.jpg")
|
||||||
print("Load ", dirpath)
|
print("Load ", dirpath)
|
||||||
print('Masks len {}'.format(len(data_paths)))
|
print('Masks len {}'.format(len(data_paths)))
|
||||||
|
@ -1,2 +0,0 @@
|
|||||||
07_WORD/WORD-V0.1.0/imagesVa/word_0001.nii.gz 07_WORD/WORD-V0.1.0/labelsVa/word_0001.nii.gz
|
|
||||||
07_WORD/WORD-V0.1.0/imagesVa/word_0007.nii.gz 07_WORD/WORD-V0.1.0/labelsVa/word_0007.nii.gz
|
|
@ -1,146 +0,0 @@
|
|||||||
"""
|
|
||||||
DataLoader only for evaluation
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
from tutils.nn.data import read, itk_to_np, np_to_itk
|
|
||||||
from einops import reduce, repeat, rearrange
|
|
||||||
# from tutils.nn.data.tsitk.preprocess import resampleImage
|
|
||||||
from trans_utils.data_utils import Data3dSolver
|
|
||||||
from tutils.nn.data.tsitk.preprocess import resampleImage
|
|
||||||
import SimpleITK as sitk
|
|
||||||
|
|
||||||
|
|
||||||
# Example
|
|
||||||
DATASET_CONFIG={
|
|
||||||
'split': 'test',
|
|
||||||
'data_root_path':'/quanquan/datasets/',
|
|
||||||
'dataset_list': ["ours"],
|
|
||||||
'data_txt_path':'./datasets/dataset_list/',
|
|
||||||
'label_idx': 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
class AbstractLoader(Dataset):
|
|
||||||
def __init__(self, config, split="test") -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.split = split
|
|
||||||
self.img_names = self.prepare_datalist()
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.img_names)
|
|
||||||
|
|
||||||
def prepare_datalist(self):
|
|
||||||
config = self.config
|
|
||||||
data_paths = []
|
|
||||||
for item in config['dataset_list']:
|
|
||||||
print("Load datalist from ", item)
|
|
||||||
for line in open(config["data_txt_path"]+ item + f"_{self.split}.txt"):
|
|
||||||
name = line.strip().split()[1].split('.')[0]
|
|
||||||
img_path = config['data_root_path'] + line.strip().split()[0]
|
|
||||||
label_path = config['data_root_path'] + line.strip().split()[1]
|
|
||||||
data_paths.append({'img_path': img_path, 'label_path': label_path, 'name': name})
|
|
||||||
print('train len {}'.format(len(data_paths)))
|
|
||||||
return data_paths
|
|
||||||
|
|
||||||
|
|
||||||
def _get_data(self, index, debug=False):
|
|
||||||
label_idx = self.config['label_idx']
|
|
||||||
name = self.img_names[index]['img_path']
|
|
||||||
img_itk = read(name)
|
|
||||||
spacing = img_itk.GetSpacing()
|
|
||||||
img_ori = itk_to_np(img_itk)
|
|
||||||
scan_orientation = np.argmin(img_ori.shape)
|
|
||||||
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
|
|
||||||
label = label_ori == label_idx
|
|
||||||
|
|
||||||
# img_ori, new_spacing = Data3dSolver().read(self.img_names[index]['img_path'])
|
|
||||||
# label_itk = read(self.img_names[index]['label_path'])
|
|
||||||
# ori_spacing = label_itk.GetSpacing()
|
|
||||||
# label = itk_to_np(label_itk) == label_idx
|
|
||||||
# print("[loader_abstract.DEBUG] size", img_ori.shape, label.shape)
|
|
||||||
# label = self._get_resized_label(label, new_size=img_ori.shape)
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
Data3dSolver().simple_write(label)
|
|
||||||
Data3dSolver().simple_write(img_ori, "tmp_img.nii.gz")
|
|
||||||
|
|
||||||
s = reduce(label, "c h w -> c", reduction="sum")
|
|
||||||
coords = np.nonzero(s)
|
|
||||||
x_min = np.min(coords[0])
|
|
||||||
x_max = np.max(coords[0])
|
|
||||||
template_slice_id = s.argmax() - x_min
|
|
||||||
|
|
||||||
if img_ori.min() < -10:
|
|
||||||
img_ori = np.clip(img_ori, -200, 400)
|
|
||||||
else:
|
|
||||||
img_ori = np.clip(img_ori, 0, 600)
|
|
||||||
|
|
||||||
img_ori = img_ori[x_min:x_max+1,:,:]
|
|
||||||
label = label[x_min:x_max+1,:,:]
|
|
||||||
assert label.shape[0] >= 3
|
|
||||||
|
|
||||||
if template_slice_id <= 1 or template_slice_id >= label.shape[0]-2:
|
|
||||||
template_slice_id == label.shape[0] // 2
|
|
||||||
|
|
||||||
dataset_name = name.replace(self.config['data_root_path'], "").split("/")[0]
|
|
||||||
template_slice = label[template_slice_id,:,:]
|
|
||||||
print("template_slice.area ", template_slice.sum(), template_slice.sum() / (template_slice.shape[0] * template_slice.shape[1]))
|
|
||||||
d = {
|
|
||||||
"name": name,
|
|
||||||
"dataset_name": dataset_name,
|
|
||||||
"img": np.array(img_ori).astype(np.float32),
|
|
||||||
"label_idx": label_idx,
|
|
||||||
"label": np.array(label).astype(np.float32),
|
|
||||||
"template_slice_id": template_slice_id,
|
|
||||||
"template_slice": np.array(label[template_slice_id,:,:]).astype(np.float32),
|
|
||||||
"spacing": np.array(spacing),
|
|
||||||
}
|
|
||||||
return d
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
return self._get_data(index)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from tutils.new.manager import ConfigManager
|
|
||||||
EX_CONFIG = {
|
|
||||||
'dataset':{
|
|
||||||
'prompt': 'box',
|
|
||||||
'dataset_list': ['guangdong'], # ["sabs"], chaos, word
|
|
||||||
'label_idx': 2,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
config = ConfigManager()
|
|
||||||
config.add_config("configs/vit_sub_rectum.yaml")
|
|
||||||
config.add_config(EX_CONFIG)
|
|
||||||
dataset = AbstractLoader(config['dataset'], split="test")
|
|
||||||
for i in range(len(dataset)):
|
|
||||||
dataset._get_data(i, debug=False)
|
|
||||||
|
|
||||||
# label_path = "/home1/quanquan/datasets/01_BCV-Abdomen/Training/label/label0001.nii.gz"
|
|
||||||
# from monai.transforms import SpatialResample
|
|
||||||
# resample = SpatialResample()
|
|
||||||
# label = itk_to_np(read(label_path)) == 1
|
|
||||||
# print(label.shape)
|
|
||||||
# # resampled = resample(label, spatial_size=(label.shape[0]*7, label.shape[1], label.shape[2]))
|
|
||||||
# print(label.shape)
|
|
||||||
|
|
||||||
exit(0)
|
|
||||||
data = itk_to_np(read("tmp_img.nii.gz"))
|
|
||||||
data = torch.Tensor(data)
|
|
||||||
|
|
||||||
maxlen = data.shape[0]
|
|
||||||
slices = []
|
|
||||||
for i in range(1, maxlen-1):
|
|
||||||
slices.append(data[i-1:i+2, :, :])
|
|
||||||
input_slices = torch.stack(slices, axis=0)
|
|
||||||
input_slices = torch.clip(input_slices, -200, 600)
|
|
||||||
|
|
||||||
input_slices
|
|
||||||
|
|
||||||
from torchvision.utils import save_image
|
|
||||||
save_image(input_slices, "tmp.jpg")
|
|
BIN
modeling/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
modeling/__pycache__/build_sam3d2.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/build_sam3d2.cpython-38.pyc
Normal file
Binary file not shown.
BIN
modeling/__pycache__/common.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/common.cpython-38.pyc
Normal file
Binary file not shown.
BIN
modeling/__pycache__/image_encoder.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/image_encoder.cpython-38.pyc
Normal file
Binary file not shown.
BIN
modeling/__pycache__/mask_decoder3d_2.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/mask_decoder3d_2.cpython-38.pyc
Normal file
Binary file not shown.
BIN
modeling/__pycache__/prompt_encoder3d.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/prompt_encoder3d.cpython-38.pyc
Normal file
Binary file not shown.
BIN
modeling/__pycache__/sam3d.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/sam3d.cpython-38.pyc
Normal file
Binary file not shown.
BIN
modeling/__pycache__/transformer.cpython-38.pyc
Normal file
BIN
modeling/__pycache__/transformer.cpython-38.pyc
Normal file
Binary file not shown.
63
readme.md
63
readme.md
@ -118,29 +118,9 @@ Then run
|
|||||||
python -m datasets.cache_dataset3d
|
python -m datasets.cache_dataset3d
|
||||||
```
|
```
|
||||||
|
|
||||||
## Configs Settings
|
|
||||||
|
|
||||||
important settings
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base:
|
|
||||||
base_dir: "../runs/sam/" # logging dir
|
|
||||||
|
|
||||||
dataset:
|
|
||||||
types: ['3d'] # ['3d', '2d']
|
|
||||||
split: 'train'
|
|
||||||
data_root_path: '../datasets/'
|
|
||||||
dataset_list: ["pancreas"]
|
|
||||||
data_txt_path: './datasets/dataset_list/'
|
|
||||||
dataset2d_path: "../08_AbdomenCT-1K/"
|
|
||||||
cache_data_path: '../cached_dataset2/'
|
|
||||||
|
|
||||||
cache_prefix: ['6016'] # cache prefix of cached dataset for training
|
|
||||||
# For example: ['07',] for 07_WORD
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## Start Training from scratch (SAM)
|
## Start Training
|
||||||
|
|
||||||
Run training on multi-gpu
|
Run training on multi-gpu
|
||||||
|
|
||||||
@ -160,47 +140,6 @@ CUDA_VISIBLE_DEVICES=0 python -m core.ddp --tag debug
|
|||||||
python -m core.volume_predictor
|
python -m core.volume_predictor
|
||||||
```
|
```
|
||||||
|
|
||||||
## Testset Validation
|
|
||||||
```python
|
|
||||||
EX_CONFIG = {
|
|
||||||
'dataset':{
|
|
||||||
'prompt': 'box', # prompt type: box or point
|
|
||||||
'dataset_list': ['word'], # dataset_list name
|
|
||||||
'label_idx': 2, # label index for inference,
|
|
||||||
},
|
|
||||||
"pth": "./model.pth"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
```
|
|
||||||
python -m test.volume_eval
|
|
||||||
```
|
|
||||||
|
|
||||||
## Finetuning (Recommended)
|
|
||||||
```yaml
|
|
||||||
training:
|
|
||||||
breakpoint_path: "./model.pth" # pretrained weight path
|
|
||||||
```
|
|
||||||
|
|
||||||
```
|
|
||||||
python -m core.ddp_sub --tag run
|
|
||||||
```
|
|
||||||
|
|
||||||
## Validation with Finetuned Weights
|
|
||||||
|
|
||||||
```
|
|
||||||
python -m test.volume_eval_sublora
|
|
||||||
```
|
|
||||||
|
|
||||||
```python
|
|
||||||
EX_CONFIG = {
|
|
||||||
'dataset':{
|
|
||||||
'prompt': 'box', # prompt type: box or point
|
|
||||||
'dataset_list': ['word'], # dataset_list name
|
|
||||||
'label_idx': 2, # label index for inference,
|
|
||||||
},
|
|
||||||
"pth": "./model_finetuned.pth"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
<p align="center" width="100%">
|
<p align="center" width="100%">
|
||||||
|
@ -137,35 +137,31 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
EX_CONFIG = {
|
EX_CONFIG = {
|
||||||
'dataset':{
|
'dataset':{
|
||||||
'prompt': 'box', # box / point
|
'prompt': 'box',
|
||||||
'dataset_list': ['example'], # ["sabs"], chaos, word
|
'dataset_list': ['word'], # ["sabs"], chaos, word
|
||||||
'label_idx': 5,
|
'label_idx': 1,
|
||||||
},
|
}
|
||||||
'model_type': "vit_b",
|
|
||||||
'pth': "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth",
|
|
||||||
}
|
|
||||||
|
|
||||||
EX_CONFIG = {
|
|
||||||
'dataset':{
|
|
||||||
'prompt': 'box', # box / point
|
|
||||||
'dataset_list': ['example'], # ["sabs"], chaos, word
|
|
||||||
'label_idx': 5,
|
|
||||||
},
|
|
||||||
'model_type': "vit_h",
|
|
||||||
'pth': "/home1/quanquan/code/projects/finetune_large/segment_anything/model_iter_3935000.pth",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config = ConfigManager()
|
config = ConfigManager()
|
||||||
config.add_config("configs/vit_b.yaml")
|
config.add_config("configs/vit_b_103.yaml")
|
||||||
config.add_config(EX_CONFIG)
|
config.add_config(EX_CONFIG)
|
||||||
config.print()
|
|
||||||
|
|
||||||
# Init Model
|
# Init Model
|
||||||
model_type = config['model_type']
|
model_type = "vit_b"
|
||||||
sam = sam_model_registry[model_type](checkpoint=None)
|
sam = sam_model_registry[model_type](checkpoint=None)
|
||||||
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
||||||
learner.use_lora()
|
learner.use_lora()
|
||||||
learner.load_well_trained_model(EX_CONFIG['pth'])
|
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt_v/model_latest.pth"
|
||||||
|
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_20.pth"
|
||||||
|
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_16.pth"
|
||||||
|
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora_small/ckpt/model_epoch_6.pth"
|
||||||
|
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora/ckpt/model_epoch_50.pth"
|
||||||
|
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_8/ckpt_v/model_latest.pth"
|
||||||
|
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_5/ckpt/model_epoch_100.pth"
|
||||||
|
pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth"
|
||||||
|
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_500000.pth"
|
||||||
|
learner.load_well_trained_model(pth)
|
||||||
learner.cuda()
|
learner.cuda()
|
||||||
predictor = VolumePredictor(
|
predictor = VolumePredictor(
|
||||||
model=learner.model,
|
model=learner.model,
|
||||||
|
@ -1,216 +0,0 @@
|
|||||||
"""
|
|
||||||
Volume evalutaion
|
|
||||||
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
# from datasets.dataset3d import Dataset3D
|
|
||||||
from tutils.new.manager import ConfigManager
|
|
||||||
from datasets.eval_dataloader.loader_abstract import AbstractLoader
|
|
||||||
|
|
||||||
from core.volume_predictor import VolumePredictor
|
|
||||||
from datasets.data_engine import DataManager, BoxPromptGenerator, PointPromptGenerator
|
|
||||||
|
|
||||||
from tutils import tfilename
|
|
||||||
from tutils.new.trainer.recorder import Recorder
|
|
||||||
from trans_utils.metrics import compute_dice_np
|
|
||||||
from trans_utils.data_utils import Data3dSolver
|
|
||||||
from tutils.tutils.ttimer import timer
|
|
||||||
|
|
||||||
|
|
||||||
class Evaluater:
|
|
||||||
def __init__(self, config) -> None:
|
|
||||||
self.config = config
|
|
||||||
self.recorder = Recorder()
|
|
||||||
|
|
||||||
def solve(self, model, dataset, finetune_number=1):
|
|
||||||
# model.eval()
|
|
||||||
self.predictor = model
|
|
||||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
|
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
|
||||||
# if i <4:
|
|
||||||
# print
|
|
||||||
# continue
|
|
||||||
# for k, v in data.items():
|
|
||||||
# if isinstance(v, torch.Tensor):
|
|
||||||
# data[k] = v.to(self.rank)
|
|
||||||
# try:
|
|
||||||
if True:
|
|
||||||
if self.config['dataset']['prompt'] == 'box':
|
|
||||||
dice, pred, label, temp_slice = self.eval_step(data, batch_idx=i)
|
|
||||||
used_slice = [temp_slice]
|
|
||||||
if finetune_number > 1:
|
|
||||||
for i in range(finetune_number - 1):
|
|
||||||
dice, pred, label, temp_slice = self.finetune_with_more_prompt(pred, label, exclude_slide_id=used_slice)
|
|
||||||
used_slice.append(temp_slice)
|
|
||||||
res = {"dice": dice}
|
|
||||||
if self.config['dataset']['prompt'] == 'point':
|
|
||||||
res = self.eval_step_point(data, batch_idx=i)
|
|
||||||
self.recorder.record(res)
|
|
||||||
# except Exception as e:
|
|
||||||
# print(e)
|
|
||||||
res = self.recorder.cal_metrics()
|
|
||||||
print(res)
|
|
||||||
print("prompt:", self.config['dataset']['prompt'], " class_idx:", self.config['dataset']['label_idx'])
|
|
||||||
|
|
||||||
def eval_step(self, data, batch_idx=0):
|
|
||||||
name = data['name']
|
|
||||||
dataset_name = data['dataset_name'][0]
|
|
||||||
label_idx = data['label_idx'][0]
|
|
||||||
template_slice_id = data['template_slice_id'][0]
|
|
||||||
|
|
||||||
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
|
||||||
if template_slice_id == 0:
|
|
||||||
template_slice_id += 1
|
|
||||||
elif template_slice_id == (data['img'].shape[0] - 1):
|
|
||||||
template_slice_id -= 1
|
|
||||||
print("Using slice ", template_slice_id, " as template slice")
|
|
||||||
|
|
||||||
spacing = data['spacing'].numpy().tolist()[0]
|
|
||||||
if data['img'].shape[-1] < 260:
|
|
||||||
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
|
||||||
img = data['img'][0][:,:256,:256]
|
|
||||||
label = data['label'][0][:,:256,:256]
|
|
||||||
else:
|
|
||||||
img = data['img'][0]
|
|
||||||
label = data['label'][0]
|
|
||||||
# img = torch.clip(img, -200, 600)
|
|
||||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
|
||||||
box = np.array([box])
|
|
||||||
pred, stability = self.predictor.predict_volume(
|
|
||||||
x=img,
|
|
||||||
box=box,
|
|
||||||
template_slice_id=template_slice_id,
|
|
||||||
return_stability=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_type = 'box'
|
|
||||||
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
|
||||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
|
||||||
# Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
|
|
||||||
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
|
|
||||||
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
|
|
||||||
print(dataset_name, name, dice)
|
|
||||||
template_slice_id = template_slice_id if isinstance(template_slice_id, int) else template_slice_id.item()
|
|
||||||
return dice, pred, label.detach().cpu().numpy(), template_slice_id
|
|
||||||
|
|
||||||
def eval_step_point(self, data, batch_idx=0):
|
|
||||||
name = data['name']
|
|
||||||
dataset_name = data['dataset_name'][0]
|
|
||||||
label_idx = data['label_idx'][0]
|
|
||||||
template_slice_id = data['template_slice_id'][0]
|
|
||||||
spacing = data['spacing'].numpy().tolist()[0]
|
|
||||||
|
|
||||||
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
|
||||||
if template_slice_id == 0:
|
|
||||||
template_slice_id += 1
|
|
||||||
elif template_slice_id == (data['img'].shape[0] - 1):
|
|
||||||
template_slice_id -= 1
|
|
||||||
|
|
||||||
if data['img'].shape[-1] < 260:
|
|
||||||
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
|
||||||
img = data['img'][0][:,:256,:256]
|
|
||||||
label = data['label'][0][:,:256,:256]
|
|
||||||
else:
|
|
||||||
img = data['img'][0]
|
|
||||||
label = data['label'][0]
|
|
||||||
|
|
||||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
|
||||||
point = (box[0]+box[2])*0.5 , (box[1]+box[3])*0.5
|
|
||||||
point = np.array([point]).astype(int)
|
|
||||||
if label[template_slice_id][point[0,1], point[0,0]] == 0:
|
|
||||||
print("Use random point instead !!!")
|
|
||||||
point = PointPromptGenerator().get_prompt_point(label[template_slice_id])
|
|
||||||
point = np.array([point]).astype(int)
|
|
||||||
# box = np.array([box])
|
|
||||||
pred = self.predictor.predict_volume(
|
|
||||||
x=img,
|
|
||||||
point_coords=point,
|
|
||||||
point_labels=np.ones_like(point)[:,:1],
|
|
||||||
template_slice_id=template_slice_id,
|
|
||||||
)
|
|
||||||
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
|
||||||
prompt_type = 'point'
|
|
||||||
Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
|
||||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}.nii.gz"))
|
|
||||||
print(dataset_name, name, dice)
|
|
||||||
return {"dice": dice}
|
|
||||||
|
|
||||||
def finetune_with_more_prompt(self, pred, label, prompt_type="box", exclude_slide_id=[]):
|
|
||||||
assert pred.shape == label.shape
|
|
||||||
dices = [compute_dice_np(pred[j,:,:], label[j,:,:]) for j in range(pred.shape[0])]
|
|
||||||
rank_list = np.array(dices[1:-1]).argsort() # Ignore the head and tail
|
|
||||||
rank_list += 1 # Ignore the head and tail
|
|
||||||
for i in rank_list:
|
|
||||||
if i in exclude_slide_id:
|
|
||||||
continue
|
|
||||||
template_slice_id = i
|
|
||||||
break
|
|
||||||
# template_slice_id += 1 # Ignore the head and tail
|
|
||||||
print("Using slice ", template_slice_id, " as template slice")
|
|
||||||
old_confidence = self.predictor.get_confidence()
|
|
||||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id])
|
|
||||||
box = np.array([box])
|
|
||||||
new_pred, stability = self.predictor.predict_with_prompt(
|
|
||||||
box=box,
|
|
||||||
template_slice_id=template_slice_id,
|
|
||||||
return_stability=True,
|
|
||||||
)
|
|
||||||
new_confidence = self.predictor.get_confidence()
|
|
||||||
new_confidence[template_slice_id] *= 2
|
|
||||||
all_conf = np.stack([old_confidence, new_confidence], axis=1)
|
|
||||||
preds = [pred, new_pred]
|
|
||||||
merged = np.zeros_like(label)
|
|
||||||
for slice_idx in range(pred.shape[0]):
|
|
||||||
idx = np.argsort(all_conf[slice_idx,:])[-1]
|
|
||||||
merged[slice_idx,:,:] = preds[idx][slice_idx]
|
|
||||||
|
|
||||||
print("old dices", [compute_dice_np(pred, label) for pred in preds])
|
|
||||||
dice = compute_dice_np(merged, label)
|
|
||||||
print("merged dice, idx", dice)
|
|
||||||
return dice, merged, label, template_slice_id
|
|
||||||
|
|
||||||
|
|
||||||
def to_RGB(img):
|
|
||||||
pass
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from core.learner3 import SamLearner
|
|
||||||
from modeling.build_sam3d2 import sam_model_registry
|
|
||||||
EX_CONFIG = {
|
|
||||||
'dataset':{
|
|
||||||
'prompt': 'box',
|
|
||||||
'prompt_number': 5,
|
|
||||||
'dataset_list': ['example'], # ["sabs"], chaos, word pancreas
|
|
||||||
'label_idx': 5,
|
|
||||||
},
|
|
||||||
"lora_r": 24,
|
|
||||||
'model_type': "vit_h",
|
|
||||||
'ckpt': "/home1/quanquan/code/projects/finetune_large/segment_anything/model_iter_3935000.pth",
|
|
||||||
}
|
|
||||||
config = ConfigManager()
|
|
||||||
config.add_config("configs/vit_b.yaml")
|
|
||||||
config.add_config(EX_CONFIG)
|
|
||||||
config.print()
|
|
||||||
dataset = AbstractLoader(config['dataset'], split="test")
|
|
||||||
print(len(dataset))
|
|
||||||
assert len(dataset) >= 1
|
|
||||||
|
|
||||||
# Init Model
|
|
||||||
model_type = config['model_type'] # "vit_b"
|
|
||||||
sam = sam_model_registry[model_type](checkpoint=None)
|
|
||||||
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
|
||||||
learner.use_lora(r=config['lora_r'])
|
|
||||||
pth = config['ckpt']
|
|
||||||
learner.load_well_trained_model(pth)
|
|
||||||
learner.cuda()
|
|
||||||
predictor = VolumePredictor(
|
|
||||||
model=learner.model,
|
|
||||||
use_postprocess=True,
|
|
||||||
use_noise_remove=True,)
|
|
||||||
|
|
||||||
solver = Evaluater(config)
|
|
||||||
dataset = AbstractLoader(config['dataset'], split="test")
|
|
||||||
solver.solve(predictor, dataset, finetune_number=config['dataset']['prompt_number'])
|
|
@ -1,231 +0,0 @@
|
|||||||
"""
|
|
||||||
Volume evalutaion
|
|
||||||
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
# from datasets.dataset3d import Dataset3D
|
|
||||||
from tutils.new.manager import ConfigManager
|
|
||||||
from datasets.eval_dataloader.loader_abstract import AbstractLoader
|
|
||||||
|
|
||||||
from core.volume_predictor import VolumePredictor
|
|
||||||
from datasets.data_engine import DataManager, BoxPromptGenerator, PointPromptGenerator
|
|
||||||
|
|
||||||
from tutils import tfilename
|
|
||||||
from tutils.new.trainer.recorder import Recorder
|
|
||||||
from trans_utils.metrics import compute_dice_np
|
|
||||||
from trans_utils.data_utils import Data3dSolver
|
|
||||||
|
|
||||||
# from monai.metrics import compute_surface_dice
|
|
||||||
import surface_distance as surfdist
|
|
||||||
from tutils.tutils.ttimer import timer
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Evaluater:
|
|
||||||
def __init__(self, config) -> None:
|
|
||||||
self.config = config
|
|
||||||
self.recorder = Recorder()
|
|
||||||
|
|
||||||
def solve(self, model, dataset):
|
|
||||||
# model.eval()
|
|
||||||
self.predictor = model
|
|
||||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
|
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
|
||||||
# if i <4:
|
|
||||||
# print
|
|
||||||
# continue
|
|
||||||
# for k, v in data.items():
|
|
||||||
# if isinstance(v, torch.Tensor):
|
|
||||||
# data[k] = v.to(self.rank)
|
|
||||||
if self.config['dataset']['prompt'] == 'box':
|
|
||||||
# res = self.eval_step_slice(data, batch_idx=i)
|
|
||||||
res = self.eval_step(data, batch_idx=i)
|
|
||||||
if self.config['dataset']['prompt'] == 'point':
|
|
||||||
res = self.eval_step_point(data, batch_idx=i)
|
|
||||||
self.recorder.record(res)
|
|
||||||
res = self.recorder.cal_metrics()
|
|
||||||
print(res)
|
|
||||||
print("prompt:", self.config['dataset']['prompt'], " class_idx:", self.config['dataset']['label_idx'])
|
|
||||||
|
|
||||||
def eval_step(self, data, batch_idx=0):
|
|
||||||
name = data['name']
|
|
||||||
dataset_name = data['dataset_name'][0]
|
|
||||||
label_idx = data['label_idx'][0]
|
|
||||||
template_slice_id = data['template_slice_id'][0]
|
|
||||||
|
|
||||||
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
|
||||||
if template_slice_id == 0:
|
|
||||||
template_slice_id += 1
|
|
||||||
elif template_slice_id == (data['img'].shape[0] - 1):
|
|
||||||
template_slice_id -= 1
|
|
||||||
|
|
||||||
spacing = data['spacing'].numpy().tolist()[0]
|
|
||||||
if data['img'].shape[-1] < 260:
|
|
||||||
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
|
||||||
img = data['img'][0][:,:256,:256]
|
|
||||||
label = data['label'][0][:,:256,:256]
|
|
||||||
else:
|
|
||||||
img = data['img'][0]
|
|
||||||
label = data['label'][0]
|
|
||||||
# img = torch.clip(img, -200, 600)
|
|
||||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
|
||||||
box = np.array([box])
|
|
||||||
pred, stability = self.predictor.predict_volume(
|
|
||||||
x=img,
|
|
||||||
box=box,
|
|
||||||
template_slice_id=template_slice_id,
|
|
||||||
return_stability=True,
|
|
||||||
)
|
|
||||||
prompt_type = 'box'
|
|
||||||
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
|
||||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
|
||||||
# Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
|
|
||||||
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
|
|
||||||
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
|
|
||||||
|
|
||||||
# nsd = compute_surface_dice(torch.Tensor(pred), label.detach().cpu(), 1)
|
|
||||||
|
|
||||||
# surface_distances = surfdist.compute_surface_distances(
|
|
||||||
# label.detach().cpu().numpy(), pred, spacing_mm=(0.6, 0.6445, 0.6445))
|
|
||||||
# nsd = surfdist.compute_surface_dice_at_tolerance(surface_distances, 1)
|
|
||||||
nsd = 0
|
|
||||||
|
|
||||||
print(dataset_name, name, dice, nsd)
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
|
|
||||||
return {"dice": dice, "nsd": nsd}
|
|
||||||
|
|
||||||
def eval_step_point(self, data, batch_idx=0):
|
|
||||||
name = data['name']
|
|
||||||
dataset_name = data['dataset_name'][0]
|
|
||||||
label_idx = data['label_idx'][0]
|
|
||||||
template_slice_id = data['template_slice_id'][0]
|
|
||||||
spacing = data['spacing'].numpy().tolist()[0]
|
|
||||||
|
|
||||||
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
|
||||||
if template_slice_id == 0:
|
|
||||||
template_slice_id += 1
|
|
||||||
elif template_slice_id == (data['img'].shape[0] - 1):
|
|
||||||
template_slice_id -= 1
|
|
||||||
|
|
||||||
if data['img'].shape[-1] < 260:
|
|
||||||
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
|
||||||
img = data['img'][0][:,:256,:256]
|
|
||||||
label = data['label'][0][:,:256,:256]
|
|
||||||
else:
|
|
||||||
img = data['img'][0]
|
|
||||||
label = data['label'][0]
|
|
||||||
|
|
||||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
|
||||||
point = (box[0]+box[2])*0.5 , (box[1]+box[3])*0.5
|
|
||||||
point = np.array([point]).astype(int)
|
|
||||||
if label[template_slice_id][point[0,1], point[0,0]] == 0:
|
|
||||||
print("Use random point instead !!!")
|
|
||||||
point = PointPromptGenerator().get_prompt_point(label[template_slice_id])
|
|
||||||
point = np.array([point]).astype(int)
|
|
||||||
# box = np.array([box])
|
|
||||||
pred = self.predictor.predict_volume(
|
|
||||||
x=img,
|
|
||||||
point_coords=point,
|
|
||||||
point_labels=np.ones_like(point)[:,:1],
|
|
||||||
template_slice_id=template_slice_id,
|
|
||||||
)
|
|
||||||
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
|
||||||
prompt_type = 'point'
|
|
||||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
|
||||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}.nii.gz"))
|
|
||||||
nsd = compute_surface_dice(pred, label.detach().cpu().numpy())
|
|
||||||
print(dataset_name, name, dice)
|
|
||||||
return {"dice": dice, "nsd": nsd}
|
|
||||||
|
|
||||||
def eval_step_slice(self, data, batch_idx=0):
|
|
||||||
name = data['name']
|
|
||||||
dataset_name = data['dataset_name'][0]
|
|
||||||
label_idx = data['label_idx'][0]
|
|
||||||
template_slice_id = data['template_slice_id'][0]
|
|
||||||
|
|
||||||
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
|
||||||
if template_slice_id == 0:
|
|
||||||
template_slice_id += 1
|
|
||||||
elif template_slice_id == (data['img'].shape[0] - 1):
|
|
||||||
template_slice_id -= 1
|
|
||||||
|
|
||||||
spacing = data['spacing'].numpy().tolist()[0]
|
|
||||||
if data['img'].shape[-1] < 260:
|
|
||||||
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
|
||||||
img = data['img'][0][:,:256,:256]
|
|
||||||
label = data['label'][0][:,:256,:256]
|
|
||||||
else:
|
|
||||||
img = data['img'][0]
|
|
||||||
label = data['label'][0]
|
|
||||||
|
|
||||||
img = img[template_slice_id-1:template_slice_id+2, :,:]
|
|
||||||
label = label[template_slice_id-1:template_slice_id+2, :,:]
|
|
||||||
template_slice_id = 1
|
|
||||||
|
|
||||||
# img = torch.clip(img, -200, 600)
|
|
||||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
|
||||||
box = np.array([box])
|
|
||||||
pred, stability = self.predictor.predict_volume(
|
|
||||||
x=img,
|
|
||||||
box=box,
|
|
||||||
template_slice_id=template_slice_id,
|
|
||||||
return_stability=True,
|
|
||||||
)
|
|
||||||
prompt_type = 'box'
|
|
||||||
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
|
||||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
|
||||||
# Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
|
|
||||||
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
|
|
||||||
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
|
|
||||||
print("Slice evaluation: ", dataset_name, name, dice)
|
|
||||||
return {"dice": dice}
|
|
||||||
|
|
||||||
def to_RGB(img):
|
|
||||||
pass
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# from core.learner3 import SamLearner
|
|
||||||
# from modeling.build_sam3d import sam_model_registry
|
|
||||||
|
|
||||||
# from core.learner3 import SamLearner
|
|
||||||
from core.learner_sub1 import SamLearner
|
|
||||||
from modeling.build_sam3d2 import sam_model_registry
|
|
||||||
|
|
||||||
EX_CONFIG = {
|
|
||||||
'dataset':{
|
|
||||||
'prompt': 'box',
|
|
||||||
'dataset_list': ['guangdong'], # ["sabs"], chaos, word, decathlon_colon, pancreas
|
|
||||||
'label_idx': 2,
|
|
||||||
},
|
|
||||||
"pth": "./model_latest.pth"
|
|
||||||
}
|
|
||||||
|
|
||||||
config = ConfigManager()
|
|
||||||
# config.add_config("configs/vit_sub.yaml")
|
|
||||||
config.add_config("configs/vit_sub.yaml")
|
|
||||||
config.add_config(EX_CONFIG)
|
|
||||||
|
|
||||||
# Init Model
|
|
||||||
model_type = "vit_b"
|
|
||||||
sam = sam_model_registry[model_type](checkpoint=None)
|
|
||||||
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
|
||||||
learner.use_lora()
|
|
||||||
learner.use_lora_sub()
|
|
||||||
pth = EX_CONFIG['pth']
|
|
||||||
learner.load_well_trained_model(pth)
|
|
||||||
learner.cuda()
|
|
||||||
predictor = VolumePredictor(
|
|
||||||
model=learner.model,
|
|
||||||
use_postprocess=True,
|
|
||||||
use_noise_remove=True,)
|
|
||||||
|
|
||||||
solver = Evaluater(config)
|
|
||||||
dataset = AbstractLoader(config['dataset'], split="test")
|
|
||||||
tt = timer()
|
|
||||||
solver.solve(predictor, dataset)
|
|
||||||
|
|
||||||
print("Time: ", tt())
|
|
303
tmp.py
303
tmp.py
@ -1,303 +0,0 @@
|
|||||||
"""
|
|
||||||
Slow Loading directly
|
|
||||||
|
|
||||||
So we pre-precess data
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
from einops import rearrange, reduce, repeat
|
|
||||||
from tutils.nn.data import read, itk_to_np, np_to_itk, write
|
|
||||||
from tutils import tfilename
|
|
||||||
from .dataset3d import DATASET_CONFIG, Dataset3D as basic_3d_dataset
|
|
||||||
from monai import transforms
|
|
||||||
import torch
|
|
||||||
import cv2
|
|
||||||
from scipy.sparse import csr_matrix
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torchvision import transforms
|
|
||||||
from einops import rearrange
|
|
||||||
import glob
|
|
||||||
from torchvision import transforms
|
|
||||||
from monai import transforms as monai_transforms
|
|
||||||
|
|
||||||
# "class": ["spleen", "right kidney", "left kidney", "gallbladder", "esophagus", "liver", "stomach", "aorta", "postcava", "portal vein and splenic vein", "pancrease", "right adrenal gland", "left adrenal gland"],
|
|
||||||
# "class": ["liver", "right kidney", "left kidney", "spleen"],
|
|
||||||
TEMPLATE={
|
|
||||||
'01': [1,2,3,4,5,6,7,8,9,10,11,12,13,14],
|
|
||||||
'02': [1,0,3,4,5,6,7,0,0,0,11,0,0,14],
|
|
||||||
'03': [6],
|
|
||||||
'04': [6,27], # post process
|
|
||||||
'05': [2,26,32], # post process
|
|
||||||
'07': [6,1,3,2,7,4,5,11,14,18,19,12,20,21,23,24],
|
|
||||||
'08': [6, 2, 1, 11],
|
|
||||||
'09': [1,2,3,4,5,6,7,8,9,11,12,13,14,21,22],
|
|
||||||
'12': [6,21,16,2],
|
|
||||||
'13': [6,2,1,11,8,9,7,4,5,12,13,25],
|
|
||||||
'14': [11,11,28,28,28], # Felix data, post process
|
|
||||||
'10_03': [6, 27], # post process
|
|
||||||
'10_06': [30],
|
|
||||||
'10_07': [11, 28], # post process
|
|
||||||
'10_08': [15, 29], # post process
|
|
||||||
'10_09': [1],
|
|
||||||
'10_10': [31],
|
|
||||||
'58': [6,2,3,1],
|
|
||||||
'59': [1,2,3,4,5,6,7,8,9,10,11,12,13,14],
|
|
||||||
'60': (np.ones(200)*(-1)).tolist(), # for debug
|
|
||||||
"65": np.zeros(200).tolist(),
|
|
||||||
}
|
|
||||||
|
|
||||||
EX_CONFIG={
|
|
||||||
"dataset":{
|
|
||||||
"split": 'train',
|
|
||||||
"data_root_path": '/home1/quanquan/datasets/',
|
|
||||||
"dataset_list": ["decathlon_colon"],
|
|
||||||
"data_txt_path": './datasets/dataset_list/',
|
|
||||||
"cache_data_path": '/home1/quanquan/datasets/cached_dataset2/',
|
|
||||||
"cache_prefix": ['10'] # '07'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class Dataset3D(basic_3d_dataset):
|
|
||||||
def __init__(self, config, use_cache=True, *args, **kwargs) -> None:
|
|
||||||
super().__init__(config=config, use_cache=use_cache, *args, **kwargs)
|
|
||||||
self.basic_dir = config['data_root_path']
|
|
||||||
self.cache_dir = config['cache_data_path']
|
|
||||||
|
|
||||||
def prepare_transforms(self):
|
|
||||||
self.transform = transforms.Compose([
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Resize((1024,1024)),
|
|
||||||
])
|
|
||||||
self.test_transform = transforms.Compose([
|
|
||||||
monai_transforms.Resized(keys=['img', 'label'], spatial_size=(3,1024,1024)),
|
|
||||||
])
|
|
||||||
|
|
||||||
# @tfunctime
|
|
||||||
# def prepare_datalist(self):
|
|
||||||
def prepare_cached_datalist(self):
|
|
||||||
raise DeprecationWarning("[Warning] Please use cache_dataset3d new version instead!")
|
|
||||||
config = self.config
|
|
||||||
data_paths = []
|
|
||||||
for dirpath in glob.glob(config['cache_data_path'] + "/*"):
|
|
||||||
data_paths += glob.glob(dirpath + "/image/*.jpg")
|
|
||||||
print("Load ", dirpath)
|
|
||||||
print('train len {}'.format(len(data_paths)))
|
|
||||||
print('Examples: ', data_paths[:2])
|
|
||||||
return data_paths
|
|
||||||
|
|
||||||
def caching_data(self):
|
|
||||||
assert self.use_cache == False
|
|
||||||
for index in range(len(self)):
|
|
||||||
self.cache_one_sample(index)
|
|
||||||
|
|
||||||
def cache_one_sample(self, index, debug=False):
|
|
||||||
# LABEL_INDICES
|
|
||||||
name = self.img_names[index]['img_path']
|
|
||||||
img_itk = read(self.img_names[index]['img_path'])
|
|
||||||
img_ori = itk_to_np(img_itk)
|
|
||||||
|
|
||||||
img_ori = np.clip(img_ori, -200,400)
|
|
||||||
|
|
||||||
# spacing = img_itk.GetSpacing()
|
|
||||||
scan_orientation = np.argmin(img_ori.shape)
|
|
||||||
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
|
|
||||||
|
|
||||||
dataset_name = self.img_names[index]['img_path'].replace(self.basic_dir,"").split("/")[0]
|
|
||||||
assert dataset_name[0] in ['0','1','2','3','4','5','6','7','8','9'], f"Got {dataset_name}"
|
|
||||||
if dataset_name[:2] == "10":
|
|
||||||
subname = self.img_names[index]['img_path'].replace(self.basic_dir,"")[17:19]
|
|
||||||
assert subname in ['10', '03', '06', '07']
|
|
||||||
all_labels = TEMPLATE[dataset_name[:2] + "_" + subname]
|
|
||||||
else:
|
|
||||||
all_labels = TEMPLATE[dataset_name[:2]]
|
|
||||||
|
|
||||||
num = 0
|
|
||||||
|
|
||||||
# if min(img_ori.shape) * 1.2 < max(img_ori.shape):
|
|
||||||
# orientation_all = [scan_orientation]
|
|
||||||
# else:
|
|
||||||
# orientation_all = [0,1,2]
|
|
||||||
orientation_all = [scan_orientation]
|
|
||||||
|
|
||||||
for orientation in orientation_all:
|
|
||||||
for slice_idx in range(2, img_ori.shape[orientation]-2):
|
|
||||||
# slice_idx = np.random.randint(2, img_ori.shape[orientation]-2)
|
|
||||||
if orientation == 0:
|
|
||||||
s = img_ori[slice_idx-1:slice_idx+2, :,:]
|
|
||||||
lb = label_ori[slice_idx-1:slice_idx+2, :,:]
|
|
||||||
# spacing = (spacing[1], spacing[2])
|
|
||||||
if orientation == 1:
|
|
||||||
s = img_ori[:,slice_idx-1:slice_idx+2,:]
|
|
||||||
s = rearrange(s, "h c w -> c h w")
|
|
||||||
lb = label_ori[:,slice_idx-1:slice_idx+2,:]
|
|
||||||
lb = rearrange(lb, "h c w -> c h w")
|
|
||||||
# spacing = (spacing[0], spacing[2])
|
|
||||||
if orientation == 2:
|
|
||||||
s = img_ori[:,:,slice_idx-1:slice_idx+2]
|
|
||||||
s = rearrange(s, "h w c -> c h w")
|
|
||||||
lb = label_ori[:,:,slice_idx-1:slice_idx+2]
|
|
||||||
lb = rearrange(lb, "h w c -> c h w")
|
|
||||||
# spacing = (spacing[0], spacing[1])
|
|
||||||
assert s.shape[0] == 3
|
|
||||||
|
|
||||||
# if np.float32(lb[1,:,:]>0).sum() <= 200:
|
|
||||||
# # return self._get_data((index+1)%len(self))
|
|
||||||
# continue
|
|
||||||
# Choose one label
|
|
||||||
label_num = int(lb.max())
|
|
||||||
|
|
||||||
masks_data = []
|
|
||||||
meta = {"img_name": name, "slice": slice_idx, "orientation": orientation, "label_idx": [], "labels": [], "id": f"{num:08d}" }
|
|
||||||
for label_idx in range(1,label_num+1):
|
|
||||||
one_lb = np.float32(lb==label_idx)
|
|
||||||
if one_lb[1,:,:].sum() <= (one_lb.shape[-1] * one_lb.shape[-2] * 0.0014):
|
|
||||||
continue
|
|
||||||
# if one_lb[0,:,:].sum()<=50 or one_lb[2,:,:].sum()<=50:
|
|
||||||
|
|
||||||
masks_data.append(one_lb)
|
|
||||||
meta['label_idx'].append(label_idx)
|
|
||||||
meta['labels'].append(all_labels[label_idx-1])
|
|
||||||
|
|
||||||
if len(masks_data) <= 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
img_rgb = s
|
|
||||||
img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy()
|
|
||||||
img_rgb = self.to_RGB(img_rgb)
|
|
||||||
save_image_name = tfilename(self.cache_dir, dataset_name, f"image/image_{index:04d}_{num:08d}.jpg")
|
|
||||||
self.save_img_rgb(rearrange(img_rgb, "c h w -> h w c"), save_image_name)
|
|
||||||
|
|
||||||
# Save cache data
|
|
||||||
save_label_name = tfilename(self.cache_dir, dataset_name, f"label_jpg/label_{index:04d}_{num:08d}")
|
|
||||||
self.save_slice_mask(masks_data, save_label_name)
|
|
||||||
print("Save ", save_image_name)
|
|
||||||
|
|
||||||
self.save_meta(meta, tfilename(self.cache_dir, dataset_name, f"meta/meta_{index:04d}_{num:08d}.npy"))
|
|
||||||
|
|
||||||
num += 1
|
|
||||||
|
|
||||||
def save_meta(self, meta, path):
|
|
||||||
assert path.endswith(".npy")
|
|
||||||
np.save(path, meta)
|
|
||||||
|
|
||||||
def save_slice_mask(self, masks_data, prefix):
|
|
||||||
masks_data = F.interpolate(torch.Tensor(masks_data), size=(1024,1024)).numpy()
|
|
||||||
assert masks_data.shape[1:] == (3,1024,1024), f"{__file__} Got{masks_data.shape}"
|
|
||||||
for i in range(masks_data.shape[0]):
|
|
||||||
labeli = masks_data[i].astype(np.uint8) * 255
|
|
||||||
assert labeli.sum() > 0
|
|
||||||
path = tfilename(prefix+f"_{i:04d}.jpg")
|
|
||||||
cv2.imwrite(path, rearrange(labeli, "c h w -> h w c"))
|
|
||||||
print("save to ", path)
|
|
||||||
|
|
||||||
def _old_save_slice_mask(self, masks_data, path):
|
|
||||||
raise DeprecationWarning()
|
|
||||||
exit(0)
|
|
||||||
assert path.endswith(".npz")
|
|
||||||
# masks_data = np.array([m['segmentation'] for m in masks]).astype(int)
|
|
||||||
masks_data = F.interpolate(torch.Tensor(masks_data), size=(1024,1024)).numpy()
|
|
||||||
# masks_data = np.int8(masks_data>0)
|
|
||||||
assert masks_data.shape[1:] == (3,1024,1024), f"{__file__} Got{masks_data.shape}"
|
|
||||||
masks_data = rearrange(masks_data, "n c h w -> n (c h w)")
|
|
||||||
csr = csr_matrix(masks_data)
|
|
||||||
np.savez_compressed(path, data=csr.data, indices=csr.indices, indptr=csr.indptr, shape=csr.shape)
|
|
||||||
|
|
||||||
def save_img_rgb(self, img, path):
|
|
||||||
assert path.endswith(".jpg")
|
|
||||||
assert img.shape == (1024,1024,3)
|
|
||||||
cv2.imwrite(path, img.astype(np.uint8))
|
|
||||||
|
|
||||||
def _get_cached_data(self, index):
|
|
||||||
name = self.img_names[index]
|
|
||||||
# print(name)
|
|
||||||
img = cv2.imread(name)
|
|
||||||
compressed = np.load(name.replace("image/image_", "label/label_").replace(".jpg", ".npz"))
|
|
||||||
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
|
|
||||||
label_ori = csr.toarray()
|
|
||||||
label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024)
|
|
||||||
meta = np.load(name.replace("image/image_", "meta/meta_").replace(".jpg", ".npy"), allow_pickle=True).tolist()
|
|
||||||
# print(meta)
|
|
||||||
pp = reduce(label_ori[:,1,:,:], "n h w -> n", reduction="sum") > 500
|
|
||||||
if pp.sum() == 0:
|
|
||||||
return self._get_cached_data((index+1)%len(self))
|
|
||||||
|
|
||||||
label_idx = np.random.choice(a=np.arange(len(pp)), p=pp/pp.sum())
|
|
||||||
# label_idx = np.random.randint(0, label_ori.shape[0])
|
|
||||||
label_ori = label_ori[label_idx]
|
|
||||||
is_edge = meta.get('is_edge', 0)
|
|
||||||
return rearrange(img, "h w c -> c h w"), label_ori, name, meta['labels'][label_idx], meta['label_idx'][label_idx]
|
|
||||||
|
|
||||||
# @tfunctime
|
|
||||||
def __getitem__(self, index, debug=False):
|
|
||||||
# print("Dataset warning", index, len(self))
|
|
||||||
index = index % len(self)
|
|
||||||
img_rgb, label_ori, name, label_idx, local_idx = self._get_cached_data(index)
|
|
||||||
|
|
||||||
if label_ori.sum() <= 0:
|
|
||||||
print("[Label Error] ", name)
|
|
||||||
return self.__getitem__(index+1)
|
|
||||||
|
|
||||||
# assert len(img_rgb.shape) == 3, f"{__file__} Got{img_rgb.shape}"
|
|
||||||
# img_rgb = self.transform((img_rgb[None,:,:,:]))
|
|
||||||
img_rgb = F.interpolate(torch.Tensor(img_rgb).unsqueeze(0), size=(1024,1024)).squeeze().numpy()
|
|
||||||
|
|
||||||
vector = np.ones(3)
|
|
||||||
ret_dict = {
|
|
||||||
"name": name,
|
|
||||||
"img": img_rgb,
|
|
||||||
"label": label_ori,
|
|
||||||
"indicators": vector,
|
|
||||||
"class": label_idx,
|
|
||||||
"local_idx": local_idx,
|
|
||||||
}
|
|
||||||
return ret_dict
|
|
||||||
|
|
||||||
def _convert_one_mask_from_npz_to_jpg(self, path1=None):
|
|
||||||
# path1 = "/home1/quanquan/datasets/cached_dataset2/01_BCV-Abdomen/label/label_0129_00000043.npz" # 32K
|
|
||||||
prefix = path1.replace(".npz", "").replace("/label/", "/label_jpg/")
|
|
||||||
compressed = np.load(path1)
|
|
||||||
csr = csr_matrix((compressed['data'], compressed['indices'], compressed['indptr']), shape=compressed['shape'])
|
|
||||||
label_ori = csr.toarray()
|
|
||||||
label_ori = rearrange(label_ori, "n (c h w) -> n c h w", c=3, h=1024, w=1024)
|
|
||||||
# print(label_ori.shape)
|
|
||||||
for i in range(label_ori.shape[0]):
|
|
||||||
labeli = label_ori[i]
|
|
||||||
path = tfilename(prefix+f"_{i:04d}.jpg")
|
|
||||||
cv2.imwrite(path, rearrange(labeli, "c h w -> h w c").astype(np.uint8))
|
|
||||||
print("save to ", path)
|
|
||||||
|
|
||||||
def convert_masks_types(self):
|
|
||||||
assert self.use_cache == True
|
|
||||||
for index in range(len(self)):
|
|
||||||
name = self.img_names[index]
|
|
||||||
label_path = name.replace("image/image_", "label/label_").replace(".jpg", ".npz")
|
|
||||||
self._convert_one_mask_from_npz_to_jpg(label_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# def go_cache():
|
|
||||||
from tutils.new.manager import ConfigManager
|
|
||||||
config = ConfigManager()
|
|
||||||
config.add_config("configs/vit_sub.yaml")
|
|
||||||
config.add_config(EX_CONFIG)
|
|
||||||
dataset = Dataset3D(config=config['dataset'], use_cache=False)
|
|
||||||
dataset.caching_data()
|
|
||||||
# config.add_config("configs/vit_b_word_103.yaml")
|
|
||||||
# dataset = Dataset3D(config=config['dataset'], use_cache=True)
|
|
||||||
# dataset.caching_data()
|
|
||||||
# dataset.convert_masks_types()
|
|
||||||
|
|
||||||
# from tutils.new.manager import ConfigManager
|
|
||||||
# config = ConfigManager()
|
|
||||||
# config.add_config("configs/vit_b_103.yaml")
|
|
||||||
# dataset = Dataset3D(config=config['dataset']) # , use_cache=True
|
|
||||||
# data = dataset.__getitem__(0)
|
|
||||||
|
|
||||||
# # import ipdb; ipdb.set_trace()
|
|
||||||
# from torch.utils.data import DataLoader
|
|
||||||
# loader = DataLoader(dataset, batch_size=8)
|
|
||||||
# for batch in loader:
|
|
||||||
# print(batch['img'].shape, batch['label'].shape)
|
|
||||||
# print(data['label'].max())
|
|
||||||
# # import ipdb; ipdb.set_trace()
|
|
BIN
trans_utils/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
trans_utils/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
trans_utils/__pycache__/trainer_ddp.cpython-38.pyc
Normal file
BIN
trans_utils/__pycache__/trainer_ddp.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
utils/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
utils/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/amg.cpython-38.pyc
Normal file
BIN
utils/__pycache__/amg.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/amg3d.cpython-38.pyc
Normal file
BIN
utils/__pycache__/amg3d.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/masks3d_utils.cpython-38.pyc
Normal file
BIN
utils/__pycache__/masks3d_utils.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/transforms.cpython-310.pyc
Normal file
BIN
utils/__pycache__/transforms.cpython-310.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/transforms.cpython-38.pyc
Normal file
BIN
utils/__pycache__/transforms.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/transforms3d.cpython-38.pyc
Normal file
BIN
utils/__pycache__/transforms3d.cpython-38.pyc
Normal file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user