Source code for kornia.losses.hausdorff

from typing import Callable

import torch
import torch.nn as nn


class _HausdorffERLossBase(torch.jit.ScriptModule):
    """Base class for binary Hausdorff loss based on morphological erosion.

    This is an Hausdorff Distance (HD) Loss that based on morphological erosion,which provided
    a differentiable approximation of Hausdorff distance as stated in :cite:`karimi2019reducing`.
    The code is refactored on top of `here <https://github.com/PatRyg99/HausdorffLoss/
        blob/master/hausdorff_loss.py>`__.

    Args:
        alpha: controls the erosion rate in each iteration.
        k: the number of iterations of erosion.
        reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
            'none': no reduction will be applied, 'mean': the weighted mean of the output is taken,
            'sum': the output will be summed.

    Returns:
        Estimated Hausdorff Loss.
    """

    conv: Callable
    max_pool: Callable

    def __init__(self, alpha: float = 2.0, k: int = 10, reduction: str = 'mean') -> None:
        super().__init__()
        self.alpha = alpha
        self.k = k
        self.reduction = reduction
        self.register_buffer("kernel", self.get_kernel())

    def get_kernel(self) -> torch.Tensor:
        """Get kernel for image morphology convolution."""
        raise NotImplementedError

    def perform_erosion(
        self, pred: torch.Tensor, target: torch.Tensor
    ) -> torch.Tensor:
        bound = (pred - target) ** 2

        kernel = torch.as_tensor(self.kernel, device=pred.device, dtype=pred.dtype)
        eroded = torch.zeros_like(bound, device=pred.device, dtype=pred.dtype)
        mask = torch.ones_like(bound, device=pred.device, dtype=torch.bool)

        # Same padding, assuming kernel is odd and square (cube) shaped.
        # NOTE: int() has to be added for enabling JIT.
        padding = int((kernel.size(-1) - 1) // 2)
        for k in range(self.k):
            # compute convolution with kernel
            dilation = self.conv(bound, weight=kernel, padding=padding, groups=1)
            # apply soft thresholding at 0.5 and normalize
            erosion = dilation - 0.5
            erosion[erosion < 0] = 0

            # image-wise differences for 2D images
            erosion_max = self.max_pool(erosion)
            erosion_min = - self.max_pool(- erosion)
            # No normalization needed if `max - min = 0`
            _to_norm = (erosion_max - erosion_min) != 0
            to_norm = _to_norm.squeeze()
            if to_norm.any():
                # NOTE: avoid in-place ops like below, which will not pass gradcheck:
                #       erosion[to_norm] = (erosion[to_norm] - erosion_min[to_norm]) / (
                #           erosion_max[to_norm] - erosion_min[to_norm])
                _erosion_to_fill = (erosion - erosion_min) / (erosion_max - erosion_min)
                erosion = torch.where(mask * _to_norm, _erosion_to_fill, erosion)

            # save erosion and add to loss
            eroded = eroded + erosion * (k + 1) ** self.alpha
            bound = erosion

        return eroded

    # NOTE: we add type ignore because the forward pass does not work well with subclassing
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:  # type: ignore
        """Compute Hausdorff loss.

        Args:
            pred: predicted tensor with a shape of :math:`(B, C, H, W)` or :math:`(B, C, D, H, W)`.
                Each channel is as binary as: 1 -> fg, 0 -> bg.
            target: target tensor with a shape of :math:`(B, 1, H, W)` or :math:`(B, C, D, H, W)`.

        Returns:
            Estimated Hausdorff Loss.
        """
        if not (pred.shape[2:] == target.shape[2:] and pred.size(0) == target.size(0) and target.size(1) == 1):
            raise ValueError(
                "Prediction and target need to be of same size, and target should not be one-hot."
                f"Got {pred.shape} and {target.shape}."
            )

        if pred.size(1) < target.max().item():
            raise ValueError("Invalid target value.")

        out = torch.stack([
            self.perform_erosion(
                pred[:, i:i + 1],
                torch.where(
                    target == i,
                    torch.tensor(1, device=target.device, dtype=target.dtype),
                    torch.tensor(0, device=target.device, dtype=target.dtype)
                )
            )
            for i in range(pred.size(1))
        ])

        if self.reduction == 'mean':
            out = out.mean()
        elif self.reduction == 'sum':
            out = out.sum()
        elif self.reduction == 'none':
            pass
        else:
            raise NotImplementedError(f"reduction `{self.reduction}` has not been implemented yet.")

        return out


[docs]class HausdorffERLoss(_HausdorffERLossBase): r"""Binary Hausdorff loss based on morphological erosion. Hausdorff Distance loss measures the maximum distance of a predicted segmentation boundary to the nearest ground-truth edge pixel. For two segmentation point sets X and Y , the one-sided HD from X to Y is defined as: .. math:: hd(X,Y) = \max_{x \in X} \min_{y \in Y}||x - y||_2 Furthermore, the bidirectional HD is: .. math:: HD(X,Y) = max(hd(X, Y), hd(Y, X)) This is an Hausdorff Distance (HD) Loss that based on morphological erosion, which provided a differentiable approximation of Hausdorff distance as stated in :cite:`karimi2019reducing`. The code is refactored on top of `here <https://github.com/PatRyg99/HausdorffLoss/ blob/master/hausdorff_loss.py>`__. Args: alpha: controls the erosion rate in each iteration. k: the number of iterations of erosion. reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken, 'sum': the output will be summed. Examples: >>> hdloss = HausdorffERLoss() >>> input = torch.randn(5, 3, 20, 20) >>> target = (torch.rand(5, 1, 20, 20) * 2).long() >>> res = hdloss(input, target) """ conv = torch.conv2d max_pool = nn.AdaptiveMaxPool2d(1) def get_kernel(self) -> torch.Tensor: """Get kernel for image morphology convolution.""" cross = torch.tensor([[[0, 1, 0], [1, 1, 1], [0, 1, 0]]]) kernel = cross * 0.2 return kernel[None] # NOTE: we add type ignore because the forward pass does not work well with subclassing def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore """Compute Hausdorff loss. Args: pred: predicted tensor with a shape of :math:`(B, C, H, W)`. Each channel is as binary as: 1 -> fg, 0 -> bg. target: target tensor with a shape of :math:`(B, 1, H, W)`. Returns: Estimated Hausdorff Loss. """ if pred.dim() != 4: raise ValueError(f"Only 2D images supported. Got {pred.dim()}.") if not (target.max() < pred.size(1) and target.min() >= 0 and target.dtype == torch.long): raise ValueError( f"Expect long type target value in range (0, {pred.size(1)})." f"({target.min()}, {target.max()})" ) return super().forward(pred, target)
[docs]class HausdorffERLoss3D(_HausdorffERLossBase): r"""Binary 3D Hausdorff loss based on morphological erosion. Hausdorff Distance loss measures the maximum distance of a predicted segmentation boundary to the nearest ground-truth edge pixel. For two segmentation point sets X and Y , the one-sided HD from X to Y is defined as: .. math:: hd(X,Y) = \max_{x \in X} \min_{y \in Y}||x - y||_2 Furthermore, the bidirectional HD is: .. math:: HD(X,Y) = max(hd(X, Y), hd(Y, X)) This is a 3D Hausdorff Distance (HD) Loss that based on morphological erosion, which provided a differentiable approximation of Hausdorff distance as stated in :cite:`karimi2019reducing`. The code is refactored on top of `here <https://github.com/PatRyg99/HausdorffLoss/ blob/master/hausdorff_loss.py>`__. Args: alpha: controls the erosion rate in each iteration. k: the number of iterations of erosion. reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the weighted mean of the output is taken, 'sum': the output will be summed. Examples: >>> hdloss = HausdorffERLoss3D() >>> input = torch.randn(5, 3, 20, 20, 20) >>> target = (torch.rand(5, 1, 20, 20, 20) * 2).long() >>> res = hdloss(input, target) """ conv = torch.conv3d max_pool = nn.AdaptiveMaxPool3d(1) def get_kernel(self) -> torch.Tensor: """Get kernel for image morphology convolution.""" cross = torch.tensor([[[0, 1, 0], [1, 1, 1], [0, 1, 0]]]) bound = torch.tensor([[[0, 0, 0], [0, 1, 0], [0, 0, 0]]]) # NOTE: The original repo claimed it shaped as (3, 1, 3, 3) # which Jian suspect it is wrongly implemented. # https://github.com/PatRyg99/HausdorffLoss/blob/9f580acd421af648e74b45d46555ccb7a876c27c/hausdorff_loss.py#L94 kernel = torch.stack([bound, cross, bound], dim=1) * (1 / 7) return kernel[None] # NOTE: we add type ignore because the forward pass does not work well with subclassing def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # type: ignore """Compute 3D Hausdorff loss. Args: pred: predicted tensor with a shape of :math:`(B, C, D, H, W)`. Each channel is as binary as: 1 -> fg, 0 -> bg. target: target tensor with a shape of :math:`(B, 1, D, H, W)`. Returns: Estimated Hausdorff Loss. """ if pred.dim() != 5: raise ValueError(f"Only 3D images supported. Got {pred.dim()}.") return super().forward(pred, target)