Source code for kornia.contrib.models.rt_detr.model

"""Based on code from PaddleDetection."""

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import Optional

from kornia.contrib.models.base import ModelBase
from kornia.contrib.models.rt_detr.architecture.hgnetv2 import PPHGNetV2
from kornia.contrib.models.rt_detr.architecture.hybrid_encoder import HybridEncoder
from kornia.contrib.models.rt_detr.architecture.resnet_d import ResNetD
from kornia.contrib.models.rt_detr.architecture.rtdetr_head import RTDETRHead
from kornia.core import Tensor


[docs]class RTDETRModelType(Enum): """Enum class that maps RT-DETR model type.""" resnet18d = 0 resnet34d = 1 resnet50d = 2 resnet101d = 3 hgnetv2_l = 4 hgnetv2_x = 5
[docs]@dataclass class RTDETRConfig: """Configuration to construct RT-DETR model. Args: model_type: model variant. Available models are - ResNetD-18: ``0``, ``'resnet18d'`` or :attr:`RTDETRModelType.resnet18d` - ResNetD-34: ``1``, ``'resnet34d'`` or :attr:`RTDETRModelType.resnet34d` - ResNetD-50: ``2``, ``'resnet50d'`` or :attr:`RTDETRModelType.resnet50d` - ResNetD-101: ``3``, ``'resnet101d'`` or :attr:`RTDETRModelType.resnet101d` - HGNetV2-L: ``4``, ``'hgnetv2_l'`` or :attr:`RTDETRModelType.hgnetv2_l` - HGNetV2-X: ``5``, ``'hgnetv2_x'`` or :attr:`RTDETRModelType.hgnetv2_x` num_classes: number of classes. checkpoint: URL or local path of model weights. neck_hidden_dim: hidden dim for neck. neck_dim_feedforward: feed-forward network dim for neck. neck_expansion: expansion ratio for neck. head_hidden_dim: hidden dim for head. head_num_queries: number of queries for Deformable DETR transformer decoder. head_num_decoder_layers: number of decoder layers for Deformable DETR transformer decoder. """ model_type: RTDETRModelType | str | int num_classes: int checkpoint: Optional[str] = None neck_hidden_dim: Optional[int] = None neck_dim_feedforward: Optional[int] = None neck_expansion: Optional[float] = None head_hidden_dim: int = 256 head_num_queries: int = 300 head_num_decoder_layers: Optional[int] = None confidence_threshold: float = 0.3
[docs]class RTDETR(ModelBase[RTDETRConfig]): """RT-DETR Object Detection model, as described in https://arxiv.org/abs/2304.08069."""
[docs] def __init__(self, backbone: ResNetD | PPHGNetV2, neck: HybridEncoder, head: RTDETRHead): """Construct RT-DETR Object Detection model. Args: backbone: backbone network for feature extraction. neck: neck network for feature fusion. head: head network to decode features into detection results. """ super().__init__() self.backbone = backbone self.neck = neck self.head = head
[docs] @staticmethod def from_config(config: RTDETRConfig) -> RTDETR: """Construct RT-DETR Object Detection model from a config object. Args: config: configuration object for RT-DETR. .. note:: For ``config.neck_hidden_dim``, ``config.neck_dim_feedforward``, ``config.neck_expansion``, and ``config.head_num_decoder_layers``, if they are ``None``, their values will be replaced with the default values depending on the ``config.model_type``. See the source code for the default values. """ model_type = config.model_type if isinstance(model_type, int): model_type = RTDETRModelType(model_type) elif isinstance(model_type, str): model_type = getattr(RTDETRModelType, model_type) backbone: ResNetD | PPHGNetV2 if model_type == RTDETRModelType.resnet18d: backbone = ResNetD.from_config(18) neck_hidden_dim = config.neck_hidden_dim or 256 neck_dim_feedforward = config.neck_dim_feedforward or 1024 head_num_decoder_layers = config.head_num_decoder_layers or 3 neck_expansion = config.neck_expansion or 0.5 elif model_type == RTDETRModelType.resnet34d: backbone = ResNetD.from_config(34) neck_hidden_dim = config.neck_hidden_dim or 256 neck_dim_feedforward = config.neck_dim_feedforward or 1024 head_num_decoder_layers = config.head_num_decoder_layers or 4 neck_expansion = config.neck_expansion or 0.5 elif model_type == RTDETRModelType.resnet50d: backbone = ResNetD.from_config(50) neck_hidden_dim = config.neck_hidden_dim or 256 neck_dim_feedforward = config.neck_dim_feedforward or 1024 head_num_decoder_layers = config.head_num_decoder_layers or 6 neck_expansion = config.neck_expansion or 1.0 elif model_type == RTDETRModelType.resnet101d: backbone = ResNetD.from_config(101) neck_hidden_dim = config.neck_hidden_dim or 384 neck_dim_feedforward = config.neck_dim_feedforward or 2048 head_num_decoder_layers = config.head_num_decoder_layers or 6 neck_expansion = config.neck_expansion or 1.0 elif model_type == RTDETRModelType.hgnetv2_l: backbone = PPHGNetV2.from_config("L") neck_hidden_dim = config.neck_hidden_dim or 256 neck_dim_feedforward = config.neck_dim_feedforward or 1024 head_num_decoder_layers = config.head_num_decoder_layers or 6 neck_expansion = config.neck_expansion or 1.0 elif model_type == RTDETRModelType.hgnetv2_x: backbone = PPHGNetV2.from_config("X") neck_hidden_dim = config.neck_hidden_dim or 384 neck_dim_feedforward = config.neck_dim_feedforward or 2048 head_num_decoder_layers = config.head_num_decoder_layers or 6 neck_expansion = config.neck_expansion or 1.0 model = RTDETR( backbone, HybridEncoder(backbone.out_channels, neck_hidden_dim, neck_dim_feedforward, neck_expansion), RTDETRHead( num_classes=config.num_classes, hidden_dim=config.head_hidden_dim, num_queries=config.head_num_queries, in_channels=[neck_hidden_dim] * 3, num_decoder_layers=head_num_decoder_layers, ), ) if config.checkpoint: model.load_checkpoint(config.checkpoint) return model
[docs] def forward(self, images: Tensor) -> tuple[Tensor, Tensor]: """Detect objects in an image. Args: images: images to be detected. Shape :math:`(N, C, H, W)`. Returns: - **logits** - Tensor of shape :math:`(N, Q, K)`, where :math:`Q` is the number of queries, :math:`K` is the number of classes. - **boxes** - Tensor of shape :math:`(N, Q, 4)`, where :math:`Q` is the number of queries. """ if self.training: raise RuntimeError("Only evaluation mode is supported. Please call model.eval().") feats = self.backbone(images) feats_buf = self.neck(feats) logits, boxes = self.head(feats_buf) return logits, boxes