BatchMasks¶
- class torchaug.ta_tensors.BatchMasks(data: Any, *, samples_ranges: List[Tuple[int, int]], dtype: dtype | None = None, device: device | str | int | None = None, requires_grad: bool | None = None)[source]¶
torch.Tensorsubclass for batch of segmentation and detection masks.- Parameters:
data – Any data that can be turned into a tensor with
torch.as_tensor().dtype – Desired data type. If omitted, will be inferred from
data.samples_ranges – Each element is the range of the indices of the masks for each sample.
device – Desired device. If omitted and
datais atorch.Tensor, the device is taken from it. Otherwise, the mask is constructed on the CPU.requires_grad – Whether autograd should record operations. If omitted and
datais atorch.Tensor, the value is taken from it. Otherwise, defaults toFalse.
- classmethod cat(masks_batches)[source]¶
Concatenates a sequence of
BatchMasksalong the first dimension.- Parameters:
masks_batches (
Sequence[BatchMasks]) – A sequence ofBatchMasksto concatenate.- Returns:
The concatenated
BatchMasks.
- get_chunk(chunk_indices)[source]¶
Get a chunk of the batch of masks.
- Parameters:
chunk_indices (
Tensor) – The indices of the chunk to get.- Return type:
- Returns:
The chunk of the batch of masks.
- classmethod masked_select(masks, mask)[source]¶
Remove masks from the batch of masks.
- Parameters:
masks (
BatchMasks) – The batch of masks to remove masks from.mask (
Tensor) – A boolean mask to keep masks.
- Return type:
- Returns:
The updated batch of masks.
- update_chunk_(chunk, chunk_indices)[source]¶
Update a chunk of the batch of masks.
- Parameters:
chunk (
BatchMasks) – The chunk update.chunk_indices (
Tensor) – The indices of the chunk to update.
- Return type:
- Returns:
The updated batch of masks.