kornia.filters¶
The functions in this sections perform various image filtering operations.
Blurring¶
-
filter2D
(input: torch.Tensor, kernel: torch.Tensor, border_type: str = 'reflect', normalized: bool = False) → torch.Tensor[source]¶ Convolve a tensor with a 2d kernel.
The function applies a given kernel to a tensor. The kernel is applied independently at each depth channel of the tensor. Before applying the kernel, the function applies padding according to the specified mode so that the output remains in the same shape.
- Parameters
input (torch.Tensor) – the input tensor with shape of \((B, C, H, W)\).
kernel (torch.Tensor) – the kernel to be convolved with the input tensor. The kernel shape must be \((1, kH, kW)\) or \((B, kH, kW)\).
border_type (str) – the padding mode to be applied before convolving. The expected modes are:
'constant'
,'reflect'
,'replicate'
or'circular'
. Default:'reflect'
.normalized (bool) – If True, kernel will be L1 normalized.
- Returns
the convolved tensor of same size and numbers of channels as the input with shape \((B, C, H, W)\).
- Return type
Example
>>> input = torch.tensor([[[ ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 5., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.],]]]) >>> kernel = torch.ones(1, 3, 3) >>> filter2D(input, kernel) tensor([[[[0., 0., 0., 0., 0.], [0., 5., 5., 5., 0.], [0., 5., 5., 5., 0.], [0., 5., 5., 5., 0.], [0., 0., 0., 0., 0.]]]])
-
filter3D
(input: torch.Tensor, kernel: torch.Tensor, border_type: str = 'replicate', normalized: bool = False) → torch.Tensor[source]¶ Convolve a tensor with a 3d kernel.
The function applies a given kernel to a tensor. The kernel is applied independently at each depth channel of the tensor. Before applying the kernel, the function applies padding according to the specified mode so that the output remains in the same shape.
- Parameters
input (torch.Tensor) – the input tensor with shape of \((B, C, D, H, W)\).
kernel (torch.Tensor) – the kernel to be convolved with the input tensor. The kernel shape must be \((1, kD, kH, kW)\) or \((B, kD, kH, kW)\).
border_type (str) – the padding mode to be applied before convolving. The expected modes are:
'constant'
,'replicate'
or'circular'
. Default:'replicate'
.normalized (bool) – If True, kernel will be L1 normalized.
- Returns
the convolved tensor of same size and numbers of channels as the input with shape \((B, C, D, H, W)\).
- Return type
Example
>>> input = torch.tensor([[[ ... [[0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.]], ... [[0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 5., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.]], ... [[0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.], ... [0., 0., 0., 0., 0.]] ... ]]]) >>> kernel = torch.ones(1, 3, 3, 3) >>> filter3D(input, kernel) tensor([[[[[0., 0., 0., 0., 0.], [0., 5., 5., 5., 0.], [0., 5., 5., 5., 0.], [0., 5., 5., 5., 0.], [0., 0., 0., 0., 0.]], <BLANKLINE> [[0., 0., 0., 0., 0.], [0., 5., 5., 5., 0.], [0., 5., 5., 5., 0.], [0., 5., 5., 5., 0.], [0., 0., 0., 0., 0.]], <BLANKLINE> [[0., 0., 0., 0., 0.], [0., 5., 5., 5., 0.], [0., 5., 5., 5., 0.], [0., 5., 5., 5., 0.], [0., 0., 0., 0., 0.]]]]])
-
box_blur
(input: torch.Tensor, kernel_size: Tuple[int, int], border_type: str = 'reflect', normalized: bool = True) → torch.Tensor[source]¶ Blurs an image using the box filter.
The function smooths an image using the kernel:
\[\begin{split}K = \frac{1}{\text{kernel_size}_x * \text{kernel_size}_y} \begin{bmatrix} 1 & 1 & 1 & \cdots & 1 & 1 \\ 1 & 1 & 1 & \cdots & 1 & 1 \\ \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 1 & 1 & 1 & \cdots & 1 & 1 \\ \end{bmatrix}\end{split}\]- Parameters
image (torch.Tensor) – the image to blur with shape \((B,C,H,W)\).
border_type (str) – the padding mode to be applied before convolving. The expected modes are:
'constant'
,'reflect'
,'replicate'
or'circular'
. Default:'reflect'
.normalized (bool) – if True, L1 norm of the kernel is set to 1.
- Returns
the blurred tensor with shape \((B,C,H,W)\).
- Return type
Example
>>> input = torch.rand(2, 4, 5, 7) >>> output = box_blur(input, (3, 3)) # 2x4x5x7 >>> output.shape torch.Size([2, 4, 5, 7])
-
median_blur
(input: torch.Tensor, kernel_size: Tuple[int, int]) → torch.Tensor[source]¶ Blurs an image using the median filter.
- Parameters
input (torch.Tensor) – the input image with shape \((B,C,H,W)\).
- Returns
the blurred input tensor with shape \((B,C,H,W)\).
- Return type
Example
>>> input = torch.rand(2, 4, 5, 7) >>> output = median_blur(input, (3, 3)) >>> output.shape torch.Size([2, 4, 5, 7])
-
gaussian_blur2d
(input: torch.Tensor, kernel_size: Tuple[int, int], sigma: Tuple[float, float], border_type: str = 'reflect') → torch.Tensor[source]¶ Creates an operator that blurs a tensor using a Gaussian filter.
The operator smooths the given tensor with a gaussian kernel by convolving it to each channel. It supports batched operation.
- Parameters
input (torch.Tensor) – the input tensor with shape \((B,C,H,W)\).
sigma (Tuple[float, float]) – the standard deviation of the kernel.
border_type (str) – the padding mode to be applied before convolving. The expected modes are:
'constant'
,'reflect'
,'replicate'
or'circular'
. Default:'reflect'
.
- Returns
the blurred tensor with shape \((B, C, H, W)\).
- Return type
Examples
>>> input = torch.rand(2, 4, 5, 5) >>> output = gaussian_blur2d(input, (3, 3), (1.5, 1.5)) >>> output.shape torch.Size([2, 4, 5, 5])
-
motion_blur
(input: torch.Tensor, kernel_size: int, angle: Union[float, torch.Tensor], direction: Union[float, torch.Tensor], border_type: str = 'constant', mode: str = 'nearest') → torch.Tensor[source]¶ Perform motion blur on 2D images (4D tensor).
- Parameters
input (torch.Tensor) – the input tensor with shape \((B, C, H, W)\).
kernel_size (int) – motion kernel width and height. It should be odd and positive.
angle (Union[torch.Tensor, float]) – angle of the motion blur in degrees (anti-clockwise rotation). If tensor, it must be \((B,)\).
direction (tensor or float) – forward/backward direction of the motion blur. Lower values towards -1.0 will point the motion blur towards the back (with angle provided via angle), while higher values towards 1.0 will point the motion blur forward. A value of 0.0 leads to a uniformly (but still angled) motion blur. If tensor, it must be \((B,)\).
border_type (str) – the padding mode to be applied before convolving. The expected modes are:
'constant'
,'reflect'
,'replicate'
or'circular'
. Default:'constant'
.mode (str) – interpolation mode for rotating the kernel.
'bilinear'
or'nearest'
. Default:'nearest'
- Returns
the blurred image with shape \((B, C, H, W)\).
- Return type
Example
>>> input = torch.randn(1, 3, 80, 90).repeat(2, 1, 1, 1) >>> # perform exact motion blur across the batch >>> out_1 = motion_blur(input, 5, 90., 1) >>> torch.allclose(out_1[0], out_1[1]) True >>> # perform element-wise motion blur across the batch >>> out_1 = motion_blur(input, 5, torch.tensor([90., 180,]), torch.tensor([1., -1.])) >>> torch.allclose(out_1[0], out_1[1]) False
Kernels¶
-
get_gaussian_kernel1d
(kernel_size: int, sigma: float, force_even: bool = False) → torch.Tensor[source]¶ Function that returns Gaussian filter coefficients.
- Parameters
- Returns
1D tensor with gaussian filter coefficients.
- Return type
Tensor
- Shape:
Output: \((\text{kernel_size})\)
Examples
>>> get_gaussian_kernel1d(3, 2.5) tensor([0.3243, 0.3513, 0.3243])
>>> get_gaussian_kernel1d(5, 1.5) tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201])
-
get_gaussian_erf_kernel1d
(kernel_size: int, sigma: float, force_even: bool = False) → torch.Tensor[source]¶ Function that returns Gaussian filter coefficients by interpolating the error fucntion, adapted from: https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
- Parameters
- Returns
1D tensor with gaussian filter coefficients.
- Return type
Tensor
- Shape:
Output: \((\text{kernel_size})\)
Examples
>>> get_gaussian_erf_kernel1d(3, 2.5) tensor([0.3245, 0.3511, 0.3245])
>>> get_gaussian_erf_kernel1d(5, 1.5) tensor([0.1226, 0.2331, 0.2887, 0.2331, 0.1226])
-
get_gaussian_discrete_kernel1d
(kernel_size: int, sigma: float, force_even: bool = False) → torch.Tensor[source]¶ Function that returns Gaussian filter coefficients based on the modified Bessel functions. Adapted from: https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
- Parameters
- Returns
1D tensor with gaussian filter coefficients.
- Return type
Tensor
- Shape:
Output: \((\text{kernel_size})\)
Examples
>>> get_gaussian_discrete_kernel1d(3, 2.5) tensor([0.3235, 0.3531, 0.3235])
>>> get_gaussian_discrete_kernel1d(5, 1.5) tensor([0.1096, 0.2323, 0.3161, 0.2323, 0.1096])
-
get_gaussian_kernel2d
(kernel_size: Tuple[int, int], sigma: Tuple[float, float], force_even: bool = False) → torch.Tensor[source]¶ Function that returns Gaussian filter matrix coefficients.
- Parameters
- Returns
2D tensor with gaussian filter matrix coefficients.
- Return type
Tensor
- Shape:
Output: \((\text{kernel_size}_x, \text{kernel_size}_y)\)
Examples
>>> get_gaussian_kernel2d((3, 3), (1.5, 1.5)) tensor([[0.0947, 0.1183, 0.0947], [0.1183, 0.1478, 0.1183], [0.0947, 0.1183, 0.0947]]) >>> get_gaussian_kernel2d((3, 5), (1.5, 1.5)) tensor([[0.0370, 0.0720, 0.0899, 0.0720, 0.0370], [0.0462, 0.0899, 0.1123, 0.0899, 0.0462], [0.0370, 0.0720, 0.0899, 0.0720, 0.0370]])
-
get_laplacian_kernel1d
(kernel_size: int) → torch.Tensor[source]¶ Function that returns the coefficients of a 1D Laplacian filter.
- Parameters
kernel_size (int) – filter size. It should be odd and positive.
- Returns
1D tensor with laplacian filter coefficients.
- Return type
Tensor (float)
- Shape:
Output: math:(text{kernel_size})
Examples
>>> get_laplacian_kernel1d(3) tensor([ 1., -2., 1.]) >>> get_laplacian_kernel1d(5) tensor([ 1., 1., -4., 1., 1.])
-
get_laplacian_kernel2d
(kernel_size: int) → torch.Tensor[source]¶ Function that returns Gaussian filter matrix coefficients.
- Parameters
kernel_size (int) – filter size should be odd.
- Returns
2D tensor with laplacian filter matrix coefficients.
- Return type
Tensor
- Shape:
Output: \((\text{kernel_size}_x, \text{kernel_size}_y)\)
Examples
>>> get_laplacian_kernel2d(3) tensor([[ 1., 1., 1.], [ 1., -8., 1.], [ 1., 1., 1.]]) >>> get_laplacian_kernel2d(5) tensor([[ 1., 1., 1., 1., 1.], [ 1., 1., 1., 1., 1.], [ 1., 1., -24., 1., 1.], [ 1., 1., 1., 1., 1.], [ 1., 1., 1., 1., 1.]])
-
get_motion_kernel2d
(kernel_size: int, angle: Union[torch.Tensor, float], direction: Union[torch.Tensor, float] = 0.0, mode: str = 'nearest') → torch.Tensor[source]¶ Return 2D motion blur filter.
- Parameters
kernel_size (int) – motion kernel width and height. It should be odd and positive.
angle (torch.Tensor, float) – angle of the motion blur in degrees (anti-clockwise rotation).
direction (float) – forward/backward direction of the motion blur. Lower values towards -1.0 will point the motion blur towards the back (with angle provided via angle), while higher values towards 1.0 will point the motion blur forward. A value of 0.0 leads to a uniformly (but still angled) motion blur.
mode (str) – interpolation mode for rotating the kernel.
'bilinear'
or'nearest'
. Default:'nearest'
- Returns
the motion blur kernel.
- Return type
- Shape:
Output: \((B, ksize, ksize)\)
- Examples::
>>> get_motion_kernel2d(5, 0., 0.) tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2000, 0.2000, 0.2000, 0.2000, 0.2000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]) >>> get_motion_kernel2d(3, 215., -0.5) tensor([[[0.0000, 0.0000, 0.1667], [0.0000, 0.3333, 0.0000], [0.5000, 0.0000, 0.0000]]])
Edge detection¶
-
laplacian
(input: torch.Tensor, kernel_size: int, border_type: str = 'reflect', normalized: bool = True) → torch.Tensor[source]¶ Creates an operator that returns a tensor using a Laplacian filter.
The operator smooths the given tensor with a laplacian kernel by convolving it to each channel. It supports batched operation.
- Parameters
input (torch.Tensor) – the input image tensor with shape \((B, C, H, W)\).
kernel_size (int) – the size of the kernel.
border_type (str) – the padding mode to be applied before convolving. The expected modes are:
'constant'
,'reflect'
,'replicate'
or'circular'
. Default:'reflect'
.normalized (bool) – if True, L1 norm of the kernel is set to 1.
- Returns
the blurred image with shape \((B, C, H, W)\).
- Return type
Examples
>>> input = torch.rand(2, 4, 5, 5) >>> output = laplacian(input, 3) >>> output.shape torch.Size([2, 4, 5, 5])
-
sobel
(input: torch.Tensor, normalized: bool = True, eps: float = 1e-06) → torch.Tensor[source]¶ Computes the Sobel operator and returns the magnitude per channel.
- Parameters
input (torch.Tensor) – the input image with shape \((B,C,H,W)\).
normalized (bool) – if True, L1 norm of the kernel is set to 1.
eps (float) – regularization number to avoid NaN during backprop. Default: 1e-6.
- Returns
the sobel edge gradient magnitudes map with shape \((B,C,H,W)\).
- Return type
Example
>>> input = torch.rand(1, 3, 4, 4) >>> output = sobel(input) # 1x3x4x4 >>> output.shape torch.Size([1, 3, 4, 4])
-
spatial_gradient
(input: torch.Tensor, mode: str = 'sobel', order: int = 1, normalized: bool = True) → torch.Tensor[source]¶ Computes the first order image derivative in both x and y using a Sobel operator.
- Parameters
input (torch.Tensor) – input image tensor with shape \((B, C, H, W)\).
mode (str) – derivatives modality, can be: sobel or diff. Default: sobel.
order (int) – the order of the derivatives. Default: 1.
normalized (bool) – whether the output is normalized. Default: True.
- Returns
the derivatives of the input feature map. with shape \((B, C, 2, H, W)\).
- Return type
Examples
>>> input = torch.rand(1, 3, 4, 4) >>> output = spatial_gradient(input) # 1x3x2x4x4 >>> output.shape torch.Size([1, 3, 2, 4, 4])
-
spatial_gradient3d
(input: torch.Tensor, mode: str = 'diff', order: int = 1) → torch.Tensor[source]¶ Computes the first and second order volume derivative in x, y and d using a diff operator.
- Parameters
input (torch.Tensor) – input features tensor with shape \((B, C, D, H, W)\).
mode (str) – derivatives modality, can be: sobel or diff. Default: diff.
order (int) – the order of the derivatives. Default: 1.
- Returns
the spatial gradients of the input feature map.
- Return type
- Shape:
Input: \((B, C, D, H, W)\). D, H, W are spatial dimensions, gradient is calculated w.r.t to them.
Output: \((B, C, 3, D, H, W)\) or \((B, C, 6, D, H, W)\)
Examples
>>> input = torch.rand(1, 4, 2, 4, 4) >>> output = spatial_gradient3d(input) >>> output.shape torch.Size([1, 4, 3, 2, 4, 4])
Module¶
-
class
BoxBlur
(kernel_size: Tuple[int, int], border_type: str = 'reflect', normalized: bool = True)[source]¶ Blurs an image using the box filter.
The function smooths an image using the kernel:
\[\begin{split}K = \frac{1}{\text{kernel_size}_x * \text{kernel_size}_y} \begin{bmatrix} 1 & 1 & 1 & \cdots & 1 & 1 \\ 1 & 1 & 1 & \cdots & 1 & 1 \\ \vdots & \vdots & \vdots & \ddots & \vdots & \vdots \\ 1 & 1 & 1 & \cdots & 1 & 1 \\ \end{bmatrix}\end{split}\]- Parameters
- Returns
the blurred input tensor.
- Return type
- Shape:
Input: \((B, C, H, W)\)
Output: \((B, C, H, W)\)
Example
>>> input = torch.rand(2, 4, 5, 7) >>> blur = BoxBlur((3, 3)) >>> output = blur(input) # 2x4x5x7 >>> output.shape torch.Size([2, 4, 5, 7])
-
class
MedianBlur
(kernel_size: Tuple[int, int])[source]¶ Blurs an image using the median filter.
- Parameters
- Returns
the blurred input tensor.
- Return type
- Shape:
Input: \((B, C, H, W)\)
Output: \((B, C, H, W)\)
Example
>>> input = torch.rand(2, 4, 5, 7) >>> blur = MedianBlur((3, 3)) >>> output = blur(input) >>> output.shape torch.Size([2, 4, 5, 7])
-
class
GaussianBlur2d
(kernel_size: Tuple[int, int], sigma: Tuple[float, float], border_type: str = 'reflect')[source]¶ Creates an operator that blurs a tensor using a Gaussian filter.
The operator smooths the given tensor with a gaussian kernel by convolving it to each channel. It supports batched operation.
- Parameters
- Returns
the blurred tensor.
- Return type
Tensor
- Shape:
Input: \((B, C, H, W)\)
Output: \((B, C, H, W)\)
Examples:
>>> input = torch.rand(2, 4, 5, 5) >>> gauss = GaussianBlur2d((3, 3), (1.5, 1.5)) >>> output = gauss(input) # 2x4x5x5 >>> output.shape torch.Size([2, 4, 5, 5])
-
class
Laplacian
(kernel_size: int, border_type: str = 'reflect', normalized: bool = True)[source]¶ Creates an operator that returns a tensor using a Laplacian filter.
The operator smooths the given tensor with a laplacian kernel by convolving it to each channel. It supports batched operation.
- Parameters
- Shape:
Input: \((B, C, H, W)\)
Output: \((B, C, H, W)\)
Examples
>>> input = torch.rand(2, 4, 5, 5) >>> laplace = Laplacian(5) >>> output = laplace(input) >>> output.shape torch.Size([2, 4, 5, 5])
-
class
Sobel
(normalized: bool = True, eps: float = 1e-06)[source]¶ Computes the Sobel operator and returns the magnitude per channel.
- Parameters
- Returns
the sobel edge gradient magnitudes map.
- Return type
- Shape:
Input: \((B, C, H, W)\)
Output: \((B, C, H, W)\)
Examples
>>> input = torch.rand(1, 3, 4, 4) >>> output = Sobel()(input) # 1x3x4x4
-
class
SpatialGradient
(mode: str = 'sobel', order: int = 1, normalized: bool = True)[source]¶ Computes the first order image derivative in both x and y using a Sobel operator.
- Parameters
- Returns
the sobel edges of the input feature map.
- Return type
- Shape:
Input: \((B, C, H, W)\)
Output: \((B, C, 2, H, W)\)
Examples
>>> input = torch.rand(1, 3, 4, 4) >>> output = SpatialGradient()(input) # 1x3x2x4x4
-
class
SpatialGradient3d
(mode: str = 'diff', order: int = 1)[source]¶ Computes the first and second order volume derivative in x, y and d using a diff operator.
- Parameters
- Returns
the spatial gradients of the input feature map.
- Return type
- Shape:
Input: \((B, C, D, H, W)\). D, H, W are spatial dimensions, gradient is calculated w.r.t to them.
Output: \((B, C, 3, D, H, W)\) or \((B, C, 6, D, H, W)\)
Examples
>>> input = torch.rand(1, 4, 2, 4, 4) >>> output = SpatialGradient3d()(input) >>> output.shape torch.Size([1, 4, 3, 2, 4, 4])
-
class
MotionBlur
(kernel_size: int, angle: float, direction: float, border_type: str = 'constant')[source]¶ Blur 2D images (4D tensor) using the motion filter.
- Parameters
kernel_size (int) – motion kernel width and height. It should be odd and positive.
angle (float) – angle of the motion blur in degrees (anti-clockwise rotation).
direction (float) – forward/backward direction of the motion blur. Lower values towards -1.0 will point the motion blur towards the back (with angle provided via angle), while higher values towards 1.0 will point the motion blur forward. A value of 0.0 leads to a uniformly (but still angled) motion blur.
border_type (str) – the padding mode to be applied before convolving. The expected modes are:
'constant'
,'reflect'
,'replicate'
or'circular'
. Default:'constant'
.
- Returns
the blurred input tensor.
- Return type
- Shape:
Input: \((B, C, H, W)\)
Output: \((B, C, H, W)\)
Examples
>>> input = torch.rand(2, 4, 5, 7) >>> motion_blur = MotionBlur(3, 35., 0.5) >>> output = motion_blur(input) # 2x4x5x7