kornia.geometry.subpix

Module with useful functionalities to extract coordinates sub-pixel accuracy.

Convolutional

kornia.geometry.subpix.conv_soft_argmax2d(input, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), temperature=1.0, normalized_coordinates=True, eps=1e-8, output_value=False)[source]

Compute the convolutional spatial Soft-Argmax 2D over the windows of a given heatmap.

\[ij(X) = \frac{\sum{(i,j)} * exp(x / T) \in X} {\sum{exp(x / T) \in X}}\]
\[val(X) = \frac{\sum{x * exp(x / T) \in X}} {\sum{exp(x / T) \in X}}\]

where \(T\) is temperature.

Parameters:
  • input (Tensor) – the given heatmap with shape \((N, C, H_{in}, W_{in})\).

  • kernel_size (tuple[int, int], optional) – the size of the window. Default: (3, 3)

  • stride (tuple[int, int], optional) – the stride of the window. Default: (1, 1)

  • padding (tuple[int, int], optional) – input zero padding. Default: (1, 1)

  • temperature (Tensor | float, optional) – factor to apply to input. Default: 1.0

  • normalized_coordinates (bool, optional) – whether to return the coordinates normalized in the range of \([-1, 1]\). Otherwise, it will return the coordinates in the range of the input shape. Default: True

  • eps (float, optional) – small value to avoid zero division. Default: 1e-8

  • output_value (bool, optional) – if True, val is output, if False, only ij. Default: False

Return type:

Tensor | tuple[Tensor, Tensor]

Returns:

Function has two outputs - argmax coordinates and the softmaxpooled heatmap values themselves. On each window, the function computed returns with shapes \((N, C, 2, H_{out}, W_{out})\), \((N, C, H_{out}, W_{out})\),

where

\[H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor\]
\[W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor\]

Examples

>>> input = torch.randn(20, 16, 50, 32)
>>> nms_coords, nms_val = conv_soft_argmax2d(input, (3,3), (2,2), (1,1), output_value=True)
kornia.geometry.subpix.conv_soft_argmax3d(input, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), temperature=1.0, normalized_coordinates=False, eps=1e-8, output_value=True, strict_maxima_bonus=0.0)[source]

Compute the convolutional spatial Soft-Argmax 3D over the windows of a given heatmap.

\[ijk(X) = \frac{\sum{(i,j,k)} * exp(x / T) \in X} {\sum{exp(x / T) \in X}}\]
\[val(X) = \frac{\sum{x * exp(x / T) \in X}} {\sum{exp(x / T) \in X}}\]

where T is temperature.

