# Source code for kornia.geometry.transform.thin_plate_spline

from typing import Tuple

import torch
import torch.nn as nn

from kornia.utils import create_meshgrid
from kornia.utils.helpers import _torch_solve_cast

__all__ = ["get_tps_transform", "warp_points_tps", "warp_image_tps"]

# utilities for computing thin plate spline transforms

def _pair_square_euclidean(tensor1: torch.Tensor, tensor2: torch.Tensor) -> torch.Tensor:
r"""Compute the pairwise squared euclidean distance matrices :math:(B, N, M) between two tensors
with shapes (B, N, C) and (B, M, C)."""
# ||t1-t2||^2 = (t1-t2)^T(t1-t2) = t1^T*t1 + t2^T*t2 - 2*t1^T*t2
t1_sq: torch.Tensor = tensor1.mul(tensor1).sum(dim=-1, keepdim=True)
t2_sq: torch.Tensor = tensor2.mul(tensor2).sum(dim=-1, keepdim=True).transpose(1, 2)
t1_t2: torch.Tensor = tensor1.matmul(tensor2.transpose(1, 2))
square_dist: torch.Tensor = -2 * t1_t2 + t1_sq + t2_sq
square_dist = square_dist.clamp(min=0)  # handle possible numerical errors
return square_dist

