Source code for kornia.contrib.models.efficient_vit.backbone

# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
from __future__ import annotations

from typing import Any

import torch
from torch import nn

from kornia.contrib.models.efficient_vit.nn.ops import (  # type: ignore
    ConvLayer,
    DSConv,
    EfficientViTBlock,
    FusedMBConv,
    IdentityLayer,
    MBConv,
    OpSequential,
    ResBlock,
    ResidualBlock,
)
from kornia.contrib.models.efficient_vit.utils import build_kwargs_from_config


[docs]class EfficientViTBackbone(nn.Module): def __init__( self, width_list: list[int], depth_list: list[int], in_channels: int = 3, dim: int = 32, expand_ratio: float = 4, norm: str = "bn2d", act_func: str = "hswish", ) -> None: super().__init__() self.width_list = [] # input stem input_stem = [ ConvLayer(in_channels=in_channels, out_channels=width_list[0], stride=2, norm=norm, act_func=act_func) ] for _ in range(depth_list[0]): block = self.build_local_block( in_channels=width_list[0], out_channels=width_list[0], stride=1, expand_ratio=1, norm=norm, act_func=act_func, ) input_stem.append(ResidualBlock(block, IdentityLayer())) in_channels = width_list[0] self.input_stem = OpSequential(input_stem) self.width_list.append(in_channels) # stages stages = [] for w, d in zip(width_list[1:3], depth_list[1:3]): stage = [] for i in range(d): stride = 2 if i == 0 else 1 block = self.build_local_block( in_channels=in_channels, out_channels=w, stride=stride, expand_ratio=expand_ratio, norm=norm, act_func=act_func, ) block = ResidualBlock(block, IdentityLayer() if stride == 1 else None) stage.append(block) in_channels = w stages.append(OpSequential(stage)) self.width_list.append(in_channels) for w, d in zip(width_list[3:], depth_list[3:]): stage = [] block = self.build_local_block( in_channels=in_channels, out_channels=w, stride=2, expand_ratio=expand_ratio, norm=norm, act_func=act_func, fewer_norm=True, ) stage.append(ResidualBlock(block, None)) in_channels = w for _ in range(d): stage.append( EfficientViTBlock( in_channels=in_channels, dim=dim, expand_ratio=expand_ratio, norm=norm, act_func=act_func ) ) stages.append(OpSequential(stage)) self.width_list.append(in_channels) self.stages = nn.ModuleList(stages)
[docs] @staticmethod def build_local_block( in_channels: int, out_channels: int, stride: int, expand_ratio: float, norm: str, act_func: str, fewer_norm: bool = False, ) -> nn.Module: if expand_ratio == 1: block = DSConv( in_channels=in_channels, out_channels=out_channels, stride=stride, use_bias=(True, False) if fewer_norm else False, norm=(None, norm) if fewer_norm else norm, act_func=(act_func, None), ) else: block = MBConv( in_channels=in_channels, out_channels=out_channels, stride=stride, expand_ratio=expand_ratio, use_bias=(True, True, False) if fewer_norm else False, norm=(None, None, norm) if fewer_norm else norm, act_func=(act_func, act_func, None), ) return block
[docs] def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: output_dict = {"input": x} output_dict["stage0"] = x = self.input_stem(x) for stage_id, stage in enumerate(self.stages, 1): output_dict["stage%d" % stage_id] = x = stage(x) output_dict["stage_final"] = x return output_dict
[docs]def efficientvit_backbone_b0(**kwargs: dict[str, Any]) -> EfficientViTBackbone: backbone = EfficientViTBackbone( width_list=[8, 16, 32, 64, 128], depth_list=[1, 2, 2, 2, 2], dim=16, **build_kwargs_from_config(kwargs, EfficientViTBackbone), ) return backbone
[docs]def efficientvit_backbone_b1(**kwargs: dict[str, Any]) -> EfficientViTBackbone: backbone = EfficientViTBackbone( width_list=[16, 32, 64, 128, 256], depth_list=[1, 2, 3, 3, 4], dim=16, **build_kwargs_from_config(kwargs, EfficientViTBackbone), ) return backbone
[docs]def efficientvit_backbone_b2(**kwargs: dict[str, Any]) -> EfficientViTBackbone: backbone = EfficientViTBackbone( width_list=[24, 48, 96, 192, 384], depth_list=[1, 3, 4, 4, 6], dim=32, **build_kwargs_from_config(kwargs, EfficientViTBackbone), ) return backbone
[docs]def efficientvit_backbone_b3(**kwargs: dict[str, Any]) -> EfficientViTBackbone: backbone = EfficientViTBackbone( width_list=[32, 64, 128, 256, 512], depth_list=[1, 4, 6, 6, 9], dim=32, **build_kwargs_from_config(kwargs, EfficientViTBackbone), ) return backbone
[docs]class EfficientViTLargeBackbone(nn.Module): def __init__( self, width_list: list[int], depth_list: list[int], in_channels: int = 3, qkv_dim: int = 32, norm: str = "bn2d", act_func: str = "gelu", ) -> None: super().__init__() self.width_list = [] stages = [] # stage 0 stage0 = [ ConvLayer(in_channels=in_channels, out_channels=width_list[0], stride=2, norm=norm, act_func=act_func) ] for _ in range(depth_list[0]): block = self.build_local_block( stage_id=0, in_channels=width_list[0], out_channels=width_list[0], stride=1, expand_ratio=1, norm=norm, act_func=act_func, ) stage0.append(ResidualBlock(block, IdentityLayer())) in_channels = width_list[0] stages.append(OpSequential(stage0)) self.width_list.append(in_channels) for stage_id, (w, d) in enumerate(zip(width_list[1:4], depth_list[1:4]), start=1): stage = [] for i in range(d + 1): stride = 2 if i == 0 else 1 block = self.build_local_block( stage_id=stage_id, in_channels=in_channels, out_channels=w, stride=stride, expand_ratio=4 if stride == 1 else 16, norm=norm, act_func=act_func, fewer_norm=stage_id > 2, ) block = ResidualBlock(block, IdentityLayer() if stride == 1 else None) stage.append(block) in_channels = w stages.append(OpSequential(stage)) self.width_list.append(in_channels) for stage_id, (w, d) in enumerate(zip(width_list[4:], depth_list[4:]), start=4): stage = [] block = self.build_local_block( stage_id=stage_id, in_channels=in_channels, out_channels=w, stride=2, expand_ratio=24, norm=norm, act_func=act_func, fewer_norm=True, ) stage.append(ResidualBlock(block, None)) in_channels = w for _ in range(d): stage.append( EfficientViTBlock( in_channels=in_channels, dim=qkv_dim, expand_ratio=6, norm=norm, act_func=act_func ) ) stages.append(OpSequential(stage)) self.width_list.append(in_channels) self.stages = nn.ModuleList(stages)
[docs] @staticmethod def build_local_block( stage_id: int, in_channels: int, out_channels: int, stride: int, expand_ratio: float, norm: str, act_func: str, fewer_norm: bool = False, ) -> nn.Module: if expand_ratio == 1: block = ResBlock( in_channels=in_channels, out_channels=out_channels, stride=stride, use_bias=(True, False) if fewer_norm else False, norm=(None, norm) if fewer_norm else norm, act_func=(act_func, None), ) elif stage_id <= 2: block = FusedMBConv( in_channels=in_channels, out_channels=out_channels, stride=stride, expand_ratio=expand_ratio, use_bias=(True, False) if fewer_norm else False, norm=(None, norm) if fewer_norm else norm, act_func=(act_func, None), ) else: block = MBConv( in_channels=in_channels, out_channels=out_channels, stride=stride, expand_ratio=expand_ratio, use_bias=(True, True, False) if fewer_norm else False, norm=(None, None, norm) if fewer_norm else norm, act_func=(act_func, act_func, None), ) return block
[docs] def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: output_dict = {"input": x} for stage_id, stage in enumerate(self.stages): output_dict["stage%d" % stage_id] = x = stage(x) output_dict["stage_final"] = x return output_dict
[docs]def efficientvit_backbone_l0(**kwargs: dict[str, Any]) -> EfficientViTLargeBackbone: backbone = EfficientViTLargeBackbone( width_list=[32, 64, 128, 256, 512], depth_list=[1, 1, 1, 4, 4], **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), ) return backbone
[docs]def efficientvit_backbone_l1(**kwargs: dict[str, Any]) -> EfficientViTLargeBackbone: backbone = EfficientViTLargeBackbone( width_list=[32, 64, 128, 256, 512], depth_list=[1, 1, 1, 6, 6], **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), ) return backbone
[docs]def efficientvit_backbone_l2(**kwargs: dict[str, Any]) -> EfficientViTLargeBackbone: backbone = EfficientViTLargeBackbone( width_list=[32, 64, 128, 256, 512], depth_list=[1, 2, 2, 8, 8], **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), ) return backbone
[docs]def efficientvit_backbone_l3(**kwargs: dict[str, Any]) -> EfficientViTLargeBackbone: backbone = EfficientViTLargeBackbone( width_list=[64, 128, 256, 512, 1024], depth_list=[1, 2, 2, 8, 8], **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone), ) return backbone