Source code for kornia.feature.mkd

from typing import Union, Tuple, Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F

from kornia.constants import pi
from kornia.utils import create_meshgrid
from kornia.geometry.conversions import cart2pol
from kornia.filters import SpatialGradient, GaussianBlur2d


# Precomputed coefficients for Von Mises kernel, given N and K(appa).
sqrt2: float = 1.4142135623730951
COEFFS_N1_K1: List[float] = [0.38214156, 0.48090413]
COEFFS_N2_K8: List[float] = [0.14343168, 0.268285, 0.21979234]
COEFFS_N3_K8: List[float] = [0.14343168, 0.268285, 0.21979234, 0.15838885]
COEFFS: Dict[str, List[float]] = {'xy': COEFFS_N1_K1,
                                  'rhophi': COEFFS_N2_K8,
                                  'theta': COEFFS_N3_K8}

urls: Dict[str, str] = {k: f'https://github.com/manyids2/mkd_pytorch/raw/master/mkd_pytorch/mkd-{k}-64.pth'
                        for k in ['cart', 'polar', 'concat']}


def get_grid_dict(patch_size: int = 32) -> Dict[str, torch.Tensor]:
    """Gets cartesian and polar parametrizations of grid. """
    kgrid = create_meshgrid(height=patch_size,
                            width=patch_size,
                            normalized_coordinates=True)
    x = kgrid[0, :, :, 0]
    y = kgrid[0, :, :, 1]
    rho, phi = cart2pol(x, y)
    grid_dict = {'x': x, 'y': y, 'rho': rho, 'phi': phi}
    return grid_dict


def get_kron_order(d1: int, d2: int) -> torch.Tensor:
    """Gets order for doing kronecker product. """
    kron_order = torch.zeros([d1 * d2, 2], dtype=torch.int64)
    for i in range(d1):
        for j in range(d2):
            kron_order[i * d2 + j, 0] = i
            kron_order[i * d2 + j, 1] = j
    return kron_order


class MKDGradients(nn.Module):
    r"""
    Module, which computes gradients of given patches,
    stacked as [magnitudes, orientations].
    Given gradients $g_x$, $g_y$ with respect to $x$, $y$ respectively,
      - $\mathbox{mags} = $\sqrt{g_x^2 + g_y^2 + eps}$
      - $\mathbox{oris} = $\mbox{tan}^{-1}(\nicefrac{g_y}{g_x})$.
    Args:
        patch_size: (int) Input patch size in pixels (32 is default)
    Returns:
        Tensor: gradients of given patches
    Shape:
        - Input: (B, 1, patch_size, patch_size)
        - Output: (B, 2, patch_size, patch_size)
    Examples::
        >>> patches = torch.rand(23, 1, 32, 32)
        >>> gradient = MKDGradients()
        >>> g = gradient(patches) # 23x2x32x32
    """

    def __init__(self) -> None:
        super().__init__()
        self.eps = 1e-8

        self.grad = SpatialGradient(mode='diff', order=1, normalized=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input type is not a torch.Tensor. Got {}"
                            .format(type(x)))
        if not len(x.shape) == 4:
            raise ValueError("Invalid input shape, we expect Bx1xHxW. Got: {}"
                             .format(x.shape))
        # Modify 'diff' gradient. Before we had lambda function, but it is not jittable
        grads_xy = -self.grad(x)
        gx = grads_xy[:, :, 0, :, :]
        gy = grads_xy[:, :, 1, :, :]
        y = torch.cat(cart2pol(gx, gy, self.eps), dim=1)
        return y

    def __repr__(self) -> str:
        return self.__class__.__name__


