From 173762f7566f38e96fbb6303a99ab77e0e77448b Mon Sep 17 00:00:00 2001 From: transcendentsky Date: Wed, 20 Mar 2024 15:58:48 +0800 Subject: [PATCH] add finetune --- configs/vit_sub.yaml | 50 ++++++++ core/volume_predictor.py | 2 +- readme.md | 59 ++++++++- test/volume_eval.py | 16 +-- test/volume_eval_sublora.py | 231 ++++++++++++++++++++++++++++++++++++ 5 files changed, 344 insertions(+), 14 deletions(-) create mode 100644 configs/vit_sub.yaml create mode 100644 test/volume_eval_sublora.py diff --git a/configs/vit_sub.yaml b/configs/vit_sub.yaml new file mode 100644 index 0000000..787fda3 --- /dev/null +++ b/configs/vit_sub.yaml @@ -0,0 +1,50 @@ +# ---------------------- Common Configs -------------------------- +base: + base_dir: "../runs/sam/" + tag: '' + stage: '' +logger: + mode: ['tb', ] +# mode: '' + recorder_reduction: 'mean' + +training: + save_mode: ['all', 'best', 'latest'] # , + batch_size : 2 # 8 for A100 + num_workers : 8 + num_epochs : 100 # epochs + use_amp: false + save_interval : 1 + val_check_interval: 6 + load_pretrain_model: false + + # optim: + lr: 0.00002 + decay_step: 2000 + decay_gamma: 0.8 + weight_decay: 0.0001 + alpha: 0.99 + validation_interval: 100 + + sam_checkpoint: "/home1/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server + model_type: "vit_b" + + continue_training: false + load_optimizer: false + breakpoint_path: "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth" + +dataset: + types: ['3d'] # ['3d', '2d'] + split: 'train' + data_root_path: '/home1/quanquan/datasets/' + dataset_list: ["pancreas"] + data_txt_path: './datasets/dataset_list/' + dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/" + cache_data_path: '/home1/quanquan/datasets/cached_dataset2/' + + cache_prefix: ['6016'] # '07' + specific_label: [2] + +test: + batch_size: 1 + diff --git a/core/volume_predictor.py b/core/volume_predictor.py index ecfe397..88e77cc 100644 --- a/core/volume_predictor.py +++ b/core/volume_predictor.py @@ -573,7 +573,7 @@ if __name__ == "__main__": sam = sam_model_registry[model_type](checkpoint=None) learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024))) learner.use_lora() - pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth" + pth = "model_iter_360000.pth" learner.load_well_trained_model(pth) learner.cuda() diff --git a/readme.md b/readme.md index 5d068d6..ca159ec 100644 --- a/readme.md +++ b/readme.md @@ -118,9 +118,29 @@ Then run python -m datasets.cache_dataset3d ``` +## Configs Settings + +important settings + +```yaml +base: + base_dir: "../runs/sam/" # logging dir + +dataset: + types: ['3d'] # ['3d', '2d'] + split: 'train' + data_root_path: '../datasets/' + dataset_list: ["pancreas"] + data_txt_path: './datasets/dataset_list/' + dataset2d_path: "../08_AbdomenCT-1K/" + cache_data_path: '../cached_dataset2/' + + cache_prefix: ['6016'] # cache prefix of cached dataset for training + # For example: ['07',] for 07_WORD +``` -## Start Training +## Start Training from scratch (SAM) Run training on multi-gpu @@ -141,11 +161,46 @@ python -m core.volume_predictor ``` ## Testset Validation - +```python +EX_CONFIG = { + 'dataset':{ + 'prompt': 'box', # prompt type: box or point + 'dataset_list': ['word'], # dataset_list name + 'label_idx': 2, # label index for inference, + }, + "pth": "./model.pth" + } +``` ``` python -m test.volume_eval ``` +## Finetuning (Recommended) +```yaml +training: + breakpoint_path: "./model.pth" # pretrained weight path +``` + +``` +python -m core.ddp_sub --tag run +``` + +## Validation with Finetuned Weights + +``` + python -m test.volume_eval_sublora +``` + +```python +EX_CONFIG = { + 'dataset':{ + 'prompt': 'box', # prompt type: box or point + 'dataset_list': ['word'], # dataset_list name + 'label_idx': 2, # label index for inference, + }, + "pth": "./model_finetuned.pth" + } +```

