Source code for kornia.geometry.transform.thin_plate_spline

from typing import Tuple

import torch
import torch.nn as nn

from kornia.utils import create_meshgrid

__all__ = [
    "get_tps_transform",
    "warp_points_tps",
    "warp_image_tps"
]

# utilities for computing thin plate spline transforms


def _pair_square_euclidean(tensor1: torch.Tensor, tensor2: torch.Tensor) -> torch.Tensor:
    r"""Compute the pairwise squared euclidean distance matrices :math:`(B, N, M)` between two tensors
    with shapes (B, N, C) and (B, M, C)."""
    # ||t1-t2||^2 = (t1-t2)^T(t1-t2) = t1^T*t1 + t2^T*t2 - 2*t1^T*t2
    t1_sq: torch.Tensor = tensor1.mul(tensor1).sum(dim=-1, keepdim=True)
    t2_sq: torch.Tensor = tensor2.mul(tensor2).sum(dim=-1, keepdim=True).transpose(1, 2)
    t1_t2: torch.Tensor = tensor1.matmul(tensor2.transpose(1, 2))
    square_dist: torch.Tensor = -2 * t1_t2 + t1_sq + t2_sq
    square_dist = square_dist.clamp(min=0)  # handle possible numerical errors
    return square_dist


def _kernel_distance(squared_distances: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    r"""Compute the TPS kernel distance function: :math:`r^2 log(r)`, where `r` is the euclidean distance.
    Since :math:`\log(r) = 1/2 \log(r^2)`, this function takes the squared distance matrix and calculates
    :math:`0.5 r^2 log(r^2)`."""
    # r^2 * log(r) = 1/2 * r^2 * log(r^2)
    return 0.5 * squared_distances * squared_distances.add(eps).log()


[docs]def get_tps_transform(points_src: torch.Tensor, points_dst: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: r"""Compute the TPS transform parameters that warp source points to target points. The input to this function is a tensor of :math:`(x, y)` source points :math:`(B, N, 2)` and a corresponding tensor of target :math:`(x, y)` points :math:`(B, N, 2)`. Args: points_src (torch.Tensor): batch of source points :math:`(B, N, 2)` as :math:`(x, y)` coordinate vectors. points_dst (torch.Tensor): batch of target points :math:`(B, N, 2)` as :math:`(x, y)` coordinate vectors. Returns: Tuple[torch.Tensor, torch.Tensor]: :math:`(B, N, 2)` tensor of kernel weights and :math:`(B, 3, 2)` tensor of affine weights. The last dimension contains the x-transform and y-transform weights as seperate columns. Example: >>> points_src = torch.rand(1, 5, 2) >>> points_dst = torch.rand(1, 5, 2) >>> kernel_weights, affine_weights = get_tps_transform(points_src, points_dst) .. note:: This function is often used in conjuntion with :func:`warp_points_tps`, :func:`warp_image_tps`. """ if not isinstance(points_src, torch.Tensor): raise TypeError(f"Input points_src is not torch.Tensor. Got {type(points_src)}") if not isinstance(points_dst, torch.Tensor): raise TypeError(f"Input points_dst is not torch.Tensor. Got {type(points_dst)}") if not len(points_src.shape) == 3: raise ValueError(f"Invalid shape for points_src, expected BxNx2. Got {points_src.shape}") if not len(points_dst.shape) == 3: raise ValueError(f"Invalid shape for points_dst, expected BxNx2. Got {points_dst.shape}") device, dtype = points_src.device, points_src.dtype batch_size, num_points = points_src.shape[:2] # set up and solve linear system # [K P] [w] = [dst] # [P^T 0] [a] [ 0 ] pair_distance: torch.Tensor = _pair_square_euclidean(points_src, points_dst) k_matrix: torch.Tensor = _kernel_distance(pair_distance) zero_mat: torch.Tensor = torch.zeros(batch_size, 3, 3, device=device, dtype=dtype) one_mat: torch.Tensor = torch.ones(batch_size, num_points, 1, device=device, dtype=dtype) dest_with_zeros: torch.Tensor = torch.cat((points_dst, zero_mat[:, :, :2]), 1) p_matrix: torch.Tensor = torch.cat((one_mat, points_src), -1) p_matrix_t: torch.Tensor = torch.cat((p_matrix, zero_mat), 1).transpose(1, 2) l_matrix: torch.Tensor = torch.cat((k_matrix, p_matrix), -1) l_matrix = torch.cat((l_matrix, p_matrix_t), 1) weights, _ = torch.solve(dest_with_zeros, l_matrix) kernel_weights: torch.Tensor = weights[:, :-3] affine_weights: torch.Tensor = weights[:, -3:] return (kernel_weights, affine_weights)
[docs]def warp_points_tps(points_src: torch.Tensor, kernel_centers: torch.Tensor, kernel_weights: torch.Tensor, affine_weights: torch.Tensor) -> torch.Tensor: r"""Warp a tensor of coordinate points using the thin plate spline defined by kernel points, kernel weights, and affine weights. The source points should be a :math:`(B, N, 2)` tensor of :math:`(x, y)` coordinates. The kernel centers are a :math:`(B, K, 2)` tensor of :math:`(x, y)` coordinates. The kernel weights are a :math:`(B, K, 2)` tensor, and the affine weights are a :math:`(B, 3, 2)` tensor. For the weight tensors, tensor[..., 0] contains the weights for the x-transform and tensor[..., 1] the weights for the y-transform. Args: points_src (torch.Tensor): tensor of source points :math:`(B, N, 2)`. kernel_centers (torch.Tensor): tensor of kernel center points :math:`(B, K, 2)`. kernel_weights (torch.Tensor): tensor of kernl weights :math:`(B, K, 2)`. affine_weights (torch.Tensor): tensor of affine weights :math:`(B, 3, 2)`. Returns: torch.Tensor: The :math:`(B, N, 2)` tensor of warped source points, from applying the TPS transform. Example: >>> points_src = torch.rand(1, 5, 2) >>> points_dst = torch.rand(1, 5, 2) >>> kernel_weights, affine_weights = get_tps_transform(points_src, points_dst) >>> warped = warp_points_tps(points_src, points_dst, kernel_weights, affine_weights) >>> warped_correct = torch.allclose(warped, points_dst) .. note:: This function is often used in conjuntion with :func:`get_tps_transform`. """ if not isinstance(points_src, torch.Tensor): raise TypeError(f"Input points_src is not torch.Tensor. Got {type(points_src)}") if not isinstance(kernel_centers, torch.Tensor): raise TypeError(f"Input kernel_centers is not torch.Tensor. Got {type(kernel_centers)}") if not isinstance(kernel_weights, torch.Tensor): raise TypeError(f"Input kernel_weights is not torch.Tensor. Got {type(kernel_weights)}") if not isinstance(affine_weights, torch.Tensor): raise TypeError(f"Input affine_weights is not torch.Tensor. Got {type(affine_weights)}") if not len(points_src.shape) == 3: raise ValueError(f"Invalid shape for points_src, expected BxNx2. Got {points_src.shape}") if not len(kernel_centers.shape) == 3: raise ValueError(f"Invalid shape for kernel_centers, expected BxNx2. Got {kernel_centers.shape}") if not len(kernel_weights.shape) == 3: raise ValueError(f"Invalid shape for kernel_weights, expected BxNx2. Got {kernel_weights.shape}") if not len(affine_weights.shape) == 3: raise ValueError(f"Invalid shape for affine_weights, expected BxNx2. Got {affine_weights.shape}") # f_{x|y}(v) = a_0 + [a_x a_y].v + \sum_i w_i * U(||v-u_i||) pair_distance: torch.Tensor = _pair_square_euclidean(points_src, kernel_centers) k_matrix: torch.Tensor = _kernel_distance(pair_distance) # broadcast the kernel distance matrix against the x and y weights to compute the x and y # transforms simultaneously warped: torch.Tensor = ( k_matrix[..., None].mul(kernel_weights[:, None]).sum(-2) + points_src[..., None].mul(affine_weights[:, None, 1:]).sum(-2) + affine_weights[:, None, 0] ) return warped
[docs]def warp_image_tps(image: torch.Tensor, kernel_centers: torch.Tensor, kernel_weights: torch.Tensor, affine_weights: torch.Tensor, align_corners: bool = False) -> torch.Tensor: r"""Warp an image tensor according to the thin plate spline transform defined by kernel centers, kernel weights, and affine weights. The transform is applied to each pixel coordinate in the output image to obtain a point in the input image for interpolation of the output pixel. So the TPS parameters should correspond to a warp from output space to input space. The input `image` is a :math:`(B, C, H, W)` tensor. The kernel centers, kernel weight and affine weights are the same as in `warp_points_tps`. Args: image (torch.Tensor): input image tensor :math:`(B, C, H, W)`. kernel_centers (torch.Tensor): kernel center points :math:`(B, K, 2)`. kernel_weights (torch.Tensor): tensor of kernl weights :math:`(B, K, 2)`. affine_weights (torch.Tensor): tensor of affine weights :math:`(B, 3, 2)`. align_corners (bool): interpolation flag used by `grid_sample`. Default: False. Returns: torch.Tensor: warped image tensor :math:`(B, C, H, W)`. Example: >>> points_src = torch.rand(1, 5, 2) >>> points_dst = torch.rand(1, 5, 2) >>> image = torch.rand(1, 3, 32, 32) >>> # note that we are getting the reverse transform: dst -> src >>> kernel_weights, affine_weights = get_tps_transform(points_dst, points_src) >>> warped_image = warp_image_tps(image, points_src, kernel_weights, affine_weights) .. note:: This function is often used in conjuntion with :func:`get_tps_transform`. """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input image is not torch.Tensor. Got {type(image)}") if not isinstance(kernel_centers, torch.Tensor): raise TypeError(f"Input kernel_centers is not torch.Tensor. Got {type(kernel_centers)}") if not isinstance(kernel_weights, torch.Tensor): raise TypeError(f"Input kernel_weights is not torch.Tensor. Got {type(kernel_weights)}") if not isinstance(affine_weights, torch.Tensor): raise TypeError(f"Input affine_weights is not torch.Tensor. Got {type(affine_weights)}") if not len(image.shape) == 4: raise ValueError(f"Invalid shape for image, expected BxCxHxW. Got {image.shape}") if not len(kernel_centers.shape) == 3: raise ValueError(f"Invalid shape for kernel_centers, expected BxNx2. Got {kernel_centers.shape}") if not len(kernel_weights.shape) == 3: raise ValueError(f"Invalid shape for kernel_weights, expected BxNx2. Got {kernel_weights.shape}") if not len(affine_weights.shape) == 3: raise ValueError(f"Invalid shape for affine_weights, expected BxNx2. Got {affine_weights.shape}") device, dtype = image.device, image.dtype batch_size, _, h, w = image.shape coords: torch.Tensor = create_meshgrid(h, w, device=device).to(dtype=dtype) coords = coords.reshape(-1, 2).expand(batch_size, -1, -1) warped: torch.Tensor = warp_points_tps(coords, kernel_centers, kernel_weights, affine_weights) warped = warped.view(-1, h, w, 2) warped_image: torch.Tensor = nn.functional.grid_sample(image, warped, align_corners=align_corners) return warped_image