class VonMisesKernel(nn.Module):
    """
    Module, which computes parameters of Von Mises kernel given coefficients,
    and embeds given patches.
    Args:
        patch_size: (int) Input patch size in pixels (32 is default)
        coeffs: (list) List of coefficients
              Some examples are hardcoded in COEFFS
    Returns:
        Tensor: Von Mises embedding of given parametrization
    Shape:
        - Input: (B, 1, patch_size, patch_size)
        - Output: (B, d, patch_size, patch_size)
    Examples::
        >>> oris = torch.rand(23, 1, 32, 32)
        >>> vm = VonMisesKernel(patch_size=32,
        ...                     coeffs=[0.14343168,
        ...                             0.268285,
        ...                             0.21979234])
        >>> emb = vm(oris) # 23x7x32x32
    """

    def __init__(self,
                 patch_size: int,
                 coeffs: Union[list, tuple]) -> None:
        super().__init__()

        self.patch_size = patch_size
        b_coeffs: torch.Tensor = torch.tensor(coeffs)
        self.register_buffer('coeffs', b_coeffs)

        # Compute parameters.
        n: int = len(coeffs) - 1
        self.n: int = n
        self.d: int = 2 * n + 1

        # Precompute helper variables.
        emb0 = torch.ones([1, 1, patch_size, patch_size])
        frange = torch.arange(n) + 1
        frange = frange.reshape(-1, 1, 1)
        weights = torch.zeros([2 * n + 1])
        weights[:n + 1] = torch.sqrt(b_coeffs)
        weights[n + 1:] = torch.sqrt(b_coeffs[1:])
        weights = weights.reshape(-1, 1, 1)
        self.register_buffer('emb0', emb0)
        self.register_buffer('frange', frange)
        self.register_buffer('weights', weights)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input type is not a torch.Tensor. Got {}"
                            .format(type(x)))

        if not len(x.shape) == 4 or x.shape[1] != 1:
            raise ValueError("Invalid input shape, we expect Bx1xHxW. Got: {}"
                             .format(x.shape))

        # TODO: unify the two lines below when pytorch 1.6 support is dropped
        emb0: torch.Tensor = torch.jit.annotate(torch.Tensor, self.emb0)
        emb0 = emb0.to(x).repeat(x.size(0), 1, 1, 1)
        frange = self.frange.to(x) * x
        emb1 = torch.cos(frange)
        emb2 = torch.sin(frange)
        embedding = torch.cat([emb0, emb1, emb2], dim=1)
        embedding = self.weights * embedding
        return embedding

    def __repr__(self) -> str:
        return self.__class__.__name__ +\
            '(' + 'patch_size=' + str(self.patch_size) +\
            ', ' + 'n=' + str(self.n) +\
            ', ' + 'd=' + str(self.d) +\
            ', ' + 'coeffs=' + str(self.coeffs) + ')'


class EmbedGradients(nn.Module):
    """
    Module that computes gradient embedding,
    weighted by sqrt of magnitudes of given patches.
    Args:
        patch_size: (int) Input patch size in pixels (32 is default)
        relative: (bool) absolute or relative gradients (False is default)
    Returns:
        torch.Tensor: Gradient embedding
    Shape:
        - Input: (B, 2, patch_size, patch_size)
        - Output: (B, 7, patch_size, patch_size)
    Examples::
        >>> grads = torch.rand(23, 2, 32, 32)
        >>> emb_grads = EmbedGradients(patch_size=32,
        ...                            relative=False)
        >>> emb = emb_grads(grads) # 23x7x32x32
    """

    def __init__(self,
                 patch_size: int = 32,
                 relative: bool = False) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.relative = relative
        self.eps = 1e-8

        # Theta kernel for gradients.
        self.kernel = VonMisesKernel(patch_size=patch_size,
                                     coeffs=COEFFS['theta'])

        # Relative gradients.
        kgrid = create_meshgrid(height=patch_size,
                                width=patch_size,
                                normalized_coordinates=True)
        _, phi = cart2pol(kgrid[:, :, :, 0], kgrid[:, :, :, 1])
        self.register_buffer('phi', phi)

    def emb_mags(self, mags: torch.Tensor) -> torch.Tensor:
        """Embed square roots of magnitudes with eps for numerical reasons. """
        mags = torch.sqrt(mags + self.eps)
        return mags

    def forward(self, grads: torch.Tensor) -> torch.Tensor:
        if not isinstance(grads, torch.Tensor):
            raise TypeError("Input type is not a torch.Tensor. Got {}"
                            .format(type(grads)))
        if not len(grads.shape) == 4:
            raise ValueError("Invalid input shape, we expect Bx2xHxW. Got: {}"
                             .format(grads.shape))
        mags = grads[:, :1, :, :]
        oris = grads[:, 1:, :, :]
        if self.relative:
            oris = oris - self.phi.to(oris)
        y = self.kernel(oris) * self.emb_mags(mags)
        return y

    def __repr__(self) -> str:
        return self.__class__.__name__ +\
            '(' + 'patch_size=' + str(self.patch_size) +\
            ', ' + 'relative=' + str(self.relative) + ')'


