# Source code for kornia.losses.ssim

from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from kornia.filters import get_gaussian_kernel2d, filter2D

return (kernel_size - 1) // 2

[docs]class SSIM(nn.Module):
r"""Creates a criterion that measures the Structural Similarity (SSIM)
index between each element in the input x and target y.

The index can be described as:

.. math::

\text{SSIM}(x, y) = \frac{(2\mu_x\mu_y+c_1)(2\sigma_{xy}+c_2)}
{(\mu_x^2+\mu_y^2+c_1)(\sigma_x^2+\sigma_y^2+c_2)}

where:
- :math:c_1=(k_1 L)^2 and :math:c_2=(k_2 L)^2 are two variables to
stabilize the division with weak denominator.
- :math:L is the dynamic range of the pixel-values (typically this is
:math:2^{\#\text{bits per pixel}}-1).

the loss, or the Structural dissimilarity (DSSIM) can be finally described
as:

.. math::

\text{loss}(x, y) = \frac{1 - \text{SSIM}(x, y)}{2}

Arguments:
window_size (int): the size of the kernel.
max_val (float): the dynamic range of the images. Default: 1.
reduction (str, optional): Specifies the reduction to apply to the
output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
'mean': the sum of the output will be divided by the number of elements
in the output, 'sum': the output will be summed. Default: 'none'.

Returns:
Tensor: the ssim index.

Shape:
- Input: :math:(B, C, H, W)
- Target :math:(B, C, H, W)
- Output: scale, if reduction is 'none', then :math:(B, C, H, W)

Examples::

>>> input1 = torch.rand(1, 4, 5, 5)
>>> input2 = torch.rand(1, 4, 5, 5)
>>> ssim = kornia.losses.SSIM(5, reduction='none')
>>> loss = ssim(input1, input2)  # 1x4x5x5
"""

def __init__(
self,
window_size: int,
reduction: str = "none",
max_val: float = 1.0) -> None:
super(SSIM, self).__init__()
self.window_size: int = window_size
self.max_val: float = max_val
self.reduction: str = reduction

self.window: torch.Tensor = get_gaussian_kernel2d(
(window_size, window_size), (1.5, 1.5))

self.C1: float = (0.01 * self.max_val) ** 2
self.C2: float = (0.03 * self.max_val) ** 2

def forward(  # type: ignore
self,
img1: torch.Tensor,
img2: torch.Tensor) -> torch.Tensor:

if not torch.is_tensor(img1):
raise TypeError("Input img1 type is not a torch.Tensor. Got {}"
.format(type(img1)))

if not torch.is_tensor(img2):
raise TypeError("Input img2 type is not a torch.Tensor. Got {}"
.format(type(img2)))

if not len(img1.shape) == 4:
raise ValueError("Invalid img1 shape, we expect BxCxHxW. Got: {}"
.format(img1.shape))

if not len(img2.shape) == 4:
raise ValueError("Invalid img2 shape, we expect BxCxHxW. Got: {}"
.format(img2.shape))

if not img1.shape == img2.shape:
raise ValueError("img1 and img2 shapes must be the same. Got: {} and {}"
.format(img1.shape, img2.shape))

if not img1.device == img2.device:
raise ValueError("img1 and img2 must be in the same device. Got: {} and {}"
.format(img1.device, img2.device))

if not img1.dtype == img2.dtype:
raise ValueError("img1 and img2 must be in the same dtype. Got: {} and {}"
.format(img1.dtype, img2.dtype))

# prepare kernel
b, c, h, w = img1.shape
tmp_kernel: torch.Tensor = self.window.to(img1.device).to(img1.dtype)
tmp_kernel = torch.unsqueeze(tmp_kernel, dim=0)

# compute local mean per channel
mu1: torch.Tensor = filter2D(img1, tmp_kernel)
mu2: torch.Tensor = filter2D(img2, tmp_kernel)

mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2

# compute local sigma per channel
sigma1_sq = filter2D(img1 * img1, tmp_kernel) - mu1_sq
sigma2_sq = filter2D(img2 * img2, tmp_kernel) - mu2_sq
sigma12 = filter2D(img1 * img2, tmp_kernel) - mu1_mu2

ssim_map = ((2. * mu1_mu2 + self.C1) * (2. * sigma12 + self.C2)) / \
((mu1_sq + mu2_sq + self.C1) * (sigma1_sq + sigma2_sq + self.C2))

loss = torch.clamp(-ssim_map + 1., min=0, max=1) / 2.

if self.reduction == "mean":
loss = torch.mean(loss)
elif self.reduction == "sum":
loss = torch.sum(loss)
elif self.reduction == "none":
pass
return loss

######################
# functional interface
######################

[docs]def ssim(
img1: torch.Tensor,
img2: torch.Tensor,
window_size: int,
reduction: str = "none",
max_val: float = 1.0) -> torch.Tensor:
r"""Function that measures the Structural Similarity (SSIM) index between
each element in the input x and target y.

See :class:~kornia.losses.SSIM for details.
"""
return SSIM(window_size, reduction, max_val)(img1, img2)