Source code for kornia.losses.ssim

from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from kornia.filters import get_gaussian_kernel2d, filter2D


[docs]def ssim(img1: torch.Tensor, img2: torch.Tensor, window_size: int, max_val: float = 1.0, eps: float = 1e-12) -> torch.Tensor: r"""Function that computes the Structural Similarity (SSIM) index map between two images. Measures the (SSIM) index between each element in the input `x` and target `y`. The index can be described as: .. math:: \text{SSIM}(x, y) = \frac{(2\mu_x\mu_y+c_1)(2\sigma_{xy}+c_2)} {(\mu_x^2+\mu_y^2+c_1)(\sigma_x^2+\sigma_y^2+c_2)} where: - :math:`c_1=(k_1 L)^2` and :math:`c_2=(k_2 L)^2` are two variables to stabilize the division with weak denominator. - :math:`L` is the dynamic range of the pixel-values (typically this is :math:`2^{\#\text{bits per pixel}}-1`). Args: img1 (torch.Tensor): the first input image with shape :math:`(B, C, H, W)`. img2 (torch.Tensor): the second input image with shape :math:`(B, C, H, W)`. window_size (int): the size of the gaussian kernel to smooth the images. max_val (float): the dynamic range of the images. Default: 1. eps (float): Small value for numerically stability when dividing. Default: 1e-12. Returns: torch.Tensor: The ssim index map with shape :math:`(B, C, H, W)`. Examples: >>> input1 = torch.rand(1, 4, 5, 5) >>> input2 = torch.rand(1, 4, 5, 5) >>> ssim_map = ssim(input1, input2, 5) # 1x4x5x5 """ if not isinstance(img1, torch.Tensor): raise TypeError("Input img1 type is not a torch.Tensor. Got {}" .format(type(img1))) if not isinstance(img2, torch.Tensor): raise TypeError("Input img2 type is not a torch.Tensor. Got {}" .format(type(img2))) if not isinstance(max_val, float): raise TypeError(f"Input max_val type is not a float. Got {type(max_val)}") if not len(img1.shape) == 4: raise ValueError("Invalid img1 shape, we expect BxCxHxW. Got: {}" .format(img1.shape)) if not len(img2.shape) == 4: raise ValueError("Invalid img2 shape, we expect BxCxHxW. Got: {}" .format(img2.shape)) if not img1.shape == img2.shape: raise ValueError("img1 and img2 shapes must be the same. Got: {} and {}" .format(img1.shape, img2.shape)) # prepare kernel kernel: torch.Tensor = ( get_gaussian_kernel2d((window_size, window_size), (1.5, 1.5)).unsqueeze(0) ) # compute coefficients C1: float = (0.01 * max_val) ** 2 C2: float = (0.03 * max_val) ** 2 # compute local mean per channel mu1: torch.Tensor = filter2D(img1, kernel) mu2: torch.Tensor = filter2D(img2, kernel) mu1_sq = mu1 ** 2 mu2_sq = mu2 ** 2 mu1_mu2 = mu1 * mu2 # compute local sigma per channel sigma1_sq = filter2D(img1 ** 2, kernel) - mu1_sq sigma2_sq = filter2D(img2 ** 2, kernel) - mu2_sq sigma12 = filter2D(img1 * img2, kernel) - mu1_mu2 # compute the similarity index map num: torch.Tensor = (2. * mu1_mu2 + C1) * (2. * sigma12 + C2) den: torch.Tensor = ( (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) ) return num / (den + eps)
[docs]def ssim_loss(img1: torch.Tensor, img2: torch.Tensor, window_size: int, max_val: float = 1.0, eps: float = 1e-12, reduction: str = 'mean') -> torch.Tensor: r"""Function that computes a loss based on the SSIM measurement. The loss, or the Structural dissimilarity (DSSIM) is described as: .. math:: \text{loss}(x, y) = \frac{1 - \text{SSIM}(x, y)}{2} See :meth:`~kornia.losses.ssim` for details about SSIM. Args: img1 (torch.Tensor): the first input image with shape :math:`(B, C, H, W)`. img2 (torch.Tensor): the second input image with shape :math:`(B, C, H, W)`. window_size (int): the size of the gaussian kernel to smooth the images. max_val (float): the dynamic range of the images. Default: 1. eps (float): Small value for numerically stability when dividing. Default: 1e-12. reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed. Default: 'mean'. Returns: torch.Tensor: The loss based on the ssim index. Examples: >>> input1 = torch.rand(1, 4, 5, 5) >>> input2 = torch.rand(1, 4, 5, 5) >>> loss = ssim_loss(input1, input2, 5) """ # compute the ssim map ssim_map: torch.Tensor = ssim(img1, img2, window_size, max_val, eps) # compute and reduce the loss loss = torch.clamp((1. - ssim_map) / 2, min=0, max=1) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) elif reduction == "none": pass return loss
[docs]class SSIM(nn.Module): r"""Creates a module that computes the Structural Similarity (SSIM) index between two images. Measures the (SSIM) index between each element in the input `x` and target `y`. The index can be described as: .. math:: \text{SSIM}(x, y) = \frac{(2\mu_x\mu_y+c_1)(2\sigma_{xy}+c_2)} {(\mu_x^2+\mu_y^2+c_1)(\sigma_x^2+\sigma_y^2+c_2)} where: - :math:`c_1=(k_1 L)^2` and :math:`c_2=(k_2 L)^2` are two variables to stabilize the division with weak denominator. - :math:`L` is the dynamic range of the pixel-values (typically this is :math:`2^{\#\text{bits per pixel}}-1`). Args: window_size (int): the size of the gaussian kernel to smooth the images. max_val (float): the dynamic range of the images. Default: 1. eps (float): Small value for numerically stability when dividing. Default: 1e-12. Shape: - Input: :math:`(B, C, H, W)`. - Target :math:`(B, C, H, W)`. - Output: :math:`(B, C, H, W)`. Examples: >>> input1 = torch.rand(1, 4, 5, 5) >>> input2 = torch.rand(1, 4, 5, 5) >>> ssim = SSIM(5) >>> ssim_map = ssim(input1, input2) # 1x4x5x5 """ def __init__(self, window_size: int, max_val: float = 1.0, eps: float = 1e-12) -> None: super(SSIM, self).__init__() self.window_size: int = window_size self.max_val: float = max_val self.eps = eps def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor: return ssim(img1, img2, self.window_size, self.max_val, self.eps)
[docs]class SSIMLoss(nn.Module): r"""Creates a criterion that computes a loss based on the SSIM measurement. The loss, or the Structural dissimilarity (DSSIM) is described as: .. math:: \text{loss}(x, y) = \frac{1 - \text{SSIM}(x, y)}{2} See :meth:`~kornia.losses.ssim_loss` for details about SSIM. Args: window_size (int): the size of the gaussian kernel to smooth the images. max_val (float): the dynamic range of the images. Default: 1. eps (float): Small value for numerically stability when dividing. Default: 1e-12. reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed. Default: 'mean'. Returns: torch.Tensor: The loss based on the ssim index. Examples: >>> input1 = torch.rand(1, 4, 5, 5) >>> input2 = torch.rand(1, 4, 5, 5) >>> criterion = SSIMLoss(5) >>> loss = criterion(input1, input2) """ def __init__(self, window_size: int, max_val: float = 1.0, eps: float = 1e-12, reduction: str = 'mean') -> None: super(SSIMLoss, self).__init__() self.window_size: int = window_size self.max_val: float = max_val self.eps: float = eps self.reduction: str = reduction def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor: return ssim_loss(img1, img2, self.window_size, self.max_val, self.eps, self.reduction)