def spatial_kernel_embedding(kernel_type, grids: dict) -> torch.Tensor:
    """Compute embeddings for cartesian and polar parametrizations. """
    factors = {"phi": 1.0, "rho": pi / sqrt2, "x": pi / 2, "y": pi / 2}
    if kernel_type == 'cart':
        coeffs_ = 'xy'
        params_ = ['x', 'y']
    elif kernel_type == 'polar':
        coeffs_ = 'rhophi'
        params_ = ['phi', 'rho']

    # Infer patch_size.
    keys = list(grids.keys())
    patch_size = grids[keys[0]].shape[-1]

    # Scale appropriately.
    grids_normed = {k: v * factors[k] for k, v in grids.items()}
    grids_normed = {k: v.unsqueeze(0).unsqueeze(0).float()
                    for k, v in grids_normed.items()}

    # x,y/rho,phi kernels.
    vm_a = VonMisesKernel(patch_size=patch_size, coeffs=COEFFS[coeffs_])
    vm_b = VonMisesKernel(patch_size=patch_size, coeffs=COEFFS[coeffs_])

    emb_a = vm_a(grids_normed[params_[0]]).squeeze()
    emb_b = vm_b(grids_normed[params_[1]]).squeeze()

    # Final precomputed position embedding.
    kron_order = get_kron_order(vm_a.d, vm_b.d)
    spatial_kernel = emb_a.index_select(0,
                                        kron_order[:, 0]) * emb_b.index_select(0, kron_order[:, 1])
    return spatial_kernel


