Slide-SAM/trans_utils/data_utils.py
transcendentsky e04459c6fe first commit
2023-12-05 14:58:38 +08:00

97 lines
3.5 KiB
Python

import torch
import numpy as np
from tutils import tfilename
from tutils.nn.data import read, itk_to_np, np_to_itk, write
from torchvision.utils import save_image
import SimpleITK as sitk
from tutils.nn.data.tsitk.preprocess import resampleImage
class Data3dSolver:
def __init__(self) -> None:
pass
def simple_write(self, data_np, path="tmp.nii.gz", spacing=None):
assert len(data_np.shape) == 3, f"Got {data_np.shape}"
data_np = data_np.astype(np.int16)
data_itk = np_to_itk(data_np)
if spacing is not None:
data_itk.SetSpacing(spacing)
write(data_itk, path=tfilename(path))
print("Save to ", path)
def write_slices(self, data, path="tmp_masks.jpg"):
if isinstance(data, torch.Tensor):
pass
if isinstance(data, np.ndarray):
data = torch.Tensor(data)
assert len(data.shape) == 4, f"Shape should be (b c h w) c=1/3, Got {data.shape}"
assert data.shape[1] == 1 or data.shape[1] == 3, f"Shape should be (b c h w) c=1/3, Got {data.shape}"
assert path.endswith(".jpg") or path.endswith(".png")
save_image(torch.Tensor(data).unsqueeze(1), tfilename(path))
print("Save to ", path)
def write_multilabel_nii(self, data, path, meta=None):
if isinstance(data, dict):
data_all = [v for k,v in data.items()]
data = np.stack(data_all, axis=0)
assert len(data.shape) == 4, f"Shape should be (b c h w) , Got {data.shape}"
# Merge labels to one
merged = np.zeros_like(data[0])
for i, datai in enumerate(data):
merged = np.where(datai > 0, datai * (i+1), merged)
merged = merged.astype(np.int16)
data_itk = np_to_itk(merged)
if meta is not None:
data_itk = formalize(data_itk, meta)
write(data_itk, path=tfilename(path))
print("Save to ", path)
def fwrite(self, data, path, meta):
data = data.astype(np.int16)
data_itk = np_to_itk(data)
data_itk = formalize(data_itk, meta)
write(data_itk, path=tfilename(path))
def read(self, path, spacing_norm=True):
data_itk = read(path)
if spacing_norm:
ori_size = data_itk.GetSize()
ori_spacing = data_itk.GetSpacing()
data_itk = self.normalize_spacing(data_itk)
new_size = data_itk.GetSize()
new_spacing = data_itk.GetSpacing()
print("Change size from ", ori_size, new_size)
print("Change spacing from ", ori_spacing, new_spacing)
data_np = itk_to_np(data_itk)
print("[data_utils.DEBUG]", data_np.shape)
return data_np, data_itk.GetSpacing()
def normalize_spacing(self, data_itk):
spacing = data_itk.GetSpacing()
new_spacing = (min(spacing),min(spacing),min(spacing))
data_itk = resampleImage(data_itk, NewSpacing=new_spacing)
return data_itk
def formalize(img:sitk.SimpleITK.Image, meta:sitk.SimpleITK.Image):
# Size = meta.GetSize()
Spacing = meta.GetSpacing()
Origin = meta.GetOrigin()
Direction = meta.GetDirection()
img.SetSpacing(Spacing)
img.SetOrigin(Origin)
img.SetDirection(Direction)
return img
def write(img:sitk.SimpleITK.Image, path:str, mode:str="nifti"):
"""
Path: (example) os.path.join(jpg_dir, f"trans_{random_name}.nii.gz")
"""
mode = mode.lower()
writer = sitk.ImageFileWriter()
writer.SetFileName(path)
writer.Execute(img)