Source code for kornia.metrics.mean_average_precision

from typing import List, Tuple

import torch

from .mean_iou import mean_iou_bbox


[docs]def mean_average_precision( pred_boxes: List[torch.Tensor], pred_labels: List[torch.Tensor], pred_scores: List[torch.Tensor], gt_boxes: List[torch.Tensor], gt_labels: List[torch.Tensor], n_classes: int, threshold: float = 0.5, ) -> Tuple[torch.Tensor, dict]: """Calculate the Mean Average Precision (mAP) of detected objects. Code altered from https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection/blob/master/utils.py#L271. Background class (0 index) is excluded. Args: pred_boxes: a tensor list of predicted bounding boxes. pred_labels: a tensor list of predicted labels. pred_scores: a tensor list of predicted labels' scores. gt_boxes: a tensor list of ground truth bounding boxes. gt_labels: a tensor list of ground truth labels. n_classes: the number of classes. threshold: count as a positive if the overlap is greater than the threshold. Returns: mean average precision (mAP), list of average precisions for each class. Examples: >>> boxes, labels, scores = torch.tensor([[100, 50, 150, 100.]]), torch.tensor([1]), torch.tensor([.7]) >>> gt_boxes, gt_labels = torch.tensor([[100, 50, 150, 100.]]), torch.tensor([1]) >>> mean_average_precision([boxes], [labels], [scores], [gt_boxes], [gt_labels], 2) (tensor(1.), {1: 1.0}) """ # these are all lists of tensors of the same length, i.e. number of images if not len(pred_boxes) == len(pred_labels) == len(pred_scores) == len(gt_boxes) == len(gt_labels): raise AssertionError # Store all (true) objects in a single continuous tensor while keeping track of the image it is from gt_images = [] for i, labels in enumerate(gt_labels): gt_images.extend([i] * labels.size(0)) # (n_objects), n_objects is the total no. of objects across all images _gt_boxes = torch.cat(gt_boxes, dim=0) # (n_objects, 4) _gt_labels = torch.cat(gt_labels, dim=0) # (n_objects) _gt_images = torch.tensor(gt_images, device=_gt_boxes.device, dtype=torch.long) if not _gt_images.size(0) == _gt_boxes.size(0) == _gt_labels.size(0): raise AssertionError # Store all detections in a single continuous tensor while keeping track of the image it is from pred_images = [] for i, labels in enumerate(pred_labels): pred_images.extend([i] * labels.size(0)) _pred_boxes = torch.cat(pred_boxes, dim=0) # (n_detections, 4) _pred_labels = torch.cat(pred_labels, dim=0) # (n_detections) _pred_scores = torch.cat(pred_scores, dim=0) # (n_detections) _pred_images = torch.tensor(pred_images, device=_pred_boxes.device, dtype=torch.long) # (n_detections) if not _pred_images.size(0) == _pred_boxes.size(0) == _pred_labels.size(0) == _pred_scores.size(0): raise AssertionError # Calculate APs for each class (except background) average_precisions = torch.zeros( (n_classes - 1), device=_pred_boxes.device, dtype=_pred_boxes.dtype ) # (n_classes - 1) for c in range(1, n_classes): # Extract only objects with this class gt_class_images = _gt_images[_gt_labels == c] # (n_class_objects) gt_class_boxes = _gt_boxes[_gt_labels == c] # (n_class_objects, 4) # Keep track of which true objects with this class have already been 'detected' # (n_class_objects) gt_class_boxes_detected = torch.zeros( (gt_class_images.size(0)), dtype=torch.uint8, device=gt_class_images.device ) # Extract only detections with this class pred_class_images = _pred_images[_pred_labels == c] # (n_class_detections) pred_class_boxes = _pred_boxes[_pred_labels == c] # (n_class_detections, 4) pred_class_scores = _pred_scores[_pred_labels == c] # (n_class_detections) n_class_detections = pred_class_boxes.size(0) if n_class_detections == 0: continue # Sort detections in decreasing order of confidence/scores pred_class_scores, sort_ind = torch.sort(pred_class_scores, dim=0, descending=True) # (n_class_detections) pred_class_images = pred_class_images[sort_ind] # (n_class_detections) pred_class_boxes = pred_class_boxes[sort_ind] # (n_class_detections, 4) # In the order of decreasing scores, check if true or false positive gt_positives = torch.zeros( (n_class_detections,), dtype=pred_class_boxes.dtype, device=pred_class_boxes.device ) # (n_class_detections) false_positives = torch.zeros( (n_class_detections,), dtype=pred_class_boxes.dtype, device=pred_class_boxes.device ) # (n_class_detections) for d in range(n_class_detections): this_detection_box = pred_class_boxes[d].unsqueeze(0) # (1, 4) this_image = pred_class_images[d] # (), scalar # Find objects in the image with this class, their difficulties, and whether they have been detected before object_boxes = gt_class_boxes[gt_class_images == this_image] # (n_class_objects_in_img) # If no such object in this image, then the detection is a false positive if object_boxes.size(0) == 0: false_positives[d] = 1 continue # Find maximum overlap of this detection with objects in this image of this class overlaps = mean_iou_bbox(this_detection_box, object_boxes) # (1, n_class_objects_in_img) max_overlap, ind = torch.max(overlaps.squeeze(0), dim=0) # (), () - scalars # 'ind' is the index of the object in these image-level tensors 'object_boxes', 'object_difficulties' # In the original class-level tensors 'gt_class_boxes', etc., 'ind' corresponds to object with index... original_ind = torch.tensor( range(gt_class_boxes.size(0)), device=gt_class_boxes_detected.device, dtype=torch.long )[gt_class_images == this_image][ind] # We need 'original_ind' to update 'gt_class_boxes_detected' # If the maximum overlap is greater than the threshold of 0.5, it's a match if max_overlap.item() > threshold: # If this object has already not been detected, it's a true positive if gt_class_boxes_detected[original_ind] == 0: gt_positives[d] = 1 gt_class_boxes_detected[original_ind] = 1 # this object has now been detected/accounted for # Otherwise, it's a false positive (since this object is already accounted for) else: false_positives[d] = 1 # Otherwise, the detection occurs in a different location than the actual object, and is a false positive else: false_positives[d] = 1 # Compute cumulative precision and recall at each detection in the order of decreasing scores cumul_gt_positives = torch.cumsum(gt_positives, dim=0) # (n_class_detections) cumul_false_positives = torch.cumsum(false_positives, dim=0) # (n_class_detections) cumul_precision = cumul_gt_positives / ( cumul_gt_positives + cumul_false_positives + 1e-10 ) # (n_class_detections) cumul_recall = cumul_gt_positives / _gt_boxes.size(0) # (n_class_detections) # Find the mean of the maximum of the precisions corresponding to recalls above the threshold 't' recall_thresholds = torch.arange(start=0, end=1.1, step=0.1).tolist() # (11) precisions = torch.zeros((len(recall_thresholds)), device=_gt_boxes.device, dtype=_gt_boxes.dtype) # (11) for i, t in enumerate(recall_thresholds): recalls_above_t = cumul_recall >= t if recalls_above_t.any(): precisions[i] = cumul_precision[recalls_above_t].max() else: precisions[i] = 0.0 average_precisions[c - 1] = precisions.mean() # c is in [1, n_classes - 1] # Calculate Mean Average Precision (mAP) mean_ap = average_precisions.mean() # Keep class-wise average precisions in a dictionary ap_dict = {c + 1: v for c, v in enumerate(average_precisions.tolist())} return mean_ap, ap_dict