import torch
from .confusion_matrix import confusion_matrix
[docs]def mean_iou(input: torch.Tensor, target: torch.Tensor, num_classes: int, eps: float = 1e-6) -> torch.Tensor:
r"""Calculate mean Intersection-Over-Union (mIOU).
The function internally computes the confusion matrix.
Args:
input : tensor with estimated targets returned by a
classifier. The shape can be :math:`(B, *)` and must contain integer
values between 0 and K-1.
target: tensor with ground truth (correct) target
values. The shape can be :math:`(B, *)` and must contain integer
values between 0 and K-1, where targets are assumed to be provided as
one-hot vectors.
num_classes: total possible number of classes in target.
Returns:
a tensor representing the mean intersection-over union
with shape :math:`(B, K)` where K is the number of classes.
Example:
>>> logits = torch.tensor([[0, 1, 0]])
>>> target = torch.tensor([[0, 1, 0]])
>>> mean_iou(logits, target, num_classes=3)
tensor([[1., 1., 1.]])
"""
if not torch.is_tensor(input) and input.dtype is not torch.int64:
raise TypeError(f"Input input type is not a torch.Tensor with torch.int64 dtype. Got {type(input)}")
if not torch.is_tensor(target) and target.dtype is not torch.int64:
raise TypeError(f"Input target type is not a torch.Tensor with torch.int64 dtype. Got {type(target)}")
if not input.shape == target.shape:
raise ValueError(f"Inputs input and target must have the same shape. Got: {input.shape} and {target.shape}")
if not input.device == target.device:
raise ValueError(f"Inputs must be in the same device. Got: {input.device} - {target.device}")
if not isinstance(num_classes, int) or num_classes < 2:
raise ValueError(f"The number of classes must be an integer bigger than two. Got: {num_classes}")
# we first compute the confusion matrix
conf_mat: torch.Tensor = confusion_matrix(input, target, num_classes)
# compute the actual intersection over union
sum_over_row = torch.sum(conf_mat, dim=1)
sum_over_col = torch.sum(conf_mat, dim=2)
conf_mat_diag = torch.diagonal(conf_mat, dim1=-2, dim2=-1)
denominator = sum_over_row + sum_over_col - conf_mat_diag
# NOTE: we add epsilon so that samples that are neither in the
# prediction or ground truth are taken into account.
ious = (conf_mat_diag + eps) / (denominator + eps)
return ious
[docs]def mean_iou_bbox(boxes_1: torch.Tensor, boxes_2: torch.Tensor) -> torch.Tensor:
"""Compute the IoU of the cartesian product of two sets of boxes.
Each box in each set shall be (x1, y1, x2, y2).
Args:
boxes_1: a tensor of bounding boxes in :math:`(B1, 4)`.
boxes_2: a tensor of bounding boxes in :math:`(B2, 4)`.
Returns:
a tensor in dimensions :math:`(B1, B2)`, representing the
intersection of each of the boxes in set 1 with respect to each of the boxes in set 2.
Example:
>>> boxes_1 = torch.tensor([[40, 40, 60, 60], [30, 40, 50, 60]])
>>> boxes_2 = torch.tensor([[40, 50, 60, 70], [30, 40, 40, 50]])
>>> mean_iou_bbox(boxes_1, boxes_2)
tensor([[0.3333, 0.0000],
[0.1429, 0.2500]])
"""
# TODO: support more box types. e.g. xywh,
if not (((boxes_1[:, 2] - boxes_1[:, 0]) > 0).all() or ((boxes_1[:, 3] - boxes_1[:, 1]) > 0).all()):
raise AssertionError("Boxes_1 does not follow (x1, y1, x2, y2) format.")
if not (((boxes_2[:, 2] - boxes_2[:, 0]) > 0).all() or ((boxes_2[:, 3] - boxes_2[:, 1]) > 0).all()):
raise AssertionError("Boxes_2 does not follow (x1, y1, x2, y2) format.")
# find intersection
lower_bounds = torch.max(boxes_1[:, :2].unsqueeze(1), boxes_2[:, :2].unsqueeze(0)) # (n1, n2, 2)
upper_bounds = torch.min(boxes_1[:, 2:].unsqueeze(1), boxes_2[:, 2:].unsqueeze(0)) # (n1, n2, 2)
intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0) # (n1, n2, 2)
intersection = intersection_dims[:, :, 0] * intersection_dims[:, :, 1] # (n1, n2)
# Find areas of each box in both sets
areas_set_1 = (boxes_1[:, 2] - boxes_1[:, 0]) * (boxes_1[:, 3] - boxes_1[:, 1]) # (n1)
areas_set_2 = (boxes_2[:, 2] - boxes_2[:, 0]) * (boxes_2[:, 3] - boxes_2[:, 1]) # (n2)
# Find the union
# PyTorch auto-broadcasts singleton dimensions
union = areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection # (n1, n2)
return intersection / union # (n1, n2)