205 lines
7.5 KiB
Python
205 lines
7.5 KiB
Python
|
|
import numpy as np
|
|
import torch
|
|
|
|
import math
|
|
from copy import deepcopy
|
|
from itertools import product
|
|
from typing import Any, Dict, Generator, ItemsView, List, Tuple
|
|
|
|
|
|
class MaskData3d:
|
|
def __init__(self, size, **kwargs) -> None:
|
|
self.size = size
|
|
for v in kwargs.values():
|
|
assert isinstance(
|
|
v, (list, np.ndarray, torch.Tensor)
|
|
), "MaskData only supports list, numpy arrays, and torch tensors."
|
|
self._stats = dict(**kwargs)
|
|
|
|
def __setitem__(self, slice_id, num, item):
|
|
assert isinstance(
|
|
item, (list, np.ndarray, torch.Tensor)
|
|
), "MaskData only supports list, numpy arrays, and torch tensors."
|
|
key = str(slice_id) + str(num)
|
|
if self._stats.get(key, None) is None:
|
|
self._stats[key] = np.zeros(self.size)
|
|
self._stats[key][slice_id-1:slice_id+2] = item
|
|
|
|
def __delitem__(self, key: str) -> None:
|
|
del self._stats[key]
|
|
|
|
def __getitem__(self, key: str) -> Any:
|
|
return self._stats[key]
|
|
|
|
def items(self) -> ItemsView[str, Any]:
|
|
return self._stats.items()
|
|
|
|
def merge(self, slice_ids, num, item):
|
|
pass
|
|
|
|
|
|
def build_all_layer_point_grids(
|
|
n_per_side: int = 32, n_layers: int = 0, scale_per_layer: int = 1) -> List[np.ndarray]:
|
|
"""Generates point grids for all crop layers."""
|
|
points_by_layer = []
|
|
for i in range(n_layers + 1):
|
|
n_points = int(n_per_side / (scale_per_layer**i))
|
|
points_by_layer.append(build_point_grid(n_points))
|
|
return points_by_layer
|
|
|
|
|
|
def build_point_grid(n_per_side: int, size) -> np.ndarray:
|
|
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
|
offset = 1 / (2 * n_per_side)
|
|
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
|
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
|
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
|
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
|
return points * np.array(size)
|
|
|
|
# def calculate_stability_score(
|
|
# masks: torch.Tensor, mask_threshold: float, threshold_offset: float
|
|
# ) -> torch.Tensor:
|
|
# """
|
|
# Computes the stability score for a batch of masks. The stability
|
|
# score is the IoU between the binary masks obtained by thresholding
|
|
# the predicted mask logits at high and low values.
|
|
# """
|
|
# # One mask is always contained inside the other.
|
|
# # Save memory by preventing unnecessary cast to torch.int64
|
|
# intersections = (
|
|
# (masks > (mask_threshold + threshold_offset))
|
|
# .sum(-1, dtype=torch.int16)
|
|
# .sum(-1, dtype=torch.int32)
|
|
# )
|
|
# unions = (
|
|
# (masks > (mask_threshold - threshold_offset))
|
|
# .sum(-1, dtype=torch.int16)
|
|
# .sum(-1, dtype=torch.int32)
|
|
# )
|
|
# return intersections / unions
|
|
|
|
|
|
def calculate_stability_score_3d(
|
|
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
|
|
) -> torch.Tensor:
|
|
"""
|
|
Computes the stability score for a batch of masks. The stability
|
|
score is the IoU between the binary masks obtained by thresholding
|
|
the predicted mask logits at high and low values.
|
|
"""
|
|
# One mask is always contained inside the other.
|
|
# Save memory by preventing unnecessary cast to torch.int64
|
|
intersections = (
|
|
(masks > (mask_threshold + threshold_offset))
|
|
.sum(-1, dtype=torch.int16)
|
|
.sum(-1, dtype=torch.int32)
|
|
.sum(-1, dtype=torch.int32)
|
|
)
|
|
# intersections = intersections2d.sum(-1, dtype=torch.int32)
|
|
unions = (
|
|
(masks > (mask_threshold - threshold_offset))
|
|
.sum(-1, dtype=torch.int16)
|
|
.sum(-1, dtype=torch.int32)
|
|
.sum(-1, dtype=torch.int32)
|
|
)
|
|
return (intersections / unions)
|
|
|
|
|
|
def remove_small_regions(
|
|
mask: np.ndarray, area_thresh: float, mode: str
|
|
) -> Tuple[np.ndarray, bool]:
|
|
"""
|
|
Removes small disconnected regions and holes in a mask. Returns the
|
|
mask and an indicator of if the mask has been modified.
|
|
"""
|
|
import cv2 # type: ignore
|
|
|
|
assert mode in ["holes", "islands"]
|
|
correct_holes = mode == "holes"
|
|
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
|
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
|
sizes = stats[:, -1][1:] # Row 0 is background label
|
|
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
|
if len(small_regions) == 0:
|
|
return mask, False
|
|
fill_labels = [0] + small_regions
|
|
if not correct_holes:
|
|
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
|
# If every region is below threshold, keep largest
|
|
if len(fill_labels) == 0:
|
|
fill_labels = [int(np.argmax(sizes)) + 1]
|
|
mask = np.isin(regions, fill_labels)
|
|
return mask, True
|
|
|
|
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
|
assert len(args) > 0 and all(
|
|
len(a) == len(args[0]) for a in args
|
|
), "Batched iteration must have inputs of all the same size."
|
|
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
|
for b in range(n_batches):
|
|
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
|
|
|
|
|
class MaskData:
|
|
"""
|
|
A structure for storing masks and their related data in batched format.
|
|
Implements basic filtering and concatenation.
|
|
"""
|
|
|
|
def __init__(self, **kwargs) -> None:
|
|
for v in kwargs.values():
|
|
assert isinstance(
|
|
v, (list, np.ndarray, torch.Tensor)
|
|
), "MaskData only supports list, numpy arrays, and torch tensors."
|
|
self._stats = dict(**kwargs)
|
|
|
|
def __setitem__(self, key: str, item: Any) -> None:
|
|
assert isinstance(
|
|
item, (list, np.ndarray, torch.Tensor)
|
|
), "MaskData only supports list, numpy arrays, and torch tensors."
|
|
self._stats[key] = item
|
|
|
|
def __delitem__(self, key: str) -> None:
|
|
del self._stats[key]
|
|
|
|
def __getitem__(self, key: str) -> Any:
|
|
return self._stats[key]
|
|
|
|
def items(self) -> ItemsView[str, Any]:
|
|
return self._stats.items()
|
|
|
|
def filter(self, keep: torch.Tensor) -> None:
|
|
for k, v in self._stats.items():
|
|
if v is None:
|
|
self._stats[k] = None
|
|
elif isinstance(v, torch.Tensor):
|
|
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
|
|
elif isinstance(v, np.ndarray):
|
|
self._stats[k] = v[keep.detach().cpu().numpy()]
|
|
elif isinstance(v, list) and keep.dtype == torch.bool:
|
|
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
|
|
elif isinstance(v, list):
|
|
self._stats[k] = [v[i] for i in keep]
|
|
else:
|
|
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
|
|
|
def cat(self, new_stats: "MaskData") -> None:
|
|
for k, v in new_stats.items():
|
|
if k not in self._stats or self._stats[k] is None:
|
|
self._stats[k] = deepcopy(v)
|
|
elif isinstance(v, torch.Tensor):
|
|
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
|
|
elif isinstance(v, np.ndarray):
|
|
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
|
|
elif isinstance(v, list):
|
|
self._stats[k] = self._stats[k] + deepcopy(v)
|
|
else:
|
|
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
|
|
|
def to_numpy(self) -> None:
|
|
for k, v in self._stats.items():
|
|
if isinstance(v, torch.Tensor):
|
|
self._stats[k] = v.detach().cpu().numpy()
|