Source code for kornia.feature.keynet

from typing import List, Optional

import torch
import torch.nn.functional as F
from torch import nn
from typing_extensions import TypedDict

from kornia.core import Module, Tensor, concatenate
from kornia.filters import SpatialGradient
from kornia.geometry.transform import pyrdown
from kornia.utils.helpers import map_location_to_cpu

from .scale_space_detector import Detector_config, MultiResolutionDetector, get_default_detector_config


class KeyNet_conf(TypedDict):
    num_filters: int
    num_levels: int
    kernel_size: int
    Detector_conf: Detector_config


keynet_default_config: KeyNet_conf = {
    # Key.Net Model
    "num_filters": 8,
    "num_levels": 3,
    "kernel_size": 5,
    # Extraction Parameters
    "Detector_conf": get_default_detector_config(),
}

KeyNet_URL = "https://github.com/axelBarroso/Key.Net-Pytorch/raw/main/model/weights/keynet_pytorch.pth"


class _FeatureExtractor(Module):
    """Helper class for KeyNet.

    It loads both, the handcrafted and learnable blocks
    """

    def __init__(self) -> None:
        super().__init__()

        self.hc_block = _HandcraftedBlock()
        self.lb_block = _LearnableBlock()

    def forward(self, x: Tensor) -> Tensor:
        x_hc = self.hc_block(x)
        x_lb = self.lb_block(x_hc)
        return x_lb


class _HandcraftedBlock(Module):
    """Helper class for KeyNet, it defines the handcrafted filters within the Key.Net handcrafted block."""

    def __init__(self) -> None:
        super().__init__()
        self.spatial_gradient = SpatialGradient("sobel", 1)

    def forward(self, x: Tensor) -> Tensor:
        sobel = self.spatial_gradient(x)
        dx, dy = sobel[:, :, 0, :, :], sobel[:, :, 1, :, :]

        sobel_dx = self.spatial_gradient(dx)
        dxx, dxy = sobel_dx[:, :, 0, :, :], sobel_dx[:, :, 1, :, :]

        sobel_dy = self.spatial_gradient(dy)
        dyy = sobel_dy[:, :, 1, :, :]

        hc_feats = concatenate([dx, dy, dx**2.0, dy**2.0, dx * dy, dxy, dxy**2.0, dxx, dyy, dxx * dyy], 1)

        return hc_feats


class _LearnableBlock(nn.Sequential):
    """Helper class for KeyNet.

    It defines the learnable blocks within the Key.Net
    """

    def __init__(self, in_channels: int = 10) -> None:
        super().__init__()

        self.conv0 = _KeyNetConvBlock(in_channels)
        self.conv1 = _KeyNetConvBlock()
        self.conv2 = _KeyNetConvBlock()

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv2(self.conv1(self.conv0(x)))
        return x


def _KeyNetConvBlock(
    in_channels: int = 8,
    out_channels: int = 8,
    kernel_size: int = 5,
    stride: int = 1,
    padding: int = 2,
    dilation: int = 1,
) -> nn.Sequential:
    """Helper function for KeyNet.

    Default learnable convolutional block for KeyNet.
    """
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )


[docs]class KeyNet(Module): """Key.Net model definition -- local feature detector (response function). This is based on the original code from paper "Key.Net: Keypoint Detection by Handcrafted and Learned CNN Filters". See :cite:`KeyNet2019` for more details. .. image:: _static/img/KeyNet.png Args: pretrained: Download and set pretrained weights to the model. keynet_conf: Dict with initialization parameters. Do not pass it, unless you know what you are doing`. Returns: KeyNet response score. Shape: - Input: :math:`(B, 1, H, W)` - Output: :math:`(B, 1, H, W)` """ def __init__(self, pretrained: bool = False, keynet_conf: KeyNet_conf = keynet_default_config) -> None: super().__init__() num_filters = keynet_conf["num_filters"] self.num_levels = keynet_conf["num_levels"] kernel_size = keynet_conf["kernel_size"] padding = kernel_size // 2 self.feature_extractor = _FeatureExtractor() self.last_conv = nn.Sequential( nn.Conv2d( in_channels=num_filters * self.num_levels, out_channels=1, kernel_size=kernel_size, padding=padding ), nn.ReLU(inplace=True), ) # use torch.hub to load pretrained model if pretrained: pretrained_dict = torch.hub.load_state_dict_from_url(KeyNet_URL, map_location=map_location_to_cpu) self.load_state_dict(pretrained_dict["state_dict"], strict=True) self.eval() def forward(self, x: Tensor) -> Tensor: """ x - input image """ shape_im = x.shape feats: List[Tensor] = [self.feature_extractor(x)] for i in range(1, self.num_levels): x = pyrdown(x, factor=1.2) feats_i = self.feature_extractor(x) feats_i = F.interpolate(feats_i, size=(shape_im[2], shape_im[3]), mode="bilinear") feats.append(feats_i) scores = self.last_conv(concatenate(feats, 1)) return scores
[docs]class KeyNetDetector(MultiResolutionDetector): """Multi-scale feature detector based on KeyNet. This is based on the original code from paper "Key.Net: Keypoint Detection by Handcrafted and Learned CNN Filters". See :cite:`KeyNet2019` for more details. .. image:: _static/img/keynet.jpg Args: pretrained: Download and set pretrained weights to the model. num_features: Number of features to detect. keynet_conf: Dict with initialization parameters. Do not pass it, unless you know what you are doing`. ori_module: for local feature orientation estimation. Default: :class:`~kornia.feature.PassLAF`, which does nothing. See :class:`~kornia.feature.LAFOrienter` for details. aff_module: for local feature affine shape estimation. Default: :class:`~kornia.feature.PassLAF`, which does nothing. See :class:`~kornia.feature.LAFAffineShapeEstimator` for details. """ def __init__( self, pretrained: bool = False, num_features: int = 2048, keynet_conf: KeyNet_conf = keynet_default_config, ori_module: Optional[Module] = None, aff_module: Optional[Module] = None, ) -> None: model = KeyNet(pretrained, keynet_conf) super().__init__(model, num_features, keynet_conf["Detector_conf"], ori_module, aff_module)