class ExplicitSpacialEncoding(nn.Module):
    """
    Module that computes explicit cartesian or polar embedding.
    Args:
        kernel_type: (str) Parametrization of kernel.
                     'polar', 'cart' ('polar' is default)
        fmap_size: (int) Input feature map size in pixels (32 is default)
        in_dims: (int) Dimensionality of input feature map (7 is default)
        do_gmask: (bool) Apply gaussian mask (True is default)
        do_l2: (bool) Apply l2-normalization (True is default)
    Returns:
        Tensor: Explicit cartesian or polar embedding
    Shape:
        - Input: (B, in_dims, fmap_size, fmap_size)
        - Output: (B, out_dims, fmap_size, fmap_size)
    Examples::
        >>> emb_ori = torch.rand(23, 7, 32, 32)
        >>> ese = ExplicitSpacialEncoding(kernel_type='polar',
        ...                               fmap_size=32,
        ...                               in_dims=7,
        ...                               do_gmask=True,
        ...                               do_l2=True)
        >>> desc = ese(emb_ori) # 23x175x32x32
    """

    def __init__(self,
                 kernel_type: str = 'polar',
                 fmap_size: int = 32,
                 in_dims: int = 7,
                 do_gmask: bool = True,
                 do_l2: bool = True) -> None:
        super().__init__()

        if kernel_type not in ['polar', 'cart']:
            raise NotImplementedError(f'{kernel_type} is not valid, use polar or cart).')

        self.kernel_type = kernel_type
        self.fmap_size = fmap_size
        self.in_dims = in_dims
        self.do_gmask = do_gmask
        self.do_l2 = do_l2
        self.grid = get_grid_dict(fmap_size)
        self.gmask = None

        # Precompute embedding.
        emb = spatial_kernel_embedding(self.kernel_type, self.grid)

        # Gaussian mask.
        if self.do_gmask:
            self.gmask = self.get_gmask(sigma=1.0)
            emb = emb * self.gmask

        # Store precomputed embedding.
        self.register_buffer('emb', emb.unsqueeze(0))
        self.d_emb: int = emb.shape[0]
        self.out_dims: int = self.in_dims * self.d_emb
        self.odims: int = self.out_dims

        # Store kronecker form.
        emb2, idx1 = self.init_kron()
        self.register_buffer('emb2', emb2)
        self.register_buffer('idx1', idx1)

    def get_gmask(self, sigma: float) -> torch.Tensor:
        """Compute Gaussian mask. """
        norm_rho = self.grid['rho'] / self.grid['rho'].max()
        gmask = torch.exp(-1 * norm_rho**2 / sigma**2)
        return gmask

    def init_kron(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """Initialize helper variables to calculate kronecker. """
        kron = get_kron_order(self.in_dims, self.d_emb)
        _emb = torch.jit.annotate(torch.Tensor, self.emb)
        emb2 = torch.index_select(_emb, 1, kron[:, 1])
        return emb2, kron[:, 0]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input type is not a torch.Tensor. Got {}"
                            .format(type(x)))
        if not ((len(x.shape) == 4) | (x.shape[1] == self.in_dims)):
            raise ValueError("Invalid input shape, we expect Bx{}xHxW. Got: {}"
                             .format(self.in_dims, x.shape))
        idx1 = torch.jit.annotate(torch.Tensor, self.idx1)
        emb1 = torch.index_select(x, 1, idx1)
        output = emb1 * self.emb2
        output = output.sum(dim=(2, 3))
        if self.do_l2:
            output = F.normalize(output, dim=1)
        return output

    def __repr__(self) -> str:
        return self.__class__.__name__ +\
            '(' + 'kernel_type=' + str(self.kernel_type) +\
            ', ' + 'fmap_size=' + str(self.fmap_size) +\
            ', ' + 'in_dims=' + str(self.in_dims) +\
            ', ' + 'out_dims=' + str(self.out_dims) +\
            ', ' + 'do_gmask=' + str(self.do_gmask) +\
            ', ' + 'do_l2=' + str(self.do_l2) + ')'


