97 lines
3.5 KiB
Python
97 lines
3.5 KiB
Python
import torch
|
|
import numpy as np
|
|
from tutils import tfilename
|
|
from tutils.nn.data import read, itk_to_np, np_to_itk, write
|
|
from torchvision.utils import save_image
|
|
import SimpleITK as sitk
|
|
from tutils.nn.data.tsitk.preprocess import resampleImage
|
|
|
|
|
|
class Data3dSolver:
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def simple_write(self, data_np, path="tmp.nii.gz", spacing=None):
|
|
assert len(data_np.shape) == 3, f"Got {data_np.shape}"
|
|
data_np = data_np.astype(np.int16)
|
|
data_itk = np_to_itk(data_np)
|
|
if spacing is not None:
|
|
data_itk.SetSpacing(spacing)
|
|
write(data_itk, path=tfilename(path))
|
|
print("Save to ", path)
|
|
|
|
def write_slices(self, data, path="tmp_masks.jpg"):
|
|
if isinstance(data, torch.Tensor):
|
|
pass
|
|
if isinstance(data, np.ndarray):
|
|
data = torch.Tensor(data)
|
|
assert len(data.shape) == 4, f"Shape should be (b c h w) c=1/3, Got {data.shape}"
|
|
assert data.shape[1] == 1 or data.shape[1] == 3, f"Shape should be (b c h w) c=1/3, Got {data.shape}"
|
|
assert path.endswith(".jpg") or path.endswith(".png")
|
|
save_image(torch.Tensor(data).unsqueeze(1), tfilename(path))
|
|
print("Save to ", path)
|
|
|
|
def write_multilabel_nii(self, data, path, meta=None):
|
|
if isinstance(data, dict):
|
|
data_all = [v for k,v in data.items()]
|
|
data = np.stack(data_all, axis=0)
|
|
assert len(data.shape) == 4, f"Shape should be (b c h w) , Got {data.shape}"
|
|
# Merge labels to one
|
|
merged = np.zeros_like(data[0])
|
|
for i, datai in enumerate(data):
|
|
merged = np.where(datai > 0, datai * (i+1), merged)
|
|
|
|
merged = merged.astype(np.int16)
|
|
data_itk = np_to_itk(merged)
|
|
if meta is not None:
|
|
data_itk = formalize(data_itk, meta)
|
|
write(data_itk, path=tfilename(path))
|
|
print("Save to ", path)
|
|
|
|
def fwrite(self, data, path, meta):
|
|
data = data.astype(np.int16)
|
|
data_itk = np_to_itk(data)
|
|
data_itk = formalize(data_itk, meta)
|
|
write(data_itk, path=tfilename(path))
|
|
|
|
def read(self, path, spacing_norm=True):
|
|
data_itk = read(path)
|
|
if spacing_norm:
|
|
ori_size = data_itk.GetSize()
|
|
ori_spacing = data_itk.GetSpacing()
|
|
data_itk = self.normalize_spacing(data_itk)
|
|
new_size = data_itk.GetSize()
|
|
new_spacing = data_itk.GetSpacing()
|
|
print("Change size from ", ori_size, new_size)
|
|
print("Change spacing from ", ori_spacing, new_spacing)
|
|
data_np = itk_to_np(data_itk)
|
|
print("[data_utils.DEBUG]", data_np.shape)
|
|
return data_np, data_itk.GetSpacing()
|
|
|
|
def normalize_spacing(self, data_itk):
|
|
spacing = data_itk.GetSpacing()
|
|
new_spacing = (min(spacing),min(spacing),min(spacing))
|
|
data_itk = resampleImage(data_itk, NewSpacing=new_spacing)
|
|
return data_itk
|
|
|
|
|
|
def formalize(img:sitk.SimpleITK.Image, meta:sitk.SimpleITK.Image):
|
|
# Size = meta.GetSize()
|
|
Spacing = meta.GetSpacing()
|
|
Origin = meta.GetOrigin()
|
|
Direction = meta.GetDirection()
|
|
|
|
img.SetSpacing(Spacing)
|
|
img.SetOrigin(Origin)
|
|
img.SetDirection(Direction)
|
|
return img
|
|
|
|
|
|
def write(img:sitk.SimpleITK.Image, path:str, mode:str="nifti"):
|
|
"""
|
|
Path: (example) os.path.join(jpg_dir, f"trans_{random_name}.nii.gz")
|
|
"""
|
|
mode = mode.lower()
|
|
writer = sitk.ImageFileWriter()
|
|
writer.SetFileName(path)
|
|
writer.Execute(img) |