# Source code for kornia.losses.dice

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from kornia.utils.one_hot import one_hot

# based on:
# https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py

[docs]def dice_loss(input: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
r"""Criterion that computes Sørensen-Dice Coefficient loss.

According to , we compute the Sørensen-Dice Coefficient as follows:

.. math::

\text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|}

Where:
- :math:X expects to be the scores of each class.
- :math:Y expects to be the one-hot tensor with the class labels.

the loss, is finally computed as:

.. math::

\text{loss}(x, class) = 1 - \text{Dice}(x, class)

Reference:
 https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient

Args:
input (torch.Tensor): logits tensor with shape :math:(N, C, H, W) where C = number of classes.
labels (torch.Tensor): labels tensor with shape :math:(N, H, W) where each value
is :math:0 ≤ targets[i] ≤ C−1.
eps (float, optional): Scalar to enforce numerical stabiliy. Default: 1e-8.

Return:
torch.Tensor: the computed loss.

Example:
>>> N = 5  # num_classes
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = dice_loss(input, target)
>>> output.backward()
"""
if not isinstance(input, torch.Tensor):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(input)))

if not len(input.shape) == 4:
raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}"
.format(input.shape))

if not input.shape[-2:] == target.shape[-2:]:
raise ValueError("input and target shapes must be the same. Got: {} and {}"
.format(input.shape, target.shape))

if not input.device == target.device:
raise ValueError(
"input and target must be in the same device. Got: {} and {}" .format(
input.device, target.device))

# compute softmax over the classes axis
input_soft: torch.Tensor = F.softmax(input, dim=1)

# create the labels one hot tensor
target_one_hot: torch.Tensor = one_hot(
target, num_classes=input.shape,
device=input.device, dtype=input.dtype)

# compute the actual dice score
dims = (1, 2, 3)
intersection = torch.sum(input_soft * target_one_hot, dims)
cardinality = torch.sum(input_soft + target_one_hot, dims)

dice_score = 2. * intersection / (cardinality + eps)

[docs]class DiceLoss(nn.Module):
r"""Criterion that computes Sørensen-Dice Coefficient loss.

According to , we compute the Sørensen-Dice Coefficient as follows:

.. math::

\text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|}

Where:
- :math:X expects to be the scores of each class.
- :math:Y expects to be the one-hot tensor with the class labels.

the loss, is finally computed as:

.. math::

\text{loss}(x, class) = 1 - \text{Dice}(x, class)

Reference:
 https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient

Args:
eps (float, optional): Scalar to enforce numerical stabiliy. Default: 1e-8.

Shape:
- Input: :math:(N, C, H, W) where C = number of classes.
- Target: :math:(N, H, W) where each value is
:math:0 ≤ targets[i] ≤ C−1.

Example:
>>> N = 5  # num_classes
>>> criterion = DiceLoss()
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = criterion(input, target)
>>> output.backward()
"""

def __init__(self, eps: float = 1e-8) -> None:
super(DiceLoss, self).__init__()
self.eps: float = eps

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return dice_loss(input, target, self.eps)