Source code for kornia.geometry.transform.image_registrator

from typing import Callable, List, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from kornia.geometry.conversions import angle_to_rotation_matrix, convert_affinematrix_to_homography

from .homography_warper import HomographyWarper
from .pyramid import build_pyramid

__all__ = [

[docs]class Homography(nn.Module): r"""Homography geometric model to be used together with ImageRegistrator module for the optimization-based image registration.""" def __init__(self) -> None: super().__init__() self.model = nn.Parameter(torch.eye(3)) self.reset_model() def __repr__(self) -> str: return f'{self.__class__.__name__}({self.model})'
[docs] def reset_model(self): """Initializes the model with identity transform.""" torch.nn.init.eye_(self.model)
[docs] def forward(self) -> torch.Tensor: r"""Single-batch homography". Returns: Homography matrix with shape :math:`(1, 3, 3)`. """ return torch.unsqueeze(self.model / self.model[2, 2], dim=0) # 1x3x3
[docs] def forward_inverse(self) -> torch.Tensor: r"""Interted Single-batch homography". Returns: Homography martix with shape :math:`(1, 3, 3)`. """ return torch.unsqueeze(torch.inverse(self.model), dim=0)
[docs]class Similarity(nn.Module): """Similarity geometric model to be used together with ImageRegistrator module for the optimization-based image registration. Args: rotation: if True, the rotation is optimizable, else constant zero. scale: if True, the scale is optimizable, else constant zero. shift: if True, the shift is optimizable, else constant one. """ def __init__(self, rotation: bool = True, scale: bool = True, shift: bool = True) -> None: super().__init__() if rotation: self.rot = nn.Parameter(torch.zeros(1)) else: self.register_buffer('rot', torch.zeros(1)) if shift: self.shift = nn.Parameter(torch.zeros(1, 2, 1)) else: self.register_buffer('shift', torch.zeros(1, 2, 1)) if scale: self.scale = nn.Parameter(torch.ones(1)) else: self.register_buffer('scale', torch.ones(1)) self.reset_model() def __repr__(self) -> str: return f'{self.__class__.__name__}(angle = {self.rot},\ \n shift={self.shift}, \n scale={self.scale})'
[docs] def reset_model(self) -> None: """Initialize the model with identity transform.""" torch.nn.init.zeros_(self.rot) torch.nn.init.zeros_(self.shift) torch.nn.init.ones_(self.scale)
[docs] def forward(self) -> torch.Tensor: r"""Single-batch similarity transform". Returns: Similarity with shape :math:`(1, 3, 3)`""" rot = self.scale * angle_to_rotation_matrix(self.rot) out = convert_affinematrix_to_homography([rot, self.shift], dim=2)) return out
[docs] def forward_inverse(self) -> torch.Tensor: r"""Single-batch inverse similarity transform". Returns: Similarity with shape :math:`(1, 3, 3)`""" return torch.inverse(self.forward())
[docs]class ImageRegistrator(nn.Module): r"""Module, which performs optimization-based image registration. Args: model_type: Geometrical model for registration. Can be string or Module. optimizer: optimizer class used for the optimization. loss_fn: torch loss function. pyramid_levels: number of scale pyramid levels. lr: learning rate for optimization. num_iterations: maximum number of iterations. tolerance: stop optimizing if loss difference is less. default 1e-4. warper: if model_type is not string, one needs to provide warper object. Example: >>> from kornia.geometry import ImageRegistrator >>> img_src = torch.rand(1, 1, 32, 32) >>> img_dst = torch.rand(1, 1, 32, 32) >>> registrator = ImageRegistrator('similarity') >>> homo = registrator.register(img_src, img_dst) """ known_models = ['homography', 'similarity', 'translation', 'scale', 'rotation'] # TODO: resolve better type, potentially using factory. def __init__(self, model_type='homography', optimizer=optim.Adam, loss_fn: Callable = F.l1_loss, pyramid_levels: int = 5, lr: float = 1e-3, num_iterations: int = 100, tolerance: float = 1e-4, warper=None) -> None: super().__init__() # We provide pre-defined combinations or allow user to supply model # together with warper if not isinstance(model_type, str): if warper is None: raise ValueError("You must supply warper together with custom model") self.warper = warper self.model = model_type else: if model_type.lower() == "homography": self.warper = HomographyWarper self.model = Homography() elif model_type.lower() == "similarity": self.warper = HomographyWarper self.model = Similarity(True, True, True) elif model_type.lower() == "translation": self.warper = HomographyWarper self.model = Similarity(False, False, True) elif model_type.lower() == "rotation": self.warper = HomographyWarper self.model = Similarity(True, False, False) elif model_type.lower() == "scale": self.warper = HomographyWarper self.model = Similarity(False, True, False) else: raise ValueError(f"{model_type} is not supported. Try {self.known_models}") self.pyramid_levels = pyramid_levels self.optimizer = optimizer = lr self.loss_fn = loss_fn self.num_iterations = num_iterations self.tolerance = tolerance
[docs] def get_single_level_loss(self, img_src: torch.Tensor, img_dst: torch.Tensor, transform_model: torch.Tensor) -> torch.Tensor: """Warp img_src into img_dst with transform_model and returns loss.""" # ToDo: Make possible registration of images of different shape if img_src.shape != img_dst.shape: raise ValueError(f"Cannot register images of different shapes\ {img_src.shape} {img_dst.shape:} ") _height, _width = img_dst.shape[-2:] warper = self.warper(_height, _width) img_src_to_dst = warper(img_src, transform_model) # compute and mask loss loss = self.loss_fn(img_src_to_dst, img_dst, reduction='none') # 1xCxHxW ones = warper(torch.ones_like(img_src), transform_model) loss = loss.masked_select(ones > 0.9).mean() return loss
[docs] def reset_model(self) -> None: """Calls model reset function.""" self.model.reset_model()
[docs] def register(self, src_img: torch.Tensor, dst_img: torch.Tensor, verbose: bool = False, output_intermediate_models: bool = False) -> \ Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: r"""Estimate the tranformation' which warps src_img into dst_img by gradient descent. The shape of the tensors is not checked, because it may depend on the model, e.g. volume registration Args: src_img: Input image tensor. dst_img: Input image tensor. verbose: if True, outputs loss every 10 iterations. output_intermediate_models: if True with intermediate models Returns: the transformation between two images, shape depends on the model, typically [1x3x3] tensor for string model_types. """ self.reset_model() # ToDo: better parameter passing to optimizer opt: optim.Optimizer = self.optimizer(self.model.parameters(), # compute the gaussian pyramids # [::-1] because we have to register from coarse to fine img_src_pyr = build_pyramid(src_img, self.pyramid_levels)[::-1] img_dst_pyr = build_pyramid(dst_img, self.pyramid_levels)[::-1] prev_loss = 1e10 aux_models = [] if len(img_dst_pyr) != len(img_src_pyr): raise ValueError("Cannot register images of different sizes") for img_src_level, img_dst_level in zip(img_src_pyr, img_dst_pyr): for i in range(self.num_iterations): # compute gradient and update optimizer parameters opt.zero_grad() loss = self.get_single_level_loss(img_src_level, img_dst_level, self.model()) loss += self.get_single_level_loss(img_dst_level, img_src_level, self.model.forward_inverse()) current_loss = loss.item() if abs(current_loss - prev_loss) < self.tolerance: break prev_loss = current_loss loss.backward() if verbose and (i % 10 == 0): print(f"Loss = {current_loss:.4f}, iter={i}") opt.step() if output_intermediate_models: aux_models.append(self.model().clone().detach()) if output_intermediate_models: return self.model(), aux_models return self.model()
[docs] def warp_src_into_dst(self, src_img: torch.Tensor) -> torch.Tensor: r"""Warp src_img with estimated model.""" _height, _width = src_img.shape[-2:] warper = self.warper(_height, _width) img_src_to_dst = warper(src_img, self.model()) return img_src_to_dst
[docs] def warp_dst_inro_src(self, dst_img: torch.Tensor) -> torch.Tensor: r"""Warp src_img with inverted estimated model.""" _height, _width = dst_img.shape[-2:] warper = self.warper(_height, _width) img_dst_to_src = warper(dst_img, self.model.forward_inverse()) return img_dst_to_src