BatchLabels

class torchaug.ta_tensors.BatchLabels(data: Any, *, samples_ranges: List[Tuple[int, int]], dtype: dtype | None = None, device: device | str | int | None = None, requires_grad: bool | None = None)[source]

BatchLabels subclass for concatenated labels.

Useful for labels of bounding boxes or masks, where each sample can have a different number of labels.

Parameters:
  • data – Any data that can be turned into a tensor with torch.as_tensor().

  • dtype – Desired data type. If omitted, will be inferred from data.

  • samples_ranges – Each element is the range of the indices of the labels for each sample.

  • device – Desired device. If omitted and data is a torch.Tensor, the device is taken from it. Otherwise, the batch of tensor is constructed on the CPU.

  • requires_grad – Whether autograd should record operations. If omitted and data is a Labels, the value is taken from it. Otherwise, defaults to False.

classmethod cat(labels_batches)[source]

Concatenates a sequence of BatchLabels along the first dimension.

Parameters:

labels_batches (Sequence[BatchLabels]) – A sequence of BatchLabels to concatenate.

Returns:

The concatenated BatchLabels.

get_chunk(chunk_indices)[source]

Get a chunk of the batch of tensors.

Parameters:

chunk_indices (Tensor) – The indices of the chunk to get.

Return type:

BatchLabels

Returns:

The chunk of the batch of tensors.

get_sample(idx)[source]

Get the tensors for a sample in the batch.

Parameters:

idx (int) – The index of the sample to get.

Return type:

Labels

Returns:

The tensors for the sample.

classmethod masked_select(labels, mask)[source]

Remove labels from the batch of labels.

Parameters:
  • labels (BatchLabels) – The batch of labels to remove labels from.

  • mask (Tensor) – A boolean mask to keep labels.

Return type:

BatchLabels

Returns:

The updated batch of labels.

to_samples()[source]

Get the tensors.

Return type:

list[Labels]

update_chunk_(chunk, chunk_indices)[source]

Update a chunk of the batch of labels.

Parameters:
  • chunk (BatchLabels) – The chunk update.

  • chunk_indices (Tensor) – The indices of the chunk to update.

Return type:

BatchLabels

Returns:

The updated batch of labels.