Source code for kornia.filters.guided

from __future__ import annotations

import torch
from torch.nn.functional import interpolate

from kornia.core import Module, Tensor
from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE

from .blur import box_blur
from .kernels import _unpack_2d_ks


def _preprocess_fast_guided_blur(
    guidance: Tensor, input: Tensor, kernel_size: tuple[int, int] | int, subsample: int = 1
) -> tuple[Tensor, Tensor, tuple[int, int]]:
    ky, kx = _unpack_2d_ks(kernel_size)
    if subsample > 1:
        s = 1 / subsample
        guidance_sub = interpolate(guidance, scale_factor=s, mode="nearest")
        input_sub = guidance_sub if input is guidance else interpolate(input, scale_factor=s, mode="nearest")
        ky, kx = ((k - 1) // subsample + 1 for k in (ky, kx))
    else:
        guidance_sub = guidance
        input_sub = input
    return guidance_sub, input_sub, (ky, kx)


def _guided_blur_grayscale_guidance(
    guidance: Tensor,
    input: Tensor,
    kernel_size: tuple[int, int] | int,
    eps: float | Tensor,
    border_type: str = "reflect",
    subsample: int = 1,
) -> Tensor:
    guidance_sub, input_sub, kernel_size = _preprocess_fast_guided_blur(guidance, input, kernel_size, subsample)

    mean_I = box_blur(guidance_sub, kernel_size, border_type)
    corr_I = box_blur(guidance_sub.square(), kernel_size, border_type)
    var_I = corr_I - mean_I.square()

    if input is guidance:
        mean_p = mean_I
        cov_Ip = var_I

    else:
        mean_p = box_blur(input_sub, kernel_size, border_type)
        corr_Ip = box_blur(guidance_sub * input_sub, kernel_size, border_type)
        cov_Ip = corr_Ip - mean_I * mean_p

    if isinstance(eps, Tensor):
        eps = eps.view(-1, 1, 1, 1)  # N -> NCHW

    a = cov_Ip / (var_I + eps)
    b = mean_p - a * mean_I

    mean_a = box_blur(a, kernel_size, border_type)
    mean_b = box_blur(b, kernel_size, border_type)

    if subsample > 1:
        mean_a = interpolate(mean_a, scale_factor=subsample, mode="bilinear")
        mean_b = interpolate(mean_b, scale_factor=subsample, mode="bilinear")

    return mean_a * guidance + mean_b


def _guided_blur_multichannel_guidance(
    guidance: Tensor,
    input: Tensor,
    kernel_size: tuple[int, int] | int,
    eps: float | Tensor,
    border_type: str = "reflect",
    subsample: int = 1,
) -> Tensor:
    guidance_sub, input_sub, kernel_size = _preprocess_fast_guided_blur(guidance, input, kernel_size, subsample)
    B, C, H, W = guidance_sub.shape

    mean_I = box_blur(guidance_sub, kernel_size, border_type).permute(0, 2, 3, 1)
    II = (guidance_sub.unsqueeze(1) * guidance_sub.unsqueeze(2)).flatten(1, 2)
    corr_I = box_blur(II, kernel_size, border_type).permute(0, 2, 3, 1)
    var_I = corr_I.reshape(B, H, W, C, C) - mean_I.unsqueeze(-2) * mean_I.unsqueeze(-1)

    if guidance is input:
        mean_p = mean_I
        cov_Ip = var_I

    else:
        mean_p = box_blur(input_sub, kernel_size, border_type).permute(0, 2, 3, 1)
        Ip = (input_sub.unsqueeze(1) * guidance_sub.unsqueeze(2)).flatten(1, 2)
        corr_Ip = box_blur(Ip, kernel_size, border_type).permute(0, 2, 3, 1)
        cov_Ip = corr_Ip.reshape(B, H, W, C, -1) - mean_p.unsqueeze(-2) * mean_I.unsqueeze(-1)

    if isinstance(eps, Tensor):
        _eps = torch.eye(C, device=guidance.device, dtype=guidance.dtype).view(1, 1, 1, C, C) * eps.view(-1, 1, 1, 1, 1)
    else:
        _eps = guidance.new_full((C,), eps).diag().view(1, 1, 1, C, C)
    a = torch.linalg.solve(var_I + _eps, cov_Ip)  # B, H, W, C_guidance, C_input
    b = mean_p - (mean_I.unsqueeze(-2) @ a).squeeze(-2)  # B, H, W, C_input

    mean_a = box_blur(a.flatten(-2).permute(0, 3, 1, 2), kernel_size, border_type)
    mean_b = box_blur(b.permute(0, 3, 1, 2), kernel_size, border_type)

    if subsample > 1:
        mean_a = interpolate(mean_a, scale_factor=subsample, mode="bilinear")
        mean_b = interpolate(mean_b, scale_factor=subsample, mode="bilinear")
    mean_a = mean_a.view(B, C, -1, H * subsample, W * subsample)

    # einsum might not be contiguous, thus mean_b is the first argument
    return mean_b + torch.einsum("BCHW,BCcHW->BcHW", guidance, mean_a)


[docs]def guided_blur( guidance: Tensor, input: Tensor, kernel_size: tuple[int, int] | int, eps: float | Tensor, border_type: str = "reflect", subsample: int = 1, ) -> Tensor: r"""Blur a tensor using a Guided filter. .. image:: _static/img/guided_blur.png The operator is an edge-preserving image smoothing filter. See :cite:`he2010guided` and :cite:`he2015fast` for details. Guidance and input can have different number of channels. Arguments: guidance: the guidance tensor with shape :math:`(B,C,H,W)`. input: the input tensor with shape :math:`(B,C,H,W)`. kernel_size: the size of the kernel. eps: regularization parameter. Smaller values preserve more edges. border_type: the padding mode to be applied before convolving. The expected modes are: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. subsample: subsampling factor for Fast Guided filtering. Default: 1 (no subsampling) Returns: the blurred tensor with same shape as `input` :math:`(B, C, H, W)`. Examples: >>> guidance = torch.rand(2, 3, 5, 5) >>> input = torch.rand(2, 4, 5, 5) >>> output = guided_blur(guidance, input, 3, 0.1) >>> output.shape torch.Size([2, 4, 5, 5]) """ KORNIA_CHECK_IS_TENSOR(guidance) KORNIA_CHECK_SHAPE(guidance, ["B", "C", "H", "W"]) if input is not guidance: KORNIA_CHECK_IS_TENSOR(input) KORNIA_CHECK_SHAPE(input, ["B", "C", "H", "W"]) KORNIA_CHECK( (guidance.shape[0] == input.shape[0]) and (guidance.shape[-2:] == input.shape[-2:]), "guidance and input should have the same batch size and spatial dimensions", ) if guidance.shape[1] == 1: return _guided_blur_grayscale_guidance(guidance, input, kernel_size, eps, border_type, subsample) else: return _guided_blur_multichannel_guidance(guidance, input, kernel_size, eps, border_type, subsample)
[docs]class GuidedBlur(Module): r"""Blur a tensor using a Guided filter. The operator is an edge-preserving image smoothing filter. See :cite:`he2010guided` and :cite:`he2015fast` for details. Guidance and input can have different number of channels. Arguments: kernel_size: the size of the kernel. eps: regularization parameter. Smaller values preserve more edges. border_type: the padding mode to be applied before convolving. The expected modes are: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. subsample: subsampling factor for Fast Guided filtering. Default: 1 (no subsampling) Returns: the blurred input tensor. Shape: - Input: :math:`(B, C, H, W)`, :math:`(B, C, H, W)` - Output: :math:`(B, C, H, W)` Examples: >>> guidance = torch.rand(2, 3, 5, 5) >>> input = torch.rand(2, 4, 5, 5) >>> blur = GuidedBlur(3, 0.1) >>> output = blur(guidance, input) >>> output.shape torch.Size([2, 4, 5, 5]) """ def __init__( self, kernel_size: tuple[int, int] | int, eps: float, border_type: str = "reflect", subsample: int = 1 ) -> None: super().__init__() self.kernel_size = kernel_size self.eps = eps self.border_type = border_type self.subsample = subsample def __repr__(self) -> str: return ( f"{self.__class__.__name__}" f"(kernel_size={self.kernel_size}, " f"eps={self.eps}, " f"border_type={self.border_type}, " f"subsample={self.subsample})" ) def forward(self, guidance: Tensor, input: Tensor) -> Tensor: return guided_blur(guidance, input, self.kernel_size, self.eps, self.border_type, self.subsample)