fixed eval_dataloader
This commit is contained in:
parent
841be2acbe
commit
96a2bd15a6
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
*.pth
|
||||||
|
*.pyc
|
||||||
|
*.nii.gz
|
@ -25,7 +25,7 @@ training:
|
|||||||
load_pretrain_model: false
|
load_pretrain_model: false
|
||||||
|
|
||||||
# optim:
|
# optim:
|
||||||
lr: 0.0002
|
lr: 0.000005
|
||||||
decay_step: 2000
|
decay_step: 2000
|
||||||
decay_gamma: 0.8
|
decay_gamma: 0.8
|
||||||
weight_decay: 0.0001
|
weight_decay: 0.0001
|
||||||
@ -35,7 +35,7 @@ training:
|
|||||||
dataset:
|
dataset:
|
||||||
types: ['3d'] # ['3d', '2d']
|
types: ['3d'] # ['3d', '2d']
|
||||||
split: 'train'
|
split: 'train'
|
||||||
data_root_path: '/quanquan/datasets/'
|
data_root_path: '/home1/quanquan/datasets/'
|
||||||
dataset_list: ["alp", "word", "debug"] # ['sam', "their", "ours"]
|
dataset_list: ["alp", "word", "debug"] # ['sam', "their", "ours"]
|
||||||
data_txt_path: './datasets/dataset_list/'
|
data_txt_path: './datasets/dataset_list/'
|
||||||
dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/"
|
dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/"
|
||||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -107,8 +107,8 @@ class SamLearner(LearnerModule):
|
|||||||
self.model.load_state_dict(state_dict)
|
self.model.load_state_dict(state_dict)
|
||||||
# self.lora_module.load_lora_parameters(pth.replace(".pth", "_lora.safetensors"))
|
# self.lora_module.load_lora_parameters(pth.replace(".pth", "_lora.safetensors"))
|
||||||
|
|
||||||
def use_lora(self):
|
def use_lora(self, r=8):
|
||||||
lora_r = 8
|
lora_r = r
|
||||||
lora_sam = LoRA_Sam(self.model, lora_r, freeze_prompt_encoder=True)
|
lora_sam = LoRA_Sam(self.model, lora_r, freeze_prompt_encoder=True)
|
||||||
self.lora_module = lora_sam
|
self.lora_module = lora_sam
|
||||||
|
|
||||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2
datasets/dataset_list/example_test.txt
Normal file
2
datasets/dataset_list/example_test.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
07_WORD/WORD-V0.1.0/imagesVa/word_0001.nii.gz 07_WORD/WORD-V0.1.0/labelsVa/word_0001.nii.gz
|
||||||
|
07_WORD/WORD-V0.1.0/imagesVa/word_0007.nii.gz 07_WORD/WORD-V0.1.0/labelsVa/word_0007.nii.gz
|
146
datasets/eval_dataloader/loader_abstract.py
Normal file
146
datasets/eval_dataloader/loader_abstract.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
"""
|
||||||
|
DataLoader only for evaluation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from tutils.nn.data import read, itk_to_np, np_to_itk
|
||||||
|
from einops import reduce, repeat, rearrange
|
||||||
|
# from tutils.nn.data.tsitk.preprocess import resampleImage
|
||||||
|
from trans_utils.data_utils import Data3dSolver
|
||||||
|
from tutils.nn.data.tsitk.preprocess import resampleImage
|
||||||
|
import SimpleITK as sitk
|
||||||
|
|
||||||
|
|
||||||
|
# Example
|
||||||
|
DATASET_CONFIG={
|
||||||
|
'split': 'test',
|
||||||
|
'data_root_path':'/quanquan/datasets/',
|
||||||
|
'dataset_list': ["ours"],
|
||||||
|
'data_txt_path':'./datasets/dataset_list/',
|
||||||
|
'label_idx': 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
class AbstractLoader(Dataset):
|
||||||
|
def __init__(self, config, split="test") -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.split = split
|
||||||
|
self.img_names = self.prepare_datalist()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.img_names)
|
||||||
|
|
||||||
|
def prepare_datalist(self):
|
||||||
|
config = self.config
|
||||||
|
data_paths = []
|
||||||
|
for item in config['dataset_list']:
|
||||||
|
print("Load datalist from ", item)
|
||||||
|
for line in open(config["data_txt_path"]+ item + f"_{self.split}.txt"):
|
||||||
|
name = line.strip().split()[1].split('.')[0]
|
||||||
|
img_path = config['data_root_path'] + line.strip().split()[0]
|
||||||
|
label_path = config['data_root_path'] + line.strip().split()[1]
|
||||||
|
data_paths.append({'img_path': img_path, 'label_path': label_path, 'name': name})
|
||||||
|
print('train len {}'.format(len(data_paths)))
|
||||||
|
return data_paths
|
||||||
|
|
||||||
|
|
||||||
|
def _get_data(self, index, debug=False):
|
||||||
|
label_idx = self.config['label_idx']
|
||||||
|
name = self.img_names[index]['img_path']
|
||||||
|
img_itk = read(name)
|
||||||
|
spacing = img_itk.GetSpacing()
|
||||||
|
img_ori = itk_to_np(img_itk)
|
||||||
|
scan_orientation = np.argmin(img_ori.shape)
|
||||||
|
label_ori = itk_to_np(read(self.img_names[index]['label_path']))
|
||||||
|
label = label_ori == label_idx
|
||||||
|
|
||||||
|
# img_ori, new_spacing = Data3dSolver().read(self.img_names[index]['img_path'])
|
||||||
|
# label_itk = read(self.img_names[index]['label_path'])
|
||||||
|
# ori_spacing = label_itk.GetSpacing()
|
||||||
|
# label = itk_to_np(label_itk) == label_idx
|
||||||
|
# print("[loader_abstract.DEBUG] size", img_ori.shape, label.shape)
|
||||||
|
# label = self._get_resized_label(label, new_size=img_ori.shape)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
Data3dSolver().simple_write(label)
|
||||||
|
Data3dSolver().simple_write(img_ori, "tmp_img.nii.gz")
|
||||||
|
|
||||||
|
s = reduce(label, "c h w -> c", reduction="sum")
|
||||||
|
coords = np.nonzero(s)
|
||||||
|
x_min = np.min(coords[0])
|
||||||
|
x_max = np.max(coords[0])
|
||||||
|
template_slice_id = s.argmax() - x_min
|
||||||
|
|
||||||
|
if img_ori.min() < -10:
|
||||||
|
img_ori = np.clip(img_ori, -200, 400)
|
||||||
|
else:
|
||||||
|
img_ori = np.clip(img_ori, 0, 600)
|
||||||
|
|
||||||
|
img_ori = img_ori[x_min:x_max+1,:,:]
|
||||||
|
label = label[x_min:x_max+1,:,:]
|
||||||
|
assert label.shape[0] >= 3
|
||||||
|
|
||||||
|
if template_slice_id <= 1 or template_slice_id >= label.shape[0]-2:
|
||||||
|
template_slice_id == label.shape[0] // 2
|
||||||
|
|
||||||
|
dataset_name = name.replace(self.config['data_root_path'], "").split("/")[0]
|
||||||
|
template_slice = label[template_slice_id,:,:]
|
||||||
|
print("template_slice.area ", template_slice.sum(), template_slice.sum() / (template_slice.shape[0] * template_slice.shape[1]))
|
||||||
|
d = {
|
||||||
|
"name": name,
|
||||||
|
"dataset_name": dataset_name,
|
||||||
|
"img": np.array(img_ori).astype(np.float32),
|
||||||
|
"label_idx": label_idx,
|
||||||
|
"label": np.array(label).astype(np.float32),
|
||||||
|
"template_slice_id": template_slice_id,
|
||||||
|
"template_slice": np.array(label[template_slice_id,:,:]).astype(np.float32),
|
||||||
|
"spacing": np.array(spacing),
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self._get_data(index)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from tutils.new.manager import ConfigManager
|
||||||
|
EX_CONFIG = {
|
||||||
|
'dataset':{
|
||||||
|
'prompt': 'box',
|
||||||
|
'dataset_list': ['guangdong'], # ["sabs"], chaos, word
|
||||||
|
'label_idx': 2,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config = ConfigManager()
|
||||||
|
config.add_config("configs/vit_sub_rectum.yaml")
|
||||||
|
config.add_config(EX_CONFIG)
|
||||||
|
dataset = AbstractLoader(config['dataset'], split="test")
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
dataset._get_data(i, debug=False)
|
||||||
|
|
||||||
|
# label_path = "/home1/quanquan/datasets/01_BCV-Abdomen/Training/label/label0001.nii.gz"
|
||||||
|
# from monai.transforms import SpatialResample
|
||||||
|
# resample = SpatialResample()
|
||||||
|
# label = itk_to_np(read(label_path)) == 1
|
||||||
|
# print(label.shape)
|
||||||
|
# # resampled = resample(label, spatial_size=(label.shape[0]*7, label.shape[1], label.shape[2]))
|
||||||
|
# print(label.shape)
|
||||||
|
|
||||||
|
exit(0)
|
||||||
|
data = itk_to_np(read("tmp_img.nii.gz"))
|
||||||
|
data = torch.Tensor(data)
|
||||||
|
|
||||||
|
maxlen = data.shape[0]
|
||||||
|
slices = []
|
||||||
|
for i in range(1, maxlen-1):
|
||||||
|
slices.append(data[i-1:i+2, :, :])
|
||||||
|
input_slices = torch.stack(slices, axis=0)
|
||||||
|
input_slices = torch.clip(input_slices, -200, 600)
|
||||||
|
|
||||||
|
input_slices
|
||||||
|
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
save_image(input_slices, "tmp.jpg")
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -137,21 +137,31 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
EX_CONFIG = {
|
EX_CONFIG = {
|
||||||
'dataset':{
|
'dataset':{
|
||||||
'prompt': 'box',
|
'prompt': 'box', # box / point
|
||||||
'dataset_list': ['word'], # ["sabs"], chaos, word
|
'dataset_list': ['example'], # ["sabs"], chaos, word
|
||||||
'label_idx': 1,
|
'label_idx': 5,
|
||||||
},
|
},
|
||||||
'pth': "model.pth"
|
'model_type': "vit_b",
|
||||||
|
'pth': "/home1/quanquan/code/projects/finetune_large/runs/sam/ddp_b9/lora3/ckpt/model_iter_360000.pth",
|
||||||
|
}
|
||||||
|
|
||||||
|
EX_CONFIG = {
|
||||||
|
'dataset':{
|
||||||
|
'prompt': 'box', # box / point
|
||||||
|
'dataset_list': ['example'], # ["sabs"], chaos, word
|
||||||
|
'label_idx': 5,
|
||||||
|
},
|
||||||
|
'model_type': "vit_h",
|
||||||
|
'pth': "/home1/quanquan/code/projects/finetune_large/segment_anything/model_iter_3935000.pth",
|
||||||
}
|
}
|
||||||
|
|
||||||
config = ConfigManager()
|
config = ConfigManager()
|
||||||
config.add_config("configs/vit_b_103.yaml")
|
config.add_config("configs/vit_b.yaml")
|
||||||
config.add_config(EX_CONFIG)
|
config.add_config(EX_CONFIG)
|
||||||
|
config.print()
|
||||||
print(config)
|
|
||||||
|
|
||||||
# Init Model
|
# Init Model
|
||||||
model_type = "vit_b"
|
model_type = config['model_type']
|
||||||
sam = sam_model_registry[model_type](checkpoint=None)
|
sam = sam_model_registry[model_type](checkpoint=None)
|
||||||
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
||||||
learner.use_lora()
|
learner.use_lora()
|
||||||
|
216
test/volume_eval_mbox.py
Normal file
216
test/volume_eval_mbox.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
"""
|
||||||
|
Volume evalutaion
|
||||||
|
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
# from datasets.dataset3d import Dataset3D
|
||||||
|
from tutils.new.manager import ConfigManager
|
||||||
|
from datasets.eval_dataloader.loader_abstract import AbstractLoader
|
||||||
|
|
||||||
|
from core.volume_predictor import VolumePredictor
|
||||||
|
from datasets.data_engine import DataManager, BoxPromptGenerator, PointPromptGenerator
|
||||||
|
|
||||||
|
from tutils import tfilename
|
||||||
|
from tutils.new.trainer.recorder import Recorder
|
||||||
|
from trans_utils.metrics import compute_dice_np
|
||||||
|
from trans_utils.data_utils import Data3dSolver
|
||||||
|
from tutils.tutils.ttimer import timer
|
||||||
|
|
||||||
|
|
||||||
|
class Evaluater:
|
||||||
|
def __init__(self, config) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.recorder = Recorder()
|
||||||
|
|
||||||
|
def solve(self, model, dataset, finetune_number=1):
|
||||||
|
# model.eval()
|
||||||
|
self.predictor = model
|
||||||
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
|
||||||
|
|
||||||
|
for i, data in enumerate(dataloader):
|
||||||
|
# if i <4:
|
||||||
|
# print
|
||||||
|
# continue
|
||||||
|
# for k, v in data.items():
|
||||||
|
# if isinstance(v, torch.Tensor):
|
||||||
|
# data[k] = v.to(self.rank)
|
||||||
|
# try:
|
||||||
|
if True:
|
||||||
|
if self.config['dataset']['prompt'] == 'box':
|
||||||
|
dice, pred, label, temp_slice = self.eval_step(data, batch_idx=i)
|
||||||
|
used_slice = [temp_slice]
|
||||||
|
if finetune_number > 1:
|
||||||
|
for i in range(finetune_number - 1):
|
||||||
|
dice, pred, label, temp_slice = self.finetune_with_more_prompt(pred, label, exclude_slide_id=used_slice)
|
||||||
|
used_slice.append(temp_slice)
|
||||||
|
res = {"dice": dice}
|
||||||
|
if self.config['dataset']['prompt'] == 'point':
|
||||||
|
res = self.eval_step_point(data, batch_idx=i)
|
||||||
|
self.recorder.record(res)
|
||||||
|
# except Exception as e:
|
||||||
|
# print(e)
|
||||||
|
res = self.recorder.cal_metrics()
|
||||||
|
print(res)
|
||||||
|
print("prompt:", self.config['dataset']['prompt'], " class_idx:", self.config['dataset']['label_idx'])
|
||||||
|
|
||||||
|
def eval_step(self, data, batch_idx=0):
|
||||||
|
name = data['name']
|
||||||
|
dataset_name = data['dataset_name'][0]
|
||||||
|
label_idx = data['label_idx'][0]
|
||||||
|
template_slice_id = data['template_slice_id'][0]
|
||||||
|
|
||||||
|
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
||||||
|
if template_slice_id == 0:
|
||||||
|
template_slice_id += 1
|
||||||
|
elif template_slice_id == (data['img'].shape[0] - 1):
|
||||||
|
template_slice_id -= 1
|
||||||
|
print("Using slice ", template_slice_id, " as template slice")
|
||||||
|
|
||||||
|
spacing = data['spacing'].numpy().tolist()[0]
|
||||||
|
if data['img'].shape[-1] < 260:
|
||||||
|
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
||||||
|
img = data['img'][0][:,:256,:256]
|
||||||
|
label = data['label'][0][:,:256,:256]
|
||||||
|
else:
|
||||||
|
img = data['img'][0]
|
||||||
|
label = data['label'][0]
|
||||||
|
# img = torch.clip(img, -200, 600)
|
||||||
|
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
||||||
|
box = np.array([box])
|
||||||
|
pred, stability = self.predictor.predict_volume(
|
||||||
|
x=img,
|
||||||
|
box=box,
|
||||||
|
template_slice_id=template_slice_id,
|
||||||
|
return_stability=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_type = 'box'
|
||||||
|
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
||||||
|
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
||||||
|
# Data3dSolver().simple_write(label.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/label_{batch_idx}.nii.gz"))
|
||||||
|
# Data3dSolver().simple_write(img.detach().cpu().numpy(), path=tfilename(f"visual/{dataset_name}/img_{batch_idx}.nii.gz"))
|
||||||
|
# np.save(tfilename(f"meta/{dataset_name}/stability_{batch_idx}.npy"), stability)
|
||||||
|
print(dataset_name, name, dice)
|
||||||
|
template_slice_id = template_slice_id if isinstance(template_slice_id, int) else template_slice_id.item()
|
||||||
|
return dice, pred, label.detach().cpu().numpy(), template_slice_id
|
||||||
|
|
||||||
|
def eval_step_point(self, data, batch_idx=0):
|
||||||
|
name = data['name']
|
||||||
|
dataset_name = data['dataset_name'][0]
|
||||||
|
label_idx = data['label_idx'][0]
|
||||||
|
template_slice_id = data['template_slice_id'][0]
|
||||||
|
spacing = data['spacing'].numpy().tolist()[0]
|
||||||
|
|
||||||
|
assert data['img'].shape[1] >= 3, f" Got img.shape {data['img'].shape}"
|
||||||
|
if template_slice_id == 0:
|
||||||
|
template_slice_id += 1
|
||||||
|
elif template_slice_id == (data['img'].shape[0] - 1):
|
||||||
|
template_slice_id -= 1
|
||||||
|
|
||||||
|
if data['img'].shape[-1] < 260:
|
||||||
|
# assert data['img'].shape[-1] < 260, f"Got {data['img'].shape}"
|
||||||
|
img = data['img'][0][:,:256,:256]
|
||||||
|
label = data['label'][0][:,:256,:256]
|
||||||
|
else:
|
||||||
|
img = data['img'][0]
|
||||||
|
label = data['label'][0]
|
||||||
|
|
||||||
|
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id].detach().cpu().numpy())
|
||||||
|
point = (box[0]+box[2])*0.5 , (box[1]+box[3])*0.5
|
||||||
|
point = np.array([point]).astype(int)
|
||||||
|
if label[template_slice_id][point[0,1], point[0,0]] == 0:
|
||||||
|
print("Use random point instead !!!")
|
||||||
|
point = PointPromptGenerator().get_prompt_point(label[template_slice_id])
|
||||||
|
point = np.array([point]).astype(int)
|
||||||
|
# box = np.array([box])
|
||||||
|
pred = self.predictor.predict_volume(
|
||||||
|
x=img,
|
||||||
|
point_coords=point,
|
||||||
|
point_labels=np.ones_like(point)[:,:1],
|
||||||
|
template_slice_id=template_slice_id,
|
||||||
|
)
|
||||||
|
dice = compute_dice_np(pred, label.detach().cpu().numpy())
|
||||||
|
prompt_type = 'point'
|
||||||
|
Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}_label_{label_idx}_{prompt_type}.nii.gz"), spacing=spacing)
|
||||||
|
# Data3dSolver().simple_write(pred, path=tfilename(f"visual/{dataset_name}/pred_{batch_idx}.nii.gz"))
|
||||||
|
print(dataset_name, name, dice)
|
||||||
|
return {"dice": dice}
|
||||||
|
|
||||||
|
def finetune_with_more_prompt(self, pred, label, prompt_type="box", exclude_slide_id=[]):
|
||||||
|
assert pred.shape == label.shape
|
||||||
|
dices = [compute_dice_np(pred[j,:,:], label[j,:,:]) for j in range(pred.shape[0])]
|
||||||
|
rank_list = np.array(dices[1:-1]).argsort() # Ignore the head and tail
|
||||||
|
rank_list += 1 # Ignore the head and tail
|
||||||
|
for i in rank_list:
|
||||||
|
if i in exclude_slide_id:
|
||||||
|
continue
|
||||||
|
template_slice_id = i
|
||||||
|
break
|
||||||
|
# template_slice_id += 1 # Ignore the head and tail
|
||||||
|
print("Using slice ", template_slice_id, " as template slice")
|
||||||
|
old_confidence = self.predictor.get_confidence()
|
||||||
|
box = BoxPromptGenerator(size=None).mask_to_bbox(label[template_slice_id])
|
||||||
|
box = np.array([box])
|
||||||
|
new_pred, stability = self.predictor.predict_with_prompt(
|
||||||
|
box=box,
|
||||||
|
template_slice_id=template_slice_id,
|
||||||
|
return_stability=True,
|
||||||
|
)
|
||||||
|
new_confidence = self.predictor.get_confidence()
|
||||||
|
new_confidence[template_slice_id] *= 2
|
||||||
|
all_conf = np.stack([old_confidence, new_confidence], axis=1)
|
||||||
|
preds = [pred, new_pred]
|
||||||
|
merged = np.zeros_like(label)
|
||||||
|
for slice_idx in range(pred.shape[0]):
|
||||||
|
idx = np.argsort(all_conf[slice_idx,:])[-1]
|
||||||
|
merged[slice_idx,:,:] = preds[idx][slice_idx]
|
||||||
|
|
||||||
|
print("old dices", [compute_dice_np(pred, label) for pred in preds])
|
||||||
|
dice = compute_dice_np(merged, label)
|
||||||
|
print("merged dice, idx", dice)
|
||||||
|
return dice, merged, label, template_slice_id
|
||||||
|
|
||||||
|
|
||||||
|
def to_RGB(img):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from core.learner3 import SamLearner
|
||||||
|
from modeling.build_sam3d2 import sam_model_registry
|
||||||
|
EX_CONFIG = {
|
||||||
|
'dataset':{
|
||||||
|
'prompt': 'box',
|
||||||
|
'prompt_number': 5,
|
||||||
|
'dataset_list': ['example'], # ["sabs"], chaos, word pancreas
|
||||||
|
'label_idx': 5,
|
||||||
|
},
|
||||||
|
"lora_r": 24,
|
||||||
|
'model_type': "vit_h",
|
||||||
|
'ckpt': "/home1/quanquan/code/projects/finetune_large/segment_anything/model_iter_3935000.pth",
|
||||||
|
}
|
||||||
|
config = ConfigManager()
|
||||||
|
config.add_config("configs/vit_b.yaml")
|
||||||
|
config.add_config(EX_CONFIG)
|
||||||
|
config.print()
|
||||||
|
dataset = AbstractLoader(config['dataset'], split="test")
|
||||||
|
print(len(dataset))
|
||||||
|
assert len(dataset) >= 1
|
||||||
|
|
||||||
|
# Init Model
|
||||||
|
model_type = config['model_type'] # "vit_b"
|
||||||
|
sam = sam_model_registry[model_type](checkpoint=None)
|
||||||
|
learner = SamLearner(sam_model=sam, config=config, data_engine=DataManager(img_size=(1024,1024)))
|
||||||
|
learner.use_lora(r=config['lora_r'])
|
||||||
|
pth = config['ckpt']
|
||||||
|
learner.load_well_trained_model(pth)
|
||||||
|
learner.cuda()
|
||||||
|
predictor = VolumePredictor(
|
||||||
|
model=learner.model,
|
||||||
|
use_postprocess=True,
|
||||||
|
use_noise_remove=True,)
|
||||||
|
|
||||||
|
solver = Evaluater(config)
|
||||||
|
dataset = AbstractLoader(config['dataset'], split="test")
|
||||||
|
solver.solve(predictor, dataset, finetune_number=config['dataset']['prompt_number'])
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user