# @Copyright: CEA-LIST/DIASI/SIALV/ (2023- )
# @Author: CEA-LIST/DIASI/SIALV/ <julien.denize@cea.fr>
# @License: CECILL-C
#
# Code partially based on Torchvision (BSD 3-Clause License), available at:
# https://github.com/pytorch/vision
from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchaug._utils import _log_api_usage_once
from torchaug.transforms._utils import _assert_list_of_modules
from ._transform import RandomApplyTransform, Transform
[docs]
class Compose(Transform):
"""Composes several transforms together.
This transform does not support torchscript.
Please, see the note below.
Args:
transforms: list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> ])
.. note::
In order to script the transformations, please use ``torch.nn.SequentialTransform`` as below.
>>> transforms = torch.nn.SequentialTransform(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``.
"""
def __init__(self, transforms: Sequence[Callable]) -> None:
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
elif not transforms:
raise ValueError("Pass at least one transform")
self.transforms = transforms
[docs]
def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1
for transform in self.transforms:
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs
[docs]
class RandomApply(Transform):
"""Apply randomly a list of transformations with a given probability.
.. note::
In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
transforms as shown below:
>>> transforms = transforms.RandomApply(torch.nn.ModuleList([
>>> transforms.ColorJitter(),
>>> ]), p=0.3)
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
Args:
transforms: list of transformations
p: probability of applying the list of transforms.
"""
def __init__(self, transforms: Sequence[Union[Callable, nn.ModuleList]], p: float = 0.5) -> None:
super().__init__()
if not isinstance(transforms, (Sequence, nn.ModuleList)):
raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
self.transforms = transforms
if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
self.p = p
[docs]
def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1
if torch.rand(1) >= self.p:
return inputs if needs_unpacking else inputs[0]
for transform in self.transforms:
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs
[docs]
def extra_repr(self) -> str: # type: ignore[override]
format_string = []
for t in self.transforms:
format_string.append(f" {t}")
return "\n".join(format_string)
[docs]
class RandomChoice(Transform):
"""Apply single transformation randomly picked from a list.
This transform does not support torchscript.
Args:
transforms: list of transformations
p: probability of each transform being picked.
If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
(default), all transforms have the same probability.
"""
def __init__(
self,
transforms: Sequence[Callable],
p: Optional[List[float]] = None,
) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
if p is None:
p = [1] * len(transforms)
elif len(p) != len(transforms):
raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}")
super().__init__()
self.transforms = transforms
total = sum(p)
self.p_choices = [prob / total for prob in p]
[docs]
def forward(self, *inputs: Any) -> Any:
idx = int(torch.multinomial(torch.tensor(self.p_choices), 1))
transform = self.transforms[idx]
return transform(*inputs)
[docs]
class RandomOrder(Transform):
"""Apply a list of transformations in a random order.
This transform does not support torchscript.
Args:
transforms: list of transformations
"""
def __init__(self, transforms: Sequence[Callable]) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
super().__init__()
self.transforms = transforms
[docs]
def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1
for idx in torch.randperm(len(self.transforms)):
transform = self.transforms[idx]
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs