Source code for torchaug.transforms._transform

# @Copyright: CEA-LIST/DIASI/SIALV/ (2023-    )
# @Author: CEA-LIST/DIASI/SIALV/ <julien.denize@cea.fr>
# @License: CECILL-C
#
# Code partially based on Torchvision (BSD 3-Clause License), available at:
#   https://github.com/pytorch/vision

from __future__ import annotations

import enum
from math import ceil, floor
from typing import Any, Callable, Dict, List, Tuple, Type, Union

import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.transforms.v2._utils import check_type, has_any

from torchaug import ta_tensors
from torchaug._utils import _log_api_usage_once
from torchaug.ta_tensors import set_return_type

from ._utils import is_pure_tensor
from .functional._utils._kernel import _get_kernel


[docs] class RandomApplyTransform(nn.Module): """Base class for all randomly applied transforms. For more details, please see :ref:`tutorial/transforms:Transforms Tutorial`. Args: p: The probability of applying the transform. batch_inplace: whether to apply the batch transform in-place. Does not prevent functionals to make copy but can reduce time and memory consumption. num_chunks: number of chunks to split the batched input into. permute_chunks: whether to permute the chunks. batch_transform: whether to apply the transform in batch mode. """ _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor,) def __init__( self, p: float = 0.5, batch_inplace: bool = False, num_chunks: int = 1, permute_chunks: bool = False, batch_transform: bool = False, ) -> None: if not (0.0 <= p <= 1.0): raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].") elif p > 0 and p < 1 and self._reshape_transform and batch_transform: raise ValueError("`p` should be 0 or 1 if `_reshape_transform` is True and `batch_transform` is True.") if batch_inplace and self._reshape_transform: raise ValueError("`inplace` should be False if `_reshape_transform` is True.") super().__init__() _log_api_usage_once(self) self.batch_inplace = batch_inplace self.permute_chunks = permute_chunks self.p = p self.batch_transform = batch_transform self.num_chunks = num_chunks self._receive_flatten_inputs = False @property def _reshape_transform(self) -> bool: return False @property def num_chunks(self) -> int: """Get the number of chunks to split the input into. Some subclasses can have a specific logic to determine the number of chunks. """ return self._num_chunks @num_chunks.setter def num_chunks(self, num_chunks) -> None: """Get the number of chunks to split the input into. Some subclasses can have a specific logic to determine the number of chunks. """ if (num_chunks == -1 or num_chunks > 1) and not self.batch_transform: raise ValueError("`num_chunks` should be 1 if `batch_transform` is False.") elif num_chunks < -1 or num_chunks == 0: raise ValueError("`num_chunks` should be greater than 0 or -1.") self._num_chunks = num_chunks @staticmethod def _get_input_batch_size(inpt: Any): if isinstance(inpt, (ta_tensors.BatchBoundingBoxes, ta_tensors.BatchMasks)): batch_size = inpt.batch_size elif isinstance(inpt, torch.Tensor): batch_size = inpt.shape[0] else: raise ValueError( f"Expected input to be of type `BatchBoundingBoxes`, `BatchMasks` or `Tensor`, but got {type(inpt)}." ) return batch_size def _get_chunks_indices(self, batch_size: int, num_chunks: int, device: torch.device) -> Tuple[torch.Tensor, ...]: if num_chunks <= 0: raise ValueError("`num_chunks` should be greater than 0.") elif num_chunks > batch_size: raise ValueError( f"`num_chunks` should be less than or equal to the batch size, but got {num_chunks} " f"and batch size {batch_size}." ) elif num_chunks == 1: return (torch.arange(0, batch_size, device=device),) if self.permute_chunks: indices = torch.randperm(batch_size, device=device) else: indices = torch.arange(0, batch_size, device=device) return indices.chunk(num_chunks) def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]: # Below is a heuristic on how to deal with pure tensor inputs: # 1. Pure tensors, i.e. tensors that are not a ta_tensor, are passed through if there is an explicit image # (`ta_tensors.Image`, `ta_tensors.BatchImages`) or video (`ta_tensors.Video`, `ta_tensors.BatchVideos`) # in the sample. # 2. If there is no explicit image or video in the sample, only the first encountered pure tensor is # transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs` # of `tree_flatten`, which recurses depth-first through the input. # # This heuristic stems from two requirements: # 1. We need to keep BC for single input pure tensors and treat them as images. # 2. We don't want to treat all pure tensors as images, because some datasets like `CelebA` or `Widerface` # return supplemental numerical data as tensors that cannot be transformed as images. # # The heuristic should work well for most people in practice. The only case where it doesn't is if someone # tries to transform multiple pure tensors at the same time, expecting them all to be treated as images. # However, this case wasn't supported by transforms v1 either, so there is no BC concern. needs_transform_list = [] if self.batch_transform: transform_pure_tensor = not has_any( flat_inputs, ta_tensors.BatchImages, ta_tensors.BatchVideos, ) else: transform_pure_tensor = not has_any( flat_inputs, ta_tensors.Image, ta_tensors.BatchImages, ta_tensors.Video, ta_tensors.BatchVideos, ) for inpt in flat_inputs: needs_transform = True if not check_type(inpt, self._transformed_types): needs_transform = False elif is_pure_tensor(inpt): if transform_pure_tensor: transform_pure_tensor = False else: needs_transform = False needs_transform_list.append(needs_transform) return needs_transform_list def _get_params( self, flat_inputs: List[Any], num_chunks: int, chunks_indices: Tuple[torch.Tensor], ) -> List[Dict[str, Any]]: return [{} for _ in range(num_chunks)] def _get_indices_transform(self, batch_size: int, device: torch.device) -> torch.Tensor: p_mul_batch_size = self.p * batch_size floor_apply = floor(p_mul_batch_size) ceil_apply = ceil(p_mul_batch_size) # If 0 < p_mul_batch_size < 1, then only one element from input is augmented # with p probability. if floor_apply == 0 or ceil_apply == 0: num_transform = 1 if torch.rand(1).item() < self.p else 0 elif floor_apply == ceil_apply: num_transform = floor_apply # If p_mul_batch_size is rational, then upper or lower integer p_mul_batch_size # elements from input are augmented randomly depending with the decimal. else: decimal = p_mul_batch_size % 1 num_transform = floor_apply if decimal < torch.rand(1).item() else ceil_apply # If no augmentation return the output directly, keep consistency of inplace. if num_transform == 0: return torch.empty(0, device=device, dtype=torch.long) elif num_transform == 1: indices_transform = torch.randint(0, batch_size, (1,), device=device) elif num_transform > 1: indices_transform = torch.randperm(batch_size, device=device)[:num_transform] return indices_transform def _check_inputs(self, flat_inputs: List[Any]) -> None: pass def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: kernel = _get_kernel(functional, type(inpt), allow_passthrough=True) return kernel(inpt, *args, **kwargs) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: raise NotImplementedError def forward_single(self, flat_inputs: List[Any]) -> List[Any]: if self.p == 1.0: pass elif self.p == 0.0 or torch.rand(1) >= self.p: return flat_inputs needs_transform_list = self._needs_transform_list(flat_inputs) params = self._get_params( [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform], num_chunks=1, chunks_indices=(torch.tensor([0], device=flat_inputs[0].device),), )[0] flat_outputs = [ self._transform(inpt, params) if needs_transform else inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) ] return flat_outputs def forward_batch(self, flat_inputs: List[Any]) -> List[Any]: if self.p == 0: # if p is 0, return the input directly after checking the input return flat_inputs needs_transform_list = self._needs_transform_list(flat_inputs) batch_size = self._get_input_batch_size(flat_inputs[0]) if self.p == 1: # if p is 1, transform all inputs transform_all = True else: indices_transform = self._get_indices_transform( batch_size, flat_inputs[0].device, ) transform_all = indices_transform.shape[0] == batch_size if not transform_all and indices_transform.shape[0] == 0: # if no augmentation return the inputs directly. return flat_inputs if transform_all: transform_inpts = flat_inputs else: transform_inpts = [] # store the input part to be augmented # Store the complete outputs before augmentation. # Part of the output that is augmented will be updated with the augmented part. flat_pre_outputs = [] for inpt, needs_transform in zip(flat_inputs, needs_transform_list): if not needs_transform: transform_inpts.append(None) flat_pre_outputs.append(inpt) continue is_ta_inpt = isinstance(inpt, ta_tensors.TATensor) is_batch_bboxes_or_masks_inpt = isinstance( inpt, (ta_tensors.BatchBoundingBoxes, ta_tensors.BatchMasks) ) pre_output = ( inpt if self.batch_inplace or (self._reshape_transform and not is_batch_bboxes_or_masks_inpt) else inpt.clone() ) flat_pre_outputs.append(pre_output) if is_batch_bboxes_or_masks_inpt: transform_inpt = pre_output.get_chunk(chunk_indices=indices_transform) else: with set_return_type("TATensor" if is_ta_inpt else "Tensor"): transform_inpt = pre_output[indices_transform] transform_inpts.append(transform_inpt) transform_batch_size = indices_transform.shape[0] if not transform_all else batch_size if self._num_chunks == -1: num_chunks = transform_batch_size else: num_chunks = min(transform_batch_size, self._num_chunks) chunks_indices = self._get_chunks_indices(transform_batch_size, num_chunks, flat_inputs[0].device) params = self._get_params( [ transform_inpt for (transform_inpt, needs_transform) in zip(transform_inpts, needs_transform_list) if needs_transform ], len(chunks_indices), chunks_indices, ) transform_outputs = [] for transform_inpt, needs_transform in zip(transform_inpts, needs_transform_list): if not needs_transform: transform_outputs.append(transform_inpt) continue is_ta_inpt = isinstance(transform_inpt, ta_tensors.TATensor) is_batch_bboxes_or_masks_inpt = isinstance( transform_inpt, (ta_tensors.BatchBoundingBoxes, ta_tensors.BatchMasks) ) if num_chunks == 1: output = self._transform(transform_inpt, params[0]) else: if self._reshape_transform: output = [] for i, chunk_indices in enumerate(chunks_indices): if is_batch_bboxes_or_masks_inpt: chunk_inpt = transform_inpt.get_chunk(chunk_indices=chunk_indices) chunk_output = self._transform(chunk_inpt, params[i]) if self._reshape_transform: output.append(chunk_output) else: transform_inpt.update_chunk_(chunk_output, chunk_indices=chunk_indices) output = transform_inpt else: with set_return_type("TATensor" if is_ta_inpt else "Tensor"): chunk_inpt = transform_inpt[chunk_indices] chunk_output = self._transform(chunk_inpt, params[i]) if self._reshape_transform: output.append(chunk_output) else: with set_return_type("TATensor" if is_ta_inpt else "Tensor"): transform_inpt[chunk_indices] = chunk_output output = transform_inpt if self._reshape_transform: with set_return_type("TATensor" if is_ta_inpt else "Tensor"): output = torch.cat(output, dim=0) transform_outputs.append(output) if not transform_all: flat_outputs = [] for flat_pre_output, transform_output, needs_transform in zip( flat_pre_outputs, transform_outputs, needs_transform_list ): if not needs_transform: flat_outputs.append(flat_pre_output) continue is_ta_output = isinstance(flat_pre_output, ta_tensors.TATensor) is_batch_bboxes_or_masks_inpt = isinstance( flat_pre_output, (ta_tensors.BatchBoundingBoxes, ta_tensors.BatchMasks), ) if is_batch_bboxes_or_masks_inpt: flat_pre_output.update_chunk_(transform_output, chunk_indices=indices_transform) else: with set_return_type("TATensor" if is_ta_output else "Tensor"): flat_pre_output[indices_transform] = transform_output with set_return_type("TATensor" if is_ta_output else "Tensor"): flat_pre_output = flat_pre_output.contiguous() flat_outputs.append(flat_pre_output) else: flat_outputs = transform_outputs return flat_outputs
[docs] def forward(self, *inputs: Any) -> Any: """Performs forward pass of the transform. Args: inputs: Inputs to the transform. Returns: Transformed inputs. """ if not self._receive_flatten_inputs: inputs = inputs if len(inputs) > 1 else inputs[0] flat_inputs, spec = tree_flatten(inputs) else: flat_inputs = list(inputs) self._check_inputs(flat_inputs) if not self.batch_transform: flat_outputs = self.forward_single(flat_inputs) else: flat_outputs = self.forward_batch(flat_inputs) if not self._receive_flatten_inputs: return tree_unflatten(flat_outputs, spec) return flat_outputs
[docs] def extra_repr(self, exclude_names: List[str] = []) -> str: """Set the extra representation of the transform.""" if not self.batch_transform: exclude_names.extend(["batch_inplace", "num_chunks", "permute_chunks", "batch_transform"]) last_extra: Dict[str, Any] = { "p": None, "batch_inplace": None, "num_chunks": None, "permute_chunks": None, "batch_transform": None, } transform_extra = [] parameters_dict = dict(self.__dict__, num_chunks=self.num_chunks) for name, value in parameters_dict.items(): if name.startswith("_") or name == "training" or name in exclude_names: continue if not isinstance(value, (bool, int, float, str, tuple, list, enum.Enum)) and value is not None: continue if name in last_extra: last_extra[name] = value else: transform_extra.append(f"{name}={value}") extra = transform_extra + [f"{name}={value}" for name, value in last_extra.items() if value is not None] return ", ".join(extra)
[docs] class Transform(RandomApplyTransform): """Base class for all transforms. For more details, please see :ref:`tutorial/transforms:Transforms Tutorial`. Args: batch_inplace: whether to apply the batch transform in-place. Does not prevent functionals to make copy but can reduce time and memory consumption. num_chunks: number of chunks to split the batched input into. permute_chunks: whether to permute the chunks. batch_transform: whether to apply the transform in batch mode. """ def __init__( self, batch_inplace: bool = False, num_chunks: int = 1, permute_chunks: bool = False, batch_transform: bool = False, ) -> None: super().__init__( p=1.0, batch_inplace=batch_inplace, num_chunks=num_chunks, permute_chunks=permute_chunks, batch_transform=batch_transform, )
[docs] def extra_repr(self, exclude_names: List[str] = []) -> str: exclude_names.append("p") return super().extra_repr(exclude_names)