From 840009725f14322a507772ecb71e6c4a9cf428a1 Mon Sep 17 00:00:00 2001 From: transcendentsky Date: Wed, 20 Mar 2024 15:39:07 +0800 Subject: [PATCH] add fintuning --- core/ddp_sub.py | 122 +++++++++++++++++++++++++++++++++++++++++++ core/learner_sub1.py | 73 ++++++++++++++++++++++++++ 2 files changed, 195 insertions(+) create mode 100644 core/ddp_sub.py create mode 100644 core/learner_sub1.py diff --git a/core/ddp_sub.py b/core/ddp_sub.py new file mode 100644 index 0000000..530ed48 --- /dev/null +++ b/core/ddp_sub.py @@ -0,0 +1,122 @@ +""" + from ddp_b9.py + + Add additional bypass/side-way to finetune on other datasets +""" + +import os +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from tutils import tfilename, tdir + +from datasets.dataset3d_2dmask import Dataset2D +# from datasets.dataset3d import Dataset3D +from datasets.cache_dataset3d3 import Dataset3D +from datasets.dataset_merged import DatasetMerged, TestsetMerged +from datasets.data_engine import DataEngine +from modeling.build_sam3d2 import sam_model_registry + +from .learner_sub1 import SamLearner +# from tutils.new.trainer.trainer_ddp import DDPTrainer +from trans_utils.trainer_ddp import DDPTrainer +# from .lora_sam import LoRA_Sam + +import warnings +warnings.filterwarnings("ignore") + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + +def cleanup(): + dist.destroy_process_group() + +def ddp_train(rank, world_size, config): + setup(rank, world_size) + + # sam_checkpoint = "/quanquan/code/segment-anything/segment_anything/sam_vit_b_01ec64.pth" # A800 server + # sam_checkpoint = "/home1/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server + model_type = "vit_b" + device = rank + + config_data = config['dataset'] + data_type = config_data.get("types", ["3d", "2d"]) + data_type = [data_type] if isinstance(data_type, str) else data_type + dataset = Dataset3D(config_data, split='train') + + # assert len(validset) > 0 + data_engine = DataEngine(dataset=dataset, img_size=(1024,1024)) + sam = sam_model_registry[model_type](checkpoint=None) + + learner = SamLearner(sam_model=sam, config=config, data_engine=data_engine) + learner.use_lora() + learner.load_well_trained_model(config['training']['breakpoint_path']) # use preset path + learner.use_lora_sub() + + ddp_trainer = DDPTrainer(config=config, rank=rank, world_size=world_size) + ddp_trainer.fit(learner, trainset=data_engine, validset=None) + + cleanup() + + +def get_parameter_number(model): + total_num = sum(p.numel() for p in model.parameters()) + trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + return {'Total': total_num, 'Trainable': trainable_num} + +def run_demo(demo_fn, world_size, config): + mp.spawn(demo_fn, + args=(world_size,config), + nprocs=world_size, + join=True) + +from collections import OrderedDict +import yaml +import yamlloader +def _ordereddict_to_dict(d): + if not isinstance(d, dict): + return d + for k, v in d.items(): + if isinstance(v, OrderedDict): + v = _ordereddict_to_dict(v) + d[k] = dict(v) + elif type(v) == list: + d[k] = _ordereddict_to_dict(v) + elif isinstance(v, dict): + d[k] = _ordereddict_to_dict(v) + return d + +# CUDA_VISIBLE_DEVICES=4,5,6,7 python -m core.ddp_b3 --tag lora --config configs/vit_b_103.yaml + +if __name__ == "__main__": + import argparse + from tutils.new.manager import trans_args, trans_init, ConfigManager + + n_gpus = torch.cuda.device_count() + # assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but {__file__} Got{n_gpus}" + if n_gpus == 1: + print("Warning! Running on only 1 GPU! just for debug") + world_size = n_gpus + + parser = argparse.ArgumentParser() + parser.add_argument("--config", default="./configs/vit_sub_rectum.yaml") + parser.add_argument("--func", default="train") + parser.add_argument("--reuse", action="store_true") + + args = trans_args(parser=parser) + config = ConfigManager() + config.auto_init(file=__file__, args=args, ex_config=None) + # config.save() + path = tfilename(config['base']['runs_dir'], "config.yaml") + with open(path, "w") as f: + yaml.dump(_ordereddict_to_dict(config), f) + print("Save config file to ", path) + + if n_gpus < 1: exit(0) + run_demo(ddp_train, world_size, config) diff --git a/core/learner_sub1.py b/core/learner_sub1.py new file mode 100644 index 0000000..b97fd2b --- /dev/null +++ b/core/learner_sub1.py @@ -0,0 +1,73 @@ +""" + Use mask_decoder3d_2.py +""" + +import torch +import torchvision +import numpy as np +from tutils.trainer import Trainer, LearnerModule +from einops import rearrange, repeat, reduce +import torch.optim.lr_scheduler as lr_scheduler + +from core.loss import ranked_combined_loss_with_indicators +from .learner5 import SamLearner as basic_learner +from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss +# from torchao.quantization import apply_dynamic_quant +# from torch._inductor import config as inductorconfig +from .lora_sam import LoRA_Sam + + +class SamLearner(basic_learner): + + def load_pretrained_model(self, pth, *args, **kwargs): + """ + Unmatched: prompt_encoder.mask_downscaling.0.weight + their: torch.Size([4, 1, 2, 2]) + our: torch.Size([4, 3, 2, 2]) + Unmatched: mask_decoder.mask_tokens.weight + their: torch.Size([4, 256]) + our: torch.Size([12, 256]) + """ + print("Load pretrained model for mask_decoder3d_2 !!") + + state_dict = torch.load(pth) + model_state_dict = self.model.state_dict() + model_state_dict.update(state_dict) + model_state_dict['prompt_encoder.mask_downscaling.0.weight'] = repeat(state_dict['prompt_encoder.mask_downscaling.0.weight'], "a 1 c d -> a b c d", b=3) + # model_state_dict['mask_decoder.mask_tokens.weight'] = repeat(state_dict['mask_decoder.mask_tokens.weight'], "a d -> (a 3) d") + + for k, v in model_state_dict.items(): + if k.startswith("mask_decoder.output_upscaling2"): + k2 = k.replace("output_upscaling2.", "output_upscaling." ) + model_state_dict[k] = model_state_dict[k2] + print("Load weights: ", k) + if k.startswith("mask_decoder.output_upscaling3"): + k2 = k.replace("output_upscaling3.", "output_upscaling." ) + model_state_dict[k] = model_state_dict[k2] + print("Load weights: ", k) + + hyper_params_names = [k for k in model_state_dict.keys() if k.startswith("mask_decoder.output_hypernetworks_mlps")] + for name in hyper_params_names: + words = name.split('.') + words[2] = str(int(words[2]) // 3) + name_to_copy = ".".join(words) + model_state_dict[name] = state_dict[name_to_copy] + # for k, v in state_dict.items(): + # if model_state_dict[k].shape != state_dict[k].shape: + # print("Unmatched:", k) + self.model.load_state_dict(model_state_dict) + + def quantize(self): + # self.model.image_encoder = torch.ao.quantization.quantize_dynamic( + # self.model.image_encoder, + # dtype=torch.qint8 + # ) + apply_dynamic_quant(self.model.image_encoder) + inductorconfig.force_fuse_int_mm_with_mul = True + print("Quantized !") + + + def use_lora_sub(self): + lora_r = 1 + lora_sam = LoRA_Sam(self.model, lora_r, freeze_all=True) + self.lora_module = lora_sam \ No newline at end of file