Source code for kornia.losses.dice

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from kornia.utils import one_hot


# based on:
# https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py

[docs]class DiceLoss(nn.Module): r"""Criterion that computes Sørensen-Dice Coefficient loss. According to [1], we compute the Sørensen-Dice Coefficient as follows: .. math:: \text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|} where: - :math:`X` expects to be the scores of each class. - :math:`Y` expects to be the one-hot tensor with the class labels. the loss, is finally computed as: .. math:: \text{loss}(x, class) = 1 - \text{Dice}(x, class) [1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient Shape: - Input: :math:`(N, C, H, W)` where C = number of classes. - Target: :math:`(N, H, W)` where each value is :math:`0 ≤ targets[i] ≤ C−1`. Examples: >>> N = 5 # num_classes >>> loss = kornia.losses.DiceLoss() >>> input = torch.randn(1, N, 3, 5, requires_grad=True) >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) >>> output = loss(input, target) >>> output.backward() """ def __init__(self) -> None: super(DiceLoss, self).__init__() self.eps: float = 1e-6 def forward( # type: ignore self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if not torch.is_tensor(input): raise TypeError("Input type is not a torch.Tensor. Got {}" .format(type(input))) if not len(input.shape) == 4: raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}" .format(input.shape)) if not input.shape[-2:] == target.shape[-2:]: raise ValueError("input and target shapes must be the same. Got: {}" .format(input.shape, input.shape)) if not input.device == target.device: raise ValueError( "input and target must be in the same device. Got: {}" .format( input.device, target.device)) # compute softmax over the classes axis input_soft = F.softmax(input, dim=1) # create the labels one hot tensor target_one_hot = one_hot(target, num_classes=input.shape[1], device=input.device, dtype=input.dtype) # compute the actual dice score dims = (1, 2, 3) intersection = torch.sum(input_soft * target_one_hot, dims) cardinality = torch.sum(input_soft + target_one_hot, dims) dice_score = 2. * intersection / (cardinality + self.eps) return torch.mean(torch.tensor(1.) - dice_score)
###################### # functional interface ######################
[docs]def dice_loss( input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: r"""Function that computes Sørensen-Dice Coefficient loss. See :class:`~kornia.losses.DiceLoss` for details. """ return DiceLoss()(input, target)