from typing import Optional, Tuple
import torch
import torch.nn as nn
[docs]def match_nn(
desc1: torch.Tensor, desc2: torch.Tensor, dm: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Function, which finds nearest neighbors in desc2 for each vector in desc1.
If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
Args:
desc1: Batch of descriptors of a shape :math:`(B1, D)`.
desc2: Batch of descriptors of a shape :math:`(B2, D)`.
dm: Tensor containing the distances from each descriptor in desc1
to each descriptor in desc2, shape of :math:`(B1, B2)`.
Returns:
- Descriptor distance of matching descriptors, shape of :math:`(B1, 1)`.
- Long tensor indexes of matching descriptors in desc1 and desc2, shape of :math:`(B1, 2)`.
"""
if len(desc1.shape) != 2:
raise AssertionError
if len(desc2.shape) != 2:
raise AssertionError
if dm is None:
dm = torch.cdist(desc1, desc2)
else:
if not ((dm.size(0) == desc1.size(0)) and (dm.size(1) == desc2.size(0))):
raise AssertionError
match_dists, idxs_in_2 = torch.min(dm, dim=1)
idxs_in1: torch.Tensor = torch.arange(0, idxs_in_2.size(0), device=idxs_in_2.device)
matches_idxs: torch.Tensor = torch.cat([idxs_in1.view(-1, 1), idxs_in_2.view(-1, 1)], dim=1)
return match_dists.view(-1, 1), matches_idxs.view(-1, 2)
[docs]def match_mnn(
desc1: torch.Tensor, desc2: torch.Tensor, dm: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function, which finds mutual nearest neighbors in desc2 for each vector in desc1.
If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
Args:
desc1: Batch of descriptors of a shape :math:`(B1, D)`.
desc2: Batch of descriptors of a shape :math:`(B2, D)`.
dm: Tensor containing the distances from each descriptor in desc1
to each descriptor in desc2, shape of :math:`(B1, B2)`.
Return:
- Descriptor distance of matching descriptors, shape of. :math:`(B3, 1)`.
- Long tensor indexes of matching descriptors in desc1 and desc2, shape of :math:`(B3, 2)`,
where 0 <= B3 <= min(B1, B2)
"""
if len(desc1.shape) != 2:
raise AssertionError
if len(desc2.shape) != 2:
raise AssertionError
if dm is None:
dm = torch.cdist(desc1, desc2)
else:
if not ((dm.size(0) == desc1.size(0)) and (dm.size(1) == desc2.size(0))):
raise AssertionError
ms = min(dm.size(0), dm.size(1))
match_dists, idxs_in_2 = torch.min(dm, dim=1)
match_dists2, idxs_in_1 = torch.min(dm, dim=0)
minsize_idxs = torch.arange(ms, device=dm.device)
if dm.size(0) <= dm.size(1):
mutual_nns = minsize_idxs == idxs_in_1[idxs_in_2][:ms]
matches_idxs = torch.cat([minsize_idxs.view(-1, 1), idxs_in_2.view(-1, 1)], dim=1)[mutual_nns]
match_dists = match_dists[mutual_nns]
else:
mutual_nns = minsize_idxs == idxs_in_2[idxs_in_1][:ms]
matches_idxs = torch.cat([idxs_in_1.view(-1, 1), minsize_idxs.view(-1, 1)], dim=1)[mutual_nns]
match_dists = match_dists2[mutual_nns]
return match_dists.view(-1, 1), matches_idxs.view(-1, 2)
[docs]def match_snn(
desc1: torch.Tensor, desc2: torch.Tensor, th: float = 0.8, dm: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function, which finds nearest neighbors in desc2 for each vector in desc1.
The method satisfies first to second nearest neighbor distance <= th.
If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
Args:
desc1: Batch of descriptors of a shape :math:`(B1, D)`.
desc2: Batch of descriptors of a shape :math:`(B2, D)`.
th: distance ratio threshold.
dm: Tensor containing the distances from each descriptor in desc1
to each descriptor in desc2, shape of :math:`(B1, B2)`.
Return:
- Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
- Long tensor indexes of matching descriptors in desc1 and desc2. Shape: :math:`(B3, 2)`,
where 0 <= B3 <= B1.
"""
if len(desc1.shape) != 2:
raise AssertionError
if len(desc2.shape) != 2:
raise AssertionError
if desc2.shape[0] < 2:
raise AssertionError
if dm is None:
dm = torch.cdist(desc1, desc2)
else:
if not ((dm.size(0) == desc1.size(0)) and (dm.size(1) == desc2.size(0))):
raise AssertionError
vals, idxs_in_2 = torch.topk(dm, 2, dim=1, largest=False)
ratio = vals[:, 0] / vals[:, 1]
mask = ratio <= th
match_dists = ratio[mask]
idxs_in1 = torch.arange(0, idxs_in_2.size(0), device=dm.device)[mask]
idxs_in_2 = idxs_in_2[:, 0][mask]
matches_idxs = torch.cat([idxs_in1.view(-1, 1), idxs_in_2.view(-1, 1)], dim=1)
return match_dists.view(-1, 1), matches_idxs.view(-1, 2)
[docs]def match_smnn(
desc1: torch.Tensor, desc2: torch.Tensor, th: float = 0.8, dm: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function, which finds mutual nearest neighbors in desc2 for each vector in desc1.
the method satisfies first to second nearest neighbor distance <= th.
If the distance matrix dm is not provided, :py:func:`torch.cdist` is used.
Args:
desc1: Batch of descriptors of a shape :math:`(B1, D)`.
desc2: Batch of descriptors of a shape :math:`(B2, D)`.
th: distance ratio threshold.
dm: Tensor containing the distances from each descriptor in desc1
to each descriptor in desc2, shape of :math:`(B1, B2)`.
Return:
- Descriptor distance of matching descriptors, shape of. :math:`(B3, 1)`.
- Long tensor indexes of matching descriptors in desc1 and desc2,
shape of :math:`(B3, 2)` where 0 <= B3 <= B1.
"""
if len(desc1.shape) != 2:
raise AssertionError
if len(desc2.shape) != 2:
raise AssertionError
if desc1.shape[0] < 2:
raise AssertionError
if desc2.shape[0] < 2:
raise AssertionError
if dm is None:
dm = torch.cdist(desc1, desc2)
else:
if not ((dm.size(0) == desc1.size(0)) and (dm.size(1) == desc2.size(0))):
raise AssertionError
dists1, idx1 = match_snn(desc1, desc2, th, dm)
dists2, idx2 = match_snn(desc2, desc1, th, dm.t())
if len(dists2) > 0 and len(dists1) > 0:
idx2 = idx2.flip(1)
idxs_dm = torch.cdist(idx1.float(), idx2.float(), p=1.0)
mutual_idxs1 = idxs_dm.min(dim=1)[0] < 1e-8
mutual_idxs2 = idxs_dm.min(dim=0)[0] < 1e-8
good_idxs1 = idx1[mutual_idxs1.view(-1)]
good_idxs2 = idx2[mutual_idxs2.view(-1)]
dists1_good = dists1[mutual_idxs1.view(-1)]
dists2_good = dists2[mutual_idxs2.view(-1)]
_, idx_upl1 = torch.sort(good_idxs1[:, 0])
_, idx_upl2 = torch.sort(good_idxs2[:, 0])
good_idxs1 = good_idxs1[idx_upl1]
match_dists = torch.max(dists1_good[idx_upl1], dists2_good[idx_upl2])
matches_idxs = good_idxs1
else:
matches_idxs, match_dists = torch.empty(0, 2, device=dm.device), torch.empty(0, 1, device=dm.device)
return match_dists.view(-1, 1), matches_idxs.view(-1, 2)
[docs]class DescriptorMatcher(nn.Module):
"""Module version of matching functions.
See :func:`~kornia.feature.match_nn`, :func:`~kornia.feature.match_snn`,
:func:`~kornia.feature.match_mnn` or :func:`~kornia.feature.match_smnn` for more details.
Args:
match_mode: type of matching, can be `nn`, `snn`, `mnn`, `smnn`.
th: threshold on distance ratio, or other quality measure.
"""
known_modes = ['nn', 'mnn', 'snn', 'smnn']
def __init__(self, match_mode: str = 'snn', th: float = 0.8) -> None:
super().__init__()
_match_mode: str = match_mode.lower()
if _match_mode not in self.known_modes:
raise NotImplementedError(f"{match_mode} is not supported. Try one of {self.known_modes}")
self.match_mode = _match_mode
self.th = th
[docs] def forward(self, desc1: torch.Tensor, desc2: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
desc1: Batch of descriptors of a shape :math:`(B1, D)`.
desc2: Batch of descriptors of a shape :math:`(B2, D)`.
Return:
- Descriptor distance of matching descriptors, shape of :math:`(B3, 1)`.
- Long tensor indexes of matching descriptors in desc1 and desc2,
shape of :math:`(B3, 2)` where :math:`0 <= B3 <= B1`.
"""
if self.match_mode == 'nn':
out = match_nn(desc1, desc2)
elif self.match_mode == 'mnn':
out = match_mnn(desc1, desc2)
elif self.match_mode == 'snn':
out = match_snn(desc1, desc2, self.th)
elif self.match_mode == 'smnn':
out = match_smnn(desc1, desc2, self.th)
else:
raise NotImplementedError
return out