fixed eval_dataloader

This commit is contained in:
Curli Trans 2024-04-02 15:48:48 +08:00
parent 841be2acbe
commit 96a2bd15a6
40 changed files with 389 additions and 12 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
*.pth
*.pyc
*.nii.gz

View File

@ -25,7 +25,7 @@ training:
load_pretrain_model: false load_pretrain_model: false
# optim: # optim:
lr: 0.0002 lr: 0.000005
decay_step: 2000 decay_step: 2000
decay_gamma: 0.8 decay_gamma: 0.8
weight_decay: 0.0001 weight_decay: 0.0001
@ -35,7 +35,7 @@ training:
dataset: dataset:
types: ['3d'] # ['3d', '2d'] types: ['3d'] # ['3d', '2d']
split: 'train' split: 'train'
data_root_path: '/quanquan/datasets/' data_root_path: '/home1/quanquan/datasets/'
dataset_list: ["alp", "word", "debug"] # ['sam', "their", "ours"] dataset_list: ["alp", "word", "debug"] # ['sam', "their", "ours"]
data_txt_path: './datasets/dataset_list/' data_txt_path: './datasets/dataset_list/'
dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/" dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/"

Binary file not shown.

Binary file not shown.

View File

@ -107,8 +107,8 @@ class SamLearner(LearnerModule):
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
# self.lora_module.load_lora_parameters(pth.replace(".pth", "_lora.safetensors")) # self.lora_module.load_lora_parameters(pth.replace(".pth", "_lora.safetensors"))
def use_lora(self): def use_lora(self, r=8):
lora_r = 8 lora_r = r
lora_sam = LoRA_Sam(self.model, lora_r, freeze_prompt_encoder=True) lora_sam = LoRA_Sam(self.model, lora_r, freeze_prompt_encoder=True)
self.lora_module = lora_sam self.lora_module = lora_sam

View File

@ -0,0 +1,2 @@
07_WORD/WORD-V0.1.0/imagesVa/word_0001.nii.gz 07_WORD/WORD-V0.1.0/labelsVa/word_0001.nii.gz
07_WORD/WORD-V0.1.0/imagesVa/word_0007.nii.gz 07_WORD/WORD-V0.1.0/labelsVa/word_0007.nii.gz

View File

