Source code for torchaug.transforms.transforms

from __future__ import annotations

import numbers
from abc import ABC, abstractmethod
from typing import Sequence

import torch
import torchvision.transforms.functional as F_tv
from torch import Tensor, nn
from torchvision.transforms._functional_tensor import _assert_image_tensor
from torchvision.transforms.transforms import _setup_size

import torchaug.transforms.functional as F
from torchaug.batch_transforms._utils import \
    _assert_video_or_batch_videos_tensor
from torchaug.transforms._utils import (_assert_module_or_list_of_modules,
                                        _assert_tensor, _assert_video_tensor,
                                        _check_input)
from torchaug.utils import _log_api_usage_once


[docs] class RandomTransform(nn.Module, ABC): """Abstract class to make a base class for all random transforms. args: p: Probability to apply the transform. """ def __init__(self, p: float) -> None: super().__init__() if not 0 <= p <= 1: raise ValueError( f"p should be superior to 0 (included) and inferior to 1 (included). Got {p}." ) self.p = p
[docs] @abstractmethod def apply_transform(self, img: Tensor): """Function to perform transformation on the image. .. note:: Should be overridden by subclasses. Args: img: Image to transform. Return: The transformed image. """ ...
[docs] def forward(self, img: Tensor): """Call :meth:`apply_transform` based on random sample. Args: img: Image to randomly transform based on :attr:`~p`. Return: The randomly transformed image. """ if self.p == 0.0: return img elif self.p == 1.0 or torch.rand(1).item() < self.p: return self.apply_transform(img) return img
[docs] class VideoBase(ABC): """Abstract class to make a base class for all video transforms. Args: video_format: Dimension order of the video. Can be ``TCHW`` or ``CTHW``. """ def __init__(self, video_format: str) -> None: super().__init__() self.check_format(video_format) self._video_format = video_format
[docs] @staticmethod def check_format(format: str) -> None: """Check if the format is either ``TCHW`` or ``CTHW``. Raises an error if not. Args: format: Format to check. """ if format not in ["CTHW", "TCHW"]: raise ValueError( f"video_format should be either 'CTHW' or 'TCHW'. Got {format}." )
@property def video_format(self): """Dimension order of the video. Can be ``TCHW`` or ``CTHW``. """ return self._video_format @video_format.setter def video_format(self, format: str) -> None: self.check_format(format) self._video_format = format @property def time_before_channel(self) -> bool: """Boolean that checks if the :attr:`~video_format` has time dimension before channel.""" if self.video_format == "CTHW": return False elif self.video_format == "TCHW": return True else: raise ValueError("Attribute _video_format was wrongly changed by user.")
[docs] class Wrapper(nn.Module): """Wrap transforms to handle tensor data. .. note:: Iterates through transforms and their submodules. - If ``inplace`` attribute is found, it is set to ``True``, ``inplace`` is handled at the wrapper level. .. note:: If a transform makes a copy, the resulting tensor will not share the same underlying storage even if ``inplace`` is set to ``True``. Args: transforms: A list of transform modules. inplace: Whether to perform the transforms inplace. If a transform makes a copy, the resulting tensor will not share the same underlying storage. """ def __init__( self, transforms: list[nn.Module] | nn.Module, inplace: bool = False ) -> None: super().__init__() _log_api_usage_once(self) _assert_module_or_list_of_modules(transforms) if isinstance(transforms, nn.Module): transforms = [transforms] self._prepare_transforms(transforms) self.transforms = nn.ModuleList(transforms) self.inplace = inplace @staticmethod def _prepare_transform(transform: nn.Module): if hasattr(transform, "inplace"): transform.inplace = True @staticmethod def _prepare_transforms(transforms: list[nn.Module]): for transform in transforms: Wrapper._prepare_transform(transform) Wrapper._prepare_transforms(list(transform.modules())[1:])
[docs] def forward(self, tensor: torch.Tensor) -> Tensor: """Apply :attr:`~transforms` on the tensor. If :attr:`~inplace` is ``True``, clone the tensor. Args: tensor: The tensor to transform. Returns: The transformed tensor. """ _assert_tensor(tensor) output = tensor if self.inplace else tensor.clone() for transform in self.transforms: output: Tensor = transform(output.contiguous()) return output.contiguous()
def __repr__(self): transforms_repr = str(self.transforms).replace("\n", "\n ") return ( f"{self.__class__.__name__}(\n" f" inplace={self.inplace},\n" f" transforms={transforms_repr}\n)" )
[docs] class Div255(nn.Module): """Divide a tensor by 255. Args: inplace: Bool to make this operation in-place. """ def __init__(self, inplace: bool = False) -> None: super().__init__() _log_api_usage_once(self) self.inplace = inplace
[docs] def forward(self, tensor: Tensor) -> Tensor: """Divide tensor by 255. Args: tensor: The tensor to divide. Returns: Divided tensor. """ return F.div_255(tensor, inplace=self.inplace)
def __repr__(self): return f"{__class__.__name__}(inplace={self.inplace})"
[docs] class ImageWrapper(Wrapper): """Wrap transforms to handle image data. .. note:: Iterates through transforms and their submodules. - If ``inplace`` attribute is found, it is set to ``True``, ``inplace`` is handled at the wrapper level. .. note:: If a transform makes a copy, the resulting tensor will not share the same underlying storage even if ``inplace`` is set to ``True``. Args: transforms: A list of transform modules. inplace: Whether to perform the transforms inplace. """ def __init__( self, transforms: Sequence[nn.Module] | nn.Module, inplace: bool = False ) -> None: super().__init__(transforms=transforms, inplace=inplace) _log_api_usage_once(self)
[docs] def forward(self, img: torch.Tensor) -> Tensor: """Apply :attr:`~transforms` on the image. Call :meth:`Wrapper.forward`. Args: image: The image to transform. Returns: The transformed image. """ _assert_image_tensor(img) output = super().forward(img) return output
[docs] class MixUp(nn.Module): """Mix input tensor with linear interpolation drawn according a Beta law. The shape of the tensors is expected to be [B, ...] with ... any number of dimensions. The tensor should be float. .. note:: The tensor is rolled according its first dimension and mixed with one drawn interpolation parameter for the whole tensor. Args: alpha: Parameter for the Beta law. inplace: Whether to perform the operation inplace. """ def __init__(self, alpha: float, inplace: bool = False) -> None: super().__init__() _log_api_usage_once(self) self.alpha = alpha self.inplace = inplace self.mix_sampler = torch.distributions.Beta( torch.tensor([alpha]), torch.tensor([alpha]) ) def _get_params(self) -> float: """Draw the mixing coefficient. Returns: The mixing coefficient. """ return float(self.mix_sampler.sample(()))
[docs] def forward( self, tensor: Tensor, labels: Tensor | None = None ) -> tuple[Tensor, Tensor | None, float]: """Mix the input tensor and labels. Args: tensor: The tensor to mix. labels: If not None, the labels to mix. Returns: Tuple: - mixed tensor. - mixed labels or None. - mixing coefficient. """ lam = self._get_params() tensor = tensor if self.inplace else tensor.clone() if labels is None: return F.mixup(tensor, tensor.roll(1, 0), lam, True), None, lam labels = labels if self.inplace else labels.clone() return ( F.mixup(tensor, tensor.roll(1, 0), lam, True), F.mixup(labels, labels.roll(1, 0), lam, True), lam, )
def __repr__(self): return f"{__class__.__name__}(alpha={self.alpha}, inplace={self.inplace})"
[docs] class Mul255(nn.Module): """Multiply a tensor by 255. Args: inplace: Bool to make this operation in-place. """ def __init__(self, inplace: bool = False) -> None: super().__init__() _log_api_usage_once(self) self.inplace = inplace
[docs] def forward(self, tensor: Tensor) -> Tensor: """Multiply tensor by 255. Args: tensor: The tensor to multiply. Returns: Multiplied tensor. """ return F.mul_255(tensor, inplace=self.inplace)
def __repr__(self): return f"{__class__.__name__}(inplace={self.inplace})"
[docs] class Normalize(nn.Module): """Normalize a tensor image with mean and standard deviation. Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` channels, this transform will normalize each channel of the input ``torch.Tensor`` i.e., ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` Args: mean: Sequence of means for each channel. std: Sequence of standard deviations for each channel. cast_dtype: If not None, scale and cast input to dtype. Expected to be a float dtype. inplace: Bool to make this operation in-place. value_check: Bool to perform tensor value check. Might cause slow down on some devices because of synchronization. """ def __init__( self, mean: Sequence[float] | float, std: Sequence[float] | float, cast_dtype: torch.dtype | None = None, inplace: bool = False, value_check: bool = False, ) -> None: super().__init__() _log_api_usage_once(self) mean = torch.as_tensor(mean) std = torch.as_tensor(std) if mean.ndim in [0, 1]: mean = mean.view(-1, 1, 1) if std.ndim in [0, 1]: std = std.view(-1, 1, 1) self.register_buffer("mean", mean) self.register_buffer("std", std) self.inplace = inplace self.value_check = value_check self.cast_dtype = cast_dtype
[docs] def forward(self, tensor: Tensor) -> Tensor: """Normalize tensor. Args: tensor: The tensor to normalize. Returns: Normalized tensor. """ return F.normalize( tensor, mean=self.mean, std=self.std, cast_dtype=self.cast_dtype, inplace=self.inplace, value_check=self.value_check, )
def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"mean={self.mean.tolist()}," f" std={self.std.tolist()}," f" cast_dtype={self.cast_dtype}," f" inplace={self.inplace}," f" value_check={self.value_check})" )
[docs] class RandomApply(RandomTransform): """Apply randomly a list of transformations with a given probability. Args: transforms: list of transformations p: probability """ def __init__( self, transforms: Sequence[nn.Module] | nn.Module, p: float = 0.5 ) -> None: super().__init__(p=p) _log_api_usage_once(self) _assert_module_or_list_of_modules(transforms) if isinstance(transforms, nn.Module): transforms = [transforms] self.transforms = nn.ModuleList(transforms)
[docs] def apply_transform(self, img: Tensor) -> Tensor: """ Args: img: Image to transform. Returns: Transformed image. """ for t in self.transforms: img = t(img) return img
def __repr__(self) -> str: transforms_repr = str(self.transforms).replace("\n", "\n ") return ( f"{self.__class__.__name__}(" f"\n p={self.p}," f"\n transforms={transforms_repr}" f"\n)" )
[docs] class RandomColorJitter(RandomTransform): """Randomly change the brightness, contrast, saturation and hue to images. The images is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. Args: brightness: How much to jitter brightness. brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] or the given [min, max]. Should be non negative numbers. contrast: How much to jitter contrast. contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] or the given [min, max]. Should be non-negative numbers. saturation: How much to jitter saturation. saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] or the given [min, max]. Should be non negative numbers. hue: How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space; thus it does not work if you normalize your image to an interval with negative values, or use an interpolation that generates negative values before using this function. p: Probability to apply color jitter. """ def __init__( self, brightness: float | tuple[float, float] = 0, contrast: float | tuple[float, float] = 0, saturation: float | tuple[float, float] = 0, hue: float | tuple[float, float] = 0, p: float = 0.0, ): super().__init__(p=p) _log_api_usage_once(self) self.p = p self.brightness = _check_input(brightness, "brightness") self.contrast = _check_input(contrast, "contrast") self.saturation = _check_input(saturation, "saturation") self.hue = _check_input( hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False )
[docs] @staticmethod def get_params( brightness: list[float] | None, contrast: list[float] | None, saturation: list[float] | None, hue: list[float] | None, ) -> tuple[Tensor, float | None, float | None, float | None, float | None]: """Get the parameters for the randomized transform to be applied on image. Args: brightness: The range from which the brightness_factor is chosen uniformly. Pass None to turn off the transformation. contrast: The range from which the contrast_factor is chosen uniformly. Pass None to turn off the transformation. saturation: The range from which the saturation_factor is chosen uniformly. Pass None to turn off the transformation. hue: The range from which the hue_factor is chosen uniformly. Pass None to turn off the transformation. Returns: The parameters used to apply the randomized transform along with their random order. """ fn_idx = torch.randperm(4) b = ( None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1])) ) c = ( None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1])) ) s = ( None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1])) ) h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) return fn_idx, b, c, s, h
[docs] def apply_transform(self, img: Tensor) -> Tensor: """Color jitter the image. Args: img: Input image. Returns: Color jittered image. """ ( fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor, ) = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) for fn_id in fn_idx: if fn_id == 0 and brightness_factor is not None: img = F_tv.adjust_brightness(img, brightness_factor) elif fn_id == 1 and contrast_factor is not None: img = F_tv.adjust_contrast(img, contrast_factor) elif fn_id == 2 and saturation_factor is not None: img = F_tv.adjust_saturation(img, saturation_factor) elif fn_id == 3 and hue_factor is not None: img = F.adjust_hue(img, hue_factor) return img
def __repr__(self) -> str: s = ( f"{self.__class__.__name__}(" f"brightness={self.brightness}" f", contrast={self.contrast}" f", saturation={self.saturation}" f", hue={self.hue}" f", p={self.p})" ) return s
[docs] class RandomGaussianBlur(RandomTransform): """Blurs image with randomly chosen Gaussian blur. The image is expected to have the shape [..., C, H, W], where ... means an arbitrary number of leading dimensions. Args: kernel_size: Size of the Gaussian kernel. sigma: Standard deviation to be used for creating kernel to perform blurring. If float, sigma is fixed. If it is tuple of float (min, max), sigma is chosen uniformly at random to lie in the given range. value_check: Bool to perform tensor value check. Might cause slow down on some devices because of synchronization. """ def __init__( self, kernel_size: int | tuple[int, int], sigma: float | tuple[float, float] = (0.1, 2.0), p: float = 0.5, value_check: bool = False, ): super().__init__(p=p) _log_api_usage_once(self) self.kernel_size = _setup_size( kernel_size, "Kernel size should be a tuple/list of two integers." ) for ks in self.kernel_size: if ks <= 0 or ks % 2 == 0: raise ValueError( "Kernel size value should be an odd and positive number." ) if isinstance(sigma, numbers.Number): if sigma <= 0: raise ValueError("If sigma is a single number, it must be positive.") sigma = (sigma, sigma) elif isinstance(sigma, Sequence) and len(sigma) == 2: if not 0.0 < sigma[0] <= sigma[1]: raise ValueError( "sigma values should be positive and of the form (min, max)." ) else: raise ValueError( "sigma should be a single number or a list/tuple with length 2." ) self.register_buffer("sigma", torch.as_tensor(sigma)) self.value_check = value_check
[docs] @staticmethod def get_params(sigma_min: Tensor, sigma_max: Tensor) -> Tensor: """Choose sigma for random gaussian blurring. Args: sigma_min: Minimum standard deviation that can be chosen for blurring kernel. sigma_max: Maximum standard deviation that can be chosen for blurring kernel. Returns: Standard deviation to be passed to calculate kernel for gaussian blurring. """ dtype = sigma_min.dtype device = sigma_min.device return ( torch.rand([], dtype=dtype, device=device) * (sigma_max - sigma_min) + sigma_min )
[docs] def apply_transform(self, img: Tensor) -> Tensor: """Blur the image. Args: img: Image to be blurred. Returns: Gaussian blurred image. """ sigma: Tensor = self.get_params(self.sigma[0], self.sigma[1]) return F.gaussian_blur(img, self.kernel_size, [sigma, sigma], self.value_check)
def __repr__(self) -> str: s = f"{self.__class__.__name__}(kernel_size={self.kernel_size}, sigma={self.sigma.tolist()}, p={self.p}, value_check={self.value_check})" return s
[docs] class RandomSolarize(RandomTransform): """Solarize the image randomly with a given probability by inverting all pixel values above a threshold. The image is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. Args: threshold: all pixels equal or above this value are inverted. p: probability of the image being solarized. value_check: Bool to perform tensor value check. Might cause slow down on some devices because of synchronization. """ def __init__( self, threshold: float, p: float = 0.5, value_check: bool = False, ): super().__init__(p=p) _log_api_usage_once(self) self.register_buffer("threshold", torch.as_tensor(threshold)) self.value_check = value_check
[docs] def apply_transform(self, img: Tensor): """Solarize the image. Args: img: Image to be solarized. Returns: Solarized image. """ return F.solarize(img, self.threshold, self.value_check)
def __repr__(self) -> str: return ( f"{self.__class__.__name__}(threshold={self.threshold.item()}" f", p={self.p}" f", value_check={self.value_check})" )
[docs] class VideoNormalize(Normalize, VideoBase): """Normalize a tensor video with mean and standard deviation. Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` channels, this transform will normalize each channel of the input ``torch.*Tensor`` i.e., ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` Videos should be in format [..., T, C, H, W] or [..., C, T, H, W] with ... 0 or 1 leading dimension. Args: mean: Sequence of means for each channel. std: Sequence of standard deviations for each channel. video_format: Dimension order of the video. Can be ``TCHW`` or ``CTHW``. cast_dtype: If not None, scale and cast input to the dtype. Expected to be a float dtype. inplace: Bool to make this operation in-place. value_check: Bool to perform tensor value check. Might cause slow down on some devices because of synchronization. """ def __init__( self, mean: Sequence[float] | None = None, std: Sequence[float] | None = None, cast_dtype: torch.dtype | None = None, inplace: bool = False, value_check: bool = False, video_format: str = "CTHW", ) -> None: Normalize.__init__( self, mean=mean, std=std, cast_dtype=cast_dtype, inplace=inplace, value_check=value_check, ) VideoBase.__init__(self, video_format=video_format) _log_api_usage_once(self)
[docs] def forward(self, video: Tensor) -> Tensor: """Normalize a video. Args: video: The video to normalize. Returns: Normalized video. """ _assert_video_or_batch_videos_tensor(video) if not self.time_before_channel: dims = [0, 2, 1, 3, 4] if video.ndim == 5 else [1, 0, 2, 3] video = video.permute(dims) video = Normalize.forward(self, video) if not self.time_before_channel: video = video.permute(dims) return video
def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"mean={self.mean.tolist()}," f" std={self.std.tolist()}," f" cast_dtype={self.cast_dtype}," f" inplace={self.inplace}," f" value_check={self.value_check}," f" video_format={self.video_format})" )
[docs] class VideoWrapper(Wrapper, VideoBase): """Wrap transforms to handle video data. If the frames should be augmented differently, the transform must handle the leading dimension differently. The video is expected to be in format [C, T, H, W] or [T, C, H, W]. .. note:: Iterates through transforms and their submodules: - If ``inplace`` attribute is found, it is set to ``True``, ``inplace`` is handled at the wrapper level. - If ``video_format`` attribute is found, it is set to ``TCHW``, ``video_format`` is handled at the wrapper level. .. note:: If ``video_format`` is ``CTHW``, a copy might occur even if ``inplace`` is set to ``True``. .. note:: If a transform makes a copy, the resulting tensor will not share the same underlying storage even if ``inplace`` is set to ``True``. Args: transforms: A list of transform modules. inplace: Whether to perform the transforms inplace. video_format: Format of the video. Either ``CTHW`` or ``TCHW``. """ def __init__( self, transforms: Sequence[nn.Module] | nn.Module, inplace: bool = False, video_format: str = "CTHW", ) -> None: Wrapper.__init__(self, transforms=transforms, inplace=inplace) VideoBase.__init__(self, video_format=video_format) _log_api_usage_once(self) @staticmethod def _prepare_transform(transform: nn.Module): Wrapper._prepare_transform(transform) if hasattr(transform, "video_format"): transform.video_format = "TCHW" @staticmethod def _prepare_transforms(transforms: list[nn.Module]): for transform in transforms: VideoWrapper._prepare_transform(transform) VideoWrapper._prepare_transforms(list(transform.modules())[1:])
[docs] def forward(self, video: Tensor) -> Tensor: """Apply :attr:`~transforms` on the video. Call :meth:`Wrapper.forward`. Args: video: The video to transform. Returns: The transformed video. """ _assert_video_tensor(video) if not self.time_before_channel: video = video.permute(1, 0, 2, 3) output = Wrapper.forward(self, video) if not self.time_before_channel: output = output.permute(1, 0, 2, 3) return output
def __repr__(self): transforms_repr = str(self.transforms).replace("\n", "\n ") return ( f"{self.__class__.__name__}(\n" f" inplace={self.inplace},\n" f" video_format={self.video_format},\n" f" transforms={transforms_repr}\n)" )