add fintuning
This commit is contained in:
parent
738d4258c3
commit
840009725f
122
core/ddp_sub.py
Normal file
122
core/ddp_sub.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
"""
|
||||||
|
from ddp_b9.py
|
||||||
|
|
||||||
|
Add additional bypass/side-way to finetune on other datasets
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
from tutils import tfilename, tdir
|
||||||
|
|
||||||
|
from datasets.dataset3d_2dmask import Dataset2D
|
||||||
|
# from datasets.dataset3d import Dataset3D
|
||||||
|
from datasets.cache_dataset3d3 import Dataset3D
|
||||||
|
from datasets.dataset_merged import DatasetMerged, TestsetMerged
|
||||||
|
from datasets.data_engine import DataEngine
|
||||||
|
from modeling.build_sam3d2 import sam_model_registry
|
||||||
|
|
||||||
|
from .learner_sub1 import SamLearner
|
||||||
|
# from tutils.new.trainer.trainer_ddp import DDPTrainer
|
||||||
|
from trans_utils.trainer_ddp import DDPTrainer
|
||||||
|
# from .lora_sam import LoRA_Sam
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
|
||||||
|
def setup(rank, world_size):
|
||||||
|
os.environ['MASTER_ADDR'] = 'localhost'
|
||||||
|
os.environ['MASTER_PORT'] = '12355'
|
||||||
|
|
||||||
|
# initialize the process group
|
||||||
|
dist.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||||
|
|
||||||
|
def cleanup():
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
def ddp_train(rank, world_size, config):
|
||||||
|
setup(rank, world_size)
|
||||||
|
|
||||||
|
# sam_checkpoint = "/quanquan/code/segment-anything/segment_anything/sam_vit_b_01ec64.pth" # A800 server
|
||||||
|
# sam_checkpoint = "/home1/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server
|
||||||
|
model_type = "vit_b"
|
||||||
|
device = rank
|
||||||
|
|
||||||
|
config_data = config['dataset']
|
||||||
|
data_type = config_data.get("types", ["3d", "2d"])
|
||||||
|
data_type = [data_type] if isinstance(data_type, str) else data_type
|
||||||
|
dataset = Dataset3D(config_data, split='train')
|
||||||
|
|
||||||
|
# assert len(validset) > 0
|
||||||
|
data_engine = DataEngine(dataset=dataset, img_size=(1024,1024))
|
||||||
|
sam = sam_model_registry[model_type](checkpoint=None)
|
||||||
|
|
||||||
|
learner = SamLearner(sam_model=sam, config=config, data_engine=data_engine)
|
||||||
|
learner.use_lora()
|
||||||
|
learner.load_well_trained_model(config['training']['breakpoint_path']) # use preset path
|
||||||
|
learner.use_lora_sub()
|
||||||
|
|
||||||
|
ddp_trainer = DDPTrainer(config=config, rank=rank, world_size=world_size)
|
||||||
|
ddp_trainer.fit(learner, trainset=data_engine, validset=None)
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameter_number(model):
|
||||||
|
total_num = sum(p.numel() for p in model.parameters())
|
||||||
|
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
return {'Total': total_num, 'Trainable': trainable_num}
|
||||||
|
|
||||||
|
def run_demo(demo_fn, world_size, config):
|
||||||
|
mp.spawn(demo_fn,
|
||||||
|
args=(world_size,config),
|
||||||
|
nprocs=world_size,
|
||||||
|
join=True)
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
import yaml
|
||||||
|
import yamlloader
|
||||||
|
def _ordereddict_to_dict(d):
|
||||||
|
if not isinstance(d, dict):
|
||||||
|
return d
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, OrderedDict):
|
||||||
|
v = _ordereddict_to_dict(v)
|
||||||
|
d[k] = dict(v)
|
||||||
|
elif type(v) == list:
|
||||||
|
d[k] = _ordereddict_to_dict(v)
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
d[k] = _ordereddict_to_dict(v)
|
||||||
|
return d
|
||||||
|
|
||||||
|
# CUDA_VISIBLE_DEVICES=4,5,6,7 python -m core.ddp_b3 --tag lora --config configs/vit_b_103.yaml
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
from tutils.new.manager import trans_args, trans_init, ConfigManager
|
||||||
|
|
||||||
|
n_gpus = torch.cuda.device_count()
|
||||||
|
# assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but {__file__} Got{n_gpus}"
|
||||||
|
if n_gpus == 1:
|
||||||
|
print("Warning! Running on only 1 GPU! just for debug")
|
||||||
|
world_size = n_gpus
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--config", default="./configs/vit_sub_rectum.yaml")
|
||||||
|
parser.add_argument("--func", default="train")
|
||||||
|
parser.add_argument("--reuse", action="store_true")
|
||||||
|
|
||||||
|
args = trans_args(parser=parser)
|
||||||
|
config = ConfigManager()
|
||||||
|
config.auto_init(file=__file__, args=args, ex_config=None)
|
||||||
|
# config.save()
|
||||||
|
path = tfilename(config['base']['runs_dir'], "config.yaml")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
yaml.dump(_ordereddict_to_dict(config), f)
|
||||||
|
print("Save config file to ", path)
|
||||||
|
|
||||||
|
if n_gpus < 1: exit(0)
|
||||||
|
run_demo(ddp_train, world_size, config)
|
73
core/learner_sub1.py
Normal file
73
core/learner_sub1.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
"""
|
||||||
|
Use mask_decoder3d_2.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import numpy as np
|
||||||
|
from tutils.trainer import Trainer, LearnerModule
|
||||||
|
from einops import rearrange, repeat, reduce
|
||||||
|
import torch.optim.lr_scheduler as lr_scheduler
|
||||||
|
|
||||||
|
from core.loss import ranked_combined_loss_with_indicators
|
||||||
|
from .learner5 import SamLearner as basic_learner
|
||||||
|
from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss
|
||||||
|
# from torchao.quantization import apply_dynamic_quant
|
||||||
|
# from torch._inductor import config as inductorconfig
|
||||||
|
from .lora_sam import LoRA_Sam
|
||||||
|
|
||||||
|
|
||||||
|
class SamLearner(basic_learner):
|
||||||
|
|
||||||
|
def load_pretrained_model(self, pth, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Unmatched: prompt_encoder.mask_downscaling.0.weight
|
||||||
|
their: torch.Size([4, 1, 2, 2])
|
||||||
|
our: torch.Size([4, 3, 2, 2])
|
||||||
|
Unmatched: mask_decoder.mask_tokens.weight
|
||||||
|
their: torch.Size([4, 256])
|
||||||
|
our: torch.Size([12, 256])
|
||||||
|
"""
|
||||||
|
print("Load pretrained model for mask_decoder3d_2 !!")
|
||||||
|
|
||||||
|
state_dict = torch.load(pth)
|
||||||
|
model_state_dict = self.model.state_dict()
|
||||||
|
model_state_dict.update(state_dict)
|
||||||
|
model_state_dict['prompt_encoder.mask_downscaling.0.weight'] = repeat(state_dict['prompt_encoder.mask_downscaling.0.weight'], "a 1 c d -> a b c d", b=3)
|
||||||
|
# model_state_dict['mask_decoder.mask_tokens.weight'] = repeat(state_dict['mask_decoder.mask_tokens.weight'], "a d -> (a 3) d")
|
||||||
|
|
||||||
|
for k, v in model_state_dict.items():
|
||||||
|
if k.startswith("mask_decoder.output_upscaling2"):
|
||||||
|
k2 = k.replace("output_upscaling2.", "output_upscaling." )
|
||||||
|
model_state_dict[k] = model_state_dict[k2]
|
||||||
|
print("Load weights: ", k)
|
||||||
|
if k.startswith("mask_decoder.output_upscaling3"):
|
||||||
|
k2 = k.replace("output_upscaling3.", "output_upscaling." )
|
||||||
|
model_state_dict[k] = model_state_dict[k2]
|
||||||
|
print("Load weights: ", k)
|
||||||
|
|
||||||
|
hyper_params_names = [k for k in model_state_dict.keys() if k.startswith("mask_decoder.output_hypernetworks_mlps")]
|
||||||
|
for name in hyper_params_names:
|
||||||
|
words = name.split('.')
|
||||||
|
words[2] = str(int(words[2]) // 3)
|
||||||
|
name_to_copy = ".".join(words)
|
||||||
|
model_state_dict[name] = state_dict[name_to_copy]
|
||||||
|
# for k, v in state_dict.items():
|
||||||
|
# if model_state_dict[k].shape != state_dict[k].shape:
|
||||||
|
# print("Unmatched:", k)
|
||||||
|
self.model.load_state_dict(model_state_dict)
|
||||||
|
|
||||||
|
def quantize(self):
|
||||||
|
# self.model.image_encoder = torch.ao.quantization.quantize_dynamic(
|
||||||
|
# self.model.image_encoder,
|
||||||
|
# dtype=torch.qint8
|
||||||
|
# )
|
||||||
|
apply_dynamic_quant(self.model.image_encoder)
|
||||||
|
inductorconfig.force_fuse_int_mm_with_mul = True
|
||||||
|
print("Quantized !")
|
||||||
|
|
||||||
|
|
||||||
|
def use_lora_sub(self):
|
||||||
|
lora_r = 1
|
||||||
|
lora_sam = LoRA_Sam(self.model, lora_r, freeze_all=True)
|
||||||
|
self.lora_module = lora_sam
|
Loading…
x
Reference in New Issue
Block a user