Source code for kornia.filters.filter

from typing import Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F


def compute_padding(kernel_size: Tuple[int, int]) -> List[int]:
    """Computes padding tuple."""
    # 4 ints:  (padding_left, padding_right,padding_top,padding_bottom)
    # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
    assert len(kernel_size) == 2, kernel_size
    computed = [(k - 1) // 2 for k in kernel_size]
    return [computed[1], computed[1], computed[0], computed[0]]


[docs]def filter2D(input: torch.Tensor, kernel: torch.Tensor, border_type: str = 'reflect') -> torch.Tensor: r"""Function that convolves a tensor with a kernel. The function applies a given kernel to a tensor. The kernel is applied indepentdently at each depth channel of the tensor. Before applying the kernel, the function applies padding according to the specified mode so that the output reaims in the same shape. Args: input (torch.Tensor): the input tensor with shape of :math:`(B, C, H, W)`. kernel (torch.Tensor): the kernel to be convolved with the input tensor. The kernel shape must be :math:`(B, kH, kW)`. borde_type (str): the padding mode to be applied before convolving. The expected modes are: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. Return: torch.Tensor: the convolved tensor of same size and numbers of channels as the input. """ if not isinstance(input, torch.Tensor): raise TypeError("Input type is not a torch.Tensor. Got {}" .format(type(input))) if not isinstance(kernel, torch.Tensor): raise TypeError("Input kernel type is not a torch.Tensor. Got {}" .format(type(kernel))) if not isinstance(border_type, str): raise TypeError("Input border_type is not string. Got {}" .format(type(kernel))) if not len(input.shape) == 4: raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}" .format(input.shape)) if not len(kernel.shape) == 3: raise ValueError("Invalid kernel shape, we expect BxHxW. Got: {}" .format(kernel.shape)) borders_list: List[str] = ['constant', 'reflect', 'replicate', 'circular'] if border_type not in borders_list: raise ValueError("Invalid border_type, we expect the following: {0}." "Got: {1}".format(borders_list, border_type)) # prepare kernel b, c, h, w = input.shape tmp_kernel: torch.Tensor = kernel.to(input.device).to(input.dtype) tmp_kernel = tmp_kernel.repeat(c, 1, 1, 1) # pad the input tensor height, width = tmp_kernel.shape[-2:] padding_shape: List[int] = compute_padding((height, width)) input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type) # convolve the tensor with the kernel return F.conv2d(input_pad, tmp_kernel, padding=0, stride=1, groups=c)