197 lines
7.0 KiB
Python
197 lines
7.0 KiB
Python
# Modified from https://github.com/JamesQFreeman/Sam_LoRA
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
from torch.nn.parameter import Parameter
|
|
from safetensors import safe_open
|
|
from safetensors.torch import save_file
|
|
# from modeling.sam3d import Sam
|
|
|
|
|
|
class _LoRA_qkv(nn.Module):
|
|
"""In Sam it is implemented as
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
B, N, C = x.shape
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
q, k, v = qkv.unbind(0)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
qkv: nn.Module,
|
|
linear_a_q: nn.Module,
|
|
linear_b_q: nn.Module,
|
|
linear_a_v: nn.Module,
|
|
linear_b_v: nn.Module,
|
|
):
|
|
super().__init__()
|
|
self.qkv = qkv
|
|
self.linear_a_q = linear_a_q
|
|
self.linear_b_q = linear_b_q
|
|
self.linear_a_v = linear_a_v
|
|
self.linear_b_v = linear_b_v
|
|
self.dim = qkv.in_features
|
|
self.w_identity = torch.eye(qkv.in_features)
|
|
|
|
def forward(self, x):
|
|
qkv = self.qkv(x) # B,N,N,3*org_C
|
|
new_q = self.linear_b_q(self.linear_a_q(x))
|
|
new_v = self.linear_b_v(self.linear_a_v(x))
|
|
qkv[:, :, :, : self.dim] += new_q
|
|
qkv[:, :, :, -self.dim :] += new_v
|
|
return qkv
|
|
|
|
class LoRA_Sam(nn.Module):
|
|
"""Applies low-rank adaptation to a Sam model's image encoder.
|
|
|
|
Args:
|
|
sam_model: a vision transformer model, see base_vit.py
|
|
r: rank of LoRA
|
|
num_classes: how many classes the model output, default to the vit model
|
|
lora_layer: which layer we apply LoRA.
|
|
freeze_all: freeze whole sam, otherwise only image encoder (VIT)
|
|
|
|
Examples::
|
|
>>> model = ViT('B_16_imagenet1k')
|
|
>>> lora_model = LoRA_ViT(model, r=4)
|
|
>>> preds = lora_model(img)
|
|
>>> print(preds.shape)
|
|
torch.Size([1, 1000])
|
|
"""
|
|
|
|
def __init__(self, sam_model, r: int, lora_layer:[int]=None, freeze_all:bool=False, freeze_prompt_encoder=True):
|
|
super(LoRA_Sam, self).__init__()
|
|
|
|
assert r > 0
|
|
# base_vit_dim = sam_model.image_encoder.patch_embed.proj.out_channels
|
|
# dim = base_vit_dim
|
|
if lora_layer:
|
|
self.lora_layer = lora_layer
|
|
else:
|
|
self.lora_layer = list(range(len(sam_model.image_encoder.blocks)))
|
|
# create for storage, then we can init them or load weights
|
|
self.w_As = [] # These are linear layers
|
|
self.w_Bs = []
|
|
|
|
# lets freeze first
|
|
if freeze_all:
|
|
for param in sam_model.parameters():
|
|
param.requires_grad = False
|
|
else:
|
|
for param in sam_model.image_encoder.parameters():
|
|
param.requires_grad = False
|
|
for param in sam_model.image_encoder.patch_embed.parameters():
|
|
param.requires_grad = True
|
|
if freeze_prompt_encoder:
|
|
for param in sam_model.prompt_encoder.parameters():
|
|
param.requires_grad = False
|
|
|
|
# Here, we do the surgery
|
|
for t_layer_i, blk in enumerate(sam_model.image_encoder.blocks):
|
|
# If we only want few lora layer instead of all
|
|
if t_layer_i not in self.lora_layer:
|
|
continue
|
|
w_qkv_linear = blk.attn.qkv
|
|
self.dim = w_qkv_linear.in_features
|
|
w_a_linear_q = nn.Linear(self.dim, r, bias=False)
|
|
w_b_linear_q = nn.Linear(r, self.dim, bias=False)
|
|
w_a_linear_v = nn.Linear(self.dim, r, bias=False)
|
|
w_b_linear_v = nn.Linear(r, self.dim, bias=False)
|
|
self.w_As.append(w_a_linear_q)
|
|
self.w_Bs.append(w_b_linear_q)
|
|
self.w_As.append(w_a_linear_v)
|
|
self.w_Bs.append(w_b_linear_v)
|
|
blk.attn.qkv = _LoRA_qkv(
|
|
w_qkv_linear,
|
|
w_a_linear_q,
|
|
w_b_linear_q,
|
|
w_a_linear_v,
|
|
w_b_linear_v,
|
|
)
|
|
self.reset_parameters()
|
|
self.sam = sam_model
|
|
# with open('vit_named_para.txt', 'w') as f:
|
|
# for k, v in sam_model.image_encoder.named_parameters():
|
|
# f.write(f'{k} {v.shape}\n')
|
|
|
|
|
|
def save_lora_parameters(self, filename: str) -> None:
|
|
r"""Only safetensors is supported now.
|
|
|
|
pip install safetensor if you do not have one installed yet.
|
|
|
|
save both lora and fc parameters.
|
|
"""
|
|
# assert filename.endswith(".safetensors")
|
|
|
|
num_layer = len(self.w_As) # actually, it is half
|
|
a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)}
|
|
b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)}
|
|
|
|
|
|
merged_dict = {**a_tensors, **b_tensors}
|
|
save_file(merged_dict, filename)
|
|
|
|
def load_lora_parameters(self, filename: str) -> None:
|
|
r"""Only safetensors is supported now.
|
|
|
|
pip install safetensor if you do not have one installed yet.\
|
|
|
|
load both lora and fc parameters.
|
|
"""
|
|
|
|
assert filename.endswith(".safetensors")
|
|
|
|
with safe_open(filename, framework="pt") as f:
|
|
for i, w_A_linear in enumerate(self.w_As):
|
|
saved_key = f"w_a_{i:03d}"
|
|
saved_tensor = f.get_tensor(saved_key)
|
|
w_A_linear.weight = Parameter(saved_tensor)
|
|
|
|
for i, w_B_linear in enumerate(self.w_Bs):
|
|
saved_key = f"w_b_{i:03d}"
|
|
saved_tensor = f.get_tensor(saved_key)
|
|
w_B_linear.weight = Parameter(saved_tensor)
|
|
# import ipdb; ipdb.set_trace()
|
|
|
|
def reset_parameters(self) -> None:
|
|
for w_A in self.w_As:
|
|
nn.init.kaiming_uniform_(w_A.weight, a=5**0.5)
|
|
for w_B in self.w_Bs:
|
|
nn.init.zeros_(w_B.weight)
|
|
|
|
|
|
def get_parameter_number(model):
|
|
total_num = sum(p.numel() for p in model.parameters())
|
|
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
return {'Total': total_num, 'Trainable': trainable_num}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from segment_anything import sam_model_registry
|
|
from segment_anything import SamPredictor, SamAutomaticMaskGenerator # prompt and every mode
|
|
import numpy as np
|
|
lora_r = 8
|
|
path = "../checkpoints/sam_vit_b_01ec64.pth"
|
|
sam = sam_model_registry["vit_b"](checkpoint=None)
|
|
print('before lora', get_parameter_number(sam))
|
|
lora_sam = LoRA_Sam(sam, lora_r)
|
|
print('after lora', get_parameter_number(sam))
|
|
x = torch.rand(size=(3,1024,1024))
|
|
path = '../data/cache/data2d_3/0007_s0069_img.npy'
|
|
img = np.load(path)
|
|
print('img shape', img.shape)
|
|
mask_generator = SamAutomaticMaskGenerator(sam)
|
|
masks = mask_generator.generate(x)
|
|
print('mask num', len(masks))
|
|
# loss = np.sum([mask.mean(-1).mean(-1) for mask in masks])
|
|
# predictor = SamPredictor(sam)
|
|
# predictor.set_image(x)
|
|
# masks, _, _ = predictor.predict([50,50])
|
|
|
|
for f_name in ['save_lora_parameters', 'load_lora_parameters']:
|
|
print(f_name)
|
|
getattr(lora_sam, f_name)('tmp.safetensors')
|