def _kernel_distance(squared_distances: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
r"""Compute the TPS kernel distance function: :math:r^2 log(r), where r is the euclidean distance.
Since :math:\log(r) = 1/2 \log(r^2), this function takes the squared distance matrix and calculates
:math:0.5 r^2 log(r^2)."""
# r^2 * log(r) = 1/2 * r^2 * log(r^2)
return 0.5 * squared_distances * squared_distances.add(eps).log()

[docs]def get_tps_transform(points_src: torch.Tensor, points_dst: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Compute the TPS transform parameters that warp source points to target points.

The input to this function is a tensor of :math:(x, y) source points :math:(B, N, 2) and a corresponding
tensor of target :math:(x, y) points :math:(B, N, 2).

Args:
points_src: batch of source points :math:(B, N, 2) as :math:(x, y) coordinate vectors.
points_dst: batch of target points :math:(B, N, 2) as :math:(x, y) coordinate vectors.

Returns:
:math:(B, N, 2) tensor of kernel weights and :math:(B, 3, 2)
tensor of affine weights. The last dimension contains the x-transform and y-transform weights
as separate columns.

Example:
>>> points_src = torch.rand(1, 5, 2)
>>> points_dst = torch.rand(1, 5, 2)
>>> kernel_weights, affine_weights = get_tps_transform(points_src, points_dst)

.. note::
This function is often used in conjunction with :func:warp_points_tps, :func:warp_image_tps.
"""
if not isinstance(points_src, torch.Tensor):
raise TypeError(f"Input points_src is not torch.Tensor. Got {type(points_src)}")

if not isinstance(points_dst, torch.Tensor):
raise TypeError(f"Input points_dst is not torch.Tensor. Got {type(points_dst)}")

if not len(points_src.shape) == 3:
raise ValueError(f"Invalid shape for points_src, expected BxNx2. Got {points_src.shape}")

if not len(points_dst.shape) == 3:
raise ValueError(f"Invalid shape for points_dst, expected BxNx2. Got {points_dst.shape}")

device, dtype = points_src.device, points_src.dtype
batch_size, num_points = points_src.shape[:2]

# set up and solve linear system
# [K   P] [w] = [dst]
# [P^T 0] [a]   [ 0 ]
pair_distance: torch.Tensor = _pair_square_euclidean(points_src, points_dst)
k_matrix: torch.Tensor = _kernel_distance(pair_distance)

zero_mat: torch.Tensor = torch.zeros(batch_size, 3, 3, device=device, dtype=dtype)
one_mat: torch.Tensor = torch.ones(batch_size, num_points, 1, device=device, dtype=dtype)
dest_with_zeros: torch.Tensor = torch.cat((points_dst, zero_mat[:, :, :2]), 1)
p_matrix: torch.Tensor = torch.cat((one_mat, points_src), -1)
p_matrix_t: torch.Tensor = torch.cat((p_matrix, zero_mat), 1).transpose(1, 2)
l_matrix: torch.Tensor = torch.cat((k_matrix, p_matrix), -1)
l_matrix = torch.cat((l_matrix, p_matrix_t), 1)

weights, _ = _torch_solve_cast(dest_with_zeros, l_matrix)
kernel_weights: torch.Tensor = weights[:, :-3]
affine_weights: torch.Tensor = weights[:, -3:]

return (kernel_weights, affine_weights)

[docs]def warp_points_tps(
points_src: torch.Tensor, kernel_centers: torch.Tensor, kernel_weights: torch.Tensor, affine_weights: torch.Tensor
) -> torch.Tensor:
r"""Warp a tensor of coordinate points using the thin plate spline defined by kernel points, kernel weights,
and affine weights.

The source points should be a :math:(B, N, 2) tensor of :math:(x, y) coordinates. The kernel centers are
a :math:(B, K, 2) tensor of :math:(x, y) coordinates. The kernel weights are a :math:(B, K, 2) tensor,
and the affine weights are a :math:(B, 3, 2) tensor.  For the weight tensors, tensor[..., 0] contains the
weights for the x-transform and tensor[..., 1] the weights for the y-transform.

Args:
points_src: tensor of source points :math:(B, N, 2).
kernel_centers: tensor of kernel center points :math:(B, K, 2).
kernel_weights: tensor of kernl weights :math:(B, K, 2).
affine_weights: tensor of affine weights :math:(B, 3, 2).

Returns:
The :math:(B, N, 2) tensor of warped source points, from applying the TPS transform.

Example:
>>> points_src = torch.rand(1, 5, 2)
>>> points_dst = torch.rand(1, 5, 2)
>>> kernel_weights, affine_weights = get_tps_transform(points_src, points_dst)
>>> warped = warp_points_tps(points_src, points_dst, kernel_weights, affine_weights)
>>> warped_correct = torch.allclose(warped, points_dst)

.. note::
This function is often used in conjunction with :func:get_tps_transform.
"""
if not isinstance(points_src, torch.Tensor):
raise TypeError(f"Input points_src is not torch.Tensor. Got {type(points_src)}")

if not isinstance(kernel_centers, torch.Tensor):
raise TypeError(f"Input kernel_centers is not torch.Tensor. Got {type(kernel_centers)}")

if not isinstance(kernel_weights, torch.Tensor):
raise TypeError(f"Input kernel_weights is not torch.Tensor. Got {type(kernel_weights)}")

if not isinstance(affine_weights, torch.Tensor):
raise TypeError(f"Input affine_weights is not torch.Tensor. Got {type(affine_weights)}")

if not len(points_src.shape) == 3:
raise ValueError(f"Invalid shape for points_src, expected BxNx2. Got {points_src.shape}")

if not len(kernel_centers.shape) == 3:
raise ValueError(f"Invalid shape for kernel_centers, expected BxNx2. Got {kernel_centers.shape}")

if not len(kernel_weights.shape) == 3:
raise ValueError(f"Invalid shape for kernel_weights, expected BxNx2. Got {kernel_weights.shape}")

if not len(affine_weights.shape) == 3:
raise ValueError(f"Invalid shape for affine_weights, expected BxNx2. Got {affine_weights.shape}")

# f_{x|y}(v) = a_0 + [a_x a_y].v + \sum_i w_i * U(||v-u_i||)
pair_distance: torch.Tensor = _pair_square_euclidean(points_src, kernel_centers)
k_matrix: torch.Tensor = _kernel_distance(pair_distance)

# broadcast the kernel distance matrix against the x and y weights to compute the x and y
# transforms simultaneously
k_mul_kernel = k_matrix[..., None].mul(kernel_weights[:, None]).sum(-2)
points_mul_affine = points_src[..., None].mul(affine_weights[:, None, 1:]).sum(-2)
warped: torch.Tensor = k_mul_kernel + points_mul_affine + affine_weights[:, None, 0]

return warped

[docs]def warp_image_tps(
image: torch.Tensor,
kernel_centers: torch.Tensor,
kernel_weights: torch.Tensor,
affine_weights: torch.Tensor,
align_corners: bool = False,
) -> torch.Tensor:
r"""Warp an image tensor according to the thin plate spline transform defined by kernel centers,
kernel weights, and affine weights.

.. image:: _static/img/warp_image_tps.png

The transform is applied to each pixel coordinate in the output image to obtain a point in the input
image for interpolation of the output pixel. So the TPS parameters should correspond to a warp from
output space to input space.

The input image is a :math:(B, C, H, W) tensor. The kernel centers, kernel weight and affine weights
are the same as in warp_points_tps.

Args:
image: input image tensor :math:(B, C, H, W).
kernel_centers: kernel center points :math:(B, K, 2).
kernel_weights: tensor of kernl weights :math:(B, K, 2).
affine_weights: tensor of affine weights :math:(B, 3, 2).
align_corners: interpolation flag used by grid_sample.

Returns:
warped image tensor :math:(B, C, H, W).

Example:
>>> points_src = torch.rand(1, 5, 2)
>>> points_dst = torch.rand(1, 5, 2)
>>> image = torch.rand(1, 3, 32, 32)
>>> # note that we are getting the reverse transform: dst -> src
>>> kernel_weights, affine_weights = get_tps_transform(points_dst, points_src)
>>> warped_image = warp_image_tps(image, points_src, kernel_weights, affine_weights)

.. note::
This function is often used in conjunction with :func:get_tps_transform.
"""
if not isinstance(image, torch.Tensor):
raise TypeError(f"Input image is not torch.Tensor. Got {type(image)}")

if not isinstance(kernel_centers, torch.Tensor):
raise TypeError(f"Input kernel_centers is not torch.Tensor. Got {type(kernel_centers)}")

if not isinstance(kernel_weights, torch.Tensor):
raise TypeError(f"Input kernel_weights is not torch.Tensor. Got {type(kernel_weights)}")

if not isinstance(affine_weights, torch.Tensor):
raise TypeError(f"Input affine_weights is not torch.Tensor. Got {type(affine_weights)}")

if not len(image.shape) == 4:
raise ValueError(f"Invalid shape for image, expected BxCxHxW. Got {image.shape}")

if not len(kernel_centers.shape) == 3:
raise ValueError(f"Invalid shape for kernel_centers, expected BxNx2. Got {kernel_centers.shape}")

if not len(kernel_weights.shape) == 3:
raise ValueError(f"Invalid shape for kernel_weights, expected BxNx2. Got {kernel_weights.shape}")

if not len(affine_weights.shape) == 3:
raise ValueError(f"Invalid shape for affine_weights, expected BxNx2. Got {affine_weights.shape}")

device, dtype = image.device, image.dtype
batch_size, _, h, w = image.shape
coords: torch.Tensor = create_meshgrid(h, w, device=device).to(dtype=dtype)
coords = coords.reshape(-1, 2).expand(batch_size, -1, -1)
warped: torch.Tensor = warp_points_tps(coords, kernel_centers, kernel_weights, affine_weights)
warped = warped.view(-1, h, w, 2)
warped_image: torch.Tensor = nn.functional.grid_sample(image, warped, align_corners=align_corners)

return warped_image