Source code for kornia.geometry.subpix.spatial_soft_argmax

from typing import Tuple, Union

import torch
import torch.nn.functional as F

from kornia.core import Module, Tensor, concatenate, stack, tensor, zeros
from kornia.filters.sobel import spatial_gradient3d
from kornia.geometry.conversions import normalize_pixel_coordinates, normalize_pixel_coordinates3d
from kornia.utils import create_meshgrid, create_meshgrid3d
from kornia.utils._compat import torch_version_geq
from kornia.utils.helpers import safe_solve_with_mask

from .dsnt import spatial_expectation2d, spatial_softmax2d
from .nms import nms3d


def _get_window_grid_kernel2d(h: int, w: int, device: torch.device = torch.device('cpu')) -> Tensor:
    r"""Helper function, which generates a kernel to with window coordinates, residual to window center.

    Args:
         h: kernel height.
         : kernel width.
         device: device, on which generate.

    Returns:
        conv_kernel [2x1xhxw]
    """
    window_grid2d = create_meshgrid(h, w, False, device=device)
    window_grid2d = normalize_pixel_coordinates(window_grid2d, h, w)
    conv_kernel = window_grid2d.permute(3, 0, 1, 2)
    return conv_kernel


def _get_center_kernel2d(h: int, w: int, device: torch.device = torch.device('cpu')) -> Tensor:
    r"""Helper function, which generates a kernel to return center coordinates, when applied with F.conv2d to 2d
    coordinates grid.

    Args:
        h: kernel height.
        w: kernel width.
        device: device, on which generate.

    Returns:
        conv_kernel [2x2xhxw].
    """
    center_kernel = zeros(2, 2, h, w, device=device)

    #  If the size is odd, we have one pixel for center, if even - 2
    if h % 2 != 0:
        h_i1 = h // 2
        h_i2 = (h // 2) + 1
    else:
        h_i1 = (h // 2) - 1
        h_i2 = (h // 2) + 1
    if w % 2 != 0:
        w_i1 = w // 2
        w_i2 = (w // 2) + 1
    else:
        w_i1 = (w // 2) - 1
        w_i2 = (w // 2) + 1
    center_kernel[(0, 1), (0, 1), h_i1:h_i2, w_i1:w_i2] = 1.0 / float((h_i2 - h_i1) * (w_i2 - w_i1))
    return center_kernel


def _get_center_kernel3d(d: int, h: int, w: int, device: torch.device = torch.device('cpu')) -> Tensor:
    r"""Helper function, which generates a kernel to return center coordinates, when applied with F.conv2d to 3d
    coordinates grid.

    Args:
        d: kernel depth.
        h: kernel height.
        w: kernel width.
        device: device, on which generate.

    Returns:
        conv_kernel [3x3xdxhxw].
    """
    center_kernel = zeros(3, 3, d, h, w, device=device)
    #  If the size is odd, we have one pixel for center, if even - 2
    if h % 2 != 0:
        h_i1 = h // 2
        h_i2 = (h // 2) + 1
    else:
        h_i1 = (h // 2) - 1
        h_i2 = (h // 2) + 1
    if w % 2 != 0:
        w_i1 = w // 2
        w_i2 = (w // 2) + 1
    else:
        w_i1 = (w // 2) - 1
        w_i2 = (w // 2) + 1
    if d % 2 != 0:
        d_i1 = d // 2
        d_i2 = (d // 2) + 1
    else:
        d_i1 = (d // 2) - 1
        d_i2 = (d // 2) + 1
    center_num = float((h_i2 - h_i1) * (w_i2 - w_i1) * (d_i2 - d_i1))
    center_kernel[(0, 1, 2), (0, 1, 2), d_i1:d_i2, h_i1:h_i2, w_i1:w_i2] = 1.0 / center_num
    return center_kernel


def _get_window_grid_kernel3d(d: int, h: int, w: int, device: torch.device = torch.device('cpu')) -> Tensor:
    r"""Helper function, which generates a kernel to return coordinates, residual to window center.

    Args:
        d: kernel depth.
        h: kernel height.
        w: kernel width.
        device: device, on which generate.

    Returns:
        conv_kernel [3x1xdxhxw]
    """
    grid2d = create_meshgrid(h, w, True, device=device)
    if d > 1:
        z = torch.linspace(-1, 1, d, device=device).view(d, 1, 1, 1)
    else:  # only onr channel with index == 0
        z = zeros(1, 1, 1, 1, device=device)
    grid3d = concatenate([z.repeat(1, h, w, 1).contiguous(), grid2d.repeat(d, 1, 1, 1)], 3)
    conv_kernel = grid3d.permute(3, 0, 1, 2).unsqueeze(1)
    return conv_kernel


[docs]class ConvSoftArgmax2d(Module): r"""Module that calculates soft argmax 2d per window. See :func:`~kornia.geometry.subpix.conv_soft_argmax2d` for details. """ def __init__( self, kernel_size: Tuple[int, int] = (3, 3), stride: Tuple[int, int] = (1, 1), padding: Tuple[int, int] = (1, 1), temperature: Union[Tensor, float] = tensor(1.0), normalized_coordinates: bool = True, eps: float = 1e-8, output_value: bool = False, ) -> None: super().__init__() self.kernel_size = kernel_size self.stride = stride self.padding = padding self.temperature = temperature self.normalized_coordinates = normalized_coordinates self.eps = eps self.output_value = output_value def __repr__(self) -> str: return ( self.__class__.__name__ + '(' + 'kernel_size=' + str(self.kernel_size) + ', ' + 'stride=' + str(self.stride) + ', ' + 'padding=' + str(self.padding) + ', ' + 'temperature=' + str(self.temperature) + ', ' + 'normalized_coordinates=' + str(self.normalized_coordinates) + ', ' + 'eps=' + str(self.eps) + ', ' + 'output_value=' + str(self.output_value) + ')' ) def forward(self, x: Tensor): return conv_soft_argmax2d( x, self.kernel_size, self.stride, self.padding, self.temperature, self.normalized_coordinates, self.eps, self.output_value, )
[docs]class ConvSoftArgmax3d(Module): r"""Module that calculates soft argmax 3d per window. See :func:`~kornia.geometry.subpix.conv_soft_argmax3d` for details. """ def __init__( self, kernel_size: Tuple[int, int, int] = (3, 3, 3), stride: Tuple[int, int, int] = (1, 1, 1), padding: Tuple[int, int, int] = (1, 1, 1), temperature: Union[Tensor, float] = tensor(1.0), normalized_coordinates: bool = False, eps: float = 1e-8, output_value: bool = True, strict_maxima_bonus: float = 0.0, ) -> None: super().__init__() self.kernel_size = kernel_size self.stride = stride self.padding = padding self.temperature = temperature self.normalized_coordinates = normalized_coordinates self.eps = eps self.output_value = output_value self.strict_maxima_bonus = strict_maxima_bonus return def __repr__(self) -> str: return ( self.__class__.__name__ + '(' + 'kernel_size=' + str(self.kernel_size) + ', ' + 'stride=' + str(self.stride) + ', ' + 'padding=' + str(self.padding) + ', ' + 'temperature=' + str(self.temperature) + ', ' + 'normalized_coordinates=' + str(self.normalized_coordinates) + ', ' + 'eps=' + str(self.eps) + ', ' + 'strict_maxima_bonus=' + str(self.strict_maxima_bonus) + ', ' + 'output_value=' + str(self.output_value) + ')' ) def forward(self, x: Tensor): return conv_soft_argmax3d( x, self.kernel_size, self.stride, self.padding, self.temperature, self.normalized_coordinates, self.eps, self.output_value, self.strict_maxima_bonus, )
[docs]def conv_soft_argmax2d( input: Tensor, kernel_size: Tuple[int, int] = (3, 3), stride: Tuple[int, int] = (1, 1), padding: Tuple[int, int] = (1, 1), temperature: Union[Tensor, float] = tensor(1.0), normalized_coordinates: bool = True, eps: float = 1e-8, output_value: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: r"""Compute the convolutional spatial Soft-Argmax 2D over the windows of a given heatmap. .. math:: ij(X) = \frac{\sum{(i,j)} * exp(x / T) \in X} {\sum{exp(x / T) \in X}} .. math:: val(X) = \frac{\sum{x * exp(x / T) \in X}} {\sum{exp(x / T) \in X}} where :math:`T` is temperature. Args: input: the given heatmap with shape :math:`(N, C, H_{in}, W_{in})`. kernel_size: the size of the window. stride: the stride of the window. padding: input zero padding. temperature: factor to apply to input. normalized_coordinates: whether to return the coordinates normalized in the range of :math:`[-1, 1]`. Otherwise, it will return the coordinates in the range of the input shape. eps: small value to avoid zero division. output_value: if True, val is output, if False, only ij. Returns: Function has two outputs - argmax coordinates and the softmaxpooled heatmap values themselves. On each window, the function computed returns with shapes :math:`(N, C, 2, H_{out}, W_{out})`, :math:`(N, C, H_{out}, W_{out})`, where .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor Examples: >>> input = torch.randn(20, 16, 50, 32) >>> nms_coords, nms_val = conv_soft_argmax2d(input, (3,3), (2,2), (1,1), output_value=True) """ if not torch.is_tensor(input): raise TypeError(f"Input type is not a Tensor. Got {type(input)}") if not len(input.shape) == 4: raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") if temperature <= 0: raise ValueError(f"Temperature should be positive float or tensor. Got: {temperature}") b, c, h, w = input.shape kx, ky = kernel_size device: torch.device = input.device dtype: torch.dtype = input.dtype input = input.view(b * c, 1, h, w) center_kernel: Tensor = _get_center_kernel2d(kx, ky, device).to(dtype) window_kernel: Tensor = _get_window_grid_kernel2d(kx, ky, device).to(dtype) # applies exponential normalization trick # https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/ # https://github.com/pytorch/pytorch/blob/bcb0bb7e0e03b386ad837015faba6b4b16e3bfb9/aten/src/ATen/native/SoftMax.cpp#L44 x_max = F.adaptive_max_pool2d(input, (1, 1)) # max is detached to prevent undesired backprop loops in the graph x_exp = ((input - x_max.detach()) / temperature).exp() # F.avg_pool2d(.., divisor_override = 1.0) - proper way for sum pool in PyTorch 1.2. # Not available yet in version 1.0, so let's do manually pool_coef: float = float(kx * ky) # softmax denominator den = pool_coef * F.avg_pool2d(x_exp, kernel_size, stride=stride, padding=padding) + eps x_softmaxpool = pool_coef * F.avg_pool2d(x_exp * input, kernel_size, stride=stride, padding=padding) / den x_softmaxpool = x_softmaxpool.view(b, c, x_softmaxpool.size(2), x_softmaxpool.size(3)) # We need to output also coordinates # Pooled window center coordinates grid_global: Tensor = create_meshgrid(h, w, False, device).to(dtype).permute(0, 3, 1, 2) grid_global_pooled = F.conv2d(grid_global, center_kernel, stride=stride, padding=padding) # Coordinates of maxima residual to window center # prepare kernel coords_max: Tensor = F.conv2d(x_exp, window_kernel, stride=stride, padding=padding) coords_max = coords_max / den.expand_as(coords_max) coords_max = coords_max + grid_global_pooled.expand_as(coords_max) # [:,:, 0, ...] is x # [:,:, 1, ...] is y if normalized_coordinates: coords_max = normalize_pixel_coordinates(coords_max.permute(0, 2, 3, 1), h, w) coords_max = coords_max.permute(0, 3, 1, 2) # Back B*C -> (b, c) coords_max = coords_max.view(b, c, 2, coords_max.size(2), coords_max.size(3)) if output_value: return coords_max, x_softmaxpool return coords_max
[docs]def conv_soft_argmax3d( input: Tensor, kernel_size: Tuple[int, int, int] = (3, 3, 3), stride: Tuple[int, int, int] = (1, 1, 1), padding: Tuple[int, int, int] = (1, 1, 1), temperature: Union[Tensor, float] = tensor(1.0), normalized_coordinates: bool = False, eps: float = 1e-8, output_value: bool = True, strict_maxima_bonus: float = 0.0, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: r"""Compute the convolutional spatial Soft-Argmax 3D over the windows of a given heatmap. .. math:: ijk(X) = \frac{\sum{(i,j,k)} * exp(x / T) \in X} {\sum{exp(x / T) \in X}} .. math:: val(X) = \frac{\sum{x * exp(x / T) \in X}} {\sum{exp(x / T) \in X}} where ``T`` is temperature. Args: input: the given heatmap with shape :math:`(N, C, D_{in}, H_{in}, W_{in})`. kernel_size: size of the window. stride: stride of the window. padding: input zero padding. temperature: factor to apply to input. normalized_coordinates: whether to return the coordinates normalized in the range of :math:[-1, 1]`. Otherwise, it will return the coordinates in the range of the input shape. eps: small value to avoid zero division. output_value: if True, val is output, if False, only ij. strict_maxima_bonus: pixels, which are strict maxima will score (1 + strict_maxima_bonus) * value. This is needed for mimic behavior of strict NMS in classic local features Returns: Function has two outputs - argmax coordinates and the softmaxpooled heatmap values themselves. On each window, the function computed returns with shapes :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor Examples: >>> input = torch.randn(20, 16, 3, 50, 32) >>> nms_coords, nms_val = conv_soft_argmax3d(input, (3, 3, 3), (1, 2, 2), (0, 1, 1)) """ if not torch.is_tensor(input): raise TypeError(f"Input type is not a Tensor. Got {type(input)}") if not len(input.shape) == 5: raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}") if temperature <= 0: raise ValueError(f"Temperature should be positive float or tensor. Got: {temperature}") b, c, d, h, w = input.shape kx, ky, kz = kernel_size device: torch.device = input.device dtype: torch.dtype = input.dtype input = input.view(b * c, 1, d, h, w) center_kernel: Tensor = _get_center_kernel3d(kx, ky, kz, device).to(dtype) window_kernel: Tensor = _get_window_grid_kernel3d(kx, ky, kz, device).to(dtype) # applies exponential normalization trick # https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/ # https://github.com/pytorch/pytorch/blob/bcb0bb7e0e03b386ad837015faba6b4b16e3bfb9/aten/src/ATen/native/SoftMax.cpp#L44 x_max = F.adaptive_max_pool3d(input, (1, 1, 1)) # max is detached to prevent undesired backprop loops in the graph x_exp = ((input - x_max.detach()) / temperature).exp() pool_coef: float = float(kx * ky * kz) # softmax denominator den = pool_coef * F.avg_pool3d(x_exp.view_as(input), kernel_size, stride=stride, padding=padding) + eps # We need to output also coordinates # Pooled window center coordinates grid_global: Tensor = create_meshgrid3d(d, h, w, False, device=device).to(dtype).permute(0, 4, 1, 2, 3) grid_global_pooled = F.conv3d(grid_global, center_kernel, stride=stride, padding=padding) # Coordinates of maxima residual to window center # prepare kernel coords_max: Tensor = F.conv3d(x_exp, window_kernel, stride=stride, padding=padding) coords_max = coords_max / den.expand_as(coords_max) coords_max = coords_max + grid_global_pooled.expand_as(coords_max) # [:,:, 0, ...] is depth (scale) # [:,:, 1, ...] is x # [:,:, 2, ...] is y if normalized_coordinates: coords_max = normalize_pixel_coordinates3d(coords_max.permute(0, 2, 3, 4, 1), d, h, w) coords_max = coords_max.permute(0, 4, 1, 2, 3) # Back B*C -> (b, c) coords_max = coords_max.view(b, c, 3, coords_max.size(2), coords_max.size(3), coords_max.size(4)) if not output_value: return coords_max x_softmaxpool = ( pool_coef * F.avg_pool3d(x_exp.view(input.size()) * input, kernel_size, stride=stride, padding=padding) / den ) if strict_maxima_bonus > 0: in_levels: int = input.size(2) out_levels: int = x_softmaxpool.size(2) skip_levels: int = (in_levels - out_levels) // 2 strict_maxima: Tensor = F.avg_pool3d(nms3d(input, kernel_size), 1, stride, 0) strict_maxima = strict_maxima[:, :, skip_levels : out_levels - skip_levels] x_softmaxpool *= 1.0 + strict_maxima_bonus * strict_maxima x_softmaxpool = x_softmaxpool.view(b, c, x_softmaxpool.size(2), x_softmaxpool.size(3), x_softmaxpool.size(4)) return coords_max, x_softmaxpool
[docs]def spatial_soft_argmax2d( input: Tensor, temperature: Tensor = tensor(1.0), normalized_coordinates: bool = True ) -> Tensor: r"""Compute the Spatial Soft-Argmax 2D of a given input heatmap. Args: input: the given heatmap with shape :math:`(B, N, H, W)`. temperature: factor to apply to input. normalized_coordinates: whether to return the coordinates normalized in the range of :math:`[-1, 1]`. Otherwise, it will return the coordinates in the range of the input shape. Returns: the index of the maximum 2d coordinates of the give map :math:`(B, N, 2)`. The output order is x-coord and y-coord. Examples: >>> input = torch.tensor([[[ ... [0., 0., 0.], ... [0., 10., 0.], ... [0., 0., 0.]]]]) >>> spatial_soft_argmax2d(input, normalized_coordinates=False) tensor([[[1.0000, 1.0000]]]) """ input_soft: Tensor = spatial_softmax2d(input, temperature) output: Tensor = spatial_expectation2d(input_soft, normalized_coordinates) return output
[docs]class SpatialSoftArgmax2d(Module): r"""Compute the Spatial Soft-Argmax 2D of a given heatmap. See :func:`~kornia.geometry.subpix.spatial_soft_argmax2d` for details. """ def __init__(self, temperature: Tensor = tensor(1.0), normalized_coordinates: bool = True) -> None: super().__init__() self.temperature: Tensor = temperature self.normalized_coordinates: bool = normalized_coordinates def __repr__(self) -> str: return ( self.__class__.__name__ + '(temperature=' + str(self.temperature) + ', ' + 'normalized_coordinates=' + str(self.normalized_coordinates) + ')' ) def forward(self, input: Tensor) -> Tensor: return spatial_soft_argmax2d(input, self.temperature, self.normalized_coordinates)
[docs]def conv_quad_interp3d(input: Tensor, strict_maxima_bonus: float = 10.0, eps: float = 1e-7) -> Tuple[Tensor, Tensor]: r"""Compute the single iteration of quadratic interpolation of the extremum (max or min). Args: input: the given heatmap with shape :math:`(N, C, D_{in}, H_{in}, W_{in})`. strict_maxima_bonus: pixels, which are strict maxima will score (1 + strict_maxima_bonus) * value. This is needed for mimic behavior of strict NMS in classic local features eps: parameter to control the hessian matrix ill-condition number. Returns: the location and value per each 3x3x3 window which contains strict extremum, similar to one done is SIFT. :math:`(N, C, 3, D_{out}, H_{out}, W_{out})`, :math:`(N, C, D_{out}, H_{out}, W_{out})`, where .. math:: D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor Examples: >>> input = torch.randn(20, 16, 3, 50, 32) >>> nms_coords, nms_val = conv_quad_interp3d(input, 1.0) """ if not torch.is_tensor(input): raise TypeError(f"Input type is not a Tensor. Got {type(input)}") if not len(input.shape) == 5: raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}") B, CH, D, H, W = input.shape grid_global: Tensor = create_meshgrid3d(D, H, W, False, device=input.device).permute(0, 4, 1, 2, 3) grid_global = grid_global.to(input.dtype) # to determine the location we are solving system of linear equations Ax = b, where b is 1st order gradient # and A is Hessian matrix b: Tensor = spatial_gradient3d(input, order=1, mode='diff') # b = b.permute(0, 1, 3, 4, 5, 2).reshape(-1, 3, 1) A: Tensor = spatial_gradient3d(input, order=2, mode='diff') A = A.permute(0, 1, 3, 4, 5, 2).reshape(-1, 6) dxx = A[..., 0] dyy = A[..., 1] dss = A[..., 2] dxy = 0.25 * A[..., 3] # normalization to match OpenCV implementation dys = 0.25 * A[..., 4] # normalization to match OpenCV implementation dxs = 0.25 * A[..., 5] # normalization to match OpenCV implementation Hes = stack([dxx, dxy, dxs, dxy, dyy, dys, dxs, dys, dss], -1).view(-1, 3, 3) if not torch_version_geq(1, 10): # The following is needed to avoid singular cases Hes += torch.rand(Hes[0].size(), device=Hes.device).abs()[None] * eps nms_mask: Tensor = nms3d(input, (3, 3, 3), True) x_solved: Tensor = torch.zeros_like(b) x_solved_masked, _, solved_correctly = safe_solve_with_mask(b[nms_mask.view(-1)], Hes[nms_mask.view(-1)]) # Kill those points, where we cannot solve new_nms_mask = nms_mask.masked_scatter(nms_mask, solved_correctly) x_solved.masked_scatter_(new_nms_mask.view(-1, 1, 1), x_solved_masked[solved_correctly]) dx: Tensor = -x_solved # Ignore ones, which are far from window center mask1 = dx.abs().max(dim=1, keepdim=True)[0] > 0.7 dx.masked_fill_(mask1.expand_as(dx), 0) dy: Tensor = 0.5 * torch.bmm(b.permute(0, 2, 1), dx) y_max = input + dy.view(B, CH, D, H, W) if strict_maxima_bonus > 0: y_max += strict_maxima_bonus * new_nms_mask.to(input.dtype) dx_res: Tensor = dx.flip(1).reshape(B, CH, D, H, W, 3).permute(0, 1, 5, 2, 3, 4) coords_max: Tensor = grid_global.repeat(B, 1, 1, 1, 1).unsqueeze(1) coords_max = coords_max + dx_res return coords_max, y_max
[docs]class ConvQuadInterp3d(Module): r"""Calculate soft argmax 3d per window. See :func:`~kornia.geometry.subpix.conv_quad_interp3d` for details. """ def __init__(self, strict_maxima_bonus: float = 10.0, eps: float = 1e-7) -> None: super().__init__() self.strict_maxima_bonus = strict_maxima_bonus self.eps = eps return def __repr__(self) -> str: return self.__class__.__name__ + '(' + 'strict_maxima_bonus=' + str(self.strict_maxima_bonus) + ')' def forward(self, x: Tensor): return conv_quad_interp3d(x, self.strict_maxima_bonus, self.eps)