# ==================================
# Copyright: CEA-LIST/DIASI/SIALV/
# Author : Torchaug Developers
# License: CECILL-C
# ==================================
# Code partially based on Torchvision (BSD 3-Clause License), available at:
# https://github.com/pytorch/vision
from __future__ import annotations
import torch
import torchvision.transforms.v2.functional as TVF
from torchvision.io import decode_jpeg, encode_jpeg
from torchaug import ta_tensors
from torchaug._utils import _log_api_usage_once
from ._utils._kernel import _get_kernel, _register_kernel_internal
[docs]
def erase(
inpt: torch.Tensor,
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> torch.Tensor:
"""See :class:`~torchaug.transforms.RandomErasing` for details."""
if torch.jit.is_scripting():
return erase_image(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
_log_api_usage_once(erase)
kernel = _get_kernel(erase, type(inpt))
return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
[docs]
@_register_kernel_internal(erase, torch.Tensor)
@_register_kernel_internal(erase, ta_tensors.Image)
@_register_kernel_internal(erase, ta_tensors.BatchImages)
def erase_image(
image: torch.Tensor,
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> torch.Tensor:
return TVF.erase_image(image=image, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
[docs]
@_register_kernel_internal(erase, ta_tensors.Video)
@_register_kernel_internal(erase, ta_tensors.BatchVideos)
def erase_video(
video: torch.Tensor,
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> torch.Tensor:
return erase_image(image=video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
[docs]
def jpeg(image: torch.Tensor, quality: int) -> torch.Tensor:
"""See :class:`~torchaug.transforms.JPEG` for details."""
if torch.jit.is_scripting():
return jpeg_image(image, quality=quality)
_log_api_usage_once(jpeg)
kernel = _get_kernel(jpeg, type(image))
return kernel(image, quality=quality)
[docs]
@_register_kernel_internal(jpeg, torch.Tensor)
@_register_kernel_internal(jpeg, ta_tensors.Image)
@_register_kernel_internal(jpeg, ta_tensors.BatchImages)
def jpeg_image(image: torch.Tensor, quality: int) -> torch.Tensor:
original_shape = image.shape
image = image.view((-1,) + image.shape[-3:])
if image.shape[0] == 0: # degenerate
return image.reshape(original_shape).clone()
images = []
for i in range(image.shape[0]):
encoded_image = encode_jpeg(image[i], quality=quality)
assert isinstance(encoded_image, torch.Tensor) # For torchscript
images.append(decode_jpeg(encoded_image))
images = torch.stack(images, dim=0).view(original_shape)
return images
[docs]
@_register_kernel_internal(jpeg, ta_tensors.Video)
@_register_kernel_internal(jpeg, ta_tensors.BatchVideos)
def jpeg_video(video: torch.Tensor, quality: int) -> torch.Tensor:
return jpeg_image(video, quality=quality)