class Whitening(nn.Module):
    """
    Module, performs supervised or unsupervised whitening.

    This is based on the paper "Understanding and Improving Kernel Local Descriptors".
    See :cite:`mukundan2019understanding` for more details.

    Args:
        xform: (str) Variant of whitening to use.
                     None, 'lw', 'pca', 'pcaws', 'pcawt'
        whitening_model: (dict) Dictionary with keys
                    'mean', 'eigvecs', 'eigvals' holding torch.Tensors
        in_dims: (int) Dimensionality of input descriptors
        output_dims: (int) Dimensionality reduction (128 is default)
        keval: (int) Shrinkage parameter (40 is default)
        t: (float) Attenuation parameter (0.7 is default)
    Returns:
        Tensor: l2-normalized, whitened descriptors
    Shape:
        - Input: (B, in_dims, fmap_size, fmap_size)
        - Output: (B, out_dims, fmap_size, fmap_size)
    Examples::
        >>> descs = torch.rand(23, 238)
        >>> whitening_model = {'pca': {'mean': torch.zeros(238),
        ...                            'eigvecs': torch.eye(238),
        ...                            'eigvals': torch.ones(238)}}
        >>> whitening = Whitening(xform='pcawt',
        ...                       whitening_model=whitening_model,
        ...                       in_dims=238,
        ...                       output_dims=128,
        ...                       keval=40,
        ...                       t=0.7)
        >>> wdescs = whitening(descs) # 23x128
    """

    def __init__(self,
                 xform: str,
                 whitening_model: Union[Dict[str, Dict[str, torch.Tensor]], None],
                 in_dims: int,
                 output_dims: int = 128,
                 keval: int = 40,
                 t: float = 0.7) -> None:
        super().__init__()

        self.xform = xform
        self.in_dims = in_dims
        self.keval = keval
        self.t = t
        self.pval = 1.0

        # Compute true output_dims.
        output_dims = min(output_dims, in_dims)
        self.output_dims = output_dims

        # Initialize identity transform.
        self.mean = nn.Parameter(torch.zeros(in_dims), requires_grad=True)
        self.evecs = nn.Parameter(torch.eye(in_dims)[:, :output_dims], requires_grad=True)
        self.evals = nn.Parameter(torch.ones(in_dims)[:output_dims], requires_grad=True)

        if whitening_model is not None:
            self.load_whitening_parameters(whitening_model)

    def load_whitening_parameters(self,
                                  whitening_model: Dict[str, Dict[str, torch.Tensor]]) -> None:
        algo = 'lw' if self.xform == 'lw' else 'pca'
        wh_model = whitening_model[algo]
        self.mean.data = wh_model['mean']
        self.evecs.data = wh_model['eigvecs'][:, :self.output_dims]
        self.evals.data = wh_model['eigvals'][:self.output_dims]

        modifications = {'pca': self._modify_pca,
                         'lw': self._modify_lw,
                         'pcaws': self._modify_pcaws,
                         'pcawt': self._modify_pcawt}

        # Call modification.
        modifications[self.xform]()

    def _modify_pca(self) -> None:
        """ Modify powerlaw parameter."""
        self.pval = 0.5

    def _modify_lw(self) -> None:
        """ No modification required."""

    def _modify_pcaws(self) -> None:
        """ Shrinkage for eigenvalues."""
        alpha = self.evals[self.keval]
        evals = ((1 - alpha) * self.evals) + alpha
        self.evecs.data = self.evecs @ torch.diag(torch.pow(evals, -0.5))

    def _modify_pcawt(self) -> None:
        """ Attenuation for eigenvalues."""
        m = -0.5 * self.t
        self.evecs.data = self.evecs @ torch.diag(torch.pow(self.evals, m))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input type is not a torch.Tensor. Got {}"
                            .format(type(x)))
        if not len(x.shape) == 2:
            raise ValueError("Invalid input shape, we expect NxD. Got: {}"
                             .format(x.shape))
        x = x - self.mean  # Center the data.
        x = x @ self.evecs  # Apply rotation and/or scaling.
        x = torch.sign(x) * torch.pow(torch.abs(x), self.pval)  # Powerlaw.
        return F.normalize(x, dim=1)

    def __repr__(self) -> str:
        return self.__class__.__name__ +\
            '(' + 'xform=' + str(self.xform) +\
            ', ' + 'in_dims=' + str(self.in_dims) +\
            ', ' + 'output_dims=' + str(self.output_dims) + ')'


