Source code for kornia.geometry.transform.elastic_transform
from typing import Tuple, Union
import torch.nn.functional as F
from kornia.core import Tensor, concatenate, tensor
from kornia.core.check import KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
from kornia.filters import filter2d
from kornia.filters.kernels import get_gaussian_kernel2d
from kornia.utils import create_meshgrid
__all__ = ["elastic_transform2d"]
[docs]def elastic_transform2d(
image: Tensor,
noise: Tensor,
kernel_size: Tuple[int, int] = (63, 63),
sigma: Union[Tuple[float, float], Tensor] = (32.0, 32.0),
alpha: Union[Tuple[float, float], Tensor] = (1.0, 1.0),
align_corners: bool = False,
mode: str = "bilinear",
padding_mode: str = "zeros",
) -> Tensor:
r"""Apply elastic transform of images as described in :cite:`Simard2003BestPF`.
.. image:: _static/img/elastic_transform2d.png
Args:
image: Input image to be transformed with shape :math:`(B, C, H, W)`.
noise: Noise image used to spatially transform the input image. Same
resolution as the input image with shape :math:`(B, 2, H, W)`. The coordinates order
it is expected to be in x-y.
kernel_size: the size of the Gaussian kernel.
sigma: The standard deviation of the Gaussian in the y and x directions,
respectively. Larger sigma results in smaller pixel displacements.
alpha : The scaling factor that controls the intensity of the deformation
in the y and x directions, respectively.
align_corners: Interpolation flag used by ```grid_sample```.
mode: Interpolation mode used by ```grid_sample```. Either ``'bilinear'`` or ``'nearest'``.
padding_mode: The padding used by ```grid_sample```. Either ``'zeros'``, ``'border'`` or ``'refection'``.
Returns:
the elastically transformed input image with shape :math:`(B,C,H,W)`.
Example:
>>> image = torch.rand(1, 3, 5, 5)
>>> noise = torch.rand(1, 2, 5, 5, requires_grad=True)
>>> image_hat = elastic_transform2d(image, noise, (3, 3))
>>> image_hat.mean().backward()
>>> image = torch.rand(1, 3, 5, 5)
>>> noise = torch.rand(1, 2, 5, 5)
>>> sigma = torch.tensor([4., 4.], requires_grad=True)
>>> image_hat = elastic_transform2d(image, noise, (3, 3), sigma)
>>> image_hat.mean().backward()
>>> image = torch.rand(1, 3, 5, 5)
>>> noise = torch.rand(1, 2, 5, 5)
>>> alpha = torch.tensor([16., 32.], requires_grad=True)
>>> image_hat = elastic_transform2d(image, noise, (3, 3), alpha=alpha)
>>> image_hat.mean().backward()
"""
KORNIA_CHECK_IS_TENSOR(image)
KORNIA_CHECK_IS_TENSOR(noise)
KORNIA_CHECK_SHAPE(image, ["B", "C", "H", "W"])
KORNIA_CHECK_SHAPE(noise, ["B", "C", "H", "W"])
device, dtype = image.device, image.dtype
# if isinstance(sigma, tuple):
# sigma_t = tensor(sigma, device=device, dtype=dtype)
if isinstance(sigma, Tensor):
sigma = sigma.expand(2)[None, ...]
# sigma = sigma.to(device=device, dtype=dtype)
# Get Gaussian kernel for 'y' and 'x' displacement
kernel_x = get_gaussian_kernel2d(kernel_size, sigma) # _t[0].expand(2).unsqueeze(0))
kernel_y = get_gaussian_kernel2d(kernel_size, sigma) # _t[1].expand(2).unsqueeze(0))
if isinstance(alpha, Tensor):
alpha_x = alpha[0]
alpha_y = alpha[1]
else:
alpha_x = tensor(alpha[0], device=device, dtype=dtype)
alpha_y = tensor(alpha[1], device=device, dtype=dtype)
# Convolve over a random displacement matrix and scale them with 'alpha'
disp_x = noise[:, :1]
disp_y = noise[:, 1:]
disp_x = filter2d(disp_x, kernel=kernel_y, border_type="constant") * alpha_x
disp_y = filter2d(disp_y, kernel=kernel_x, border_type="constant") * alpha_y
# stack and normalize displacement
disp = concatenate([disp_x, disp_y], 1).permute(0, 2, 3, 1)
# Warp image based on displacement matrix
_, _, h, w = image.shape
grid = create_meshgrid(h, w, device=image.device).to(image.dtype)
warped = F.grid_sample(
image, (grid + disp).clamp(-1, 1), align_corners=align_corners, mode=mode, padding_mode=padding_mode
)
return warped