74 lines
2.4 KiB
Python
74 lines
2.4 KiB
Python
# from torchvision import transforms
|
|
from monai import transforms
|
|
import numpy as np
|
|
import SimpleITK as sitk
|
|
import torch
|
|
from torch.utils.data import Dataset as dataset
|
|
import torch.nn.functional as F
|
|
import glob
|
|
import os
|
|
from einops import rearrange
|
|
|
|
from tutils.nn.data import read, itk_to_np
|
|
|
|
from tqdm import tqdm
|
|
from tutils import tfilename, tdir
|
|
import random
|
|
# from .dataset2d import Dataset2D
|
|
from .dataset3d_2dmask import Dataset2D
|
|
from .dataset3d import Dataset3D
|
|
|
|
|
|
class DatasetMerged(dataset):
|
|
def __init__(self, config=None, is_train=True, getting_multi_mask=False) -> None:
|
|
super().__init__()
|
|
self.dataset2d = Dataset2D(dirpath="/quanquan/datasets/08_AbdomenCT-1K/", is_train=True)
|
|
self.dataset3d = Dataset3D(config=config, is_train=True)
|
|
self.len_2d = len(self.dataset2d)
|
|
self.len_3d = len(self.dataset3d)
|
|
|
|
def __getitem__(self, index, debug=False):
|
|
index = index % len(self)
|
|
# print("DEBUG! is_2d:", index < self.len_2d)
|
|
if index < self.len_2d:
|
|
return self.dataset2d.__getitem__(index)
|
|
else:
|
|
index = (index - self.len_2d) % self.len_3d
|
|
return self.dataset3d.__getitem__(index)
|
|
|
|
def __len__(self):
|
|
return len(self.dataset2d) + len(self.dataset3d) * 200
|
|
|
|
|
|
|
|
class TestsetMerged(dataset):
|
|
def __init__(self, config=None, is_train=False) -> None:
|
|
super().__init__()
|
|
self.dataset2d = Dataset2D(dirpath="/quanquan/datasets/08_AbdomenCT-1K/preprocessed/", is_train=False)
|
|
self.dataset3d = Dataset3D(config=config, is_train=False, split='val')
|
|
self.len_2d = len(self.dataset2d)
|
|
self.len_3d = len(self.dataset3d)
|
|
|
|
def __getitem__(self, index, debug=False):
|
|
index = index % len(self)
|
|
if index < self.len_2d:
|
|
return self.dataset2d.__getitem__(index)
|
|
else:
|
|
index = (index - self.len_2d) % self.len_3d
|
|
return self.dataset3d.__getitem__(index)
|
|
|
|
def __len__(self):
|
|
return len(self.dataset2d) + len(self.dataset3d) * 2
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from tutils import timer
|
|
from tutils.new.manager import trans_args, trans_init, ConfigManager
|
|
config = ConfigManager()
|
|
config.add_basic_config()
|
|
config.add_config("configs/vit_b.yaml")
|
|
dataset = DatasetMerged(config['dataset'])
|
|
tt = timer()
|
|
for i in range(20000,len(dataset)):
|
|
data = dataset.__getitem__(i)
|
|
print("time: ", tt()) |