add finetune
This commit is contained in:
parent
840009725f
commit
173762f756
50
configs/vit_sub.yaml
Normal file
50
configs/vit_sub.yaml
Normal file
@ -0,0 +1,50 @@
|
||||
# ---------------------- Common Configs --------------------------
|
||||
base:
|
||||
base_dir: "../runs/sam/"
|
||||
tag: ''
|
||||
stage: ''
|
||||
logger:
|
||||
mode: ['tb', ]
|
||||
# mode: ''
|
||||
recorder_reduction: 'mean'
|
||||
|
||||
training:
|
||||
save_mode: ['all', 'best', 'latest'] # ,
|
||||
batch_size : 2 # 8 for A100
|
||||
num_workers : 8
|
||||
num_epochs : 100 # epochs
|
||||
use_amp: false
|
||||
save_interval : 1
|
||||
val_check_interval: 6
|
||||
load_pretrain_model: false
|
||||
|
||||
# optim:
|
||||
lr: 0.00002
|
||||
decay_step: 2000
|
||||
decay_gamma: 0.8
|
||||
weight_decay: 0.0001
|
||||
alpha: 0.99
|
||||
validation_interval: 100
|
||||
|
||||
sam_checkpoint: "/home1/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server
|
||||
model_type: "vit_b"
|
||||
|
||||
continue_training: false
|
||||
load_optimizer: false
|
||||
breakpoint_path: "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth"
|
||||
|
||||
dataset:
|
||||
types: ['3d'] # ['3d', '2d']
|
||||
split: 'train'
|
||||
data_root_path: '/home1/quanquan/datasets/'
|
||||
dataset_list: ["pancreas"]
|
||||
data_txt_path: './datasets/dataset_list/'
|
||||
dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/"
|
||||
cache_data_path: '/home1/quanquan/datasets/cached_dataset2/'
|
||||
|
||||
cache_prefix: ['6016'] # '07'
|
||||
specific_label: [2]
|
||||
|
||||
test:
|
||||
batch_size: 1
|
||||
|
@ -573,7 +573,7 @@ if __name__ == "__main__":
|
||||
sam = sam_model_registry[model_type](checkpoint=None)
|
||||
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
||||
learner.use_lora()
|
||||
pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth"
|
||||
pth = "model_iter_360000.pth"
|
||||
learner.load_well_trained_model(pth)
|
||||
learner.cuda()
|
||||
|
||||
|
59
readme.md
59
readme.md
@ -118,9 +118,29 @@ Then run
|
||||
python -m datasets.cache_dataset3d
|
||||
```
|
||||
|
||||
## Configs Settings
|
||||
|
||||
important settings
|
||||
|
||||
```yaml
|
||||
base:
|
||||
base_dir: "../runs/sam/" # logging dir
|
||||
|
||||
dataset:
|
||||
types: ['3d'] # ['3d', '2d']
|
||||
split: 'train'
|
||||
data_root_path: '../datasets/'
|
||||
dataset_list: ["pancreas"]
|
||||
data_txt_path: './datasets/dataset_list/'
|
||||
dataset2d_path: "../08_AbdomenCT-1K/"
|
||||
cache_data_path: '../cached_dataset2/'
|
||||
|
||||
cache_prefix: ['6016'] # cache prefix of cached dataset for training
|
||||
# For example: ['07',] for 07_WORD
|
||||
```
|
||||
|
||||
|
||||
## Start Training
|
||||
## Start Training from scratch (SAM)
|
||||
|
||||
Run training on multi-gpu
|
||||
|
||||
@ -141,11 +161,46 @@ python -m core.volume_predictor
|
||||
```
|
||||
|
||||
## Testset Validation
|
||||
|
||||
```python
|
||||
EX_CONFIG = {
|
||||
'dataset':{
|
||||
'prompt': 'box', # prompt type: box or point
|
||||
'dataset_list': ['word'], # dataset_list name
|
||||
'label_idx': 2, # label index for inference,
|
||||
},
|
||||
"pth": "./model.pth"
|
||||
}
|
||||
```
|
||||
```
|
||||
python -m test.volume_eval
|
||||
```
|
||||
|
||||
## Finetuning (Recommended)
|
||||
```yaml
|
||||
training:
|
||||
breakpoint_path: "./model.pth" # pretrained weight path
|
||||
```
|
||||
|
||||
```
|
||||
python -m core.ddp_sub --tag run
|
||||
```
|
||||
|
||||
## Validation with Finetuned Weights
|
||||
|
||||
```
|
||||
python -m test.volume_eval_sublora
|
||||
```
|
||||
|
||||
```python
|
||||
EX_CONFIG = {
|
||||
'dataset':{
|
||||
'prompt': 'box', # prompt type: box or point
|
||||
'dataset_list': ['word'], # dataset_list name
|
||||
'label_idx': 2, # label index for inference,
|
||||
},
|
||||
"pth": "./model_finetuned.pth"
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
<p align="center" width="100%">
|
||||
|
@ -140,28 +140,22 @@ if __name__ == "__main__":
|
||||
'prompt': 'box',
|
||||
'dataset_list': ['word'], # ["sabs"], chaos, word
|
||||
'label_idx': 1,
|
||||
}
|
||||
},
|
||||
'pth': "model_latest.pth"
|
||||
}
|
||||
|
||||
config = ConfigManager()
|
||||
config.add_config("configs/vit_b_103.yaml")
|
||||
config.add_config(EX_CONFIG)
|
||||
|
||||
print(config)
|
||||
|
||||
# Init Model
|
||||
model_type = "vit_b"
|
||||
sam = sam_model_registry[model_type](checkpoint=None)
|
||||
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
||||
learner.use_lora()
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt_v/model_latest.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_20.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora+edge2/ckpt/model_epoch_16.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b3/lora_small/ckpt/model_epoch_6.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora/ckpt/model_epoch_50.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_8/ckpt_v/model_latest.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b11/spec_5/ckpt/model_epoch_100.pth"
|
||||
pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth"
|
||||
# pth = "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_500000.pth"
|
||||
learner.load_well_trained_model(pth)
|
||||
learner.load_well_trained_model(EX_CONFIG['pth'])
|
||||
learner.cuda()
|
||||
predictor = VolumePredictor(
|
||||
model=learner.model,
|
||||
|
231
test/volume_eval_sublora.py
Normal file
231
test/volume_eval_sublora.py
Normal file
@ -0,0 +1,231 @@
|
||||
"""
|
||||
Volume evalutaion
|
||||
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
# from datasets.dataset3d import Dataset3D
|
||||
from tutils.new.manager import ConfigManager
|
||||
from datasets.eval_dataloader.loader_abstract import AbstractLoader
|
||||
|
||||
from core.volume_predictor import VolumePredictor
|
||||
from datasets.data_engine import DataManager, BoxPromptGenerator, PointPromptGenerator
|
||||
|
||||
from tutils import tfilename
|
||||
from tutils.new.trainer.recorder import Recorder
|
||||
from trans_utils.metrics import compute_dice_np
|
||||
from trans_utils.data_utils import Data3dSolver
|
||||
|
||||
# from monai.metrics import compute_surface_dice
|
||||
import surface_distance as surfdist
|
||||
from tutils.tutils.ttimer import timer
|
||||
|
||||
|
||||
|
||||
class Evaluater:
|
||||
def __init__(self, config) -> None:
|
||||
self.config = config
|
||||
self.recorder = Recorder()
|
||||
|
||||
def solve(self, model, dataset):
|
||||
# model.eval()
|
||||
self.predictor = model
|
||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
# if i <4:
|
||||
# print
|
||||
# continue
|
||||
# for k, v in data.items():
|
||||
# if isinstance(v, torch.Tensor):
|
||||
# data[k] = v.to(self.rank)
|
||||
if self.config['dataset']['prompt'] == 'box':
|
||||
# res = self.eval_step_slice(data, batch_idx=i)
|
||||
res = self.eval_step(data, batch_idx=i)
|
||||
if self.config['dataset']['prompt'] == 'point':
|
||||
res = self.eval_step_point(data, batch_idx=i)
|
||||
self.recorder.record(res)
|
||||
res = self.recorder.cal_metrics()
|
||||
print(res)
|
||||
print("prompt:", self.config['dataset']['prompt'], " class_idx:", self.config['dataset']['label_idx'])
|
||||
|
||||
def eval_step(self, data, batch_idx=0):
|
||||
name = data['name']
|
||||
dataset_name = data['dataset_name'][0]
|
||||
label_idx = data['label_idx'][0]
|
||||
template_slice_id = data['template_slice_id'][0]
|
||||
|
||||
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
||||
if template_slice_id == 0:
|
||||
template_slice_id += 1
|
||||
elif template_slice_id == (data['img'].shape[0] - 1):
|
||||
template_slice_id -= 1
|
||||
|
||||
spacing = data['spacing'].numpy().tolist()[0]
|
||||
if data['img'].shape[-1] < 260:
|
||||
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
||||
img = data['img'][0][:,:256,:256]
|
||||
label = data['label'][0][:,:256,:256]
|
||||
else:
|
||||
img = data['img'][0]
|
||||
label = data['label'][0]
|
||||
# img = torch.clip(img, -200, 600)
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
||||
box = np.array([box])
|
||||
pred, stability = self.predictor.predict_volume(
|
||||
x=img,
|
||||
box=box,
|
||||
template_slice_id=template_slice_id,
|
||||
return_stability=True,
|
||||
)
|
||||
prompt_type = 'box'
|
||||
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
||||
# Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
|
||||
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
|
||||
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
|
||||
|
||||
# nsd = compute_surface_dice(torch.Tensor(pred), label.detach().cpu(), 1)
|
||||
|
||||
# surface_distances = surfdist.compute_surface_distances(
|
||||
# label.detach().cpu().numpy(), pred, spacing_mm=(0.6, 0.6445, 0.6445))
|
||||
# nsd = surfdist.compute_surface_dice_at_tolerance(surface_distances, 1)
|
||||
nsd = 0
|
||||
|
||||
print(dataset_name, name, dice, nsd)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
return {"dice": dice, "nsd": nsd}
|
||||
|
||||
def eval_step_point(self, data, batch_idx=0):
|
||||
name = data['name']
|
||||
dataset_name = data['dataset_name'][0]
|
||||
label_idx = data['label_idx'][0]
|
||||
template_slice_id = data['template_slice_id'][0]
|
||||
spacing = data['spacing'].numpy().tolist()[0]
|
||||
|
||||
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
||||
if template_slice_id == 0:
|
||||
template_slice_id += 1
|
||||
elif template_slice_id == (data['img'].shape[0] - 1):
|
||||
template_slice_id -= 1
|
||||
|
||||
if data['img'].shape[-1] < 260:
|
||||
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
||||
img = data['img'][0][:,:256,:256]
|
||||
label = data['label'][0][:,:256,:256]
|
||||
else:
|
||||
img = data['img'][0]
|
||||
label = data['label'][0]
|
||||
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
||||
point = (box[0]+box[2])*0.5 , (box[1]+box[3])*0.5
|
||||
point = np.array([point]).astype(int)
|
||||
if label[template_slice_id][point[0,1], point[0,0]] == 0:
|
||||
print("Use random point instead !!!")
|
||||
point = PointPromptGenerator().get_prompt_point(label[template_slice_id])
|
||||
point = np.array([point]).astype(int)
|
||||
# box = np.array([box])
|
||||
pred = self.predictor.predict_volume(
|
||||
x=img,
|
||||
point_coords=point,
|
||||
point_labels=np.ones_like(point)[:,:1],
|
||||
template_slice_id=template_slice_id,
|
||||
)
|
||||
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
||||
prompt_type = 'point'
|
||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}.nii.gz"))
|
||||
nsd = compute_surface_dice(pred, label.detach().cpu().numpy())
|
||||
print(dataset_name, name, dice)
|
||||
return {"dice": dice, "nsd": nsd}
|
||||
|
||||
def eval_step_slice(self, data, batch_idx=0):
|
||||
name = data['name']
|
||||
dataset_name = data['dataset_name'][0]
|
||||
label_idx = data['label_idx'][0]
|
||||
template_slice_id = data['template_slice_id'][0]
|
||||
|
||||
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
||||
if template_slice_id == 0:
|
||||
template_slice_id += 1
|
||||
elif template_slice_id == (data['img'].shape[0] - 1):
|
||||
template_slice_id -= 1
|
||||
|
||||
spacing = data['spacing'].numpy().tolist()[0]
|
||||
if data['img'].shape[-1] < 260:
|
||||
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
||||
img = data['img'][0][:,:256,:256]
|
||||
label = data['label'][0][:,:256,:256]
|
||||
else:
|
||||
img = data['img'][0]
|
||||
label = data['label'][0]
|
||||
|
||||
img = img[template_slice_id-1:template_slice_id+2, :,:]
|
||||
label = label[template_slice_id-1:template_slice_id+2, :,:]
|
||||
template_slice_id = 1
|
||||
|
||||
# img = torch.clip(img, -200, 600)
|
||||
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
||||
box = np.array([box])
|
||||
pred, stability = self.predictor.predict_volume(
|
||||
x=img,
|
||||
box=box,
|
||||
template_slice_id=template_slice_id,
|
||||
return_stability=True,
|
||||
)
|
||||
prompt_type = 'box'
|
||||
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
||||
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
||||
# Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
|
||||
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
|
||||
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
|
||||
print("Slice evaluation: ", dataset_name, name, dice)
|
||||
return {"dice": dice}
|
||||
|
||||
def to_RGB(img):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
# from core.learner3 import SamLearner
|
||||
# from modeling.build_sam3d import sam_model_registry
|
||||
|
||||
# from core.learner3 import SamLearner
|
||||
from core.learner_sub1 import SamLearner
|
||||
from modeling.build_sam3d2 import sam_model_registry
|
||||
|
||||
EX_CONFIG = {
|
||||
'dataset':{
|
||||
'prompt': 'box',
|
||||
'dataset_list': ['guangdong'], # ["sabs"], chaos, word, decathlon_colon, pancreas
|
||||
'label_idx': 2,
|
||||
},
|
||||
"pth": "./model_latest.pth"
|
||||
}
|
||||
|
||||
config = ConfigManager()
|
||||
# config.add_config("configs/vit_sub.yaml")
|
||||
config.add_config("configs/vit_sub.yaml")
|
||||
config.add_config(EX_CONFIG)
|
||||
|
||||
# Init Model
|
||||
model_type = "vit_b"
|
||||
sam = sam_model_registry[model_type](checkpoint=None)
|
||||
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
||||
learner.use_lora()
|
||||
learner.use_lora_sub()
|
||||
pth = EX_CONFIG['pth']
|
||||
learner.load_well_trained_model(pth)
|
||||
learner.cuda()
|
||||
predictor = VolumePredictor(
|
||||
model=learner.model,
|
||||
use_postprocess=True,
|
||||
use_noise_remove=True,)
|
||||
|
||||
solver = Evaluater(config)
|
||||
dataset = AbstractLoader(config['dataset'], split="test")
|
||||
tt = timer()
|
||||
solver.solve(predictor, dataset)
|
||||
|
||||
print("Time: ", tt())
|
Loading…
x
Reference in New Issue
Block a user