Source code for torchaug.transforms.functional._misc

# @Copyright: CEA-LIST/DIASI/SIALV/ (2023-    )
# @Author: CEA-LIST/DIASI/SIALV/ <julien.denize@cea.fr>
# @License: CECILL-C
#
# Code partially based on Torchvision (BSD 3-Clause License), available at:
#   https://github.com/pytorch/vision

from __future__ import annotations

import math
from typing import List, Optional

import torch
import torchvision.transforms.v2.functional as TVF
from torch.nn.functional import conv2d
from torch.nn.functional import pad as torch_pad

from torchaug import ta_tensors
from torchaug._utils import _log_api_usage_once

from ._utils._kernel import _get_kernel, _register_kernel_internal
from ._utils._tensor import _transfer_tensor_on_device


[docs] def normalize( inpt: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False, ) -> torch.Tensor: """See :class:`~torchaug.transforms.Normalize` for details.""" if torch.jit.is_scripting(): return normalize_image(inpt, mean=mean, std=std, inplace=inplace) _log_api_usage_once(normalize) kernel = _get_kernel(normalize, type(inpt)) return kernel(inpt, mean=mean, std=std, inplace=inplace)
@_register_kernel_internal(normalize, torch.Tensor) @_register_kernel_internal(normalize, ta_tensors.Image) @_register_kernel_internal(normalize, ta_tensors.BatchImages) def normalize_image( image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False, ) -> torch.Tensor: return TVF.normalize_image(image=image, mean=mean, std=std, inplace=inplace) @_register_kernel_internal(normalize, ta_tensors.Video) @_register_kernel_internal(normalize, ta_tensors.BatchVideos) def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: return normalize_image(image=video, mean=mean, std=std, inplace=inplace)
[docs] def gaussian_blur( inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None, ) -> torch.Tensor: """See :class:`~torchaug.transforms.RandomGaussianBlur` for details.""" if torch.jit.is_scripting(): return gaussian_blur_image(inpt, kernel_size=kernel_size, sigma=sigma) _log_api_usage_once(gaussian_blur) kernel = _get_kernel(gaussian_blur, type(inpt)) return kernel(inpt, kernel_size=kernel_size, sigma=sigma)
[docs] def gaussian_blur_batch( inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[torch.Tensor] = None, value_check: bool = False, ) -> torch.Tensor: """See :class:`~torchaug.transforms.RandomGaussianBlur` for details.""" if torch.jit.is_scripting(): return gaussian_blur_batch_images(inpt, kernel_size=kernel_size, sigma=sigma, value_check=value_check) _log_api_usage_once(gaussian_blur_batch) kernel = _get_kernel(gaussian_blur_batch, type(inpt)) return kernel(inpt, kernel_size=kernel_size, sigma=sigma, value_check=value_check)
def _get_gaussian_kernel1d( kernel_size: int, sigma: torch.Tensor, dtype: torch.dtype, device: torch.device ) -> torch.Tensor: lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0)) x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device) sigma = sigma.view(-1, 1) x = x.view(1, -1).expand(sigma.shape[0], -1) kernel1d = torch.softmax((x.div(sigma)).pow_(2).neg_(), dim=1, dtype=dtype) return kernel1d def _get_gaussian_kernel2d( kernel_size: List[int], sigma: torch.Tensor, dtype: torch.dtype, device: torch.device, ) -> torch.Tensor: kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[..., 0], dtype, device) kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[..., 1], dtype, device) kernel2d = (kernel1d_y.view(-1, kernel_size[1], 1) * kernel1d_x.view(-1, 1, kernel_size[0])).view( -1, kernel_size[1], kernel_size[0] ) return kernel2d @_register_kernel_internal(gaussian_blur, torch.Tensor) @_register_kernel_internal(gaussian_blur, ta_tensors.Image) @_register_kernel_internal(gaussian_blur, ta_tensors.BatchImages) def gaussian_blur_image( image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> torch.Tensor: # TODO: consider deprecating integers from sigma on the future if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size] elif len(kernel_size) != 2: raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}") for ksize in kernel_size: if ksize % 2 == 0 or ksize < 0: raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}") if sigma is None: sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] else: if isinstance(sigma, (list, tuple)): length = len(sigma) if length == 1: s = float(sigma[0]) sigma = [s, s] elif length != 2: raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}") elif isinstance(sigma, (int, float)): s = float(sigma) sigma = [s, s] else: raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}") for s in sigma: if s <= 0.0: raise ValueError(f"sigma should have positive values. Got {sigma}") sigma = torch.as_tensor(sigma, device=image.device) if image.numel() == 0: return image dtype = image.dtype shape = image.shape ndim = image.ndim if ndim == 3: image = image.unsqueeze(dim=0) elif ndim > 4: image = image.reshape((-1,) + shape[-3:]) fp = torch.is_floating_point(image) kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device) kernel = kernel.expand(shape[-3], 1, kernel.shape[1], kernel.shape[2]) output = image if fp else image.to(dtype=torch.float32) # padding = (left, right, top, bottom) padding = [ kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2, ] output = torch_pad(output, padding, mode="reflect") output = conv2d(output, kernel, groups=shape[-3]) if ndim == 3: output = output.squeeze(dim=0) elif ndim > 4: output = output.reshape(shape) if not fp: output = output.round_().to(dtype=dtype) return output @_register_kernel_internal(gaussian_blur, ta_tensors.Video) @_register_kernel_internal(gaussian_blur, ta_tensors.BatchVideos) def gaussian_blur_video( video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> torch.Tensor: return gaussian_blur_image(image=video, kernel_size=kernel_size, sigma=sigma) @_register_kernel_internal(gaussian_blur_batch, torch.Tensor) @_register_kernel_internal(gaussian_blur_batch, ta_tensors.BatchImages) def gaussian_blur_batch_images( images: torch.Tensor, kernel_size: List[int], sigma: Optional[torch.Tensor] = None, value_check: bool = False, ) -> torch.Tensor: if not isinstance(sigma, torch.Tensor): return gaussian_blur_image(image=images, kernel_size=kernel_size, sigma=sigma) if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size] elif len(kernel_size) != 2: raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}") for ksize in kernel_size: if ksize % 2 == 0 or ksize < 0: raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}") if images.numel() == 0: return images sigma = _transfer_tensor_on_device(sigma, images.device, True) if sigma.device.type == "cpu" or value_check: if (sigma < 0).any(): raise ValueError(f"sigma should have positive values. Got {sigma}") dtype = images.dtype shape = images.shape b, h, w = shape[0], shape[-2], shape[-1] images = images.reshape(b, -1, h, w) if sigma.ndim == 0 or sigma.ndim == 1 and (sigma.numel() == 1 or sigma.numel() > 2): sigma = sigma.view(-1, 1) sigma = sigma.expand(b, 2) elif sigma.ndim in [1, 2] and sigma.numel() in [1, 2]: sigma = sigma.view(1, -1) sigma = sigma.expand(b, 2) elif sigma.ndim == 2 and sigma.shape[0] != b: raise ValueError( f"sigma should have one element or the same length as the batch size. Got {sigma.shape[0]} and {b}" ) elif sigma.ndim > 2: raise ValueError(f"sigma should have 1 or 2 dimensions. Got {sigma.ndim}") fp = torch.is_floating_point(images) kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=images.device) kernel = kernel[:, None, ...] kernel = kernel.expand(-1, images.shape[-3], kernel_size[1], kernel_size[0]) kernel = kernel.reshape(-1, 1, kernel_size[1], kernel_size[0]) images = images if fp else images.to(dtype=torch.float32) # padding = (left, right, top, bottom) padding = [ kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2, ] output = torch_pad(images, padding, mode="reflect") output = output.view(-1, kernel.shape[0], output.shape[-2], output.shape[-1]) output = conv2d(output, kernel, groups=output.shape[-3]) output = output.reshape(shape) if not fp: output = output.round_().to(dtype=dtype) return output @_register_kernel_internal(gaussian_blur_batch, ta_tensors.BatchVideos) def gaussian_blur_batch_videos( videos: torch.Tensor, kernel_size: List[int], sigma: Optional[torch.Tensor] = None, value_check: bool = False, ) -> torch.Tensor: return gaussian_blur_batch_images(images=videos, kernel_size=kernel_size, sigma=sigma, value_check=value_check)
[docs] def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: """See :func:`~torchaug.transforms.ToDtype` for details.""" if torch.jit.is_scripting(): return to_dtype_image(inpt, dtype=dtype, scale=scale) _log_api_usage_once(to_dtype) kernel = _get_kernel(to_dtype, type(inpt)) return kernel(inpt, dtype=dtype, scale=scale)
@_register_kernel_internal(to_dtype, torch.Tensor) @_register_kernel_internal(to_dtype, ta_tensors.Image) @_register_kernel_internal(to_dtype, ta_tensors.BatchImages) def to_dtype_image(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: return TVF.to_dtype_image(image=image, dtype=dtype, scale=scale) @_register_kernel_internal(to_dtype, ta_tensors.Video) @_register_kernel_internal(to_dtype, ta_tensors.BatchVideos) def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: return to_dtype_image(image=video, dtype=dtype, scale=scale) @_register_kernel_internal(to_dtype, ta_tensors.BoundingBoxes, ta_tensor_wrapper=False) @_register_kernel_internal(to_dtype, ta_tensors.Mask, ta_tensor_wrapper=False) @_register_kernel_internal(to_dtype, ta_tensors.BatchBoundingBoxes, ta_tensor_wrapper=False) @_register_kernel_internal(to_dtype, ta_tensors.BatchMasks, ta_tensor_wrapper=False) def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor: # We don't need to unwrap and rewrap here, since TVTensor.to() preserves the type return inpt.to(dtype)