from __future__ import annotations
import numbers
from abc import ABC, abstractmethod
from itertools import permutations
from math import ceil, floor
from typing import Sequence
import torch
import torchvision.transforms as tv_transforms
import torchvision.transforms.functional as F_tv
from torch import Tensor, nn
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms.transforms import (_interpolation_modes_from_int,
_setup_size)
import torchaug.batch_transforms.functional as F_b
import torchaug.transforms.functional as F
from torchaug.batch_transforms._utils import (_assert_batch_images_tensor,
_assert_batch_videos_tensor)
from torchaug.transforms._utils import (_assert_module_or_list_of_modules,
_check_input,
transfer_tensor_on_device)
from torchaug.transforms.transforms import Wrapper
from torchaug.utils import _log_api_usage_once
[docs]
class BatchImageWrapper(Wrapper):
"""Wrap transforms to handle image data.
.. note::
Transforms and their submodules are iterated over:
- If ``inplace`` attribute is found, it is set to ``True``,
``inplace`` is handled at the wrapper level.
.. note::
If a transform makes a copy, the resulting tensor will not share the same
underlying storage even if ``inplace`` is set to ``True``.
Args:
transforms: A list of transform modules.
inplace: Whether to perform the transforms inplace.
"""
def __init__(
self, transforms: Sequence[nn.Module] | nn.Module, inplace: bool = False
) -> None:
super().__init__(transforms, inplace=inplace)
[docs]
def forward(self, imgs: torch.Tensor) -> Tensor:
"""Apply :attr:`~transforms` on the batch of images.
Call :meth:`torchaug.transforms.Wrapper.forward`.
.. note:: If :attr:`~same_on_frames` is ``False``, the batch and frames
dimensions are merged.
Args:
imgs: The batch of images to transform.
Returns:
The transformed images.
"""
_assert_batch_images_tensor(imgs)
output = super().forward(imgs)
return output
[docs]
class BatchRandomApply(BatchRandomTransform):
"""Apply randomly a list of transformations to a batch of images with a given probability.
Args:
transforms: List of transformations.
p: Probability to apply the transform.
inplace: If ``True``, perform inplace operation to save memory.
"""
def __init__(
self,
transforms: list[nn.Module] | nn.Module,
p: float = 0.5,
inplace: bool = False,
) -> None:
super().__init__(p=p, inplace=inplace)
_log_api_usage_once(self)
_assert_module_or_list_of_modules(transforms)
if isinstance(transforms, nn.Module):
transforms = [transforms]
self.transforms = nn.ModuleList(transforms)
self.inplace = inplace
def __repr__(self) -> str:
transforms_repr = str(self.transforms).replace("\n", "\n ")
return (
f"{self.__class__.__name__}("
f"\n p={self.p},"
f"\n inplace={self.inplace},"
f"\n transforms={transforms_repr}"
f"\n)"
)
[docs]
class BatchRandomColorJitter(BatchRandomTransform):
"""Randomly change the brightness, contrast, saturation and hue to a batch of images.
The batch is expected to have [B, ..., 1 or 3, H, W] shape,
where ... means an arbitrary number of dimensions.
Args:
brightness: How much to jitter brightness.
brightness factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
or the given [min, max]. Should be non negative numbers.
contrast: How much to jitter contrast.
contrast factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
or the given [min, max]. Should be non-negative numbers.
saturation: How much to jitter saturation.
saturation factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
or the given [min, max]. Should be non negative numbers.
hue: How much to jitter hue.
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0 <= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
To jitter hue, the pixel values of the input image has to be non-negative for conversion to HSV space;
thus it does not work if you normalize your image to an interval with negative values,
or use an interpolation that generates negative values before using this function.
p: Probability to apply color jitter.
num_rand_calls: Number of random calls performed to apply augmentations at different orders on sub-batches.
If -1, B calls are performed. The maximum is 24 = 4!, adjusted automatically if num_rand_calls > 24.
inplace: If True, perform inplace operation to save memory and time.
value_check: Bool to perform tensor value check.
Might cause slow down on some devices because of synchronization or large batch size.
"""
def __init__(
self,
brightness: float | tuple[float, float] = 0,
contrast: float | tuple[float, float] = 0,
saturation: float | tuple[float, float] = 0,
hue: float | tuple[float, float] = 0,
p: float = 0.5,
num_rand_calls: int = -1,
inplace: bool = False,
value_check: bool = False,
) -> None:
super().__init__(p=p, inplace=inplace)
_log_api_usage_once(self)
if not isinstance(num_rand_calls, int) or not num_rand_calls >= -1:
raise ValueError(
f"num_rand_calls attribute should be an int superior to -1, {num_rand_calls} given."
)
self.p = p
self.num_rand_calls = num_rand_calls
self.inplace = inplace
brightness = _check_input(brightness, "brightness")
if brightness is not None:
self.register_buffer("brightness", torch.as_tensor(brightness))
else:
self.brightness = None
contrast = _check_input(contrast, "contrast")
if contrast is not None:
self.register_buffer("contrast", torch.as_tensor(contrast))
else:
self.contrast = None
saturation = _check_input(saturation, "saturation")
if saturation is not None:
self.register_buffer("saturation", torch.as_tensor(saturation))
else:
self.saturation = None
hue = _check_input(
hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False
)
if hue is not None:
self.register_buffer("hue", torch.as_tensor(hue))
else:
self.hue = None
self.value_check = value_check
[docs]
@staticmethod
def get_params(
brightness: Tensor | None,
contrast: Tensor | None,
saturation: Tensor | None,
hue: Tensor | None,
batch_size: int,
) -> tuple[Tensor, Tensor | None, Tensor | None, Tensor | None, Tensor | None]:
"""Get the parameters for the randomized transform to be applied on image.
Args:
brightness (tensor (min, max), optional): The range from which the brightness factor is chosen
uniformly. Pass None to turn off the transformation.
contrast (tensor (min, max), optional): The range from which the contrast factor is chosen
uniformly. Pass None to turn off the transformation.
saturation (tensor (min, max), optional): The range from which the saturation factor is chosen
uniformly. Pass None to turn off the transformation.
hue: The range from which the hue_factor is chosen uniformly.
Pass None to turn off the transformation.
batch_size: The number of samples to draw.
"""
def _get_uniform_values(batch_size: int, bounds: Tensor):
return (bounds[0] - bounds[1]) * torch.rand(
(batch_size,), device=bounds.device
) + bounds[1]
b = None if brightness is None else _get_uniform_values(batch_size, brightness)
c = None if contrast is None else _get_uniform_values(batch_size, contrast)
s = None if saturation is None else _get_uniform_values(batch_size, saturation)
h = None if hue is None else _get_uniform_values(batch_size, hue)
return b, c, s, h
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"brightness={self.brightness.tolist() if self.brightness is not None else None}"
f", contrast={self.contrast.tolist() if self.contrast is not None else None}"
f", saturation={self.saturation.tolist() if self.saturation is not None else None}"
f", hue={self.hue.tolist() if self.hue is not None else None}"
f", p={self.p}"
f", num_rand_calls={self.num_rand_calls}"
f", inplace={self.inplace}"
f", value_check={self.value_check})"
)
return s
[docs]
class BatchRandomGaussianBlur(BatchRandomTransform):
"""Blurs batch of images with randomly chosen Gaussian blur.
The batch of images is expected to be of shape [B, ..., C, H, W]
where ... means an arbitrary number of dimensions.
Args:
kernel_size: Size of the Gaussian kernel.
sigma: Standard deviation to be used for creating kernel to perform blurring.
If float, sigma is fixed. If it is tuple of float (min, max), sigma
is chosen uniformly at random to lie in the given range.
p: Probability to apply gaussian blur.
inplace: If True, perform inplace operation to save memory and time.
value_check: Bool to perform tensor value check.
Might cause slow down on some devices because of synchronization or large batch size.
"""
def __init__(
self,
kernel_size: int | tuple[int, int],
sigma: float | tuple[float, float] = (0.1, 2.0),
p: float = 0.5,
inplace: bool = False,
value_check: bool = False,
):
super().__init__(
p=p,
inplace=inplace,
)
_log_api_usage_once(self)
self.kernel_size = _setup_size(
kernel_size, "Kernel size should be a tuple/list of two integers."
)
for ks in self.kernel_size:
if ks <= 0 or ks % 2 == 0:
raise ValueError(
"Kernel size value should be an odd and positive number."
)
if isinstance(sigma, numbers.Number):
if sigma <= 0:
raise ValueError("If sigma is a single number, it must be positive.")
sigma = (sigma, sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0.0 < sigma[0] <= sigma[1]:
raise ValueError(
"sigma values should be positive and of the form (min, max)."
)
else:
raise ValueError(
"sigma should be a single number or a list/tuple with length 2."
)
self.register_buffer("sigma", torch.as_tensor(sigma))
self.value_check = value_check
[docs]
@staticmethod
def get_params(sigma_min: Tensor, sigma_max: Tensor, batch_size: int) -> Tensor:
"""Choose sigma for random gaussian blurring.
Args:
sigma_min: Minimum standard deviation that can be chosen for blurring kernel.
sigma_max: Maximum standard deviation that can be chosen for blurring kernel.
batch_size: The number of samples to draw.
Returns:
Standard deviation to calculate kernel for gaussian blurring.
"""
dtype = sigma_min.dtype
device = sigma_min.device
return (
torch.rand(batch_size, 1, dtype=dtype, device=device).expand(batch_size, 2)
* (sigma_max - sigma_min)
+ sigma_min
)
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"kernel_size={self.kernel_size}, "
f"sigma={self.sigma.tolist()}, "
f"p={self.p}, "
f"inplace={self.inplace}, "
f"value_check={self.value_check})"
)
return s
[docs]
class BatchRandomGrayScale(BatchRandomTransform):
"""Convert batch of images to grayscale.
The batch of images is expected to be of shape [B, ..., C, H, W]
where ... means an arbitrary number of dimensions.
Args:
p: Probability of the images to be grayscaled.
inplace: If True, perform inplace operation to save memory and time.
"""
def __init__(
self,
p: float = 0.5,
inplace: bool = False,
):
super().__init__(p=p, inplace=inplace)
_log_api_usage_once(self)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(" f"p={self.p}" f", inplace={self.inplace})"
[docs]
class BatchMixUp(nn.Module):
"""Mix input tensor with linear interpolation drawn according a Beta law.
The shape of the tensors is expected to be [B, ...] with ... any number of dimensions.
The tensor should be float.
.. note::
The tensor is rolled according its first dimension and mixed with one
drawn interpolation parameter per element in first dimension.
Args:
alpha: Parameter for the Beta law.
inplace: Whether to perform the operation inplace.
"""
def __init__(self, alpha: float, inplace: bool = False) -> None:
super().__init__()
_log_api_usage_once(self)
self.alpha = alpha
self.inplace = inplace
self.mix_sampler = torch.distributions.Beta(
torch.tensor([alpha]), torch.tensor([alpha])
)
def _get_params(self, batch_size: int, device: torch.device) -> Tensor:
"""Draw the mixing coefficients.
Returns:
The mixing coefficients.
"""
return transfer_tensor_on_device(self.mix_sampler.sample((batch_size,)), device)
[docs]
def forward(
self, tensor: Tensor, labels: Tensor | None = None
) -> tuple[Tensor, Tensor | None, Tensor]:
"""Mix the input tensor and labels.
Args:
tensor: The tensor to mix.
labels: If not None, the labels to mix.
Returns:
Tuple:
- mixed tensor.
- mixed labels or None.
- mixing coefficients.
"""
lam = self._get_params(tensor.shape[0], tensor.device)
tensor = tensor if self.inplace else tensor.clone()
if labels is None:
return F_b.batch_mixup(tensor, tensor.roll(1, 0), lam, True), None, lam
labels = labels if self.inplace else labels.clone()
return (
F_b.batch_mixup(tensor, tensor.roll(1, 0), lam, True),
F_b.batch_mixup(labels, labels.roll(1, 0), lam, True),
lam,
)
def __repr__(self):
return f"{__class__.__name__}(alpha={self.alpha}, inplace={self.inplace})"
[docs]
class BatchRandomHorizontalFlip(BatchRandomTransform):
"""Horizontally flip the given batch of images randomly with a given probability.
The batch of images is expected to be of shape [B, ..., C, H, W]
where ... means an arbitrary number of dimensions.
Args:
p: probability of the images being flipped.
inplace: If True, perform inplace operation to save memory and time.
"""
def __init__(
self,
p: float = 0.5,
inplace: bool = False,
):
super().__init__(p=p, inplace=inplace)
_log_api_usage_once(self)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(" f"p={self.p}" f", inplace={self.inplace})"
[docs]
class BatchRandomResizedCrop(tv_transforms.RandomResizedCrop):
"""Crop a random portion of a batch of images and resize it to a given size.
The batch shape is expected to be [B, ..., H, W], where ... means an arbitrary number of dimensions
A crop of the original image is made: the crop has a random area (H * W)
and a random aspect ratio. This crop is finally resized to the given
size. This is popularly used to train the Inception networks.
Args:
size: Expected output size of the crop, for each edge. If size is an
int instead of sequence like (h, w), a square output size ``(size, size)`` is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
.. note::
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
scale: Specifies the lower and upper bounds for the random area of the crop,
before resizing. The scale is defined with respect to the area of the original image.
ratio: lower and upper bounds for the random aspect ratio of the crop, before
resizing.
interpolation: Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`.
Only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
antialias: Whether to apply antialiasing.
It only affects bilinear or bicubic modes and it is
ignored otherwise.
num_rand_calls: Number of random calls performed to apply augmentations at different orders on sub-batches.
If -1, B calls are performed. The maximum is 24 = 4!.
"""
def __init__(
self,
size: int | Sequence[int],
scale: Sequence[float] = (0.08, 1.0),
ratio: Sequence[float] = (3.0 / 4.0, 4.0 / 3.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = True,
num_rand_calls: int = -1,
) -> None:
if isinstance(size, int):
self.size = [size, size]
elif isinstance(size, Sequence) and isinstance(size[0], int) and len(size) == 1:
self.size = [size[0], size[0]]
elif not isinstance(size, Sequence) or not all(
[isinstance(s, int) for s in size]
):
raise TypeError(f"size should be a int or a sequence of int. Got {size}.")
super().__init__(size, scale, ratio, interpolation, antialias)
_log_api_usage_once(self)
if not isinstance(num_rand_calls, int) or not num_rand_calls >= -1:
raise ValueError(
f"num_rand_calls attribute should be an int superior to -1, {num_rand_calls} given."
)
self.num_rand_calls = num_rand_calls
[docs]
def single_forward(self, imgs: Tensor) -> Tensor:
"""Perform a random resized crop of same scale and shape for the batch of images.
Args:
img: Batch of images to be cropped and resized.
Returns:
Cropped and resized batch of images.
"""
i, j, h, w = self.get_params(imgs, self.scale, self.ratio)
return F_tv.resized_crop(
imgs, i, j, h, w, self.size, self.interpolation, antialias=self.antialias
)
[docs]
def forward(self, imgs: Tensor) -> Tensor:
"""Resize and crop the batch of images.
.. note:: Apply different scaling and cropping based on :attr:`~num_rand_calls`.
Args:
imgs: Batch of images to be cropped and resized.
Returns:
Randomly cropped and resized batch of images.
"""
if self.num_rand_calls == -1:
return torch.stack([self.single_forward(img) for img in imgs])
elif self.num_rand_calls == 0:
return imgs
else:
# Avoid inplace operation
output = torch.empty(
size=(*imgs.shape[:-2], *self.size),
dtype=imgs.dtype,
device=imgs.device,
)
batch_size = imgs.shape[0]
num_combination = min(batch_size, self.num_rand_calls)
num_apply_per_combination = ceil(batch_size / num_combination)
indices_do_apply = torch.randperm(batch_size, device=imgs.device)
for i in range(min(batch_size, self.num_rand_calls)):
indices_combination = indices_do_apply[
i * num_apply_per_combination : (i + 1) * num_apply_per_combination
]
imgs_combination = imgs[indices_combination]
imgs_combination = self.single_forward(imgs_combination)
output[indices_combination] = imgs_combination
return output
def __repr__(self) -> str:
interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + f"(size={self.size}"
format_string += f", scale={tuple(round(s, 4) for s in self.scale)}"
format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}"
format_string += f", interpolation={interpolate_str}"
format_string += f", antialias={self.antialias}"
format_string += f", num_rand_calls={self.num_rand_calls})"
return format_string
[docs]
class BatchRandomSolarize(BatchRandomTransform):
"""Solarize the batch of images randomly with a given probability by inverting all pixel values above a
threshold.
The batch if images it is expected to be in [B, ..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of dimensions.
Args:
threshold: all pixels equal or above this value are inverted.
p: probability of the image being solarized.
inplace: If True, perform inplace operation to save memory and time.
value_check: Bool to perform tensor value check.
Might cause slow down on some devices because of synchronization or large batch size.
"""
def __init__(
self,
threshold: float,
p: float = 0.5,
inplace: bool = False,
value_check: bool = False,
) -> None:
super().__init__(p=p, inplace=inplace)
_log_api_usage_once(self)
self.register_buffer("threshold", torch.as_tensor(threshold))
self.value_check = value_check
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"threshold={self.threshold.item()}"
f", p={self.p}"
f", inplace={self.inplace}"
f", value_check={self.value_check})"
)
[docs]
class BatchVideoWrapper(Wrapper):
"""Wrap transforms to handle batched video data.
The transforms are expected to handle the leading dimension differently.
The batch of videos is expected to be in format [B, C, T, H, W] or [B, T, C, H, W].
.. note::
Iterates through transforms and their submodules:
- If ``inplace`` attribute is found, it is set to ``True``,
``inplace`` is handled at the wrapper level.
- If ``video_format`` attribute is found, it is set to ``TCHW``,
``video_format`` is handled at the wrapper level.
.. note::
If ``video_format`` is ``CTHW``, a copy might occur even if ``inplace`` is set to ``True``.
.. note::
If a transform makes a copy, the resulting tensor will not share the same
underlying storage even if ``inplace`` is set to ``True``.
Args:
transforms: A list of transform modules.
inplace: Whether to perform the transforms inplace.
same_on_frames: If True, apply the same transform on all the frames, else it
flattens the batch and temporal dimensions to apply different transformations to each frame.
video_format: Format of the video. Either ``CTHW`` or ``TCHW``.
"""
def __init__(
self,
transforms: Sequence[nn.Module] | nn.Module,
inplace: bool = False,
same_on_frames: bool = True,
video_format: str = "CTHW",
) -> None:
super().__init__(transforms=transforms, inplace=inplace)
_log_api_usage_once(self)
self.video_format = video_format
self.same_on_frames = same_on_frames
if self.video_format == "CTHW":
self.time_before_channel = False
elif self.video_format == "TCHW":
self.time_before_channel = True
else:
raise ValueError(
f"video_format should be either 'CTHW' or 'TCHW'. Got {self.video_format}."
)
@staticmethod
def _prepare_transform(transform: nn.Module):
Wrapper._prepare_transform(transform)
if hasattr(transform, "video_format"):
transform.video_format = "TCHW"
@staticmethod
def _prepare_transforms(transforms: list[nn.Module]):
for transform in transforms:
BatchVideoWrapper._prepare_transform(transform)
BatchVideoWrapper._prepare_transforms(list(transform.modules())[1:])
[docs]
def forward(self, videos: Tensor):
"""Apply :attr:`~transforms` on the batch of videos.
Call :meth:`torchaug.transforms.Wrapper.forward`.
.. note:: If :attr:`~same_on_frames` is ``False``, the batch and frames
dimensions are merged.
Args:
videos: The batch of videos to transform.
Returns:
The transformed videos.
"""
_assert_batch_videos_tensor(videos)
if self.time_before_channel:
b, t, c, h, w = videos.shape
else:
b, c, t, h, w = videos.shape
videos = videos.permute(0, 2, 1, 3, 4)
if not self.same_on_frames:
videos = videos.reshape(b * t, c, h, w)
output = super().forward(videos)
if not self.same_on_frames:
output = output.reshape(b, t, *videos.shape[-3:])
if not self.time_before_channel:
output = output.permute(0, 2, 1, 3, 4)
return output
def __repr__(self):
transforms_repr = str(self.transforms).replace("\n", "\n ")
return (
f"{self.__class__.__name__}(\n"
f" inplace={self.inplace},\n"
f" same_on_frames={self.same_on_frames},\n"
f" video_format={self.video_format},\n"
f" transforms={transforms_repr}\n)"
)
[docs]
class BatchVideoResize(nn.Module):
"""Resize the input video to the given size. The video is expected to have [B, ..., H, W] shape, where ...
means an arbitrary number of dimensions.
Args:
size: Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size).
.. note::
In torchscript mode size as single int is not supported,
use a sequence of length 1: ``[size, ]``.
interpolation: Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`.
``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
supported.
max_size : The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
the image is resized again so that the longer edge is equal to
``max_size``. As a result, ``size`` might be overruled, i.e. the
smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript
mode).
antialias: Whether to apply antialiasing. If ``True``, will apply antialiasing for bilinear or bicubic modes.
Other mode aren't affected.
video_format: Format of the video. Either ``CTHW`` or ``TCHW``.
"""
def __init__(
self,
size: int | list[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: int | None = None,
antialias: bool = True,
video_format: str = "CTHW",
):
super().__init__()
_log_api_usage_once(self)
if not isinstance(size, (int, Sequence)):
raise TypeError(f"Size should be int or sequence. Got {type(size)}.")
if isinstance(size, Sequence) and len(size) not in (1, 2):
raise ValueError("If size is a sequence, it should have 1 or 2 values.")
self.size = size
self.max_size = max_size
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation
self.antialias = antialias
self.video_format = video_format
if self.video_format == "CTHW":
self.time_before_channel = False
elif self.video_format == "TCHW":
self.time_before_channel = True
else:
raise ValueError(
f"video_format should be either 'CTHW' or 'TCHW'. Got {self.video_format}."
)
[docs]
def forward(self, videos: Tensor):
"""Resize the batch of videos.
Args:
videos: The batch of videos to resize.
Returns:
The resized videos.
"""
_assert_batch_videos_tensor(videos)
if self.time_before_channel:
b, t, c, h, w = videos.shape
else:
b, c, t, h, w = videos.shape
videos = videos.permute(0, 2, 1, 3, 4)
videos = videos.reshape(b * t, c, h, w)
videos = F_tv.resize(
videos, self.size, self.interpolation, self.max_size, self.antialias
)
videos = videos.reshape(b, t, *videos.shape[-3:])
if not self.time_before_channel:
videos = videos.permute(0, 2, 1, 3, 4)
return videos
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"size={self.size}, "
f"interpolation={self.interpolation.value}, "
f"max_size={self.max_size}, "
f"antialias={self.antialias}, "
f"video_format={self.video_format})"
)