Source code for torchaug.transforms._container

# ==================================
# Copyright: CEA-LIST/DIASI/SIALV/
# Author : Torchaug Developers
# License: CECILL-C
# ==================================

# Code partially based on Torchvision (BSD 3-Clause License), available at:
#   https://github.com/pytorch/vision

from __future__ import annotations

from typing import Any, Callable, Dict, List, Optional, Sequence, Union

import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten

from torchaug._utils import _log_api_usage_once
from torchaug.transforms._utils import _assert_list_of_modules

from ._transform import RandomApplyTransform, Transform


[docs] class Compose(Transform): """Composes several transforms together. This transform does not support torchscript. Please, see the note below. Args: transforms: list of transforms to compose. Example: >>> transforms.Compose([ >>> transforms.CenterCrop(10), >>> transforms.PILToTensor(), >>> transforms.ConvertImageDtype(torch.float), >>> ]) .. note:: In order to script the transformations, please use ``torch.nn.SequentialTransform`` as below. >>> transforms = torch.nn.SequentialTransform( >>> transforms.CenterCrop(10), >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), >>> ) >>> scripted_transforms = torch.jit.script(transforms) Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``. """ def __init__(self, transforms: Sequence[Callable]) -> None: super().__init__() if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence of callables") elif not transforms: raise ValueError("Pass at least one transform") self.transforms = transforms
[docs] def forward(self, *inputs: Any) -> Any: needs_unpacking = len(inputs) > 1 for transform in self.transforms: outputs = transform(*inputs) inputs = outputs if needs_unpacking else (outputs,) return outputs
[docs] def extra_repr(self) -> str: # type: ignore[override] format_string = [] for t in self.transforms: format_string.append(f" {t}") return "\n".join(format_string)
[docs] class RandomApply(Transform): """Apply randomly a list of transformations with a given probability. .. note:: In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of transforms as shown below: >>> transforms = transforms.RandomApply(torch.nn.ModuleList([ >>> transforms.ColorJitter(), >>> ]), p=0.3) >>> scripted_transforms = torch.jit.script(transforms) Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require `lambda` functions or ``PIL.Image``. Args: transforms: list of transformations p: probability of applying the list of transforms. """ def __init__(self, transforms: Sequence[Union[Callable, nn.ModuleList]], p: float = 0.5) -> None: super().__init__() if not isinstance(transforms, (Sequence, nn.ModuleList)): raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`") self.transforms = transforms if not (0.0 <= p <= 1.0): raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].") self.p = p
[docs] def forward(self, *inputs: Any) -> Any: needs_unpacking = len(inputs) > 1 if torch.rand(1) >= self.p: return inputs if needs_unpacking else inputs[0] for transform in self.transforms: outputs = transform(*inputs) inputs = outputs if needs_unpacking else (outputs,) return outputs
[docs] def extra_repr(self) -> str: # type: ignore[override] format_string = [] for t in self.transforms: format_string.append(f" {t}") return "\n".join(format_string)
[docs] class RandomChoice(Transform): """Apply single transformation randomly picked from a list. This transform does not support torchscript. Args: transforms: list of transformations p: probability of each transform being picked. If ``p`` doesn't sum to 1, it is automatically normalized. If ``None`` (default), all transforms have the same probability. """ def __init__( self, transforms: Sequence[Callable], p: Optional[List[float]] = None, ) -> None: if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence of callables") if p is None: p = [1] * len(transforms) elif len(p) != len(transforms): raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}") super().__init__() self.transforms = transforms total = sum(p) self.p_choices = [prob / total for prob in p]
[docs] def forward(self, *inputs: Any) -> Any: idx = int(torch.multinomial(torch.tensor(self.p_choices), 1)) transform = self.transforms[idx] return transform(*inputs)
[docs] class RandomOrder(Transform): """Apply a list of transformations in a random order. This transform does not support torchscript. Args: transforms: list of transformations """ def __init__(self, transforms: Sequence[Callable]) -> None: if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence of callables") super().__init__() self.transforms = transforms
[docs] def forward(self, *inputs: Any) -> Any: needs_unpacking = len(inputs) > 1 for idx in torch.randperm(len(self.transforms)): transform = self.transforms[idx] outputs = transform(*inputs) inputs = outputs if needs_unpacking else (outputs,) return outputs
[docs] class SequentialTransform(Transform): """Sequentially apply a list of transforms. .. note:: By default this SequentialTransform makes all its transforms to be inplace, batched inplace, and batched transform with unlimited chunks. If you want to change this behavior, you can override the default parameters by passing them as arguments. For example, if you want to make the transforms non-inplace, you can do the following: >>> transforms = SequentialTransform( >>> [ >>> Normalize([0.5], [0.5]), >>> RandomColorJitter(), >>> ], >>> inplace=False, >>> batch_inplace=False, >>> ) Passing `batch_transform=False` will make the transforms non-batched and not inplace. Args: transforms: A list of transforms. transforms_attributes_override: Additional parameters to override the default parameters of the transforms if they exist. Useful to make transforms for batches. The list of parameters that can be overridden are: - inplace: whether the transform is inplace or not. - batch_inplace: whether the transform is inplace for batches or not. - batch_transform: whether the transform is batched or not. - num_chunks: number of chunks to split the input tensor. - permute_chunks: whether to permute the chunks or not. """ _receive_flatten_inputs = False def __init__( self, transforms: List[RandomApplyTransform], **transforms_attributes_override: Dict[str, Any], ) -> None: super().__init__() _log_api_usage_once(self) _assert_list_of_modules(transforms) base_override: Dict[str, Any] = { "inplace": True, "batch_inplace": True, "batch_transform": True, "num_chunks": -1, "permute_chunks": False, } base_override.update(transforms_attributes_override) if not base_override["batch_transform"]: base_override["num_chunks"] = 1 base_override["permute_chunks"] = False base_override["batch_inplace"] = False self.transforms_attributes_override = base_override self._prepare_transforms(transforms) self.transforms = nn.ModuleList(transforms) def _prepare_transform(self, transform: nn.Module): import inspect if isinstance(transform, RandomApplyTransform): if self.transforms_attributes_override: init_signature = inspect.signature(transform.__init__) # type: ignore[misc] parameters = init_signature.parameters for key, value in self.transforms_attributes_override.items(): has_key = key in parameters if has_key: setattr(transform, key, value) transform._receive_flatten_inputs = True def _prepare_transforms(self, transforms: List[nn.Module]): for transform in transforms: self._prepare_transform(transform) self._prepare_transforms(list(transform.modules())[1:])
[docs] def forward(self, *inputs: Any) -> Any: if not self._receive_flatten_inputs: inputs = inputs if len(inputs) > 1 else inputs[0] flat_inputs, spec = tree_flatten(inputs) else: flat_inputs = list(inputs) for transform in self.transforms: flat_inputs = transform(*flat_inputs) if not self._receive_flatten_inputs: return tree_unflatten(flat_inputs, spec) return flat_inputs
[docs] def extra_repr(self) -> str: # type: ignore[override] format_string = [] for t in self.transforms: format_string.append(f" {t}") return f"transforms_attributes_override={self.transforms_attributes_override},\ntransforms=\n" + "\n".join( format_string )