Source code for kornia.filters.median
from __future__ import annotations
import torch.nn.functional as F
from kornia.core import Module, Tensor
from kornia.core.check import KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
from .kernels import _unpack_2d_ks, get_binary_kernel2d
def _compute_zero_padding(kernel_size: tuple[int, int] | int) -> tuple[int, int]:
r"""Utility function that computes zero padding tuple."""
ky, kx = _unpack_2d_ks(kernel_size)
return (ky - 1) // 2, (kx - 1) // 2
[docs]def median_blur(input: Tensor, kernel_size: tuple[int, int] | int) -> Tensor:
r"""Blur an image using the median filter.
.. image:: _static/img/median_blur.png
Args:
input: the input image with shape :math:`(B,C,H,W)`.
kernel_size: the blurring kernel size.
Returns:
the blurred input tensor with shape :math:`(B,C,H,W)`.
.. note::
See a working example `here <https://kornia.github.io/tutorials/nbs/filtering_operators.html>`__.
Example:
>>> input = torch.rand(2, 4, 5, 7)
>>> output = median_blur(input, (3, 3))
>>> output.shape
torch.Size([2, 4, 5, 7])
"""
KORNIA_CHECK_IS_TENSOR(input)
KORNIA_CHECK_SHAPE(input, ["B", "C", "H", "W"])
padding = _compute_zero_padding(kernel_size)
# prepare kernel
kernel: Tensor = get_binary_kernel2d(kernel_size, device=input.device, dtype=input.dtype)
b, c, h, w = input.shape
# map the local window to single vector
features: Tensor = F.conv2d(input.reshape(b * c, 1, h, w), kernel, padding=padding, stride=1)
features = features.view(b, c, -1, h, w) # BxCx(K_h * K_w)xHxW
# compute the median along the feature axis
return features.median(dim=2)[0]
[docs]class MedianBlur(Module):
r"""Blur an image using the median filter.
Args:
kernel_size: the blurring kernel size.
Returns:
the blurred input tensor.
Shape:
- Input: :math:`(B, C, H, W)`
- Output: :math:`(B, C, H, W)`
Example:
>>> input = torch.rand(2, 4, 5, 7)
>>> blur = MedianBlur((3, 3))
>>> output = blur(input)
>>> output.shape
torch.Size([2, 4, 5, 7])
"""
def __init__(self, kernel_size: tuple[int, int] | int) -> None:
super().__init__()
self.kernel_size = kernel_size
def forward(self, input: Tensor) -> Tensor:
return median_blur(input, self.kernel_size)