Source code for kornia.feature.lightglue

import math
import warnings
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn

from kornia.core import (
    Module,
    ModuleList,
    Tensor,
    arange,
    concatenate,
    cos,
    einsum,
    ones_like,
    sin,
    softmax,
    stack,
    where,
    zeros,
)
from kornia.core.check import KORNIA_CHECK
from kornia.utils._compat import torch_meshgrid

try:
    from flash_attn.modules.mha import FlashCrossAttention
except ModuleNotFoundError:
    FlashCrossAttention = None

if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
    FLASH_AVAILABLE = True
else:
    FLASH_AVAILABLE = False


def math_clamp(x, min_, max_):  # type: ignore
    return max(min(x, min_), min_)


@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def normalize_keypoints(kpts: Tensor, size: Tensor) -> Tensor:
    if isinstance(size, torch.Size):
        size = Tensor(size)[None]
    shift = size.float().to(kpts) / 2
    scale = size.max(1).values.float().to(kpts) / 2
    kpts = (kpts - shift[:, None]) / scale[:, None, None]
    return kpts


def rotate_half(x: Tensor) -> Tensor:
    x = x.unflatten(-1, (-1, 2))
    x1, x2 = x.unbind(dim=-1)
    return stack((-x2, x1), dim=-1).flatten(start_dim=-2)


def apply_cached_rotary_emb(freqs: Tensor, t: Tensor) -> Tensor:
    return (t * freqs[0]) + (rotate_half(t) * freqs[1])


class LearnableFourierPositionalEncoding(Module):
    def __init__(self, M: int, dim: int, F_dim: Optional[int] = None, gamma: float = 1.0) -> None:
        super().__init__()
        F_dim = F_dim if F_dim is not None else dim
        self.gamma = gamma
        self.Wr = nn.Linear(M, F_dim // 2, bias=False)
        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)

    def forward(self, x: Tensor) -> Tensor:
        """Encode position vector."""
        projected = self.Wr(x)
        cosines, sines = cos(projected), sin(projected)
        emb = stack([cosines, sines], 0).unsqueeze(-3)
        return emb.repeat_interleave(2, dim=-1)


class TokenConfidence(Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())

    def forward(self, desc0: Tensor, desc1: Tensor) -> Tuple[Tensor, Tensor]:
        """Get confidence tokens."""
        dtype = self.token[0].weight.dtype
        orig_dtype = desc0.dtype
        return (
            self.token(desc0.detach().to(dtype)).squeeze(-1).to(orig_dtype),
            self.token(desc1.detach().to(dtype)).squeeze(-1).to(orig_dtype),
        )


class Attention(Module):
    def __init__(self, allow_flash: bool) -> None:
        super().__init__()
        if allow_flash and not FLASH_AVAILABLE:
            warnings.warn(
                "FlashAttention is not available. For optimal speed, consider installing torch >= 2.0 or flash-attn.",
                stacklevel=2,
            )
        self.enable_flash = allow_flash and FLASH_AVAILABLE
        if allow_flash and FlashCrossAttention:
            self.flash_ = FlashCrossAttention()

    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        if self.enable_flash and q.device.type == "cuda":
            if FlashCrossAttention:
                q, k, v = (x.transpose(-2, -3) for x in [q, k, v])
                m = self.flash_(q.half(), stack([k, v], 2).half())
                return m.transpose(-2, -3).to(q.dtype)
            else:  # use torch 2.0 scaled_dot_product_attention with flash
                with torch.backends.cuda.sdp_kernel(enable_flash=True):
                    return F.scaled_dot_product_attention(
                        q.half().contiguous(), k.half().contiguous(), v.half().contiguous()
                    ).to(q.dtype)
        elif hasattr(F, "scaled_dot_product_attention"):
            return F.scaled_dot_product_attention(q.contiguous(), k.contiguous(), v.contiguous()).to(q.dtype)
        else:
            s = q.shape[-1] ** -0.5
            attn = softmax(einsum("...id,...jd->...ij", q, k) * s, -1)
            return einsum("...ij,...jd->...id", attn, v)


