# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
from __future__ import annotations
from typing import Any
import torch
from torch import nn
from kornia.contrib.models.efficient_vit.nn.ops import ( # type: ignore
ConvLayer,
DSConv,
EfficientViTBlock,
FusedMBConv,
IdentityLayer,
MBConv,
OpSequential,
ResBlock,
ResidualBlock,
)
from kornia.contrib.models.efficient_vit.utils import build_kwargs_from_config
[docs]class EfficientViTBackbone(nn.Module):
def __init__(
self,
width_list: list[int],
depth_list: list[int],
in_channels: int = 3,
dim: int = 32,
expand_ratio: float = 4,
norm: str = "bn2d",
act_func: str = "hswish",
) -> None:
super().__init__()
self.width_list = []
# input stem
input_stem = [
ConvLayer(in_channels=in_channels, out_channels=width_list[0], stride=2, norm=norm, act_func=act_func)
]
for _ in range(depth_list[0]):
block = self.build_local_block(
in_channels=width_list[0],
out_channels=width_list[0],
stride=1,
expand_ratio=1,
norm=norm,
act_func=act_func,
)
input_stem.append(ResidualBlock(block, IdentityLayer()))
in_channels = width_list[0]
self.input_stem = OpSequential(input_stem)
self.width_list.append(in_channels)
# stages
stages = []
for w, d in zip(width_list[1:3], depth_list[1:3]):
stage = []
for i in range(d):
stride = 2 if i == 0 else 1
block = self.build_local_block(
in_channels=in_channels,
out_channels=w,
stride=stride,
expand_ratio=expand_ratio,
norm=norm,
act_func=act_func,
)
block = ResidualBlock(block, IdentityLayer() if stride == 1 else None)
stage.append(block)
in_channels = w
stages.append(OpSequential(stage))
self.width_list.append(in_channels)
for w, d in zip(width_list[3:], depth_list[3:]):
stage = []
block = self.build_local_block(
in_channels=in_channels,
out_channels=w,
stride=2,
expand_ratio=expand_ratio,
norm=norm,
act_func=act_func,
fewer_norm=True,
)
stage.append(ResidualBlock(block, None))
in_channels = w
for _ in range(d):
stage.append(
EfficientViTBlock(
in_channels=in_channels, dim=dim, expand_ratio=expand_ratio, norm=norm, act_func=act_func
)
)
stages.append(OpSequential(stage))
self.width_list.append(in_channels)
self.stages = nn.ModuleList(stages)
[docs] @staticmethod
def build_local_block(
in_channels: int,
out_channels: int,
stride: int,
expand_ratio: float,
norm: str,
act_func: str,
fewer_norm: bool = False,
) -> nn.Module:
if expand_ratio == 1:
block = DSConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
use_bias=(True, False) if fewer_norm else False,
norm=(None, norm) if fewer_norm else norm,
act_func=(act_func, None),
)
else:
block = MBConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
expand_ratio=expand_ratio,
use_bias=(True, True, False) if fewer_norm else False,
norm=(None, None, norm) if fewer_norm else norm,
act_func=(act_func, act_func, None),
)
return block
[docs] def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
output_dict = {"input": x}
output_dict["stage0"] = x = self.input_stem(x)
for stage_id, stage in enumerate(self.stages, 1):
output_dict["stage%d" % stage_id] = x = stage(x)
output_dict["stage_final"] = x
return output_dict
[docs]def efficientvit_backbone_b0(**kwargs: dict[str, Any]) -> EfficientViTBackbone:
backbone = EfficientViTBackbone(
width_list=[8, 16, 32, 64, 128],
depth_list=[1, 2, 2, 2, 2],
dim=16,
**build_kwargs_from_config(kwargs, EfficientViTBackbone),
)
return backbone
[docs]def efficientvit_backbone_b1(**kwargs: dict[str, Any]) -> EfficientViTBackbone:
backbone = EfficientViTBackbone(
width_list=[16, 32, 64, 128, 256],
depth_list=[1, 2, 3, 3, 4],
dim=16,
**build_kwargs_from_config(kwargs, EfficientViTBackbone),
)
return backbone
[docs]def efficientvit_backbone_b2(**kwargs: dict[str, Any]) -> EfficientViTBackbone:
backbone = EfficientViTBackbone(
width_list=[24, 48, 96, 192, 384],
depth_list=[1, 3, 4, 4, 6],
dim=32,
**build_kwargs_from_config(kwargs, EfficientViTBackbone),
)
return backbone
[docs]def efficientvit_backbone_b3(**kwargs: dict[str, Any]) -> EfficientViTBackbone:
backbone = EfficientViTBackbone(
width_list=[32, 64, 128, 256, 512],
depth_list=[1, 4, 6, 6, 9],
dim=32,
**build_kwargs_from_config(kwargs, EfficientViTBackbone),
)
return backbone
[docs]class EfficientViTLargeBackbone(nn.Module):
def __init__(
self,
width_list: list[int],
depth_list: list[int],
in_channels: int = 3,
qkv_dim: int = 32,
norm: str = "bn2d",
act_func: str = "gelu",
) -> None:
super().__init__()
self.width_list = []
stages = []
# stage 0
stage0 = [
ConvLayer(in_channels=in_channels, out_channels=width_list[0], stride=2, norm=norm, act_func=act_func)
]
for _ in range(depth_list[0]):
block = self.build_local_block(
stage_id=0,
in_channels=width_list[0],
out_channels=width_list[0],
stride=1,
expand_ratio=1,
norm=norm,
act_func=act_func,
)
stage0.append(ResidualBlock(block, IdentityLayer()))
in_channels = width_list[0]
stages.append(OpSequential(stage0))
self.width_list.append(in_channels)
for stage_id, (w, d) in enumerate(zip(width_list[1:4], depth_list[1:4]), start=1):
stage = []
for i in range(d + 1):
stride = 2 if i == 0 else 1
block = self.build_local_block(
stage_id=stage_id,
in_channels=in_channels,
out_channels=w,
stride=stride,
expand_ratio=4 if stride == 1 else 16,
norm=norm,
act_func=act_func,
fewer_norm=stage_id > 2,
)
block = ResidualBlock(block, IdentityLayer() if stride == 1 else None)
stage.append(block)
in_channels = w
stages.append(OpSequential(stage))
self.width_list.append(in_channels)
for stage_id, (w, d) in enumerate(zip(width_list[4:], depth_list[4:]), start=4):
stage = []
block = self.build_local_block(
stage_id=stage_id,
in_channels=in_channels,
out_channels=w,
stride=2,
expand_ratio=24,
norm=norm,
act_func=act_func,
fewer_norm=True,
)
stage.append(ResidualBlock(block, None))
in_channels = w
for _ in range(d):
stage.append(
EfficientViTBlock(
in_channels=in_channels, dim=qkv_dim, expand_ratio=6, norm=norm, act_func=act_func
)
)
stages.append(OpSequential(stage))
self.width_list.append(in_channels)
self.stages = nn.ModuleList(stages)
[docs] @staticmethod
def build_local_block(
stage_id: int,
in_channels: int,
out_channels: int,
stride: int,
expand_ratio: float,
norm: str,
act_func: str,
fewer_norm: bool = False,
) -> nn.Module:
if expand_ratio == 1:
block = ResBlock(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
use_bias=(True, False) if fewer_norm else False,
norm=(None, norm) if fewer_norm else norm,
act_func=(act_func, None),
)
elif stage_id <= 2:
block = FusedMBConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
expand_ratio=expand_ratio,
use_bias=(True, False) if fewer_norm else False,
norm=(None, norm) if fewer_norm else norm,
act_func=(act_func, None),
)
else:
block = MBConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
expand_ratio=expand_ratio,
use_bias=(True, True, False) if fewer_norm else False,
norm=(None, None, norm) if fewer_norm else norm,
act_func=(act_func, act_func, None),
)
return block
[docs] def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
output_dict = {"input": x}
for stage_id, stage in enumerate(self.stages):
output_dict["stage%d" % stage_id] = x = stage(x)
output_dict["stage_final"] = x
return output_dict
[docs]def efficientvit_backbone_l0(**kwargs: dict[str, Any]) -> EfficientViTLargeBackbone:
backbone = EfficientViTLargeBackbone(
width_list=[32, 64, 128, 256, 512],
depth_list=[1, 1, 1, 4, 4],
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
)
return backbone
[docs]def efficientvit_backbone_l1(**kwargs: dict[str, Any]) -> EfficientViTLargeBackbone:
backbone = EfficientViTLargeBackbone(
width_list=[32, 64, 128, 256, 512],
depth_list=[1, 1, 1, 6, 6],
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
)
return backbone
[docs]def efficientvit_backbone_l2(**kwargs: dict[str, Any]) -> EfficientViTLargeBackbone:
backbone = EfficientViTLargeBackbone(
width_list=[32, 64, 128, 256, 512],
depth_list=[1, 2, 2, 8, 8],
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
)
return backbone
[docs]def efficientvit_backbone_l3(**kwargs: dict[str, Any]) -> EfficientViTLargeBackbone:
backbone = EfficientViTLargeBackbone(
width_list=[64, 128, 256, 512, 1024],
depth_list=[1, 2, 2, 8, 8],
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
)
return backbone