Source code for kornia.losses.total_variation

from kornia.core import Module, Tensor
from kornia.testing import KORNIA_CHECK, KORNIA_CHECK_SHAPE


[docs]def total_variation(img: Tensor, reduction: str = "sum") -> Tensor: r"""Function that computes Total Variation according to [1]. Args: img: the input image with shape :math:`(*, H, W)`. reduction : Specifies the reduction to apply to the output: ``'mean'`` | ``'sum'``. ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Return: a tensor with shape :math:`(*,)`. Examples: >>> total_variation(torch.ones(4, 4)) tensor(0.) >>> total_variation(torch.ones(2, 5, 3, 4, 4)).shape torch.Size([2, 5, 3]) .. note:: See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/ total_variation_denoising.html>`__. Total Variation is formulated with summation, however this is not resolution invariant. Thus, `reduction='mean'` was added as an optional reduction method. Reference: [1] https://en.wikipedia.org/wiki/Total_variation """ KORNIA_CHECK_SHAPE(img, ["*", "H", "W"]) KORNIA_CHECK(reduction in ("mean", "sum"), f"Expected reduction to be one of 'mean'/'sum', but got '{reduction}'.") pixel_dif1 = img[..., 1:, :] - img[..., :-1, :] pixel_dif2 = img[..., :, 1:] - img[..., :, :-1] res1 = pixel_dif1.abs() res2 = pixel_dif2.abs() reduce_axes = (-2, -1) if reduction == "mean": if img.is_floating_point(): res1 = res1.to(img).mean(dim=reduce_axes) res2 = res2.to(img).mean(dim=reduce_axes) else: res1 = res1.float().mean(dim=reduce_axes) res2 = res2.float().mean(dim=reduce_axes) elif reduction == "sum": res1 = res1.sum(dim=reduce_axes) res2 = res2.sum(dim=reduce_axes) return res1 + res2
[docs]class TotalVariation(Module): r"""Compute the Total Variation according to [1]. Shape: - Input: :math:`(*, H, W)`. - Output: :math:`(*,)`. Examples: >>> tv = TotalVariation() >>> output = tv(torch.ones((2, 3, 4, 4), requires_grad=True)) >>> output.data tensor([[0., 0., 0.], [0., 0., 0.]]) >>> output.sum().backward() # grad can be implicitly created only for scalar outputs Reference: [1] https://en.wikipedia.org/wiki/Total_variation """ def forward(self, img) -> Tensor: return total_variation(img)