diff --git a/test/volume_eval.py b/test/volume_eval.py index d18ea18..c12fef4 100644 --- a/test/volume_eval.py +++ b/test/volume_eval.py @@ -140,28 +140,22 @@ if __name__ == "__main__": 'prompt': 'box', 'dataset_list': ['word'], # ["sabs"], chaos, word 'label_idx': 1, - } + }, + 'pth': "model_latest.pth" } config = ConfigManager() config.add_config("configs/vit_b_103.yaml") config.add_config(EX_CONFIG) + print(config) + # Init Model model_type = "vit_b" sam = sam_model_registry[model_type](checkpoint=None) learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024))) learner.use_lora() - # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt_v/model_latest.pth" - # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_20.pth" - # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_16.pth" - # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora_small/ckpt/model_epoch_6.pth" - # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora/ckpt/model_epoch_50.pth" - # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_8/ckpt_v/model_latest.pth" - # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_5/ckpt/model_epoch_100.pth" - pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth" - # pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_500000.pth" - learner.load_well_trained_model(pth) + learner.load_well_trained_model(EX_CONFIG['pth']) learner.cuda() predictor = VolumePredictor( model=learner.model, diff --git a/test/volume_eval_sublora.py b/test/volume_eval_sublora.py new file mode 100644 index 0000000..4a1a642 --- /dev/null +++ b/test/volume_eval_sublora.py @@ -0,0 +1,231 @@ +""" + Volume evalutaion + +""" +import torch +import numpy as np +from torch.utils.data import DataLoader +# from datasets.dataset3d import Dataset3D +from tutils.new.manager import ConfigManager +from datasets.eval_dataloader.loader_abstract import AbstractLoader + +from core.volume_predictor import VolumePredictor +from datasets.data_engine import DataManager, BoxPromptGenerator, PointPromptGenerator + +from tutils import tfilename +from tutils.new.trainer.recorder import Recorder +from trans_utils.metrics import compute_dice_np +from trans_utils.data_utils import Data3dSolver + +# from monai.metrics import compute_surface_dice +import surface_distance as surfdist +from tutils.tutils.ttimer import timer + + + +class Evaluater: + def __init__(self, config) -> None: + self.config = config + self.recorder = Recorder() + + def solve(self, model, dataset): + # model.eval() + self.predictor = model + dataloader = DataLoader(dataset, batch_size=1, shuffle=False) + + for i, data in enumerate(dataloader): + # if i <4: + # print + # continue + # for k, v in data.items(): + # if isinstance(v, torch.Tensor): + # data[k] = v.to(self.rank) + if self.config['dataset']['prompt'] == 'box': + # res = self.eval_step_slice(data, batch_idx=i) + res = self.eval_step(data, batch_idx=i) + if self.config['dataset']['prompt'] == 'point': + res = self.eval_step_point(data, batch_idx=i) + self.recorder.record(res) + res = self.recorder.cal_metrics() + print(res) + print("prompt:", self.config['dataset']['prompt'], " class_idx:", self.config['dataset']['label_idx']) + + def eval_step(self, data, batch_idx=0): + name = data['name'] + dataset_name = data['dataset_name'][0] + label_idx = data['label_idx'][0] + template_slice_id = data['template_slice_id'][0] + + assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}" + if template_slice_id == 0: + template_slice_id += 1 + elif template_slice_id == (data['img'].shape[0] - 1): + template_slice_id -= 1 + + spacing = data['spacing'].numpy().tolist()[0] + if data['img'].shape[-1] < 260: + # assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}" + img = data['img'][0][:,:256,:256] + label = data['label'][0][:,:256,:256] + else: + img = data['img'][0] + label = data['label'][0] + # img = torch.clip(img, -200, 600) + box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy()) + box = np.array([box]) + pred, stability = self.predictor.predict_volume( + x=img, + box=box, + template_slice_id=template_slice_id, + return_stability=True, + ) + prompt_type = 'box' + dice = compute_dice_np(pred, label.detach().cpu().numpy()) + # Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing) + # Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz")) + # Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz")) + # np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability) + + # nsd = compute_surface_dice(torch.Tensor(pred), label.detach().cpu(), 1) + + # surface_distances = surfdist.compute_surface_distances( + # label.detach().cpu().numpy(), pred, spacing_mm=(0.6, 0.6445, 0.6445)) + # nsd = surfdist.compute_surface_dice_at_tolerance(surface_distances, 1) + nsd = 0 + + print(dataset_name, name, dice, nsd) + # import ipdb; ipdb.set_trace() + + return {"dice": dice, "nsd": nsd} + + def eval_step_point(self, data, batch_idx=0): + name = data['name'] + dataset_name = data['dataset_name'][0] + label_idx = data['label_idx'][0] + template_slice_id = data['template_slice_id'][0] + spacing = data['spacing'].numpy().tolist()[0] + + assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}" + if template_slice_id == 0: + template_slice_id += 1 + elif template_slice_id == (data['img'].shape[0] - 1): + template_slice_id -= 1 + + if data['img'].shape[-1] < 260: + # assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}" + img = data['img'][0][:,:256,:256] + label = data['label'][0][:,:256,:256] + else: + img = data['img'][0] + label = data['label'][0] + + box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy()) + point = (box[0]+box[2])*0.5 , (box[1]+box[3])*0.5 + point = np.array([point]).astype(int) + if label[template_slice_id][point[0,1], point[0,0]] == 0: + print("Use random point instead !!!") + point = PointPromptGenerator().get_prompt_point(label[template_slice_id]) + point = np.array([point]).astype(int) + # box = np.array([box]) + pred = self.predictor.predict_volume( + x=img, + point_coords=point, + point_labels=np.ones_like(point)[:,:1], + template_slice_id=template_slice_id, + ) + dice = compute_dice_np(pred, label.detach().cpu().numpy()) + prompt_type = 'point' + # Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing) + # Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}.nii.gz")) + nsd = compute_surface_dice(pred, label.detach().cpu().numpy()) + print(dataset_name, name, dice) + return {"dice": dice, "nsd": nsd} + + def eval_step_slice(self, data, batch_idx=0): + name = data['name'] + dataset_name = data['dataset_name'][0] + label_idx = data['label_idx'][0] + template_slice_id = data['template_slice_id'][0] + + assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}" + if template_slice_id == 0: + template_slice_id += 1 + elif template_slice_id == (data['img'].shape[0] - 1): + template_slice_id -= 1 + + spacing = data['spacing'].numpy().tolist()[0] + if data['img'].shape[-1] < 260: + # assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}" + img = data['img'][0][:,:256,:256] + label = data['label'][0][:,:256,:256] + else: + img = data['img'][0] + label = data['label'][0] + + img = img[template_slice_id-1:template_slice_id+2, :,:] + label = label[template_slice_id-1:template_slice_id+2, :,:] + template_slice_id = 1 + + # img = torch.clip(img, -200, 600) + box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy()) + box = np.array([box]) + pred, stability = self.predictor.predict_volume( + x=img, + box=box, + template_slice_id=template_slice_id, + return_stability=True, + ) + prompt_type = 'box' + dice = compute_dice_np(pred, label.detach().cpu().numpy()) + # Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing) + # Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz")) + # Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz")) + # np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability) + print("Slice evaluation: ", dataset_name, name, dice) + return {"dice": dice} + +def to_RGB(img): + pass + +if __name__ == "__main__": + # from core.learner3 import SamLearner + # from modeling.build_sam3d import sam_model_registry + + # from core.learner3 import SamLearner + from core.learner_sub1 import SamLearner + from modeling.build_sam3d2 import sam_model_registry + + EX_CONFIG = { + 'dataset':{ + 'prompt': 'box', + 'dataset_list': ['guangdong'], # ["sabs"], chaos, word, decathlon_colon, pancreas + 'label_idx': 2, + }, + "pth": "./model_latest.pth" + } + + config = ConfigManager() + # config.add_config("configs/vit_sub.yaml") + config.add_config("configs/vit_sub.yaml") + config.add_config(EX_CONFIG) + + # Init Model + model_type = "vit_b" + sam = sam_model_registry[model_type](checkpoint=None) + learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024))) + learner.use_lora() + learner.use_lora_sub() + pth = EX_CONFIG['pth'] + learner.load_well_trained_model(pth) + learner.cuda() + predictor = VolumePredictor( + model=learner.model, + use_postprocess=True, + use_noise_remove=True,) + + solver = Evaluater(config) + dataset = AbstractLoader(config['dataset'], split="test") + tt = timer() + solver.solve(predictor, dataset) + + print("Time: ", tt()) \ No newline at end of file