diff --git a/configs/vit_b.yaml b/configs/vit_b.yaml index 13ce881..25095e3 100644 --- a/configs/vit_b.yaml +++ b/configs/vit_b.yaml @@ -40,6 +40,8 @@ dataset: data_txt_path: './datasets/dataset_list/' dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/" cache_data_path: '/home1/quanquan/datasets/cached_dataset2/' + cache_prefix: ['6016'] # '07' + specific_label: [2] # sam_checkpoint: "/quanquan/code/projects/medical-guangdong/segment-anything/sam_vit_b_01ec64.pth" # 103 server # model_type: "vit_b" diff --git a/configs/vit_sub.yaml b/configs/vit_sub.yaml index 3965746..7170f8c 100644 --- a/configs/vit_sub.yaml +++ b/configs/vit_sub.yaml @@ -34,7 +34,7 @@ dataset: types: ['3d'] # ['3d', '2d'] split: 'train' data_root_path: '/home1/quanquan/datasets/' - dataset_list: ["pancreas"] + dataset_list: ["example"] # for example_train.txt data_txt_path: './datasets/dataset_list/' dataset2d_path: "/home1/quanquan/datasets/08_AbdomenCT-1K/" cache_data_path: '/home1/quanquan/datasets/cached_dataset2/' diff --git a/core/ddp_sub.py b/core/ddp_sub.py index 530ed48..b7e45b0 100644 --- a/core/ddp_sub.py +++ b/core/ddp_sub.py @@ -105,7 +105,7 @@ if __name__ == "__main__": world_size = n_gpus parser = argparse.ArgumentParser() - parser.add_argument("--config", default="./configs/vit_sub_rectum.yaml") + parser.add_argument("--config", default="./configs/vit_sub.yaml") parser.add_argument("--func", default="train") parser.add_argument("--reuse", action="store_true") diff --git a/core/lora_sam.py b/core/lora_sam.py index 81f8595..c2a3b60 100644 --- a/core/lora_sam.py +++ b/core/lora_sam.py @@ -33,6 +33,7 @@ class _LoRA_qkv(nn.Module): self.linear_a_v = linear_a_v self.linear_b_v = linear_b_v self.dim = qkv.in_features + self.in_features = qkv.in_features self.w_identity = torch.eye(qkv.in_features) def forward(self, x):