73 lines
3.0 KiB
Python
73 lines
3.0 KiB
Python
"""
|
|
Use mask_decoder3d_2.py
|
|
"""
|
|
|
|
import torch
|
|
import torchvision
|
|
import numpy as np
|
|
from tutils.trainer import Trainer, LearnerModule
|
|
from einops import rearrange, repeat, reduce
|
|
import torch.optim.lr_scheduler as lr_scheduler
|
|
|
|
from core.loss import ranked_combined_loss_with_indicators
|
|
from .learner5 import SamLearner as basic_learner
|
|
from .loss import compute_all_loss, ranked_combined_loss, compute_iou, combined_loss
|
|
# from torchao.quantization import apply_dynamic_quant
|
|
# from torch._inductor import config as inductorconfig
|
|
from .lora_sam import LoRA_Sam
|
|
|
|
|
|
class SamLearner(basic_learner):
|
|
|
|
def load_pretrained_model(self, pth, *args, **kwargs):
|
|
"""
|
|
Unmatched: prompt_encoder.mask_downscaling.0.weight
|
|
their: torch.Size([4, 1, 2, 2])
|
|
our: torch.Size([4, 3, 2, 2])
|
|
Unmatched: mask_decoder.mask_tokens.weight
|
|
their: torch.Size([4, 256])
|
|
our: torch.Size([12, 256])
|
|
"""
|
|
print("Load pretrained model for mask_decoder3d_2 !!")
|
|
|
|
state_dict = torch.load(pth)
|
|
model_state_dict = self.model.state_dict()
|
|
model_state_dict.update(state_dict)
|
|
model_state_dict['prompt_encoder.mask_downscaling.0.weight'] = repeat(state_dict['prompt_encoder.mask_downscaling.0.weight'], "a 1 c d -> a b c d", b=3)
|
|
# model_state_dict['mask_decoder.mask_tokens.weight'] = repeat(state_dict['mask_decoder.mask_tokens.weight'], "a d -> (a 3) d")
|
|
|
|
for k, v in model_state_dict.items():
|
|
if k.startswith("mask_decoder.output_upscaling2"):
|
|
k2 = k.replace("output_upscaling2.", "output_upscaling." )
|
|
model_state_dict[k] = model_state_dict[k2]
|
|
print("Load weights: ", k)
|
|
if k.startswith("mask_decoder.output_upscaling3"):
|
|
k2 = k.replace("output_upscaling3.", "output_upscaling." )
|
|
model_state_dict[k] = model_state_dict[k2]
|
|
print("Load weights: ", k)
|
|
|
|
hyper_params_names = [k for k in model_state_dict.keys() if k.startswith("mask_decoder.output_hypernetworks_mlps")]
|
|
for name in hyper_params_names:
|
|
words = name.split('.')
|
|
words[2] = str(int(words[2]) // 3)
|
|
name_to_copy = ".".join(words)
|
|
model_state_dict[name] = state_dict[name_to_copy]
|
|
# for k, v in state_dict.items():
|
|
# if model_state_dict[k].shape != state_dict[k].shape:
|
|
# print("Unmatched:", k)
|
|
self.model.load_state_dict(model_state_dict)
|
|
|
|
def quantize(self):
|
|
# self.model.image_encoder = torch.ao.quantization.quantize_dynamic(
|
|
# self.model.image_encoder,
|
|
# dtype=torch.qint8
|
|
# )
|
|
apply_dynamic_quant(self.model.image_encoder)
|
|
inductorconfig.force_fuse_int_mm_with_mul = True
|
|
print("Quantized !")
|
|
|
|
|
|
def use_lora_sub(self):
|
|
lora_r = 1
|
|
lora_sam = LoRA_Sam(self.model, lora_r, freeze_all=True)
|
|
self.lora_module = lora_sam |