Source code for torchaug.ta_tensors.nested._mask
# ==================================
# Copyright: CEA-LIST/DIASI/SIALV/
# Author : Torchaug Developers
# License: CECILL-C
# ==================================
from __future__ import annotations
from typing import (
List,
)
from torchaug.ta_tensors import (
BatchMasks,
Mask,
)
from torchaug.ta_tensors._batch_masks import convert_masks_to_batch_masks
from ._ta_nested_tensors import TANestedTensors
[docs]
class MaskNestedTensors(TANestedTensors[Mask, BatchMasks]):
"""Implement Masks Nested Tensor for PyTorch."""
tensors_type = Mask
batch_tensors_type = BatchMasks
tensors: List[Mask]
[docs]
def to_batch(self) -> BatchMasks:
"""Return the batched mask of the nested masks."""
return convert_masks_to_batch_masks(self.tensors)