@ -0,0 +1,146 @@
"""
DataLoader only for evaluation
"""
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from tutils.nn.data import read, itk_to_np, np_to_itk
from einops import reduce, repeat, rearrange
# from tutils.nn.data.tsitk.preprocess import resampleImage
from trans_utils.data_utils import Data3dSolver
from tutils.nn.data.tsitk.preprocess import resampleImage
import SimpleITK as sitk
# Example
DATASET_CONFIG={
'split': 'test',
'data_root_path':'/quanquan/datasets/',
'dataset_list': ["ours"],
'data_txt_path':'./datasets/dataset_list/',
'label_idx': 0,
}
class AbstractLoader(Dataset):
def __init__(self, config, split="test") -> None:
super().__init__()
self.config = config
self.split = split
self.img_names = self.prepare_datalist()
def __len__(self):
return len(self.img_names)
def prepare_datalist(self):
config = self.config
data_paths = []
for item in config['dataset_list']:
print("Load datalist from ", item)
for line in open(config["data_txt_path"]+ item + f"_{self.split}.txt"):
name = line.strip().split()[1].split('.')[0]
img_path = config['data_root_path'] + line.strip().split()[0]
label_path = config['data_root_path'] + line.strip().split()[1]
data_paths.append({'img_path': img_path, 'label_path': label_path, 'name': name})
print('train len {}'.format(len(data_paths)))
return data_paths
def _get_data(self, index, debug=False):
label_idx = self.config['label_idx']
name = self.img_names[index]['img_path']
img_itk = read(name)
spacing = img_itk.GetSpacing()
img_ori = itk_to_np(img_itk)
scan_orientation = np.argmin(img_ori.shape)
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
label = label_ori == label_idx
# img_ori, new_spacing = Data3dSolver().read(self.img_names[index]['img_path'])
# label_itk = read(self.img_names[index]['label_path'])
# ori_spacing = label_itk.GetSpacing()
# label = itk_to_np(label_itk) == label_idx
# print("[loader_abstract.DEBUG] size", img_ori.shape, label.shape)
# label = self._get_resized_label(label, new_size=img_ori.shape)
if debug:
Data3dSolver().simple_write(label)
Data3dSolver().simple_write(img_ori, "tmp_img.nii.gz")
s = reduce(label, "c h w -> c", reduction="sum")
coords = np.nonzero(s)
x_min = np.min(coords[0])
x_max = np.max(coords[0])
template_slice_id = s.argmax() - x_min
if img_ori.min() < -10:
img_ori = np.clip(img_ori, -200, 400)
else:
img_ori = np.clip(img_ori, 0, 600)
img_ori = img_ori[x_min:x_max+1,:,:]
label = label[x_min:x_max+1,:,:]
assert label.shape[0] >= 3
if template_slice_id <= 1 or template_slice_id >= label.shape[0]-2:
template_slice_id == label.shape[0] // 2
dataset_name = name.replace(self.config['data_root_path'], "").split("/")[0]
template_slice = label[template_slice_id,:,:]
print("template_slice.area ", template_slice.sum(), template_slice.sum() / (template_slice.shape[0] * template_slice.shape[1]))
d = {
"name": name,
"dataset_name": dataset_name,
"img": np.array(img_ori).astype(np.float32),
"label_idx": label_idx,
"label": np.array(label).astype(np.float32),
"template_slice_id": template_slice_id,
"template_slice": np.array(label[template_slice_id,:,:]).astype(np.float32),
"spacing": np.array(spacing),
}
return d
def __getitem__(self, index):
return self._get_data(index)
if __name__ == "__main__":
from tutils.new.manager import ConfigManager
EX_CONFIG = {
'dataset':{
'prompt': 'box',
'dataset_list': ['guangdong'], # ["sabs"], chaos, word
'label_idx': 2,
}
}
config = ConfigManager()
config.add_config("configs/vit_sub_rectum.yaml")
config.add_config(EX_CONFIG)
dataset = AbstractLoader(config['dataset'], split="test")
for i in range(len(dataset)):
dataset._get_data(i, debug=False)
# label_path = "/home1/quanquan/datasets/01_BCV-Abdomen/Training/label/label0001.nii.gz"
# from monai.transforms import SpatialResample
# resample = SpatialResample()
# label = itk_to_np(read(label_path)) == 1
# print(label.shape)
# # resampled = resample(label, spatial_size=(label.shape[0]*7, label.shape[1], label.shape[2]))
# print(label.shape)
exit(0)
data = itk_to_np(read("tmp_img.nii.gz"))
data = torch.Tensor(data)
maxlen = data.shape[0]
slices = []
for i in range(1, maxlen-1):
slices.append(data[i-1:i+2, :, :])
input_slices = torch.stack(slices, axis=0)
input_slices = torch.clip(input_slices, -200, 600)
input_slices
from torchvision.utils import save_image
save_image(input_slices, "tmp.jpg")

View File

