kornia.losses¶
Reconstruction¶
-
ssim
(img1: torch.Tensor, img2: torch.Tensor, window_size: int, max_val: float = 1.0, eps: float = 1e-12) → torch.Tensor[source]¶ Function that computes the Structural Similarity (SSIM) index map between two images.
Measures the (SSIM) index between each element in the input x and target y.
The index can be described as:
\[\text{SSIM}(x, y) = \frac{(2\mu_x\mu_y+c_1)(2\sigma_{xy}+c_2)} {(\mu_x^2+\mu_y^2+c_1)(\sigma_x^2+\sigma_y^2+c_2)}\]- where:
\(c_1=(k_1 L)^2\) and \(c_2=(k_2 L)^2\) are two variables to stabilize the division with weak denominator.
\(L\) is the dynamic range of the pixel-values (typically this is \(2^{\#\text{bits per pixel}}-1\)).
- Parameters
img1 (torch.Tensor) – the first input image with shape \((B, C, H, W)\).
img2 (torch.Tensor) – the second input image with shape \((B, C, H, W)\).
window_size (int) – the size of the gaussian kernel to smooth the images.
max_val (float) – the dynamic range of the images. Default: 1.
eps (float) – Small value for numerically stability when dividing. Default: 1e-12.
- Returns
The ssim index map with shape \((B, C, H, W)\).
- Return type
Examples
>>> input1 = torch.rand(1, 4, 5, 5) >>> input2 = torch.rand(1, 4, 5, 5) >>> ssim_map = ssim(input1, input2, 5) # 1x4x5x5
-
ssim_loss
(img1: torch.Tensor, img2: torch.Tensor, window_size: int, max_val: float = 1.0, eps: float = 1e-12, reduction: str = 'mean') → torch.Tensor[source]¶ Function that computes a loss based on the SSIM measurement.
The loss, or the Structural dissimilarity (DSSIM) is described as:
\[\text{loss}(x, y) = \frac{1 - \text{SSIM}(x, y)}{2}\]See
ssim()
for details about SSIM.- Parameters
img1 (torch.Tensor) – the first input image with shape \((B, C, H, W)\).
img2 (torch.Tensor) – the second input image with shape \((B, C, H, W)\).
window_size (int) – the size of the gaussian kernel to smooth the images.
max_val (float) – the dynamic range of the images. Default: 1.
eps (float) – Small value for numerically stability when dividing. Default: 1e-12.
reduction (str, optional) – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Default: ‘mean’.
- Returns
The loss based on the ssim index.
- Return type
Examples
>>> input1 = torch.rand(1, 4, 5, 5) >>> input2 = torch.rand(1, 4, 5, 5) >>> loss = ssim_loss(input1, input2, 5)
-
psnr
(input: torch.Tensor, target: torch.Tensor, max_val: float) → torch.Tensor[source]¶ Creates a function that calculates the PSNR between 2 images.
PSNR is Peek Signal to Noise Ratio, which is similar to mean squared error. Given an m x n image, the PSNR is:
\[\text{PSNR} = 10 \log_{10} \bigg(\frac{\text{MAX}_I^2}{MSE(I,T)}\bigg)\]where
\[\text{MSE}(I,T) = \frac{1}{mn}\sum_{i=0}^{m-1}\sum_{j=0}^{n-1} [I(i,j) - T(i,j)]^2\]and \(\text{MAX}_I\) is the maximum possible input value (e.g for floating point images \(\text{MAX}_I=1\)).
- Parameters
input (torch.Tensor) – the input image with arbitrary shape \((*)\).
labels (torch.Tensor) – the labels image with arbitrary shape \((*)\).
max_val (float) – The maximum value in the input tensor.
- Returns
the computed loss as a scalar.
- Return type
Examples
>>> ones = torch.ones(1) >>> psnr(ones, 1.2 * ones, 2.) # 10 * log(4/((1.2-1)**2)) / log(10) tensor(20.0000)
-
psnr_loss
(input: torch.Tensor, target: torch.Tensor, max_val: float) → torch.Tensor[source]¶ Function that computes the PSNR loss.
The loss is computed as follows:
\[\text{loss} = -\text{psnr(x, y)}\]See
psnr()
for details abut PSNR.- Parameters
input (torch.Tensor) – the input image with shape \((*)\).
labels (torch.Tensor) – the labels image with shape \((*)\).
max_val (float) – The maximum value in the input tensor.
- Returns
the computed loss as a scalar.
- Return type
Examples
>>> ones = torch.ones(1) >>> psnr_loss(ones, 1.2 * ones, 2.) # 10 * log(4/((1.2-1)**2)) / log(10) tensor(-20.0000)
-
total_variation
(img: torch.Tensor) → torch.Tensor[source]¶ Function that computes Total Variation according to [1].
- Parameters
img (torch.Tensor) – the input image with shape \((N, C, H, W)\) or \((C, H, W)\).
- Returns
a scalar with the computer loss.
- Return type
Examples
>>> total_variation(torch.ones(3, 4, 4)) tensor(0.)
- Reference:
-
inverse_depth_smoothness_loss
(idepth: torch.Tensor, image: torch.Tensor) → torch.Tensor[source]¶ Criterion that computes image-aware inverse depth smoothness loss.
\[\text{loss} = \left | \partial_x d_{ij} \right | e^{-\left \| \partial_x I_{ij} \right \|} + \left | \partial_y d_{ij} \right | e^{-\left \| \partial_y I_{ij} \right \|}\]- Parameters
idepth (torch.Tensor) – tensor with the inverse depth with shape \((N, 1, H, W)\).
image (torch.Tensor) – tensor with the input image with shape \((N, 3, H, W)\).
- Returns
a scalar with the computed loss.
- Return type
Examples
>>> idepth = torch.rand(1, 1, 4, 5) >>> image = torch.rand(1, 3, 4, 5) >>> loss = inverse_depth_smoothness_loss(idepth, image)
Semantic Segmentation¶
-
binary_focal_loss_with_logits
(input: torch.Tensor, target: torch.Tensor, alpha: float = 0.25, gamma: float = 2.0, reduction: str = 'none', eps: float = 1e-08) → torch.Tensor[source]¶ Function that computes Binary Focal loss.
\[\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)\]- where:
\(p_t\) is the model’s estimated probability for each class.
- Parameters
input (torch.Tensor) – input data tensor with shape \((N, 1, *)\).
target (torch.Tensor) – the target tensor with shape \((N, 1, *)\).
alpha (float) – Weighting factor for the rare class \(\alpha \in [0, 1]\). Default: 0.25.
gamma (float) – Focusing parameter \(\gamma >= 0\). Default: 2.0.
reduction (str, optional) – Specifies the reduction to apply to the. Default: ‘none’.
eps (float) – for numerically stability when dividing. Default: 1e-8.
- Returns
the computed loss.
- Return type
torch.tensor
Examples
>>> num_classes = 1 >>> kwargs = {"alpha": 0.25, "gamma": 2.0, "reduction": 'mean'} >>> logits = torch.tensor([[[[6.325]]],[[[5.26]]],[[[87.49]]]]) >>> labels = torch.tensor([[[1.]],[[1.]],[[0.]]]) >>> binary_focal_loss_with_logits(logits, labels, **kwargs) tensor(4.6052)
-
focal_loss
(input: torch.Tensor, target: torch.Tensor, alpha: float, gamma: float = 2.0, reduction: str = 'none', eps: float = 1e-08) → torch.Tensor[source]¶ Criterion that computes Focal loss.
According to [lin2018focal], the Focal loss is computed as follows:
\[\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)\]- Where:
\(p_t\) is the model’s estimated probability for each class.
- Parameters
input (torch.Tensor) – logits tensor with shape \((N, C, *)\) where C = number of classes.
target (torch.Tensor) – labels tensor with shape \((N, *)\) where each value is \(0 ≤ targets[i] ≤ C−1\).
alpha (float) – Weighting factor \(\alpha \in [0, 1]\).
gamma (float, optional) – Focusing parameter \(\gamma >= 0\). Default 2.
reduction (str, optional) – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Default: ‘none’.
eps (float, optional) – Scalar to enforce numerical stabiliy. Default: 1e-8.
- Returns
the computed loss.
- Return type
Example
>>> N = 5 # num_classes >>> input = torch.randn(1, N, 3, 5, requires_grad=True) >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) >>> output = focal_loss(input, target, alpha=0.5, gamma=2.0, reduction='mean') >>> output.backward()
-
dice_loss
(input: torch.Tensor, target: torch.Tensor, eps: float = 1e-08) → torch.Tensor[source]¶ Criterion that computes Sørensen-Dice Coefficient loss.
According to [1], we compute the Sørensen-Dice Coefficient as follows:
\[\text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|}\]- Where:
\(X\) expects to be the scores of each class.
\(Y\) expects to be the one-hot tensor with the class labels.
the loss, is finally computed as:
\[\text{loss}(x, class) = 1 - \text{Dice}(x, class)\]- Parameters
input (torch.Tensor) – logits tensor with shape \((N, C, H, W)\) where C = number of classes.
labels (torch.Tensor) – labels tensor with shape \((N, H, W)\) where each value is \(0 ≤ targets[i] ≤ C−1\).
eps (float, optional) – Scalar to enforce numerical stabiliy. Default: 1e-8.
- Returns
the computed loss.
- Return type
Example
>>> N = 5 # num_classes >>> input = torch.randn(1, N, 3, 5, requires_grad=True) >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) >>> output = dice_loss(input, target) >>> output.backward()
-
tversky_loss
(input: torch.Tensor, target: torch.Tensor, alpha: float, beta: float, eps: float = 1e-08) → torch.Tensor[source]¶ Criterion that computes Tversky Coefficient loss.
According to [salehi2017tversky], we compute the Tversky Coefficient as follows:
\[\text{S}(P, G, \alpha; \beta) = \frac{|PG|}{|PG| + \alpha |P \setminus G| + \beta |G \setminus P|}\]- Where:
\(P\) and \(G\) are the predicted and ground truth binary labels.
\(\alpha\) and \(\beta\) control the magnitude of the penalties for FPs and FNs, respectively.
Note
\(\alpha = \beta = 0.5\) => dice coeff
\(\alpha = \beta = 1\) => tanimoto coeff
\(\alpha + \beta = 1\) => F beta coeff
- Parameters
input (torch.Tensor) – logits tensor with shape \((N, C, H, W)\) where C = number of classes.
target (torch.Tensor) – labels tensor with shape \((N, H, W)\) where each value is \(0 ≤ targets[i] ≤ C−1\).
alpha (float) – the first coefficient in the denominator.
beta (float) – the second coefficient in the denominator.
eps (float, optional) – scalar for numerical stability. Default: 1e-8.
- Returns
the computed loss.
- Return type
Example
>>> N = 5 # num_classes >>> input = torch.randn(1, N, 3, 5, requires_grad=True) >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) >>> output = tversky_loss(input, target, alpha=0.5, beta=0.5) >>> output.backward()
Distributions¶
-
js_div_loss_2d
(input: torch.Tensor, target: torch.Tensor, reduction: str = 'mean')[source]¶ Calculates the Jensen-Shannon divergence loss between heatmaps.
- Parameters
input (torch.Tensor) – the input tensor with shape \((B, N, H, W)\).
target (torch.Tensor) – the target tensor with shape \((B, N, H, W)\).
reduction (string, optional) – Specifies the reduction to apply to the output: none | mean | sum. none: no reduction will be applied, mean: the sum of the output will be divided by the number of elements in the output, sum: the output will be summed. Default: mean.
Examples
>>> input = torch.full((1, 1, 2, 4), 0.125) >>> loss = js_div_loss_2d(input, input) >>> loss.item() 0.0
-
kl_div_loss_2d
(input: torch.Tensor, target: torch.Tensor, reduction: str = 'mean')[source]¶ Calculates the Kullback-Leibler divergence loss between heatmaps.
- Parameters
input (torch.Tensor) – the input tensor with shape \((B, N, H, W)\).
target (torch.Tensor) – the target tensor with shape \((B, N, H, W)\).
reduction (string, optional) – Specifies the reduction to apply to the output: none | mean | sum. none: no reduction will be applied, mean: the sum of the output will be divided by the number of elements in the output, sum: the output will be summed. Default: mean.
Examples
>>> input = torch.full((1, 1, 2, 4), 0.125) >>> loss = js_div_loss_2d(input, input) >>> loss.item() 0.0
Module¶
-
class
DiceLoss
(eps: float = 1e-08)[source]¶ Criterion that computes Sørensen-Dice Coefficient loss.
According to [1], we compute the Sørensen-Dice Coefficient as follows:
\[\text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|}\]- Where:
\(X\) expects to be the scores of each class.
\(Y\) expects to be the one-hot tensor with the class labels.
the loss, is finally computed as:
\[\text{loss}(x, class) = 1 - \text{Dice}(x, class)\]- Parameters
eps (float, optional) – Scalar to enforce numerical stabiliy. Default: 1e-8.
- Shape:
Input: \((N, C, H, W)\) where C = number of classes.
Target: \((N, H, W)\) where each value is \(0 ≤ targets[i] ≤ C−1\).
Example
>>> N = 5 # num_classes >>> criterion = DiceLoss() >>> input = torch.randn(1, N, 3, 5, requires_grad=True) >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) >>> output = criterion(input, target) >>> output.backward()
-
class
TverskyLoss
(alpha: float, beta: float, eps: float = 1e-08)[source]¶ Criterion that computes Tversky Coefficient loss.
According to [salehi2017tversky], we compute the Tversky Coefficient as follows:
\[\text{S}(P, G, \alpha; \beta) = \frac{|PG|}{|PG| + \alpha |P \setminus G| + \beta |G \setminus P|}\]- Where:
\(P\) and \(G\) are the predicted and ground truth binary labels.
\(\alpha\) and \(\beta\) control the magnitude of the penalties for FPs and FNs, respectively.
Note
\(\alpha = \beta = 0.5\) => dice coeff
\(\alpha = \beta = 1\) => tanimoto coeff
\(\alpha + \beta = 1\) => F beta coeff
- Parameters
- Shape:
Input: \((N, C, H, W)\) where C = number of classes.
Target: \((N, H, W)\) where each value is \(0 ≤ targets[i] ≤ C−1\).
Examples
>>> N = 5 # num_classes >>> criterion = TverskyLoss(alpha=0.5, beta=0.5) >>> input = torch.randn(1, N, 3, 5, requires_grad=True) >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) >>> output = criterion(input, target) >>> output.backward()
-
class
FocalLoss
(alpha: float, gamma: float = 2.0, reduction: str = 'none', eps: float = 1e-08)[source]¶ Criterion that computes Focal loss.
According to [lin2018focal], the Focal loss is computed as follows:
\[\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)\]- Where:
\(p_t\) is the model’s estimated probability for each class.
- Parameters
alpha (float) – Weighting factor \(\alpha \in [0, 1]\).
gamma (float, optional) – Focusing parameter \(\gamma >= 0\). Default 2.
reduction (str, optional) – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Default: ‘none’.
eps (float, optional) – Scalar to enforce numerical stabiliy. Default: 1e-8.
- Shape:
Input: \((N, C, *)\) where C = number of classes.
Target: \((N, *)\) where each value is \(0 ≤ targets[i] ≤ C−1\).
Example
>>> N = 5 # num_classes >>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'} >>> criterion = FocalLoss(**kwargs) >>> input = torch.randn(1, N, 3, 5, requires_grad=True) >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) >>> output = criterion(input, target) >>> output.backward()
-
class
SSIM
(window_size: int, max_val: float = 1.0, eps: float = 1e-12)[source]¶ Creates a module that computes the Structural Similarity (SSIM) index between two images.
Measures the (SSIM) index between each element in the input x and target y.
The index can be described as:
\[\text{SSIM}(x, y) = \frac{(2\mu_x\mu_y+c_1)(2\sigma_{xy}+c_2)} {(\mu_x^2+\mu_y^2+c_1)(\sigma_x^2+\sigma_y^2+c_2)}\]- where:
\(c_1=(k_1 L)^2\) and \(c_2=(k_2 L)^2\) are two variables to stabilize the division with weak denominator.
\(L\) is the dynamic range of the pixel-values (typically this is \(2^{\#\text{bits per pixel}}-1\)).
- Parameters
- Shape:
Input: \((B, C, H, W)\).
Target \((B, C, H, W)\).
Output: \((B, C, H, W)\).
Examples
>>> input1 = torch.rand(1, 4, 5, 5) >>> input2 = torch.rand(1, 4, 5, 5) >>> ssim = SSIM(5) >>> ssim_map = ssim(input1, input2) # 1x4x5x5
-
class
SSIMLoss
(window_size: int, max_val: float = 1.0, eps: float = 1e-12, reduction: str = 'mean')[source]¶ Creates a criterion that computes a loss based on the SSIM measurement.
The loss, or the Structural dissimilarity (DSSIM) is described as:
\[\text{loss}(x, y) = \frac{1 - \text{SSIM}(x, y)}{2}\]See
ssim_loss()
for details about SSIM.- Parameters
window_size (int) – the size of the gaussian kernel to smooth the images.
max_val (float) – the dynamic range of the images. Default: 1.
eps (float) – Small value for numerically stability when dividing. Default: 1e-12.
reduction (str, optional) – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Default: ‘mean’.
- Returns
The loss based on the ssim index.
- Return type
Examples
>>> input1 = torch.rand(1, 4, 5, 5) >>> input2 = torch.rand(1, 4, 5, 5) >>> criterion = SSIMLoss(5) >>> loss = criterion(input1, input2)
-
class
InverseDepthSmoothnessLoss
[source]¶ Criterion that computes image-aware inverse depth smoothness loss.
\[\text{loss} = \left | \partial_x d_{ij} \right | e^{-\left \| \partial_x I_{ij} \right \|} + \left | \partial_y d_{ij} \right | e^{-\left \| \partial_y I_{ij} \right \|}\]- Shape:
Inverse Depth: \((N, 1, H, W)\)
Image: \((N, 3, H, W)\)
Output: scalar
Examples
>>> idepth = torch.rand(1, 1, 4, 5) >>> image = torch.rand(1, 3, 4, 5) >>> smooth = InverseDepthSmoothnessLoss() >>> loss = smooth(idepth, image)
-
class
TotalVariation
[source]¶ Computes the Total Variation according to [1].
- Shape:
Input: \((N, C, H, W)\) or \((C, H, W)\).
Output: \((N,)\) or scalar.
Examples
>>> tv = TotalVariation() >>> output = tv(torch.ones((2, 3, 4, 4), requires_grad=True)) >>> output.data tensor([0., 0.]) >>> output.sum().backward() # grad can be implicitly created only for scalar outputs
- Reference:
-
class
PSNRLoss
(max_val: float)[source]¶ Creates a criterion that calculates the PSNR loss.
The loss is computed as follows:
\[\text{loss} = -\text{psnr(x, y)}\]See
psnr()
for details abut PSNR.- Shape:
Input: arbitrary dimensional tensor \((*)\).
Target: arbitrary dimensional tensor \((*)\) same shape as input.
Output: a scalar.
Examples
>>> ones = torch.ones(1) >>> criterion = PSNRLoss(2.) >>> criterion(ones, 1.2 * ones) # 10 * log(4/((1.2-1)**2)) / log(10) tensor(-20.0000)
-
class
BinaryFocalLossWithLogits
(alpha: float, gamma: float = 2.0, reduction: str = 'none')[source]¶ Criterion that computes Focal loss.
According to [lin2017focal], the Focal loss is computed as follows:
\[\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)\]- where:
\(p_t\) is the model’s estimated probability for each class.
- Parameters
alpha (float) – Weighting factor for the rare class \(\alpha \in [0, 1]\).
gamma (float) – Focusing parameter \(\gamma >= 0\).
reduction (str, optional) – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed. Default: ‘none’.
- Shape:
Input: \((N, 1, *)\).
Target: \((N, 1, *)\).
Examples
>>> N = 1 # num_classes >>> kwargs = {"alpha": 0.25, "gamma": 2.0, "reduction": 'mean'} >>> loss = BinaryFocalLossWithLogits(**kwargs) >>> input = torch.randn(1, N, 3, 5, requires_grad=True) >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) >>> output = loss(input, target) >>> output.backward()