Parameters:
  • input (Tensor) – the given heatmap with shape \((N, C, D_{in}, H_{in}, W_{in})\).

  • kernel_size (tuple[int, int, int], optional) – size of the window. Default: (3, 3, 3)

  • stride (tuple[int, int, int], optional) – stride of the window. Default: (1, 1, 1)

  • padding (tuple[int, int, int], optional) – input zero padding. Default: (1, 1, 1)

  • temperature (Tensor | float, optional) – factor to apply to input. Default: 1.0

  • normalized_coordinates (bool, optional) – whether to return the coordinates normalized in the range of :math:[-1, 1]`. Otherwise, it will return the coordinates in the range of the input shape. Default: False

  • eps (float, optional) – small value to avoid zero division. Default: 1e-8

  • output_value (bool, optional) – if True, val is output, if False, only ij. Default: True

  • strict_maxima_bonus (float, optional) – pixels, which are strict maxima will score (1 + strict_maxima_bonus) * value. This is needed for mimic behavior of strict NMS in classic local features Default: 0.0

Return type:

Tensor | tuple[Tensor, Tensor]

Returns:

Function has two outputs - argmax coordinates and the softmaxpooled heatmap values themselves. On each window, the function computed returns with shapes \((N, C, 3, D_{out}, H_{out}, W_{out})\), \((N, C, D_{out}, H_{out}, W_{out})\),

where

\[D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor\]
\[H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor\]
\[W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor\]

Examples

>>> input = torch.randn(20, 16, 3, 50, 32)
>>> nms_coords, nms_val = conv_soft_argmax3d(input, (3, 3, 3), (1, 2, 2), (0, 1, 1))
kornia.geometry.subpix.conv_quad_interp3d(input, n_iters=5, strict_maxima_bonus=10.0, max_subpixel_shift=0.6, precomputed_nms_mask=None, dilation_radius=1, allow_scale_steps=True)[source]

Subpixel localization of 3D scale-space extrema via quadratic interpolation.

For each NMS maximum the function fits a 3-D quadratic to the local \(3 \times 3 \times 3\) neighbourhood and solves for the sub-voxel shift that maximises the fit. When the shift along any axis exceeds max_subpixel_shift the integer centre is moved one step in that direction and the solve is repeated — up to n_iters times.

Unlike a naive iterative approach, all Hessian solves are precomputed once at the start for every voxel that any keypoint could possibly visit (the dilated NMS neighbourhood, an L\(\infty\) ball of radius dilation_radius around each maximum). The subsequent iteration loop contains no data-dependent Python control flow and no GPU→CPU synchronisation, making the function fully compatible with torch.compile / CUDA graphs.

The dilation_radius controls the precompute footprint and should be set to the maximum number of integer-centre moves expected per keypoint. With the default max_subpixel_shift=0.6 almost all keypoints converge within 1 move, so the default dilation_radius=1 (i.e. \(3^3 = 27\) positions per maximum) is sufficient. Use dilation_radius=2 (\(5^3 = 125\)) for extra safety. Setting it equal to n_iters recovers the original behaviour but is much slower on large images.

Parameters:
  • input (Tensor) – response pyramid with shape \((B, C, D, H, W)\).

  • n_iters (int, optional) – maximum number of localization iterations per keypoint. Default: 5

  • strict_maxima_bonus (float, optional) – value added to y_max at NMS-maximum positions so that strict maxima are preferred during top-K selection. Default: 10.0

  • max_subpixel_shift (float, optional) – threshold above which the integer centre is moved one step and another iteration is run. Default: 0.6

  • precomputed_nms_mask (Optional[Tensor], optional) – optional bool tensor of shape \((B, C, D, H, W)\) — pass the result of nms3d() to skip the internal NMS call. Default: None

  • dilation_radius (int, optional) – L\(\infty\) radius (in voxels) of the neighbourhood around each NMS maximum where the Hessian solve is precomputed. Keypoints that attempt to move farther than this are marked invalid. Default: 1

  • allow_scale_steps (bool, optional) – if True (default), the iterative shift is also applied along the scale (depth) axis; set to False to keep the keypoint on its original scale level. Default: True

Returns:

  • coords_max — shape \((B, C, 3, D, H, W)\), refined [scale, x(width), y(height)] coordinates for each NMS maximum; non-maximum positions keep their grid coordinates.

  • y_max — shape \((B, C, D, H, W)\), quadratically corrected response with optional strict-maxima bonus.

Return type:

Tuple (coords_max, y_max)

Example

>>> input = torch.randn(2, 3, 5, 64, 64)
>>> coords, vals = conv_quad_interp3d(input, n_iters=5)
>>> coords.shape
torch.Size([2, 3, 3, 5, 64, 64])
>>> vals.shape
torch.Size([2, 3, 5, 64, 64])
kornia.geometry.subpix.iterative_quad_interp3d(input, n_iters=5, strict_maxima_bonus=10.0, max_subpixel_shift=0.6, allow_scale_steps=True, precomputed_nms_mask=None, max_candidates=None)[source]

Iterative subpixel localization of 3D extrema via quadratic interpolation.

Unlike conv_quad_interp3d(), which pre-computes the Hessian solve for all voxels reachable from NMS maxima and then follows shifts by table lookup, this function explicitly re-extracts the \(3 \times 3 \times 3\) patch at each NMS maximum and iterates up to n_iters times. When the estimated subpixel shift along any spatial or scale axis exceeds max_subpixel_shift the integer center is moved one step in that direction and the solve is repeated — matching the localization loop from the HessAff / SIFT family of detectors.

Parameters:
  • input (Tensor) – response pyramid with shape \((B, C, D, H, W)\).

  • n_iters (int, optional) – maximum number of localization iterations per keypoint. Default: 5

  • strict_maxima_bonus (float, optional) – value added to y_max at NMS-maximum positions so that strict maxima are preferred when selecting the top-K keypoints. Default: 10.0

  • max_subpixel_shift (float, optional) – if the estimated shift along any axis is larger than this threshold the integer center is displaced and another iteration is run. Default: 0.6

  • allow_scale_steps (bool, optional) – if True (default), the iterative shift is also applied along the scale (depth) axis; set to False to keep the keypoint on its original scale level. Default: True

  • precomputed_nms_mask (Optional[Tensor], optional) – optional bool tensor of shape \((B, C, D, H, W)\) — pass the result of nms3d() to skip the internal NMS call. Default: None

  • max_candidates (Optional[int], optional) – if given, only the top-max_candidates NMS maxima (ranked by pre-refinement response) are processed. The rest keep their grid-coordinate values. This is a CPU speed-up knob: for large images the number of 3-D NMS maxima can be 10x-100x larger than the desired number of keypoints, making the per-candidate gather+solve loop the dominant CPU cost. Setting max_candidates = num_features * 5 (say) dramatically reduces that work at the cost of occasionally missing a feature whose response rank would have improved after refinement. Default: None

Return type:

tuple[Tensor, Tensor]

Returns:

A tuple (coords_max, y_max) where

  • coords_max has shape \((B, C, 3, D, H, W)\) and stores the refined coordinates [scale, x, y] for every position in the input. Non-NMS positions keep their original grid coordinates.

  • y_max has shape \((B, C, D, H, W)\) and stores the quadratically corrected response values (with the optional strict-maxima bonus added).

Example

>>> input = torch.randn(2, 3, 3, 8, 8)
>>> coords, vals = iterative_quad_interp3d(input, n_iters=5)
>>> coords.shape
torch.Size([2, 3, 3, 3, 8, 8])
>>> vals.shape
torch.Size([2, 3, 3, 8, 8])

Tip

AdaptiveQuadInterp3d (the default subpix module in ScaleSpaceDetector) automatically picks the faster backend based on the input device:

Both backends produce numerically identical results (max difference < 2 × 10-6).

import torch
from kornia.geometry.subpix import AdaptiveQuadInterp3d
from kornia.feature import ScaleSpaceDetector
from kornia.feature.responses import BlobDoG
from kornia.geometry.transform import ScalePyramid

detector = ScaleSpaceDetector(
    num_features=2000,
    resp_module=BlobDoG(),
    # default — auto-selects conv on CUDA, patch on CPU:
    subpix_module=AdaptiveQuadInterp3d(strict_maxima_bonus=0.0),
    scale_pyr_module=ScalePyramid(3, 1.6, 32, double_image=True),
    scale_space_response=True,
    minima_are_also_good=True,
)

Spatial

kornia.geometry.subpix.spatial_softmax2d(input, temperature=None)[source]

Apply the Softmax function over features in each image channel.

Note that this function behaves differently to torch.nn.Softmax2d, which instead applies Softmax over features at each spatial location.

Parameters:
  • input (Tensor) – the input torch.Tensor with shape \((B, N, H, W)\).

  • temperature (Optional[Tensor], optional) – factor to apply to input, adjusting the “smoothness” of the output distribution. Default: None

Return type:

Tensor

Returns:

a 2D probability distribution per image channel with shape \((B, N, H, W)\).

Examples

>>> heatmaps = torch.tensor([[[
... [0., 0., 0.],
... [0., 0., 0.],
... [0., 1., 2.]]]])
>>> spatial_softmax2d(heatmaps)
tensor([[[[0.0585, 0.0585, 0.0585],
          [0.0585, 0.0585, 0.0585],
          [0.0585, 0.1589, 0.4319]]]])
kornia.geometry.subpix.spatial_expectation2d(input, normalized_coordinates=True)[source]

Compute the expectation of coordinate values using spatial probabilities.

The input heatmap is assumed to represent a valid spatial probability distribution, which can be achieved using spatial_softmax2d().

Parameters:
  • input (Tensor) – the input torch.Tensor representing dense spatial probabilities with shape \((B, N, H, W)\).

  • normalized_coordinates (bool, optional) – whether to return the coordinates normalized in the range of \([-1, 1]\). Otherwise, it will return the coordinates in the range of the input shape. Default: True

Return type:

Tensor

Returns:

expected value of the 2D coordinates with shape \((B, N, 2)\). Output order of the coordinates is (x, y).

Examples

>>> heatmaps = torch.tensor([[[
... [0., 0., 0.],
... [0., 0., 0.],
... [0., 1., 0.]]]])
>>> spatial_expectation2d(heatmaps, False)
tensor([[[1., 2.]]])
kornia.geometry.subpix.spatial_soft_argmax2d(input, temperature=None, normalized_coordinates=True)[source]

Compute the Spatial Soft-Argmax 2D of a given input heatmap.

Parameters:
  • input (Tensor) – the given heatmap with shape \((B, N, H, W)\).

  • temperature (Optional[Tensor], optional) – factor to apply to input. Default: None

  • normalized_coordinates (bool, optional) – whether to return the coordinates normalized in the range of \([-1, 1]\). Otherwise, it will return the coordinates in the range of the input shape. Default: True

Return type:

Tensor

Returns:

the index of the maximum 2d coordinates of the give map \((B, N, 2)\). The output order is x-coord and y-coord.

Examples

>>> input = torch.tensor([[[
... [0., 0., 0.],
... [0., 10., 0.],
... [0., 0., 0.]]]])
>>> spatial_soft_argmax2d(input, normalized_coordinates=False)
tensor([[[1.0000, 1.0000]]])
kornia.geometry.subpix.render_gaussian2d(mean, std, size, normalized_coordinates=True)[source]

Render the PDF of a 2D Gaussian distribution.

Parameters:
  • mean (Tensor) – the mean location of the Gaussian to render, \((\mu_x, \mu_y)\). Shape: \((*, 2)\).

  • std (Tensor) – the standard deviation of the Gaussian to render, \((\sigma_x, \sigma_y)\). Shape \((*, 2)\). Should be able to be broadcast with mean.

  • size (tuple[int, int]) – the (height, width) of the output image.

  • normalized_coordinates (bool, optional) – whether mean and std are assumed to use coordinates normalized in the range of \([-1, 1]\). Otherwise, coordinates are assumed to be in the range of the output shape. Default: True

Return type:

Tensor

Returns:

torch.Tensor including rendered points with shape \((*, H, W)\).

Non Maxima Suppression

kornia.geometry.subpix.nms2d(input, kernel_size, mask_only=False)[source]

Apply non maxima suppression to filter.

See NonMaximaSuppression2d for details.

Return type:

Tensor

kornia.geometry.subpix.nms3d(input, kernel_size, mask_only=False)[source]

Apply non maxima suppression to filter.

See :class: ~kornia.feature.NonMaximaSuppression3d for details.

Return type:

Tensor

kornia.geometry.subpix.nms3d_minmax(input)[source]

Compute both local-maxima and local-minima NMS masks for a 3-D scale-space tensor in one pass.

Equivalent to calling nms3d(input, (3,3,3), mask_only=True) and nms3d(-input, (3,3,3), mask_only=True) separately, but only traverses the 26-neighbour comparisons once, halving the NMS cost.

Uses integer slice literals (not Python loops or slice objects) so the 52 comparison-and-reduction ops are visible to the compiler at trace time, allowing full fusion into a minimal number of kernels.

Parameters:

input (Tensor) – 5-D tensor of shape \((B, C, D, H, W)\).

Return type:

tuple[Tensor, Tensor]

Returns:

A pair (max_mask, min_mask) of bool tensors with the same shape as input. max_mask[..., d, h, w] is True when the voxel is strictly greater than all 26 neighbours; min_mask is the same for strict local minima.

Example

>>> x = torch.randn(1, 1, 5, 10, 10)
>>> max_mask, min_mask = nms3d_minmax(x)
>>> max_mask.shape
torch.Size([1, 1, 5, 10, 10])

Module

class kornia.geometry.subpix.SpatialSoftArgmax2d(temperature=None, normalized_coordinates=True)[source]

Compute the Spatial Soft-Argmax 2D of a given heatmap.

See spatial_soft_argmax2d() for details.

class kornia.geometry.subpix.ConvSoftArgmax2d(kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), temperature=1.0, normalized_coordinates=True, eps=1e-8, output_value=False)[source]

nn.Module that calculates soft argmax 2d per window.

See :func: ~kornia.geometry.subpix.conv_soft_argmax2d for details.

class kornia.geometry.subpix.ConvSoftArgmax3d(kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), temperature=1.0, normalized_coordinates=False, eps=1e-8, output_value=True, strict_maxima_bonus=0.0)[source]

nn.Module that calculates soft argmax 3d per window.

See :func: ~kornia.geometry.subpix.conv_soft_argmax3d for details.

class kornia.geometry.subpix.AdaptiveQuadInterp3d(mode='auto', n_iters=5, strict_maxima_bonus=10.0, max_subpixel_shift=0.6, dilation_radius=1, allow_scale_steps=True, max_candidates=None)[source]

Subpixel localization of 3D scale-space extrema with automatic backend selection.

Wraps conv_quad_interp3d() and iterative_quad_interp3d(), choosing the faster backend based on the input device and the requested mode.

Benchmarks show:

  • GPUconv_quad_interp3d() is 1.5-2x faster due to better parallelism on the batched gather+solve.

  • CPUiterative_quad_interp3d() is faster for large images because it processes only the NMS maxima directly without any dilation/dedup overhead.

Parameters:
  • mode (str, optional) –

    backend selection strategy. Default: "auto"

  • n_iters (int, optional) – maximum localization iterations per keypoint. Default: 5

  • strict_maxima_bonus (float, optional) – score bonus added at NMS-maximum positions. Default: 10.0

  • max_subpixel_shift (float, optional) – integer-centre move threshold. Default: 0.6

  • dilation_radius (int, optional) – L\(\infty\) precompute radius for "conv" mode (ignored in "patch" mode). Default: 1

  • max_candidates (Optional[int], optional) – if set, only the top-max_candidates NMS maxima by pre-refinement response are processed in "patch" mode. Has no effect in "conv" mode. Useful on CPU when the number of 3-D NMS maxima greatly exceeds the desired number of keypoints (see iterative_quad_interp3d()). Default: None

Example

>>> inp = torch.randn(1, 1, 3, 64, 64)
>>> subpix = AdaptiveQuadInterp3d(mode="auto")
>>> coords, vals = subpix(inp)
>>> coords.shape
torch.Size([1, 1, 3, 3, 64, 64])
>>> vals.shape
torch.Size([1, 1, 3, 64, 64])
class kornia.geometry.subpix.ConvQuadInterp3d(n_iters=5, strict_maxima_bonus=10.0, max_subpixel_shift=0.6, dilation_radius=1, allow_scale_steps=True)[source]

Subpixel localization of 3D scale-space extrema via quadratic interpolation.

Wraps conv_quad_interp3d(). The Hessian system is solved once for each voxel in the dilated NMS neighbourhood (no dense precomputation over the whole volume), then the shift chain is followed by table lookup with no GPU→CPU synchronisation — making the module compatible with torch.compile and CUDA graphs.

Parameters:
  • n_iters (int, optional) – maximum localization iterations per keypoint. Default: 5

  • strict_maxima_bonus (float, optional) – score bonus at NMS-maximum positions. Default: 10.0

  • max_subpixel_shift (float, optional) – shift threshold that triggers integer centre move. Default: 0.6

class kornia.geometry.subpix.IterativeQuadInterp3d(n_iters=5, strict_maxima_bonus=10.0, max_subpixel_shift=0.6, allow_scale_steps=True, max_candidates=None)[source]

Iterative subpixel localization of 3D extrema via quadratic interpolation.

See iterative_quad_interp3d() for details.

class kornia.geometry.subpix.NonMaximaSuppression2d(kernel_size)[source]

Apply non maxima suppression to filter.

Flag minima_are_also_good is useful, when you want to detect both maxima and minima, e.g. for DoG

class kornia.geometry.subpix.NonMaximaSuppression3d(kernel_size)[source]

Apply non maxima suppression to filter.