@ -137,21 +137,31 @@ if __name__ == "__main__":
EX_CONFIG = { EX_CONFIG = {
'dataset':{ 'dataset':{
'prompt': 'box', 'prompt': 'box', # box / point
'dataset_list': ['word'], # ["sabs"], chaos, word 'dataset_list': ['example'], # ["sabs"], chaos, word
'label_idx': 1, 'label_idx': 5,
}, },
'pth': "model.pth" 'model_type': "vit_b",
'pth': "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth",
}
EX_CONFIG = {
'dataset':{
'prompt': 'box', # box / point
'dataset_list': ['example'], # ["sabs"], chaos, word
'label_idx': 5,
},
'model_type': "vit_h",
'pth': "/home1/quanquan/code/projects/finetune_large/segment_anything/model_iter_3935000.pth",
} }
config = ConfigManager() config = ConfigManager()
config.add_config("configs/vit_b_103.yaml") config.add_config("configs/vit_b.yaml")
config.add_config(EX_CONFIG) config.add_config(EX_CONFIG)
config.print()
print(config)
# Init Model # Init Model
model_type = "vit_b" model_type = config['model_type']
sam = sam_model_registry[model_type](checkpoint=None) sam = sam_model_registry[model_type](checkpoint=None)
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024))) learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
learner.use_lora() learner.use_lora()

216
test/volume_eval_mbox.py Normal file
View File

@ -0,0 +1,216 @@
"""
Volume evalutaion
"""
import torch
import numpy as np
from torch.utils.data import DataLoader
# from datasets.dataset3d import Dataset3D
from tutils.new.manager import ConfigManager
from datasets.eval_dataloader.loader_abstract import AbstractLoader
from core.volume_predictor import VolumePredictor
from datasets.data_engine import DataManager, BoxPromptGenerator, PointPromptGenerator
from tutils import tfilename
from tutils.new.trainer.recorder import Recorder
from trans_utils.metrics import compute_dice_np
from trans_utils.data_utils import Data3dSolver
from tutils.tutils.ttimer import timer
class Evaluater:
def __init__(self, config) -> None:
self.config = config
self.recorder = Recorder()
def solve(self, model, dataset, finetune_number=1):
# model.eval()
self.predictor = model
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
for i, data in enumerate(dataloader):
# if i <4:
# print
# continue
# for k, v in data.items():
# if isinstance(v, torch.Tensor):
# data[k] = v.to(self.rank)
# try:
if True:
if self.config['dataset']['prompt'] == 'box':
dice, pred, label, temp_slice = self.eval_step(data, batch_idx=i)
used_slice = [temp_slice]
if finetune_number > 1:
for i in range(finetune_number - 1):
dice, pred, label, temp_slice = self.finetune_with_more_prompt(pred, label, exclude_slide_id=used_slice)
used_slice.append(temp_slice)
res = {"dice": dice}
if self.config['dataset']['prompt'] == 'point':
res = self.eval_step_point(data, batch_idx=i)
self.recorder.record(res)
# except Exception as e:
# print(e)
res = self.recorder.cal_metrics()
print(res)
print("prompt:", self.config['dataset']['prompt'], " class_idx:", self.config['dataset']['label_idx'])
def eval_step(self, data, batch_idx=0):
name = data['name']
dataset_name = data['dataset_name'][0]
label_idx = data['label_idx'][0]
template_slice_id = data['template_slice_id'][0]
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
if template_slice_id == 0:
template_slice_id += 1
elif template_slice_id == (data['img'].shape[0] - 1):
template_slice_id -= 1
print("Using slice ", template_slice_id, " as template slice")
spacing = data['spacing'].numpy().tolist()[0]
if data['img'].shape[-1] < 260:
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
img = data['img'][0][:,:256,:256]
label = data['label'][0][:,:256,:256]
else:
img = data['img'][0]
label = data['label'][0]
# img = torch.clip(img, -200, 600)
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
box = np.array([box])
pred, stability = self.predictor.predict_volume(
x=img,
box=box,
template_slice_id=template_slice_id,
return_stability=True,
)
prompt_type = 'box'
dice = compute_dice_np(pred, label.detach().cpu().numpy())
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
# Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
print(dataset_name, name, dice)
template_slice_id = template_slice_id if isinstance(template_slice_id, int) else template_slice_id.item()
return dice, pred, label.detach().cpu().numpy(), template_slice_id
def eval_step_point(self, data, batch_idx=0):
name = data['name']
dataset_name = data['dataset_name'][0]
label_idx = data['label_idx'][0]
template_slice_id = data['template_slice_id'][0]
spacing = data['spacing'].numpy().tolist()[0]
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
if template_slice_id == 0:
template_slice_id += 1
elif template_slice_id == (data['img'].shape[0] - 1):
template_slice_id -= 1
if data['img'].shape[-1] < 260:
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
img = data['img'][0][:,:256,:256]
label = data['label'][0][:,:256,:256]
else:
img = data['img'][0]
label = data['label'][0]
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
point = (box[0]+box[2])*0.5 , (box[1]+box[3])*0.5
point = np.array([point]).astype(int)
if label[template_slice_id][point[0,1], point[0,0]] == 0:
print("Use random point instead !!!")
point = PointPromptGenerator().get_prompt_point(label[template_slice_id])
point = np.array([point]).astype(int)
# box = np.array([box])
pred = self.predictor.predict_volume(
x=img,
point_coords=point,
point_labels=np.ones_like(point)[:,:1],
template_slice_id=template_slice_id,
)
dice = compute_dice_np(pred, label.detach().cpu().numpy())
prompt_type = 'point'
Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}.nii.gz"))
print(dataset_name, name, dice)
return {"dice": dice}
def finetune_with_more_prompt(self, pred, label, prompt_type="box", exclude_slide_id=[]):
assert pred.shape == label.shape
dices = [compute_dice_np(pred[j,:,:], label[j,:,:]) for j in range(pred.shape[0])]
rank_list = np.array(dices[1:-1]).argsort() # Ignore the head and tail
rank_list += 1 # Ignore the head and tail
for i in rank_list:
if i in exclude_slide_id:
continue
template_slice_id = i
break
# template_slice_id += 1 # Ignore the head and tail
print("Using slice ", template_slice_id, " as template slice")
old_confidence = self.predictor.get_confidence()
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id])
box = np.array([box])
new_pred, stability = self.predictor.predict_with_prompt(
box=box,
template_slice_id=template_slice_id,
return_stability=True,
)
new_confidence = self.predictor.get_confidence()
new_confidence[template_slice_id] *= 2
all_conf = np.stack([old_confidence, new_confidence], axis=1)
preds = [pred, new_pred]
merged = np.zeros_like(label)
for slice_idx in range(pred.shape[0]):
idx = np.argsort(all_conf[slice_idx,:])[-1]
merged[slice_idx,:,:] = preds[idx][slice_idx]
print("old dices", [compute_dice_np(pred, label) for pred in preds])
dice = compute_dice_np(merged, label)
print("merged dice, idx", dice)
return dice, merged, label, template_slice_id
def to_RGB(img):
pass
if __name__ == "__main__":
from core.learner3 import SamLearner
from modeling.build_sam3d2 import sam_model_registry
EX_CONFIG = {
'dataset':{
'prompt': 'box',
'prompt_number': 5,
'dataset_list': ['example'], # ["sabs"], chaos, word pancreas
'label_idx': 5,
},
"lora_r": 24,
'model_type': "vit_h",
'ckpt': "/home1/quanquan/code/projects/finetune_large/segment_anything/model_iter_3935000.pth",
}
config = ConfigManager()
config.add_config("configs/vit_b.yaml")
config.add_config(EX_CONFIG)
config.print()
dataset = AbstractLoader(config['dataset'], split="test")
print(len(dataset))
assert len(dataset) >= 1
# Init Model
model_type = config['model_type'] # "vit_b"
sam = sam_model_registry[model_type](checkpoint=None)
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
learner.use_lora(r=config['lora_r'])
pth = config['ckpt']
learner.load_well_trained_model(pth)
learner.cuda()
predictor = VolumePredictor(
model=learner.model,
use_postprocess=True,
use_noise_remove=True,)
solver = Evaluater(config)
dataset = AbstractLoader(config['dataset'], split="test")
solver.solve(predictor, dataset, finetune_number=config['dataset']['prompt_number'])

Binary file not shown.