class Transformer(Module):
    def __init__(self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        KORNIA_CHECK(self.embed_dim % num_heads == 0, "Embed dimension should be dividable by num_heads")
        self.head_dim = self.embed_dim // num_heads
        self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
        self.inner_attn = Attention(flash)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.ffn = nn.Sequential(
            nn.Linear(2 * embed_dim, 2 * embed_dim),
            nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
            nn.GELU(),
            nn.Linear(2 * embed_dim, embed_dim),
        )

    def _forward(self, x: Tensor, encoding: Optional[Tensor] = None) -> Tensor:
        qkv = self.Wqkv(x)
        qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
        q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
        if encoding is not None:
            q = apply_cached_rotary_emb(encoding, q)
            k = apply_cached_rotary_emb(encoding, k)
        context = self.inner_attn(q, k, v)
        message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
        return x + self.ffn(concatenate([x, message], -1))

    def forward(
        self, x0: Tensor, x1: Tensor, encoding0: Optional[Tensor] = None, encoding1: Optional[Tensor] = None
    ) -> Tuple[Tensor, Tensor]:
        return self._forward(x0, encoding0), self._forward(x1, encoding1)


class CrossTransformer(Module):
    def __init__(self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True) -> None:
        super().__init__()
        self.heads = num_heads
        dim_head = embed_dim // num_heads
        self.scale = dim_head**-0.5
        inner_dim = dim_head * num_heads
        self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
        self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
        self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
        self.ffn = nn.Sequential(
            nn.Linear(2 * embed_dim, 2 * embed_dim),
            nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
            nn.GELU(),
            nn.Linear(2 * embed_dim, embed_dim),
        )

        if flash and FLASH_AVAILABLE:
            self.flash = Attention(True)
        else:
            self.flash = None  # type: ignore

    def map_(self, func: Callable, x0: Tensor, x1: Tensor) -> Tuple[Tensor, Tensor]:  # type: ignore
        return func(x0), func(x1)

    def forward(self, x0: Tensor, x1: Tensor) -> Tuple[Tensor, Tensor]:
        qk0, qk1 = self.map_(self.to_qk, x0, x1)
        v0, v1 = self.map_(self.to_v, x0, x1)
        qk0, qk1, v0, v1 = (t.unflatten(-1, (self.heads, -1)).transpose(1, 2) for t in (qk0, qk1, v0, v1))
        if self.flash is not None:
            m0 = self.flash(qk0, qk1, v1)
            m1 = self.flash(qk1, qk0, v0)
        else:
            qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
            sim = einsum("b h i d, b h j d -> b h i j", qk0, qk1)
            attn01 = softmax(sim, dim=-1)
            attn10 = softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
            m0 = einsum("bhij, bhjd -> bhid", attn01, v1)
            m1 = einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
        m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
        m0, m1 = self.map_(self.to_out, m0, m1)
        x0 = x0 + self.ffn(concatenate([x0, m0], -1))
        x1 = x1 + self.ffn(concatenate([x1, m1], -1))
        return x0, x1


def sigmoid_log_double_softmax(sim: Tensor, z0: Tensor, z1: Tensor) -> Tensor:
    """Create the log assignment matrix from logits and similarity."""
    b, m, n = sim.shape
    certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
    scores0 = F.log_softmax(sim, 2)
    scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
    scores = sim.new_full((b, m + 1, n + 1), 0)
    scores[:, :m, :n] = scores0 + scores1 + certainties
    scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
    scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
    return scores


class MatchAssignment(Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim
        self.matchability = nn.Linear(dim, 1, bias=True)
        self.final_proj = nn.Linear(dim, dim, bias=True)

    def forward(self, desc0: Tensor, desc1: Tensor) -> Tuple[Tensor, Tensor]:
        """Build assignment matrix from descriptors."""
        mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
        _, _, d = mdesc0.shape
        mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
        sim = einsum("bmd,bnd->bmn", mdesc0, mdesc1)
        z0 = self.matchability(desc0)
        z1 = self.matchability(desc1)
        scores = sigmoid_log_double_softmax(sim, z0, z1)
        return scores, sim

    def scores(self, desc0: Tensor, desc1: Tensor) -> Tuple[Tensor, Tensor]:
        m0 = torch.sigmoid(self.matchability(desc0)).squeeze(-1)
        m1 = torch.sigmoid(self.matchability(desc1)).squeeze(-1)
        return m0, m1


def filter_matches(scores: Tensor, th: float) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    """Obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
    max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
    m0, m1 = max0.indices, max1.indices
    mutual0 = arange(m0.shape[1]).to(m0)[None] == m1.gather(1, m0)
    mutual1 = arange(m1.shape[1]).to(m1)[None] == m0.gather(1, m1)
    max0_exp = max0.values.exp()
    zero = max0_exp.new_tensor(0)
    mscores0 = where(mutual0, max0_exp, zero)
    mscores1 = where(mutual1, mscores0.gather(1, m1), zero)
    if th is not None:
        valid0 = mutual0 & (mscores0 > th)
    else:
        valid0 = mutual0
    valid1 = mutual1 & valid0.gather(1, m1)
    m0 = where(valid0, m0, m0.new_tensor(-1))
    m1 = where(valid1, m1, m1.new_tensor(-1))
    return m0, m1, mscores0, mscores1


[docs]class LightGlue(Module): default_conf: ClassVar[Dict[str, Any]] = { "name": "lightglue", # just for interfacing "input_dim": 256, # input descriptor dimension (autoselected from weights) "descriptor_dim": 256, "n_layers": 9, "num_heads": 4, "flash": True, # enable FlashAttention if available. "mp": False, # enable mixed precision "depth_confidence": 0.95, # early stopping, disable with -1 "width_confidence": 0.99, # point pruning, disable with -1 "filter_threshold": 0.1, # match threshold "weights": None, } required_data_keys: ClassVar[List[str]] = ["image0", "image1"] version: ClassVar[str] = "v0.1_arxiv" url: ClassVar[str] = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth" features: ClassVar[Dict[str, Any]] = {"superpoint": ("superpoint_lightglue", 256), "disk": ("disk_lightglue", 128)} def __init__(self, features: str = "superpoint", **conf) -> None: # type: ignore super().__init__() temp_conf = {**self.default_conf, **conf} if features is not None: KORNIA_CHECK(features in list(self.features.keys()), "Features keys are wrong") temp_conf["weights"], temp_conf["input_dim"] = self.features[features] self.conf = conf_ = SimpleNamespace(**temp_conf) if conf_.input_dim != conf_.descriptor_dim: self.input_proj = nn.Linear(conf_.input_dim, conf_.descriptor_dim, bias=True) else: self.input_proj = nn.Identity() # type: ignore head_dim = conf_.descriptor_dim // conf_.num_heads self.posenc = LearnableFourierPositionalEncoding(2, head_dim, head_dim) h, n, d = conf_.num_heads, conf_.n_layers, conf_.descriptor_dim self.self_attn = ModuleList([Transformer(d, h, conf_.flash) for _ in range(n)]) self.cross_attn = ModuleList([CrossTransformer(d, h, conf_.flash) for _ in range(n)]) self.log_assignment = ModuleList([MatchAssignment(d) for _ in range(n)]) self.token_confidence = ModuleList([TokenConfidence(d) for _ in range(n - 1)]) if features is not None: fname = f"{conf_.weights}_{self.version}.pth".replace(".", "-") state_dict = torch.hub.load_state_dict_from_url(self.url.format(self.version, features), file_name=fname) self.load_state_dict(state_dict, strict=False) elif conf_.weights is not None: path = Path(__file__).parent path = path / f"weights/{self.conf.weights}.pth" state_dict = torch.load(str(path), map_location="cpu") self.load_state_dict(state_dict, strict=False) print("Loaded LightGlue model")
[docs] def forward(self, data: Dict) -> Dict: # type: ignore """Match keypoints and descriptors between two images. Input (dict): image0: dict keypoints: [B x M x 2] descriptors: [B x M x D] image: [B x C x H x W] or image_size: [B x 2] image1: dict keypoints: [B x N x 2] descriptors: [B x N x D] image: [B x C x H x W] or image_size: [B x 2] Output (dict): log_assignment: [B x M+1 x N+1] matches0: [B x M] matching_scores0: [B x M] matches1: [B x N] matching_scores1: [B x N] matches: List[[Si x 2]], scores: List[[Si]] """ with torch.cuda.amp.autocast(enabled=self.conf.mp): return self._forward(data)
def _forward(self, data: Dict[str, Dict[str, Tensor]]) -> Dict[str, Any]: for key in self.required_data_keys: KORNIA_CHECK(key in data, f"Missing key {key} in data") data0, data1 = data["image0"], data["image1"] kpts0_, kpts1_ = data0["keypoints"], data1["keypoints"] # torch 1.9.0 compatibility if hasattr(torch, "inf"): inf = torch.inf else: inf = math.inf b, m, _ = kpts0_.shape b, n, _ = kpts1_.shape size0, size1 = data0.get("image_size"), data1.get("image_size") size0 = size0 if size0 is not None else data0["image"].shape[-2:][::-1] # type: ignore size1 = size1 if size1 is not None else data1["image"].shape[-2:][::-1] # type: ignore kpts0 = normalize_keypoints(kpts0_, size=size0) kpts1 = normalize_keypoints(kpts1_, size=size1) KORNIA_CHECK(torch.all(kpts0 >= -1).item() and torch.all(kpts0 <= 1).item(), "") # type: ignore KORNIA_CHECK(torch.all(kpts1 >= -1).item() and torch.all(kpts1 <= 1).item(), "") # type: ignore desc0 = data0["descriptors"].detach() desc1 = data1["descriptors"].detach() KORNIA_CHECK(desc0.shape[-1] == self.conf.input_dim, "Descriptor dimension does not match input dim in config") KORNIA_CHECK(desc1.shape[-1] == self.conf.input_dim, "Descriptor dimension does not match input dim in config") if torch.is_autocast_enabled(): desc0 = desc0.half() desc1 = desc1.half() desc0 = self.input_proj(desc0) desc1 = self.input_proj(desc1) # cache positional embeddings encoding0 = self.posenc(kpts0) encoding1 = self.posenc(kpts1) # GNN + final_proj + assignment ind0 = arange(0, m).to(device=kpts0.device)[None] ind1 = arange(0, n).to(device=kpts0.device)[None] prune0 = ones_like(ind0) # store layer where pruning is detected prune1 = ones_like(ind1) dec, wic = self.conf.depth_confidence, self.conf.width_confidence token0, token1 = None, None for i in range(self.conf.n_layers): # self+cross attention desc0, desc1 = self.self_attn[i](desc0, desc1, encoding0, encoding1) desc0, desc1 = self.cross_attn[i](desc0, desc1) if i == self.conf.n_layers - 1: continue # no early stopping or adaptive width at last layer if dec > 0: # early stopping token0, token1 = self.token_confidence[i](desc0, desc1) if self.stop(token0, token1, self.conf_th(i), dec, m + n): break if wic > 0: # point pruning match0, match1 = self.log_assignment[i].scores(desc0, desc1) mask0 = self.get_mask(token0, match0, self.conf_th(i), 1 - wic) # type: ignore mask1 = self.get_mask(token1, match1, self.conf_th(i), 1 - wic) # type: ignore ind0, ind1 = ind0[mask0][None], ind1[mask1][None] desc0, desc1 = desc0[mask0][None], desc1[mask1][None] if desc0.shape[-2] == 0 or desc1.shape[-2] == 0: break encoding0 = encoding0[:, :, mask0][:, None] encoding1 = encoding1[:, :, mask1][:, None] prune0[:, ind0] += 1 prune1[:, ind1] += 1 if wic > 0: # scatter with indices after pruning scores_, _ = self.log_assignment[i](desc0, desc1) dt, dev = scores_.dtype, scores_.device scores = zeros(b, m + 1, n + 1, dtype=dt, device=dev) scores[:, :-1, :-1] = -inf scores[:, ind0[0], -1] = scores_[:, :-1, -1] scores[:, -1, ind1[0]] = scores_[:, -1, :-1] x, y = torch_meshgrid([ind0[0], ind1[0]], "ij") scores[:, x, y] = scores_[:, :-1, :-1] else: scores, _ = self.log_assignment[i](desc0, desc1) m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold) matches, mscores = [], [] for k in range(b): valid = m0[k] > -1 matches.append(stack([where(valid)[0], m0[k][valid]], -1)) mscores.append(mscores0[k][valid]) return { "log_assignment": scores, "matches0": m0, "matches1": m1, "matching_scores0": mscores0, "matching_scores1": mscores1, "stop": i + 1, "prune0": prune0, "prune1": prune1, "matches": matches, "scores": mscores, } def conf_th(self, i: int) -> float: """Scaled confidence threshold.""" return math_clamp(0.8 + 0.1 * math.exp(-4.0 * float(i) / self.conf.n_layers), 0.0, 1.0) def get_mask(self, confidence: Tensor, match: Tensor, conf_th: float, match_th: float) -> Tensor: """Mask points which should be removed.""" if conf_th and confidence is not None: mask = where(confidence > conf_th, match, match.new_tensor(1.0)) > match_th else: mask = match > match_th return mask def stop(self, token0: Tensor, token1: Tensor, conf_th: float, inl_th: float, seql: int) -> Tensor: """Evaluate stopping condition.""" tokens = concatenate([token0, token1], -1) if conf_th: pos = 1.0 - (tokens < conf_th).float().sum() / seql return pos > inl_th else: return tokens.mean() > inl_th