# Source code for kornia.losses.tversky

import torch
import torch.nn as nn
import torch.nn.functional as F

from kornia.utils.one_hot import one_hot

# based on:
# https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py

[docs]def tversky_loss(
input: torch.Tensor, target: torch.Tensor, alpha: float, beta: float, eps: float = 1e-8
) -> torch.Tensor:
r"""Criterion that computes Tversky Coefficient loss.

According to :cite:salehi2017tversky, we compute the Tversky Coefficient as follows:

.. math::

\text{S}(P, G, \alpha; \beta) =
\frac{|PG|}{|PG| + \alpha |P \setminus G| + \beta |G \setminus P|}

Where:
- :math:P and :math:G are the predicted and ground truth binary
labels.
- :math:\alpha and :math:\beta control the magnitude of the
penalties for FPs and FNs, respectively.

Note:
- :math:\alpha = \beta = 0.5 => dice coeff
- :math:\alpha = \beta = 1 => tanimoto coeff
- :math:\alpha + \beta = 1 => F beta coeff

Args:
input: logits tensor with shape :math:(N, C, H, W) where C = number of classes.
target: labels tensor with shape :math:(N, H, W) where each value
is :math:0 ≤ targets[i] ≤ C−1.
alpha: the first coefficient in the denominator.
beta: the second coefficient in the denominator.
eps: scalar for numerical stability.

Return:
the computed loss.

Example:
>>> N = 5  # num_classes
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = tversky_loss(input, target, alpha=0.5, beta=0.5)
>>> output.backward()
"""
if not isinstance(input, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")

if not len(input.shape) == 4:
raise ValueError(f"Invalid input shape, we expect BxNxHxW. Got: {input.shape}")

if not input.shape[-2:] == target.shape[-2:]:
raise ValueError(f"input and target shapes must be the same. Got: {input.shape} and {input.shape}")

if not input.device == target.device:
raise ValueError(f"input and target must be in the same device. Got: {input.device} and {target.device}")

# compute softmax over the classes axis
input_soft: torch.Tensor = F.softmax(input, dim=1)

# create the labels one hot tensor
target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape, device=input.device, dtype=input.dtype)

# compute the actual dice score
dims = (1, 2, 3)
intersection = torch.sum(input_soft * target_one_hot, dims)
fps = torch.sum(input_soft * (-target_one_hot + 1.0), dims)
fns = torch.sum((-input_soft + 1.0) * target_one_hot, dims)

numerator = intersection
denominator = intersection + alpha * fps + beta * fns
tversky_loss = numerator / (denominator + eps)

[docs]class TverskyLoss(nn.Module):
r"""Criterion that computes Tversky Coefficient loss.

According to :cite:salehi2017tversky, we compute the Tversky Coefficient as follows:

.. math::

\text{S}(P, G, \alpha; \beta) =
\frac{|PG|}{|PG| + \alpha |P \setminus G| + \beta |G \setminus P|}

Where:
- :math:P and :math:G are the predicted and ground truth binary
labels.
- :math:\alpha and :math:\beta control the magnitude of the
penalties for FPs and FNs, respectively.

Note:
- :math:\alpha = \beta = 0.5 => dice coeff
- :math:\alpha = \beta = 1 => tanimoto coeff
- :math:\alpha + \beta = 1 => F beta coeff

Args:
alpha: the first coefficient in the denominator.
beta: the second coefficient in the denominator.
eps: scalar for numerical stability.

Shape:
- Input: :math:(N, C, H, W) where C = number of classes.
- Target: :math:(N, H, W) where each value is
:math:0 ≤ targets[i] ≤ C−1.

Examples:
>>> N = 5  # num_classes
>>> criterion = TverskyLoss(alpha=0.5, beta=0.5)
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = criterion(input, target)
>>> output.backward()
"""

def __init__(self, alpha: float, beta: float, eps: float = 1e-8) -> None:
super().__init__()
self.alpha: float = alpha
self.beta: float = beta
self.eps: float = eps

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return tversky_loss(input, target, self.alpha, self.beta, self.eps)