Source code for kornia.feature.disk.disk

from __future__ import annotations

from typing import Optional

import torch

from kornia.core import Module, Tensor

from ._unets import Unet
from .detector import heatmap_to_keypoints
from .structs import DISKFeatures


[docs]class DISK(Module): r"""Module which detects and described local features in an image using the DISK method. See :cite:`tyszkiewicz2020disk` for details. .. image:: _static/img/disk_outdoor_depth.jpg Args: desc_dim: The dimension of the descriptor. unet: The U-Net to use. If None, a default U-Net is used. Kornia doesn't provide the training code for DISK so this is only useful when using a custom checkpoint trained using the code released with the paper. The unet should take as input a tensor of shape :math:`(B, C, H, W)` and output a tensor of shape :math:`(B, \mathrm{desc\_dim} + 1, H, W)`. Example: >>> disk = DISK.from_pretrained('depth') >>> images = torch.rand(1, 3, 256, 256) >>> features = disk(images) """ def __init__(self, desc_dim: int = 128, unet: None | Module = None) -> None: super().__init__() self.desc_dim = desc_dim if unet is None: unet = Unet(in_features=3, size=5, down=[16, 32, 64, 64, 64], up=[64, 64, 64, desc_dim + 1]) self.unet = unet
[docs] def heatmap_and_dense_descriptors(self, images: Tensor) -> tuple[Tensor, Tensor]: """Returns the heatmap and the dense descriptors. .. image:: _static/img/DISK.png Args: images: The image to detect features in. Shape :math:`(B, 3, H, W)`. Returns: A tuple of dense detection scores and descriptors. Shapes are :math:`(B, 1, H, W)` and :math:`(B, D, H, W)`, where :math:`D` is the descriptor dimension. """ unet_output = self.unet(images) if unet_output.shape[1] != self.desc_dim + 1: raise ValueError( f"U-Net output has {unet_output.shape[1]} channels, but expected self.desc_dim={self.desc_dim} + 1." ) descriptors = unet_output[:, : self.desc_dim] heatmaps = unet_output[:, self.desc_dim :] return heatmaps, descriptors
[docs] def forward( self, images: Tensor, n: Optional[int] = None, window_size: int = 5, score_threshold: float = 0.0, pad_if_not_divisible: bool = False, ) -> list[DISKFeatures]: """Detects features in an image, returning keypoint locations, descriptors and detection scores. Args: images: The image to detect features in. Shape :math:`(B, 3, H, W)`. n: The maximum number of keypoints to detect. If None, all keypoints are returned. window_size: The size of the non-maxima suppression window used to filter detections. score_threshold: The minimum score a detection must have to be returned. See :py:class:`DISKFeatures` for details. pad_if_not_divisible: if True, the non-16 divisible input is zero-padded to the closest 16-multiply Returns: A list of length :math:`B` containing the detected features. """ B = images.shape[0] if pad_if_not_divisible: h, w = images.shape[2:] pd_h = 16 - h % 16 if h % 16 > 0 else 0 pd_w = 16 - w % 16 if w % 16 > 0 else 0 images = torch.nn.functional.pad(images, (0, pd_w, 0, pd_h), value=0.0) heatmaps, descriptors = self.heatmap_and_dense_descriptors(images) if pad_if_not_divisible: heatmaps = heatmaps[..., :h, :w] descriptors = descriptors[..., :h, :w] keypoints = heatmap_to_keypoints(heatmaps, n=n, window_size=window_size, score_threshold=score_threshold) features = [] for i in range(B): features.append(keypoints[i].merge_with_descriptors(descriptors[i])) return features
[docs] @classmethod def from_pretrained(cls, checkpoint: str = "depth", device: torch.device = torch.device("cpu")) -> DISK: r"""Loads a pretrained model. Depth model was trained using depth map supervision and is slightly more precise but biased to detect keypoints only where SfM depth is available. Epipolar model was trained using epipolar geometry supervision and is less precise but detects keypoints everywhere where they are matchable. The difference is especially pronounced on thin structures and on edges of objects. Args: checkpoint: The checkpoint to load. One of 'depth' or 'epipolar'. device: The device to load the model to. Returns: The pretrained model. """ urls = { "depth": "https://raw.githubusercontent.com/cvlab-epfl/disk/master/depth-save.pth", "epipolar": "https://raw.githubusercontent.com/cvlab-epfl/disk/master/epipolar-save.pth", } if checkpoint not in urls: raise ValueError(f"Unknown pretrained model: {checkpoint}") pretrained_dict = torch.hub.load_state_dict_from_url(urls[checkpoint], map_location=device) model: DISK = cls().to(device) model.load_state_dict(pretrained_dict["extractor"]) model.eval() return model