# @Copyright: CEA-LIST/DIASI/SIALV/ (2023- )
# @Author: CEA-LIST/DIASI/SIALV/ <julien.denize@cea.fr>
# @License: CECILL-C
#
# Code partially based on Torchvision (BSD 3-Clause License), available at:
# https://github.com/pytorch/vision
import math
import numbers
import warnings
from typing import Any, Callable, Dict, List, Tuple
import torch
from torch.nn.functional import one_hot
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.transforms.v2._utils import _parse_labels_getter, has_any, query_chw
from torchaug import ta_tensors
from . import functional as F
from ._transform import RandomApplyTransform, Transform
from ._utils import is_pure_tensor, query_size
[docs]
class RandomErasing(RandomApplyTransform):
"""Randomly select a rectangle region in the input image or video and erase its pixels.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
Args:
p: probability that the random erasing operation will be performed.
scale: range of proportion of erased area against input image.
ratio: range of aspect ratio of erased area.
value: erasing value. If a single int, it is used to
erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively.
If a str of 'random', erasing each pixel with random values.
batch_inplace: whether to apply the batch transform in-place.
Does not prevent functionals to make copy but can reduce time and memory consumption.
num_chunks: number of chunks to split the batched input into.
permute_chunks: whether to permute the chunks.
batch_transform: whether to apply the transform in batch mode.
"""
def __init__(
self,
p: float = 0.5,
scale: Tuple[float, float] = (0.02, 0.33),
ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0.0,
inplace: bool = False,
batch_inplace: bool = False,
num_chunks: int = 1,
permute_chunks: bool = False,
batch_transform: bool = False,
):
(
super().__init__(
p=p,
batch_inplace=batch_inplace,
num_chunks=num_chunks,
permute_chunks=permute_chunks,
batch_transform=batch_transform,
),
)
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
raise ValueError("If value is str, it should be 'random'")
if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("Scale should be between 0 and 1")
self.scale = scale
self.ratio = ratio
if isinstance(value, (int, float)):
self.value = [float(value)]
elif isinstance(value, str):
self.value = None
elif isinstance(value, (list, tuple)):
self.value = [float(v) for v in value]
else:
self.value = value
self.inplace = inplace
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(
inpt,
(
ta_tensors.BoundingBoxes,
ta_tensors.BatchBoundingBoxes,
ta_tensors.Mask,
ta_tensors.BatchMasks,
),
):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"ta_tensors.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(functional, inpt, *args, **kwargs)
def _get_params(
self,
flat_inputs: List[Any],
num_chunks: int,
chunks_indices: Tuple[torch.Tensor],
) -> List[Dict[str, Any]]:
img_c, img_h, img_w = query_chw(flat_inputs)
if self.value is not None and len(self.value) not in (1, img_c):
raise ValueError(
f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
)
area = img_h * img_w
log_ratio = self._log_ratio
params = []
for i in range(num_chunks):
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if not (h < img_h and w < img_w):
continue
if self.value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(self.value)[:, None, None]
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
break
else:
i, j, h, w, v = 0, 0, img_h, img_w, None
params.append({"i": i, "j": j, "h": h, "w": w, "v": v})
return params
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["v"] is not None:
inpt = self._call_kernel(
F.erase,
inpt,
**params,
inplace=self.inplace,
)
return inpt
class _BaseMixUpCutMix(Transform):
def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None:
super().__init__(batch_transform=True)
self.alpha = float(alpha)
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
self.num_classes = num_classes
self._labels_getter = _parse_labels_getter(labels_getter)
def forward(self, *inputs):
inputs = inputs if len(inputs) > 1 else inputs[0]
flat_inputs, spec = tree_flatten(inputs)
needs_transform_list = self._needs_transform_list(flat_inputs)
if has_any(
flat_inputs,
ta_tensors.Image,
ta_tensors.Video,
ta_tensors.BoundingBoxes,
ta_tensors.BatchBoundingBoxes,
ta_tensors.Mask,
ta_tensors.BatchMasks,
):
raise ValueError(f"{type(self).__name__}() supports only batch of images or videos.")
labels = self._labels_getter(inputs)
if not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.")
elif labels.ndim != 1:
raise ValueError(
f"labels tensor should be of shape (batch_size,) " f"but got shape {labels.shape} instead."
)
params = {
"labels": labels,
"batch_size": labels.shape[0],
**self._get_params(
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform],
1,
[torch.arange(labels.shape[0], device=labels.device)],
)[0],
}
# By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor coming
# after an image or video. However, we need to handle them in _transform, so we make sure to set them to True
needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True
flat_outputs = [
self._transform(inpt, params) if needs_transform else inpt
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
]
return tree_unflatten(flat_outputs, spec)
def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
expected_num_dims = 5 if isinstance(inpt, ta_tensors.BatchVideos) else 4
if inpt.ndim != expected_num_dims:
raise ValueError(
f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead."
)
if inpt.shape[0] != batch_size:
raise ValueError(
f"The batch size of the image or video does not match the batch size of the labels: "
f"{inpt.shape[0]} != {batch_size}."
)
def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
label = one_hot(label, num_classes=self.num_classes)
if not label.dtype.is_floating_point:
label = label.float()
return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))
[docs]
class MixUp(_BaseMixUpCutMix):
"""Apply MixUp to the provided batch of images and labels.
Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
.. note::
This transform is meant to be used on **batches** of samples, not
individual images.
The sample pairing is deterministic and done by matching consecutive
samples in the batch, so the batch needs to be shuffled (this is an
implementation detail, not a guaranteed convention.)
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
into a tensor of shape ``(batch_size, num_classes)``.
Args:
alpha: hyperparameter of the Beta distribution used for mixup.
num_classes: number of classes in the batch. Used for one-hot-encoding.
labels_getter: indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
def _get_params(
self, flat_inputs: List[Any], num_chunks: int, chunks_indices: Tuple[torch.Tensor]
) -> List[Dict[str, Any]]:
return [{"lam": float(self._dist.sample(()))}]
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"]
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=lam)
elif isinstance(
inpt,
(
ta_tensors.BatchImages,
ta_tensors.BatchVideos,
),
) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
if isinstance(
inpt,
(
ta_tensors.BatchImages,
ta_tensors.BatchVideos,
),
):
output = ta_tensors.wrap(output, like=inpt)
return output
else:
return inpt
[docs]
class CutMix(_BaseMixUpCutMix):
"""Apply CutMix to the provided batch of images and labels.
Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
<https://arxiv.org/abs/1905.04899>`_.
.. note::
This transform is meant to be used on **batches** of samples, not
individual images.
The sample pairing is deterministic and done by matching consecutive
samples in the batch, so the batch needs to be shuffled (this is an
implementation detail, not a guaranteed convention.)
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
into a tensor of shape ``(batch_size, num_classes)``.
Args:
alpha: hyperparameter of the Beta distribution used for mixup.
num_classes: number of classes in the batch. Used for one-hot-encoding.
labels_getter: indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
def _get_params(
self, flat_inputs: List[Any], num_chunks: int, chunks_indices: Tuple[torch.Tensor]
) -> List[Dict[str, Any]]:
lam = float(self._dist.sample(())) # type: ignore[arg-type]
H, W = query_size(flat_inputs)
r_x = torch.randint(W, size=(1,))
r_y = torch.randint(H, size=(1,))
r = 0.5 * math.sqrt(1.0 - lam)
r_w_half = int(r * W)
r_h_half = int(r * H)
x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
box = (x1, y1, x2, y2)
lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
return [{"box": box, "lam_adjusted": lam_adjusted}]
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=params["lam_adjusted"])
elif isinstance(
inpt,
(
ta_tensors.BatchImages,
ta_tensors.BatchVideos,
),
) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])
x1, y1, x2, y2 = params["box"]
rolled = inpt.roll(1, 0)
output = inpt.clone()
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
if isinstance(
inpt,
(
ta_tensors.Image,
ta_tensors.Video,
ta_tensors.BatchImages,
ta_tensors.BatchVideos,
),
):
output = ta_tensors.wrap(output, like=inpt)
return output
else:
return inpt