216 lines
9.1 KiB
Python
216 lines
9.1 KiB
Python
"""
|
|
Volume evalutaion
|
|
|
|
"""
|
|
import torch
|
|
import numpy as np
|
|
from torch.utils.data import DataLoader
|
|
# from datasets.dataset3d import Dataset3D
|
|
from tutils.new.manager import ConfigManager
|
|
from datasets.eval_dataloader.loader_abstract import AbstractLoader
|
|
|
|
from core.volume_predictor import VolumePredictor
|
|
from datasets.data_engine import DataManager, BoxPromptGenerator, PointPromptGenerator
|
|
|
|
from tutils import tfilename
|
|
from tutils.new.trainer.recorder import Recorder
|
|
from trans_utils.metrics import compute_dice_np
|
|
from trans_utils.data_utils import Data3dSolver
|
|
from tutils.tutils.ttimer import timer
|
|
|
|
|
|
class Evaluater:
|
|
def __init__(self, config) -> None:
|
|
self.config = config
|
|
self.recorder = Recorder()
|
|
|
|
def solve(self, model, dataset, finetune_number=1):
|
|
# model.eval()
|
|
self.predictor = model
|
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
|
|
|
|
for i, data in enumerate(dataloader):
|
|
# if i <4:
|
|
# print
|
|
# continue
|
|
# for k, v in data.items():
|
|
# if isinstance(v, torch.Tensor):
|
|
# data[k] = v.to(self.rank)
|
|
# try:
|
|
if True:
|
|
if self.config['dataset']['prompt'] == 'box':
|
|
dice, pred, label, temp_slice = self.eval_step(data, batch_idx=i)
|
|
used_slice = [temp_slice]
|
|
if finetune_number > 1:
|
|
for i in range(finetune_number - 1):
|
|
dice, pred, label, temp_slice = self.finetune_with_more_prompt(pred, label, exclude_slide_id=used_slice)
|
|
used_slice.append(temp_slice)
|
|
res = {"dice": dice}
|
|
if self.config['dataset']['prompt'] == 'point':
|
|
res = self.eval_step_point(data, batch_idx=i)
|
|
self.recorder.record(res)
|
|
# except Exception as e:
|
|
# print(e)
|
|
res = self.recorder.cal_metrics()
|
|
print(res)
|
|
print("prompt:", self.config['dataset']['prompt'], " class_idx:", self.config['dataset']['label_idx'])
|
|
|
|
def eval_step(self, data, batch_idx=0):
|
|
name = data['name']
|
|
dataset_name = data['dataset_name'][0]
|
|
label_idx = data['label_idx'][0]
|
|
template_slice_id = data['template_slice_id'][0]
|
|
|
|
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
|
if template_slice_id == 0:
|
|
template_slice_id += 1
|
|
elif template_slice_id == (data['img'].shape[0] - 1):
|
|
template_slice_id -= 1
|
|
print("Using slice ", template_slice_id, " as template slice")
|
|
|
|
spacing = data['spacing'].numpy().tolist()[0]
|
|
if data['img'].shape[-1] < 260:
|
|
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
|
img = data['img'][0][:,:256,:256]
|
|
label = data['label'][0][:,:256,:256]
|
|
else:
|
|
img = data['img'][0]
|
|
label = data['label'][0]
|
|
# img = torch.clip(img, -200, 600)
|
|
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
|
box = np.array([box])
|
|
pred, stability = self.predictor.predict_volume(
|
|
x=img,
|
|
box=box,
|
|
template_slice_id=template_slice_id,
|
|
return_stability=True,
|
|
)
|
|
|
|
prompt_type = 'box'
|
|
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
|
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
|
# Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
|
|
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
|
|
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
|
|
print(dataset_name, name, dice)
|
|
template_slice_id = template_slice_id if isinstance(template_slice_id, int) else template_slice_id.item()
|
|
return dice, pred, label.detach().cpu().numpy(), template_slice_id
|
|
|
|
def eval_step_point(self, data, batch_idx=0):
|
|
name = data['name']
|
|
dataset_name = data['dataset_name'][0]
|
|
label_idx = data['label_idx'][0]
|
|
template_slice_id = data['template_slice_id'][0]
|
|
spacing = data['spacing'].numpy().tolist()[0]
|
|
|
|
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
|
if template_slice_id == 0:
|
|
template_slice_id += 1
|
|
elif template_slice_id == (data['img'].shape[0] - 1):
|
|
template_slice_id -= 1
|
|
|
|
if data['img'].shape[-1] < 260:
|
|
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
|
img = data['img'][0][:,:256,:256]
|
|
label = data['label'][0][:,:256,:256]
|
|
else:
|
|
img = data['img'][0]
|
|
label = data['label'][0]
|
|
|
|
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
|
point = (box[0]+box[2])*0.5 , (box[1]+box[3])*0.5
|
|
point = np.array([point]).astype(int)
|
|
if label[template_slice_id][point[0,1], point[0,0]] == 0:
|
|
print("Use random point instead !!!")
|
|
point = PointPromptGenerator().get_prompt_point(label[template_slice_id])
|
|
point = np.array([point]).astype(int)
|
|
# box = np.array([box])
|
|
pred = self.predictor.predict_volume(
|
|
x=img,
|
|
point_coords=point,
|
|
point_labels=np.ones_like(point)[:,:1],
|
|
template_slice_id=template_slice_id,
|
|
)
|
|
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
|
prompt_type = 'point'
|
|
Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
|
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}.nii.gz"))
|
|
print(dataset_name, name, dice)
|
|
return {"dice": dice}
|
|
|
|
def finetune_with_more_prompt(self, pred, label, prompt_type="box", exclude_slide_id=[]):
|
|
assert pred.shape == label.shape
|
|
dices = [compute_dice_np(pred[j,:,:], label[j,:,:]) for j in range(pred.shape[0])]
|
|
rank_list = np.array(dices[1:-1]).argsort() # Ignore the head and tail
|
|
rank_list += 1 # Ignore the head and tail
|
|
for i in rank_list:
|
|
if i in exclude_slide_id:
|
|
continue
|
|
template_slice_id = i
|
|
break
|
|
# template_slice_id += 1 # Ignore the head and tail
|
|
print("Using slice ", template_slice_id, " as template slice")
|
|
old_confidence = self.predictor.get_confidence()
|
|
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id])
|
|
box = np.array([box])
|
|
new_pred, stability = self.predictor.predict_with_prompt(
|
|
box=box,
|
|
template_slice_id=template_slice_id,
|
|
return_stability=True,
|
|
)
|
|
new_confidence = self.predictor.get_confidence()
|
|
new_confidence[template_slice_id] *= 2
|
|
all_conf = np.stack([old_confidence, new_confidence], axis=1)
|
|
preds = [pred, new_pred]
|
|
merged = np.zeros_like(label)
|
|
for slice_idx in range(pred.shape[0]):
|
|
idx = np.argsort(all_conf[slice_idx,:])[-1]
|
|
merged[slice_idx,:,:] = preds[idx][slice_idx]
|
|
|
|
print("old dices", [compute_dice_np(pred, label) for pred in preds])
|
|
dice = compute_dice_np(merged, label)
|
|
print("merged dice, idx", dice)
|
|
return dice, merged, label, template_slice_id
|
|
|
|
|
|
def to_RGB(img):
|
|
pass
|
|
|
|
if __name__ == "__main__":
|
|
from core.learner3 import SamLearner
|
|
from modeling.build_sam3d2 import sam_model_registry
|
|
EX_CONFIG = {
|
|
'dataset':{
|
|
'prompt': 'box',
|
|
'prompt_number': 5,
|
|
'dataset_list': ['example'], # ["sabs"], chaos, word pancreas
|
|
'label_idx': 5,
|
|
},
|
|
"lora_r": 24,
|
|
'model_type': "vit_h",
|
|
'ckpt': "/home1/quanquan/code/projects/finetune_large/segment_anything/model_iter_3935000.pth",
|
|
}
|
|
config = ConfigManager()
|
|
config.add_config("configs/vit_b.yaml")
|
|
config.add_config(EX_CONFIG)
|
|
config.print()
|
|
dataset = AbstractLoader(config['dataset'], split="test")
|
|
print(len(dataset))
|
|
assert len(dataset) >= 1
|
|
|
|
# Init Model
|
|
model_type = config['model_type'] # "vit_b"
|
|
sam = sam_model_registry[model_type](checkpoint=None)
|
|
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
|
learner.use_lora(r=config['lora_r'])
|
|
pth = config['ckpt']
|
|
learner.load_well_trained_model(pth)
|
|
learner.cuda()
|
|
predictor = VolumePredictor(
|
|
model=learner.model,
|
|
use_postprocess=True,
|
|
use_noise_remove=True,)
|
|
|
|
solver = Evaluater(config)
|
|
dataset = AbstractLoader(config['dataset'], split="test")
|
|
solver.solve(predictor, dataset, finetune_number=config['dataset']['prompt_number']) |