Source code for kornia.contrib.models.rt_detr.post_processor

from __future__ import annotations

from typing import Any

from kornia.core import Module, Tensor, concatenate, tensor


[docs]class DETRPostProcessor(Module): def __init__(self, confidence_threshold: float) -> None: super().__init__() self.confidence_threshold = confidence_threshold
[docs] def forward(self, data: dict[str, Tensor], meta: dict[str, Any]) -> list[Tensor]: """Post-process outputs from DETR. Args: data: dictionary with keys ``logits`` and ``boxes``. ``logits`` has shape :math:`(N, Q, K)` and ``boxes`` has shape :math:`(N, Q, 4)`, where :math:`Q` is the number of queries, :math:`K` is the number of classes. meta: dictionary containing meta information. It must have key ``original_size``, which is the original image size of input images. Each tuple represent (img_height, img_width). Returns: Processed detections. For each image, the detections have shape (D, 6), where D is the number of detections in that image, 6 represent (class_id, confidence_score, x, y, w, h). """ logits = data["logits"] boxes = data["boxes"] original_sizes = meta["original_size"] # NOTE: boxes are not clipped to image dimensions # https://github.com/PaddlePaddle/PaddleDetection/blob/5d1f888362241790000950e2b63115dc8d1c6019/ppdet/modeling/post_process.py#L446 # box format is cxcywh # convert to xywh # bboxes[..., :2] -= bboxes[..., 2:] * 0.5 # in-place operation is not torch.compile()-friendly cxcy = boxes[..., :2] wh = boxes[..., 2:] boxes = concatenate([cxcy - wh * 0.5, wh], -1) boxes = boxes * tensor(original_sizes, device=boxes.device, dtype=boxes.dtype).flip(1).repeat(1, 2).unsqueeze(1) scores = logits.sigmoid() # RT-DETR was trained with focal loss. thus sigmoid is used instead of softmax # the original code is slightly different # it allows 1 bounding box to have multiple classes (multi-label) scores, labels = scores.max(-1) detections = [] for i in range(scores.shape[0]): mask = scores[i] >= self.confidence_threshold labels_i = labels[i, mask].unsqueeze(-1) scores_i = scores[i, mask].unsqueeze(-1) boxes_i = boxes[i, mask] detections.append(concatenate([labels_i, scores_i, boxes_i], -1)) return detections