Source code for kornia.filters.median

from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from .kernels import get_binary_kernel2d


def _compute_zero_padding(kernel_size: Tuple[int, int]) -> Tuple[int, int]:
    r"""Utility function that computes zero padding tuple."""
    computed: List[int] = [(k - 1) // 2 for k in kernel_size]
    return computed[0], computed[1]


[docs]def median_blur(input: torch.Tensor, kernel_size: Tuple[int, int]) -> torch.Tensor: r"""Blur an image using the median filter. .. image:: _static/img/median_blur.png Args: input: the input image with shape :math:`(B,C,H,W)`. kernel_size: the blurring kernel size. Returns: the blurred input tensor with shape :math:`(B,C,H,W)`. .. note:: See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/ filtering_operators.html>`__. Example: >>> input = torch.rand(2, 4, 5, 7) >>> output = median_blur(input, (3, 3)) >>> output.shape torch.Size([2, 4, 5, 7]) """ if not isinstance(input, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") if not len(input.shape) == 4: raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") padding: Tuple[int, int] = _compute_zero_padding(kernel_size) # prepare kernel kernel: torch.Tensor = get_binary_kernel2d(kernel_size).to(input) b, c, h, w = input.shape # map the local window to single vector features: torch.Tensor = F.conv2d(input.reshape(b * c, 1, h, w), kernel, padding=padding, stride=1) features = features.view(b, c, -1, h, w) # BxCx(K_h * K_w)xHxW # compute the median along the feature axis median: torch.Tensor = torch.median(features, dim=2)[0] return median
[docs]class MedianBlur(nn.Module): r"""Blur an image using the median filter. Args: kernel_size: the blurring kernel size. Returns: the blurred input tensor. Shape: - Input: :math:`(B, C, H, W)` - Output: :math:`(B, C, H, W)` Example: >>> input = torch.rand(2, 4, 5, 7) >>> blur = MedianBlur((3, 3)) >>> output = blur(input) >>> output.shape torch.Size([2, 4, 5, 7]) """ def __init__(self, kernel_size: Tuple[int, int]) -> None: super().__init__() self.kernel_size: Tuple[int, int] = kernel_size def forward(self, input: torch.Tensor) -> torch.Tensor: return median_blur(input, self.kernel_size)