Source code for kornia.losses.dice

import torch
import torch.nn as nn
import torch.nn.functional as F

from kornia.utils.one_hot import one_hot

# based on:
# https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py


[docs]def dice_loss(input: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: r"""Criterion that computes Sørensen-Dice Coefficient loss. According to [1], we compute the Sørensen-Dice Coefficient as follows: .. math:: \text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|} Where: - :math:`X` expects to be the scores of each class. - :math:`Y` expects to be the one-hot tensor with the class labels. the loss, is finally computed as: .. math:: \text{loss}(x, class) = 1 - \text{Dice}(x, class) Reference: [1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient Args: input: logits tensor with shape :math:`(N, C, H, W)` where C = number of classes. labels: labels tensor with shape :math:`(N, H, W)` where each value is :math:`0 ≤ targets[i] ≤ C−1`. eps: Scalar to enforce numerical stabiliy. Return: the computed loss. Example: >>> N = 5 # num_classes >>> input = torch.randn(1, N, 3, 5, requires_grad=True) >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) >>> output = dice_loss(input, target) >>> output.backward() """ if not isinstance(input, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") if not len(input.shape) == 4: raise ValueError(f"Invalid input shape, we expect BxNxHxW. Got: {input.shape}") if not input.shape[-2:] == target.shape[-2:]: raise ValueError(f"input and target shapes must be the same. Got: {input.shape} and {target.shape}") if not input.device == target.device: raise ValueError(f"input and target must be in the same device. Got: {input.device} and {target.device}") # compute softmax over the classes axis input_soft: torch.Tensor = F.softmax(input, dim=1) # create the labels one hot tensor target_one_hot: torch.Tensor = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual dice score dims = (1, 2, 3) intersection = torch.sum(input_soft * target_one_hot, dims) cardinality = torch.sum(input_soft + target_one_hot, dims) dice_score = 2.0 * intersection / (cardinality + eps) return torch.mean(-dice_score + 1.0)
[docs]class DiceLoss(nn.Module): r"""Criterion that computes Sørensen-Dice Coefficient loss. According to [1], we compute the Sørensen-Dice Coefficient as follows: .. math:: \text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|} Where: - :math:`X` expects to be the scores of each class. - :math:`Y` expects to be the one-hot tensor with the class labels. the loss, is finally computed as: .. math:: \text{loss}(x, class) = 1 - \text{Dice}(x, class) Reference: [1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient Args: eps: Scalar to enforce numerical stabiliy. Shape: - Input: :math:`(N, C, H, W)` where C = number of classes. - Target: :math:`(N, H, W)` where each value is :math:`0 ≤ targets[i] ≤ C−1`. Example: >>> N = 5 # num_classes >>> criterion = DiceLoss() >>> input = torch.randn(1, N, 3, 5, requires_grad=True) >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) >>> output = criterion(input, target) >>> output.backward() """ def __init__(self, eps: float = 1e-8) -> None: super().__init__() self.eps: float = eps def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return dice_loss(input, target, self.eps)