Slide-SAM/datasets/dataset_merged.py
transcendentsky e04459c6fe first commit
2023-12-05 14:58:38 +08:00

74 lines
2.4 KiB
Python

# from torchvision import transforms
from monai import transforms
import numpy as np
import SimpleITK as sitk
import torch
from torch.utils.data import Dataset as dataset
import torch.nn.functional as F
import glob
import os
from einops import rearrange
from tutils.nn.data import read, itk_to_np
from tqdm import tqdm
from tutils import tfilename, tdir
import random
# from .dataset2d import Dataset2D
from .dataset3d_2dmask import Dataset2D
from .dataset3d import Dataset3D
class DatasetMerged(dataset):
def __init__(self, config=None, is_train=True, getting_multi_mask=False) -> None:
super().__init__()
self.dataset2d = Dataset2D(dirpath="/quanquan/datasets/08_AbdomenCT-1K/", is_train=True)
self.dataset3d = Dataset3D(config=config, is_train=True)
self.len_2d = len(self.dataset2d)
self.len_3d = len(self.dataset3d)
def __getitem__(self, index, debug=False):
index = index % len(self)
# print("DEBUG! is_2d:", index < self.len_2d)
if index < self.len_2d:
return self.dataset2d.__getitem__(index)
else:
index = (index - self.len_2d) % self.len_3d
return self.dataset3d.__getitem__(index)
def __len__(self):
return len(self.dataset2d) + len(self.dataset3d) * 200
class TestsetMerged(dataset):
def __init__(self, config=None, is_train=False) -> None:
super().__init__()
self.dataset2d = Dataset2D(dirpath="/quanquan/datasets/08_AbdomenCT-1K/preprocessed/", is_train=False)
self.dataset3d = Dataset3D(config=config, is_train=False, split='val')
self.len_2d = len(self.dataset2d)
self.len_3d = len(self.dataset3d)
def __getitem__(self, index, debug=False):
index = index % len(self)
if index < self.len_2d:
return self.dataset2d.__getitem__(index)
else:
index = (index - self.len_2d) % self.len_3d
return self.dataset3d.__getitem__(index)
def __len__(self):
return len(self.dataset2d) + len(self.dataset3d) * 2
if __name__ == "__main__":
from tutils import timer
from tutils.new.manager import trans_args, trans_init, ConfigManager
config = ConfigManager()
config.add_basic_config()
config.add_config("configs/vit_b.yaml")
dataset = DatasetMerged(config['dataset'])
tt = timer()
for i in range(20000,len(dataset)):
data = dataset.__getitem__(i)
print("time: ", tt())