"""Based from the original code from Meta Platforms, Inc. and affiliates.
https://github.com/facebookresearch/segment-
anything/blob/3518c86b78b3bc9cf4fbe3d18e682fad1c79dc51/segment_anything/build_sam.py
https://github.com/facebookresearch/segment-
anything/blob/3518c86b78b3bc9cf4fbe3d18e682fad1c79dc51/segment_anything/modeling/sam.py
"""
from __future__ import annotations
import warnings
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional
import torch
from kornia.contrib.models import SegmentationResults
from kornia.contrib.models.base import ModelBase
from kornia.contrib.models.sam.architecture.common import LayerNorm
from kornia.contrib.models.sam.architecture.image_encoder import ImageEncoderViT
from kornia.contrib.models.sam.architecture.mask_decoder import MaskDecoder
from kornia.contrib.models.sam.architecture.prompt_encoder import PromptEncoder
from kornia.contrib.models.sam.architecture.transformer import TwoWayTransformer
from kornia.contrib.models.tiny_vit import TinyViT
from kornia.core import Tensor
from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SHAPE
[docs]class SamModelType(Enum):
"""Map the SAM model types."""
vit_h = 0
vit_l = 1
vit_b = 2
mobile_sam = 3
[docs]@dataclass
class SamConfig:
"""Encapsulate the Config to build a SAM model.
Args:
model_type: the available models are:
- 0, 'vit_h' or :func:`kornia.contrib.sam.SamModelType.vit_h`
- 1, 'vit_l' or :func:`kornia.contrib.sam.SamModelType.vit_l`
- 2, 'vit_b' or :func:`kornia.contrib.sam.SamModelType.vit_b`
- 3, 'mobile_sam', or :func:`kornia.contrib.sam.SamModelType.mobile_sam`
checkpoint: URL or a path for a file with the weights of the model
encoder_embed_dim: Patch embedding dimension.
encoder_depth: Depth of ViT.
encoder_num_heads: Number of attention heads in each ViT block.
encoder_global_attn_indexes: Encoder indexes for blocks using global attention.
"""
model_type: Optional[str | int | SamModelType] = None
checkpoint: Optional[str] = None
pretrained: bool = False
encoder_embed_dim: Optional[int] = None
encoder_depth: Optional[int] = None
encoder_num_heads: Optional[int] = None
encoder_global_attn_indexes: Optional[tuple[int, ...]] = None
[docs]class Sam(ModelBase[SamConfig]):
mask_threshold: float = 0.0
[docs] def __init__(
self, image_encoder: ImageEncoderViT | TinyViT, prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder
) -> None:
"""SAM predicts object masks from an image and input prompts.
Args:
image_encoder: The backbone used to encode the image into image embeddings that allow for efficient mask
prediction.
prompt_encoder: Encodes various types of input prompts.
mask_decoder: Predicts masks from the image embeddings and encoded prompts.
"""
super().__init__()
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
[docs] @staticmethod
def from_config(config: SamConfig) -> Sam:
"""Build/load the SAM model based on it's config.
Args:
config: The SamConfig data structure. If the model_type is available, build from it, otherwise will use
the parameters set.
Returns:
The respective SAM model
Example:
>>> from kornia.contrib.models.sam import SamConfig
>>> sam_model = Sam.from_config(SamConfig('vit_b'))
"""
model_type = config.model_type
if isinstance(model_type, int):
model_type = SamModelType(model_type)
elif isinstance(model_type, str):
_map_sam_type = {
"vit_h": SamModelType.vit_h,
"vit_l": SamModelType.vit_l,
"vit_b": SamModelType.vit_b,
"mobile_sam": SamModelType.mobile_sam,
}
model_type = _map_sam_type[model_type]
if model_type == SamModelType.vit_b:
model = _build_sam(
encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, encoder_global_attn_indexes=(2, 5, 8, 11)
)
elif model_type == SamModelType.vit_l:
model = _build_sam(
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_global_attn_indexes=(5, 11, 17, 23),
)
elif model_type == SamModelType.vit_h:
model = _build_sam(
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_global_attn_indexes=(7, 15, 23, 31),
)
elif model_type == SamModelType.mobile_sam:
# TODO: merge this with _build_sam()
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
model = Sam(
image_encoder=TinyViT.from_config("5m", img_size=image_size, mobile_sam=True),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
# pixel_mean=[123.675, 116.28, 103.53],
# pixel_std=[58.395, 57.12, 57.375],
)
elif (
isinstance(config.encoder_embed_dim, int)
and isinstance(config.encoder_depth, int)
and isinstance(config.encoder_num_heads, int)
and isinstance(config.encoder_global_attn_indexes, int)
):
model = _build_sam(
encoder_embed_dim=config.encoder_embed_dim,
encoder_depth=config.encoder_depth,
encoder_num_heads=config.num_heads,
encoder_global_attn_indexes=config.encoder_global_attn_indexes,
)
else:
raise NotImplementedError("Unexpected config. The model_type should be provide or the encoder configs.")
checkpoint = config.checkpoint
if config.pretrained:
if checkpoint is None:
checkpoint = {
SamModelType.vit_b: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
SamModelType.vit_l: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
SamModelType.vit_h: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
SamModelType.mobile_sam: "https://github.com/ChaoningZhang/MobileSAM/raw/a509aac54fdd7af59f843135f2f7cee307283c88/weights/mobile_sam.pt",
}[model_type]
else:
warnings.warn("checkpoint is not None. pretrained=True is ignored")
if checkpoint:
model.load_checkpoint(checkpoint)
return model
[docs] @torch.no_grad()
def forward(
self, images: Tensor, batched_prompts: list[dict[str, Any]], multimask_output: bool
) -> list[SegmentationResults]:
"""Predicts masks end-to-end from provided images and prompts.
This method expects that the images have already been pre-processed, at least been normalized, resized and
padded to be compatible with the `self.image_encoder`.
.. note:: For each image :math:`(3, H, W)`, it is possible to input a batch (:math:`K`) of :math:`N` prompts,
the results are batched by the number of prompts batch. So given a prompt with :math:`K=5`, and
:math:`N=10`, the results will look like :math:`5xCxHxW` where :math:`C` is determined by
multimask_output. And within each of these masks :math:`(5xC)`, it should be possible to find
:math:`N` instances if the model succeed.
Args:
images: The image as a torch tensor in :math:`(B, 3, H, W)` format, already transformed for input to the
model.
batched_prompts: A list over the batch of images (list length should be :math:`B`), each a dictionary with
the following keys. If it does not have the respective prompt, it should not be included
in this dictionary. The options are:
- "points": tuple of (Tensor, Tensor) within the coordinate keypoints and their respective labels.
the tuple should look like (keypoints, labels), where:
- The keypoints (a tensor) are a batched point prompts for this image, with shape
:math:`(K, N, 2)`. Already transformed to the input frame of the model.
- The labels (a tensor) are a batched labels for point prompts, with shape :math:`(K, N)`.
Where 1 indicates a foreground point and 0 indicates a background point.
- "boxes": (Tensor) Batched box inputs, with shape :math:`(K, 4)`. Already transformed to the input
frame of the model.
- "mask_inputs": (Tensor) Batched mask inputs to the model, in the form :math:`(K, 1, H, W)`.
multimask_output: Whether the model should predict multiple disambiguating masks, or return a single mask.
Returns:
A list over input images, where each element is as SegmentationResults the following.
- logits: Low resolution logits with shape :math:`(K, C, H, W)`. Can be passed as mask input to
subsequent iterations of prediction. Where :math:`K` is the number of input prompts,
:math:`C` is determined by multimask_output, and :math:`H=W=256` are the model output size.
- scores: The model's predictions of mask quality (iou prediction), in shape BxC.
"""
KORNIA_CHECK_SHAPE(images, ["B", "3", "H", "W"])
KORNIA_CHECK(
images.shape[0] == len(batched_prompts),
"The number of images (`B`) should match with the length of prompts!",
)
image_embeddings = self.image_encoder(images)
outputs = []
for prompt_record, curr_embedding in zip(batched_prompts, image_embeddings):
# Embed prompts
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=prompt_record.get("points", None),
boxes=prompt_record.get("boxes", None),
masks=prompt_record.get("mask_inputs", None),
)
# Predict masks
low_res_logits, iou_predictions = self.mask_decoder(
image_embeddings=curr_embedding[None, ...],
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# Save results
outputs.append(SegmentationResults(low_res_logits, iou_predictions, self.mask_threshold))
return outputs
def _build_sam(
encoder_embed_dim: int, encoder_depth: int, encoder_num_heads: int, encoder_global_attn_indexes: tuple[int, ...]
) -> Sam:
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
return Sam(
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=LayerNorm,
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
# pixel_mean=[123.675, 116.28, 103.53],
# pixel_std=[58.395, 57.12, 57.375],
)