Source code for torchaug.ta_tensors._batch_masks

# @Copyright: CEA-LIST/DIASI/SIALV/ (2023-    )
# @Author: CEA-LIST/DIASI/SIALV/ <julien.denize@cea.fr>
# @License: CECILL-C

from __future__ import annotations

from typing import Any, List, Mapping, Optional, Sequence, Union

import torch
from torch import Tensor
from torch.utils._pytree import tree_flatten

from ._mask import Mask
from ._ta_tensor import TATensor


[docs] def convert_masks_to_batch_masks( masks: Sequence[Mask], ) -> BatchMasks: """Convert a sequence of :class:`~torchaug.ta_tensors.Mask` objects to a :class:`~torchaug.ta_tensors.BatchMasks` object. Assumes all masks are valid. """ attrs = [ "requires_grad", "device", "dtype", ] for mask in masks: if not mask.shape[-2:] == masks[0].shape[-2:]: raise ValueError("All masks must have the same size.") for attr in attrs: if getattr(mask, attr) != getattr(masks[0], attr): raise ValueError(f"All masks must have the same {attr} attribute.") masks_data = torch.cat([mask.data for mask in masks]) idx_sample = torch.tensor([0] + [mask.shape[0] for mask in masks], dtype=torch.long).cumsum(0).tolist() batch_masks = BatchMasks( masks_data, idx_sample=idx_sample, ) return batch_masks
[docs] def convert_batch_masks_to_masks( batch_masks: BatchMasks, ) -> List[Mask]: """Convert :class:`~torchaug.ta_tensors.BatchMasks` object to a list of :class:`~torchaug.ta_tensors.Mask` objects. """ idx_sample = batch_masks.idx_sample list_masks = [ Mask( batch_masks[idx_sample[i] : idx_sample[i + 1]], ) for i in range(len(idx_sample) - 1) ] return list_masks
[docs] class BatchMasks(TATensor): """:class:`torch.Tensor` subclass for batch of segmentation and detection masks. Args: data: Any data that can be turned into a tensor with :func:`torch.as_tensor`. dtype: Desired data type. If omitted, will be inferred from ``data``. idx_sample: Each element is the index of the first mask of the corresponding sample in the batch of N samples. Contains N+1 elements whose last value is the number of masks. device: Desired device. If omitted and ``data`` is a :class:`torch.Tensor`, the device is taken from it. Otherwise, the mask is constructed on the CPU. requires_grad: Whether autograd should record operations. If omitted and ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``. """ idx_sample: List[int] @property def batch_size(self) -> int: return len(self.idx_sample) - 1 @property def num_masks(self) -> int: return self.data.shape[0] def get_num_masks_sample(self, idx: int) -> int: return self.idx_sample[idx + 1] - self.idx_sample[idx]
[docs] @classmethod def cat(cls, masks_batches: Sequence[BatchMasks]): """Concatenates a sequence of :class:`~torchaug.ta_tensors.BatchMasks` along the first dimension. Args: masks_batches: A sequence of :class:`~torchaug.ta_tensors.BatchMasks` to concatenate. Returns: The concatenated :class:`~torchaug.ta_tensors.BatchMasks`. """ attrs = [ "requires_grad", "device", "dtype", ] for batch_mask in masks_batches: if not isinstance(batch_mask, BatchMasks): raise ValueError("All batches must be of type BatchMasks.") if not batch_mask.shape[-2:] == masks_batches[0].shape[-2:]: raise ValueError("All batches of masks must have the same size.") for attr in attrs: if getattr(batch_mask, attr) != getattr(masks_batches[0], attr): raise ValueError(f"All batches of masks must have the same {attr} attribute.") idx_sample = ( torch.tensor( [0] + [ batch_mask.get_num_masks_sample(i) for batch_mask in masks_batches for i in range(batch_mask.batch_size) ] ) .cumsum(0) .tolist() ) data = torch.cat([mask.data for mask in masks_batches], 0) return cls( data, idx_sample=idx_sample, )
@classmethod def _wrap( cls, tensor: Tensor, *, idx_sample: List[int], check_dims: bool = True, ) -> BatchMasks: # type: ignore[override] if check_dims and tensor.ndim < 2: raise ValueError(f"Expected at least a 2D tensor, got {tensor.ndim}D tensor") batch_masks = tensor.as_subclass(cls) batch_masks.idx_sample = idx_sample return batch_masks def __new__( cls, data: Any, *, idx_sample: List[int], dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, requires_grad: Optional[bool] = None, ) -> BatchMasks: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) if tensor.ndim < 2: raise ValueError return cls._wrap(tensor, idx_sample=idx_sample) @classmethod def _wrap_output( cls, output: torch.Tensor, args: Sequence[Any] = (), kwargs: Optional[Mapping[str, Any]] = None, ) -> BatchMasks: # If there are BatchMasks instances in the output, their metadata got lost when we called # super().__torch_function__. We need to restore the metadata somehow, so we choose to take # the metadata from the first mask in the parameters. # This should be what we want in most cases. When it's not, it's probably a mis-use anyway, e.g. # something like mask_N1HW + mask_N2HW; we don't guard against those cases. flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator] first_batch_masks_from_args = next(x for x in flat_params if isinstance(x, BatchMasks)) idx_sample = first_batch_masks_from_args.idx_sample.copy() # clone the list. if isinstance(output, torch.Tensor) and not isinstance(output, BatchMasks): output = BatchMasks._wrap( output, idx_sample=idx_sample, check_dims=False, ) elif isinstance(output, (tuple, list)): output = type(output)( BatchMasks._wrap( part, idx_sample=idx_sample, check_dims=False, ) for part in output ) return output
[docs] def get_sample(self, idx: int) -> Mask: """Get the masks for a sample in the batch. Args: idx: The index of the sample to get. Returns: The masks for the sample. """ masks = self[self.idx_sample[idx] : self.idx_sample[idx + 1]] return Mask( masks, device=self.device, requires_grad=self.requires_grad, )
[docs] def get_chunk(self, chunk_indices: torch.Tensor) -> BatchMasks: """Get a chunk of the batch of masks. Args: chunk_indices: The indices of the chunk to get. Returns: The chunk of the batch of masks. """ chunk_idx_sample = torch.tensor( [0] + [self.idx_sample[chunk_indice + 1] - self.idx_sample[chunk_indice] for chunk_indice in chunk_indices] ) chunk_idx_sample = chunk_idx_sample.cumsum(0).tolist() return BatchMasks( self[chunk_indices], idx_sample=chunk_idx_sample, device=self.device, requires_grad=self.requires_grad, )
[docs] def update_chunk_(self, chunk: BatchMasks, chunk_indices: torch.Tensor) -> BatchMasks: """Update a chunk of the batch of masks. Args: chunk: The chunk update. chunk_indices: The indices of the chunk to update. Returns: The updated batch of masks. """ self[chunk_indices] = chunk return self
[docs] @classmethod def masked_remove(cls, masks: BatchMasks, mask: torch.Tensor) -> BatchMasks: """Remove masks from the batch of masks. Args: masks: The batch of masks to remove masks from. mask: A boolean mask to keep masks. Returns: The updated batch of masks. """ old_idx_sample = masks.idx_sample data = masks.data[~mask] cpu_mask = mask.cpu() num_delete_per_sample = [ cpu_mask[old_idx_sample[i] : old_idx_sample[i + 1]].sum().item() for i in range(len(old_idx_sample) - 1) ] new_idx_sample = [old_idx_sample[i] - sum(num_delete_per_sample[: i + 1]) for i in range(len(old_idx_sample))] return cls._wrap( data, idx_sample=new_idx_sample, )