[docs]class MKDDescriptor(nn.Module): r"""Module that computes Multiple Kernel local descriptors. This is based on the paper "Understanding and Improving Kernel Local Descriptors". See :cite:`mukundan2019understanding` for more details. Args: patch_size: (int) Input patch size in pixels (32 is default). kernel_type: (str) Parametrization of kernel 'concat', 'cart', 'polar' ('concat' is default). whitening: (str) Whitening transform to apply None, 'lw', 'pca', 'pcawt', 'pcaws' ('pcawt' is default). training_set: (str) Set that model was trained on 'liberty', 'notredame', 'yosemite' ('liberty' is default). output_dims: (int) Dimensionality reduction (128 is default). Returns: torch.Tensor: Explicit cartesian or polar embedding. Shape: - Input: :math:`(B, in_dims, fmap_size, fmap_size)`. - Output: :math:`(B, out_dims, fmap_size, fmap_size)`, Examples: >>> patches = torch.rand(23, 1, 32, 32) >>> mkd = MKDDescriptor(patch_size=32, ... kernel_type='concat', ... whitening='pcawt', ... training_set='liberty', ... output_dims=128) >>> desc = mkd(patches) # 23x128 """ def __init__(self, patch_size: int = 32, kernel_type: str = 'concat', whitening: str = 'pcawt', training_set: str = 'liberty', output_dims: int = 128) -> None: super().__init__() self.patch_size: int = patch_size self.kernel_type: str = kernel_type self.whitening: str = whitening self.training_set: str = training_set self.sigma = 1.4 * (patch_size / 64) self.smoothing = GaussianBlur2d((5, 5), (self.sigma, self.sigma), 'replicate') self.gradients = MKDGradients() # This stupid thing needed for jitting... polar_s: str = 'polar' cart_s: str = 'cart' self.parametrizations = [polar_s, cart_s] if self.kernel_type == 'concat' else [self.kernel_type] # Initialize cartesian/polar embedding with absolute/relative gradients. self.odims: int = 0 relative_orientations = {polar_s: True, cart_s: False} self.feats = {} for parametrization in self.parametrizations: gradient_embedding = EmbedGradients(patch_size=patch_size, relative=relative_orientations[parametrization]) spatial_encoding = ExplicitSpacialEncoding(kernel_type=parametrization, fmap_size=patch_size, in_dims=gradient_embedding.kernel.d) self.feats[parametrization] = nn.Sequential(gradient_embedding, spatial_encoding) self.odims += spatial_encoding.odims # Compute true output_dims. self.output_dims: int = min(output_dims, self.odims) # Load supervised(lw)/unsupervised(pca) model trained on training_set. if self.whitening is not None: whitening_models = torch.hub.load_state_dict_from_url( urls[self.kernel_type], map_location=lambda storage, loc: storage ) whitening_model = whitening_models[training_set] self.whitening_layer = Whitening(whitening, whitening_model, in_dims=self.odims, output_dims=self.output_dims) self.odims = self.output_dims def forward(self, patches: torch.Tensor) -> torch.Tensor: if not isinstance(patches, torch.Tensor): raise TypeError("Input type is not a torch.Tensor. Got {}" .format(type(patches))) if not len(patches.shape) == 4: raise ValueError("Invalid input shape, we expect Bx1xHxW. Got: {}" .format(patches.shape)) # Extract gradients. g = self.smoothing(patches) g = self.gradients(g) # Extract polar/cart features. features = [] for parametrization in self.parametrizations: self.feats[parametrization].to(g.device) features.append(self.feats[parametrization](g)) # Concatenate. y = torch.cat(features, dim=1) # l2-normalize. y = F.normalize(y, dim=1) # Whiten descriptors. if self.whitening is not None: y = self.whitening_layer(y) return y def __repr__(self) -> str: return self.__class__.__name__ +\ '(' + 'patch_size=' + str(self.patch_size) +\ ', ' + 'kernel_type=' + str(self.kernel_type) +\ ', ' + 'whitening=' + str(self.whitening) +\ ', ' + 'training_set=' + str(self.training_set) +\ ', ' + 'output_dims=' + str(self.output_dims) + ')'
def load_whitening_model(kernel_type: str, training_set: str) -> Dict: whitening_models = torch.hub.load_state_dict_from_url( urls[kernel_type], map_location=lambda storage, loc: storage ) whitening_model = whitening_models[training_set] return whitening_model class SimpleKD(nn.Module): """Example to write custom Kernel Descriptors. """ def __init__(self, patch_size: int = 32, kernel_type: str = 'polar', # 'cart' 'polar' whitening: str = 'pcawt', # 'lw', 'pca', 'pcaws', 'pcawt training_set: str = 'liberty', # 'liberty', 'notredame', 'yosemite' output_dims: int = 128) -> None: super().__init__() relative: bool = kernel_type == 'polar' sigma: float = 1.4 * (patch_size / 64) # Sequence of modules. smoothing = GaussianBlur2d((5, 5), (sigma, sigma), 'replicate') gradients = MKDGradients() ori = EmbedGradients(patch_size=patch_size, relative=relative) ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=patch_size, in_dims=ori.kernel.d) wh = Whitening(whitening, load_whitening_model(kernel_type, training_set), in_dims=ese.odims, output_dims=output_dims) self.features = nn.Sequential(smoothing, gradients, ori, ese, wh) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.features(x)