import numpy as np import torch import math from copy import deepcopy from itertools import product from typing import Any, Dict, Generator, ItemsView, List, Tuple class MaskData3d: def __init__(self, size, **kwargs) -> None: self.size = size for v in kwargs.values(): assert isinstance( v, (list, np.ndarray, torch.Tensor) ), "MaskData only supports list, numpy arrays, and torch tensors." self._stats = dict(**kwargs) def __setitem__(self, slice_id, num, item): assert isinstance( item, (list, np.ndarray, torch.Tensor) ), "MaskData only supports list, numpy arrays, and torch tensors." key = str(slice_id) + str(num) if self._stats.get(key, None) is None: self._stats[key] = np.zeros(self.size) self._stats[key][slice_id-1:slice_id+2] = item def __delitem__(self, key: str) -> None: del self._stats[key] def __getitem__(self, key: str) -> Any: return self._stats[key] def items(self) -> ItemsView[str, Any]: return self._stats.items() def merge(self, slice_ids, num, item): pass def build_all_layer_point_grids( n_per_side: int = 32, n_layers: int = 0, scale_per_layer: int = 1) -> List[np.ndarray]: """Generates point grids for all crop layers.""" points_by_layer = [] for i in range(n_layers + 1): n_points = int(n_per_side / (scale_per_layer**i)) points_by_layer.append(build_point_grid(n_points)) return points_by_layer def build_point_grid(n_per_side: int, size) -> np.ndarray: """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" offset = 1 / (2 * n_per_side) points_one_side = np.linspace(offset, 1 - offset, n_per_side) points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) points_y = np.tile(points_one_side[:, None], (1, n_per_side)) points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) return points * np.array(size) # def calculate_stability_score( # masks: torch.Tensor, mask_threshold: float, threshold_offset: float # ) -> torch.Tensor: # """ # Computes the stability score for a batch of masks. The stability # score is the IoU between the binary masks obtained by thresholding # the predicted mask logits at high and low values. # """ # # One mask is always contained inside the other. # # Save memory by preventing unnecessary cast to torch.int64 # intersections = ( # (masks > (mask_threshold + threshold_offset)) # .sum(-1, dtype=torch.int16) # .sum(-1, dtype=torch.int32) # ) # unions = ( # (masks > (mask_threshold - threshold_offset)) # .sum(-1, dtype=torch.int16) # .sum(-1, dtype=torch.int32) # ) # return intersections / unions def calculate_stability_score_3d( masks: torch.Tensor, mask_threshold: float, threshold_offset: float ) -> torch.Tensor: """ Computes the stability score for a batch of masks. The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high and low values. """ # One mask is always contained inside the other. # Save memory by preventing unnecessary cast to torch.int64 intersections = ( (masks > (mask_threshold + threshold_offset)) .sum(-1, dtype=torch.int16) .sum(-1, dtype=torch.int32) .sum(-1, dtype=torch.int32) ) # intersections = intersections2d.sum(-1, dtype=torch.int32) unions = ( (masks > (mask_threshold - threshold_offset)) .sum(-1, dtype=torch.int16) .sum(-1, dtype=torch.int32) .sum(-1, dtype=torch.int32) ) return (intersections / unions) def remove_small_regions( mask: np.ndarray, area_thresh: float, mode: str ) -> Tuple[np.ndarray, bool]: """ Removes small disconnected regions and holes in a mask. Returns the mask and an indicator of if the mask has been modified. """ import cv2 # type: ignore assert mode in ["holes", "islands"] correct_holes = mode == "holes" working_mask = (correct_holes ^ mask).astype(np.uint8) n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) sizes = stats[:, -1][1:] # Row 0 is background label small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] if len(small_regions) == 0: return mask, False fill_labels = [0] + small_regions if not correct_holes: fill_labels = [i for i in range(n_labels) if i not in fill_labels] # If every region is below threshold, keep largest if len(fill_labels) == 0: fill_labels = [int(np.argmax(sizes)) + 1] mask = np.isin(regions, fill_labels) return mask, True def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: assert len(args) > 0 and all( len(a) == len(args[0]) for a in args ), "Batched iteration must have inputs of all the same size." n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) for b in range(n_batches): yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] class MaskData: """ A structure for storing masks and their related data in batched format. Implements basic filtering and concatenation. """ def __init__(self, **kwargs) -> None: for v in kwargs.values(): assert isinstance( v, (list, np.ndarray, torch.Tensor) ), "MaskData only supports list, numpy arrays, and torch tensors." self._stats = dict(**kwargs) def __setitem__(self, key: str, item: Any) -> None: assert isinstance( item, (list, np.ndarray, torch.Tensor) ), "MaskData only supports list, numpy arrays, and torch tensors." self._stats[key] = item def __delitem__(self, key: str) -> None: del self._stats[key] def __getitem__(self, key: str) -> Any: return self._stats[key] def items(self) -> ItemsView[str, Any]: return self._stats.items() def filter(self, keep: torch.Tensor) -> None: for k, v in self._stats.items(): if v is None: self._stats[k] = None elif isinstance(v, torch.Tensor): self._stats[k] = v[torch.as_tensor(keep, device=v.device)] elif isinstance(v, np.ndarray): self._stats[k] = v[keep.detach().cpu().numpy()] elif isinstance(v, list) and keep.dtype == torch.bool: self._stats[k] = [a for i, a in enumerate(v) if keep[i]] elif isinstance(v, list): self._stats[k] = [v[i] for i in keep] else: raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") def cat(self, new_stats: "MaskData") -> None: for k, v in new_stats.items(): if k not in self._stats or self._stats[k] is None: self._stats[k] = deepcopy(v) elif isinstance(v, torch.Tensor): self._stats[k] = torch.cat([self._stats[k], v], dim=0) elif isinstance(v, np.ndarray): self._stats[k] = np.concatenate([self._stats[k], v], axis=0) elif isinstance(v, list): self._stats[k] = self._stats[k] + deepcopy(v) else: raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") def to_numpy(self) -> None: for k, v in self._stats.items(): if isinstance(v, torch.Tensor): self._stats[k] = v.detach().cpu().numpy()