kornia.contrib¶
Models¶
Base¶
- class kornia.models.base.ModelBase(*args, **kwargs)[source]¶
Abstract model class with some utilities function.
- compile(*, fullgraph=False, dynamic=False, backend='inductor', mode=None, options=None, disable=False)[source]¶
Compile this model with
torch.compile().- Parameters:
fullgraph (
bool, optional) – Whether Dynamo should require a single full graph. Default:Falsedynamic (
bool, optional) – Whether dynamic shape tracing is enabled. Default:Falsebackend (
str, optional) – Compilation backend name passed totorch.compile(). Default:"inductor"mode (
Optional[str], optional) – Optional backend-specific compilation mode. Default:Noneoptions (
Optional[dict[Any,Any]], optional) – Optional backend-specific option dictionary. Default:Nonedisable (
bool, optional) – IfTrue, return an uncompiled model wrapper according to PyTorch’s compile semantics. Default:False
- Return type:
- Returns:
Compiled model object with the same high-level interface as this instance.
EfficientViT¶
- class kornia.models.efficient_vit.EfficientViT(backbone)[source]¶
EfficientViT backbone model.
- __init__(backbone)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(images)[source]¶
Extract features from the input images.
- Parameters:
images (
Tensor) – input images tensor of shape \((B, C, H, W)\).- Returns:
a dictionary containing the features.
- Return type:
Dict[str, torch.Tensor]
- static from_config(config)[source]¶
Build the EfficientViT model from a configuration object.
- Parameters:
config (
EfficientViTConfig) – EfficientViT configuration object. SeeEfficientViTConfig.- Returns:
the EfficientViT model.
- Return type:
- class kornia.models.efficient_vit.EfficientViTConfig(checkpoint=<factory>)[source]¶
Configuration to construct EfficientViT model.
Model weights can be loaded from a checkpoint URL or local path. The model weights are hosted on HuggingFace’s model hub: https://huggingface.co/kornia.
- Parameters:
checkpoint (
str, optional) – URL or local path of model weights. Default:<factory>
Backbones¶
- class kornia.models.efficient_vit.backbone.EfficientViTBackbone(width_list, depth_list, in_channels=3, dim=32, expand_ratio=4, norm='bn2d', act_func='hswish')[source]¶
Implement the EfficientViT backbone architecture.
EfficientViT is a high-speed vision transformer designed for efficient inference on mobile and edge devices by optimizing the attention mechanism and structural blocks.
- Parameters:
depth_list (
list[int]) – List of depths (number of blocks) for each stage.in_channels (
int, optional) – Number of input image channels. Default: 3.dim (
int, optional) – Dimension of the query, key, and value tensors in the attention mechanism. Default: 32.expand_ratio (
float, optional) – Expansion ratio for the MBConv blocks. Default: 4.norm (
str, optional) – Normalization layer type. Default: “bn2d”.act_func (
str, optional) – Activation function type. Default: “hswish”.
- static build_local_block(in_channels, out_channels, stride, expand_ratio, norm, act_func, fewer_norm=False)[source]¶
Build the local convolution block used between EfficientViT stages.
- Parameters:
in_channels (
int) – Number of input feature channels.out_channels (
int) – Number of output feature channels.stride (
int) – Spatial stride for the block.expand_ratio (
float) – Expansion ratio used by MBConv-style blocks.norm (
str) – Normalization layer name.act_func (
str) – Activation function name.fewer_norm (
bool, optional) – IfTrue, omit selected normalization layers. Default:False
- Return type:
- Returns:
Depthwise-separable or inverted-bottleneck convolution block.
- kornia.models.efficient_vit.backbone.efficientvit_backbone_b0(**kwargs)[source]¶
Create EfficientViT B0.
- Return type:
- kornia.models.efficient_vit.backbone.efficientvit_backbone_b1(**kwargs)[source]¶
Create EfficientViT B1.
- Return type:
- kornia.models.efficient_vit.backbone.efficientvit_backbone_b2(**kwargs)[source]¶
Create EfficientViT B2.
- Return type:
- kornia.models.efficient_vit.backbone.efficientvit_backbone_b3(**kwargs)[source]¶
Create EfficientViT B3.
- Return type:
- class kornia.models.efficient_vit.backbone.EfficientViTLargeBackbone(width_list, depth_list, in_channels=3, qkv_dim=32, norm='bn2d', act_func='gelu')[source]¶
Implement the large-scale variant of the EfficientViT backbone.
This backbone is designed for high-resolution dense prediction tasks. It utilizes multi-scale linear attention to achieve a global receptive field while maintaining linear computational complexity relative to the input resolution.
- Parameters:
width_list (
list[int]) – List of channel widths for each stage of the backbone.depth_list (
list[int]) – List of number of blocks for each stage.in_channels (
int, optional) – Number of input image channels. Default: 3.qkv_dim (
int, optional) – The internal dimension for query, key, and value projections in the attention layers. Default: 32.norm (
str, optional) – Normalization layer type to use (e.g., “bn2d”, “ln”). Default: “bn2d”.act_func (
str, optional) – Activation function type to use (e.g., “gelu”, “relu”). Default: “gelu”.
- static build_local_block(stage_id, in_channels, out_channels, stride, expand_ratio, norm, act_func, fewer_norm=False)[source]¶
Build a local block for an EfficientViT large stage.
- Parameters:
stage_id (
int) – Index of the stage being constructed.in_channels (
int) – Number of input feature channels.out_channels (
int) – Number of output feature channels.stride (
int) – Spatial stride for the block.expand_ratio (
float) – Expansion ratio controlling intermediate channels.norm (
str) – Normalization layer name.act_func (
str) – Activation function name.fewer_norm (
bool, optional) – IfTrue, use the reduced-normalization variant. Default:False
- Return type:
- Returns:
Residual, fused-MBConv, or MBConv block chosen for the stage.
- kornia.models.efficient_vit.backbone.efficientvit_backbone_l0(**kwargs)[source]¶
Create EfficientViT L0.
- Return type:
- kornia.models.efficient_vit.backbone.efficientvit_backbone_l1(**kwargs)[source]¶
Create EfficientViT L.
- Return type:
- kornia.models.efficient_vit.backbone.efficientvit_backbone_l2(**kwargs)[source]¶
Create EfficientViT L2.
- Return type:
Structures¶
- class kornia.models.structures.SegmentationResults(logits, scores, mask_threshold=0.0, _original_res_logits=None)[source]¶
Encapsulate the results obtained by a Segmentation model.
- Parameters:
- property binary_masks: Tensor¶
Binary mask generated from logits considering the mask_threshold.
Shape will be the same of logits \((B, C, H, W)\) where \(C\) is the number masks predicted.
Note
If you run original_res_logits, this will generate the masks based on the original resolution logits. Otherwise, this will use the low resolution logits (self.logits).
- original_res_logits(input_size, original_size, image_size_encoder)[source]¶
Remove padding and upscale the logits to the original image size.
Resize to image encoder input -> remove padding (bottom and right) -> Resize to original size
Note
This method set a internal original_res_logits which will be used if available for the binary masks.
- Parameters:
input_size (
tuple[int,int]) – The size of the image input to the model, in (H, W) format. Used to remove padding.original_size (
tuple[int,int]) – The original size of the image before resizing for input to the model, in (H, W) format.image_size_encoder (
Optional[tuple[int,int]]) – The size of the input image for image encoder, in (H, W) format. Used to resize the logits back to encoder resolution before remove the padding.
- Return type:
- Returns:
Batched logits in \((K, C, H, W)\) format, where (H, W) is given by original_size.
- class kornia.models.structures.Prompts(points=None, boxes=None, masks=None)[source]¶
Encapsulate the prompts inputs for a Model.
- Parameters:
points (
Optional[tuple[Tensor,Tensor]], optional) – A tuple with the keypoints (coordinates x, y) and their respective labels. Shape \((K, N, 2)\) for the keypoints, and \((K, N)\) Default:Noneboxes (
Optional[Tensor], optional) – Batched box inputs, with shape \((K, 4)\). Expected to be into xyxy format. Default:Nonemasks (
Optional[Tensor], optional) – Batched mask prompts to the model with shape \((K, 1, H, W)\) Default:None
VisualPrompter¶
- class kornia.contrib.visual_prompter.VisualPrompter(config=None, device=None, dtype=None)[source]¶
Allow the user to run multiple query with multiple prompts for a model.
At the moment, we just support the SAM model. The model is loaded based on the given config.
For default the images are transformed to have their long side with size of the image_encoder.img_size. This Prompter class ensure to transform the images and the prompts before prediction. Also, the image is passed automatically for the method preprocess_image, which is responsible for F.normalize the image and F.pad it to have the right size for the SAM model \((\text{image_encoder.img_size}, \text{image_encoder.img_size})\). For default the image is normalized by the mean and standard deviation of the SAM dataset values.
- Parameters:
config (
Optional[SamConfig], optional) – A model config to generate the model. Now just the SAM model is supported. Default:Nonedevice (
Optional[device], optional) – The desired device to use the model. Default:Nonedtype (
Optional[dtype], optional) – The desired dtype to use the model. Default:None
Example
>>> # prompter = VisualPrompter() # Will load the vit h for default >>> # You can load a custom SAM type for modifying the config >>> prompter = VisualPrompter(SamConfig('vit_b')) >>> image = torch.rand(3, 25, 30) >>> prompter.set_image(image) >>> boxes = Boxes( ... torch.tensor( ... [[[[0, 0], [0, 10], [10, 0], [10, 10]]]], ... device=prompter.device, ... dtype=torch.float32 ... ), ... mode='xyxy' ... ) >>> prediction = prompter.predict(boxes=boxes) >>> prediction.logits.shape torch.Size([1, 3, 256, 256])
- compile(*, fullgraph=False, dynamic=False, backend='inductor', mode=None, options=None, disable=False)[source]¶
Apply torch.compile(…)/dynamo API into the VisualPrompter API.
Note
For more information about the dynamo API check the official docs https://pytorch.org/docs/stable/generated/torch.compile.html
- Parameters:
fullgraph (
bool, optional) – Whether it is ok to break model into several subgraphs Default:Falsedynamic (
bool, optional) – Use dynamic shape tracing Default:Falsebackend (
str, optional) – backend to be used Default:"inductor"mode (
Optional[str], optional) – Can be either “default”, “reduce-overhead” or “max-autotune” Default:Noneoptions (
Optional[dict[Any,Any]], optional) – A dictionary of options to pass to the backend. Default:Nonedisable (
bool, optional) – Turn torch.compile() into a no-op for testing Default:False
- Return type:
Example
>>> # prompter = VisualPrompter() >>> # prompter.compile() # You should have torch >= 2.0.0 installed >>> # Use the prompter methods ...
- predict(keypoints=None, keypoints_labels=None, boxes=None, masks=None, multimask_output=True, output_original_size=True)[source]¶
Predict masks for the given image based on the input prompts.
- Parameters:
keypoints (
Union[Keypoints,Tensor,None], optional) – Point prompts to the model. Each point is in (X,Y) in pixels. Shape \((K, N, 2)\). Where N is the number of points and K the number of prompts. Default:Nonekeypoints_labels (
Optional[Tensor], optional) – Labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point. Shape \((K, N)\). Where N is the number of points, and K the number of prompts. Default:Noneboxes (
Union[Boxes,Tensor,None], optional) – A box prompt to the model. If a torch.Tensor, should be in a xyxy mode. Shape \((K, 4)\) Default:Nonemasks (
Optional[Tensor], optional) – A low resolution mask input to the model, typically coming from a previous prediction iteration. Has shape \((K, 1, H, W)\), where for SAM, H=W=256. Default:Nonemultimask_output (
bool, optional) – If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model’s predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results. Default:Trueoutput_original_size (
bool, optional) – If true, the logits of SegmentationResults will be post-process to match the original input image size. Default:True
- Return type:
- Returns:
A prediction with the logits and scores (IoU of each predicted mask)
- preprocess_image(x, mean=None, std=None)[source]¶
Normalize and F.pad a torch.Tensor.
For F.normalize the tensor: will prioritize the mean and std passed as argument, if None will use the default Sam Dataset values.
For F.pad the tensor: Will F.pad the torch.Tensor into the right and bottom to match with the size of self.model.image_encoder.img_size
- Parameters:
- Return type:
- Returns:
The image preprocessed (normalized if has mean and str available and padded to encoder size)
- preprocess_prompts(keypoints=None, keypoints_labels=None, boxes=None, masks=None)[source]¶
Validate and preprocess the given prompts to be aligned with the input image.
- Return type:
- reset_image()[source]¶
Clear cached image state and prompt-transform metadata.
This method invalidates previously computed image embeddings and resets all size/transform bookkeeping so a new call to
set_image()starts from a clean state.In practice, this resets: - transformed-image parameters, - original/input/encoder spatial sizes, - cached image embeddings, -
is_image_setstatus flag.- Return type:
- set_image(image, mean=None, std=None)[source]¶
Set the embeddings from the given image with image_decoder of the model.
Prepare the given image with the selected transforms and the preprocess method.
- Parameters:
image (
Tensor) – RGB image. Normally images with range of [0-1], the model preprocess F.normalize the pixel values with the mean and std defined in its initialization. Expected to be into a float32 dtype. Shape \((3, H, W)\).mean (
Optional[Tensor], optional) – mean value of dataset for normalization. Default:Nonestd (
Optional[Tensor], optional) – standard deviation of dataset for normalization. Default:None
- Return type:
Edge Detection¶
- class kornia.contrib.EdgeDetector(model, pre_processor, post_processor, name=None)[source]¶
EdgeDetector is a module that wraps an edge detection model.
This is a high-level API that wraps edge detection models like
kornia.models.DexiNed.- Parameters:
Example
>>> from kornia.models.dexined import DexiNed >>> from kornia.models.processors import ResizePreProcessor, ResizePostProcessor >>> model = DexiNed(pretrained=True) >>> detector = EdgeDetector(model, ResizePreProcessor(352, 352), ResizePostProcessor()) >>> img = torch.rand(1, 3, 320, 320) >>> out = detector(img)
Face Detection¶
- class kornia.contrib.FaceDetector(top_k=5000, confidence_threshold=0.3, nms_threshold=0.3, keep_top_k=750)[source]¶
Detect faces in a given image using YuNet model.
This is a high-level API that wraps the
kornia.models.YuNetmodel for face detection. By default, it uses the method described in [FYP+21].- Parameters:
top_k (
int, optional) – the maximum number of detections to return before the nms. Default:5000confidence_threshold (
float, optional) – the threshold used to discard detections. Default:0.3nms_threshold (
float, optional) – the threshold used by the nms for iou. Default:0.3keep_top_k (
int, optional) – the maximum number of detections to return after the nms. Default:750
- Returns:
A list of B tensors with shape \((N,15)\) to be used with
kornia.contrib.FaceDetectorResult.
Example
>>> img = torch.rand(1, 3, 320, 320) >>> detect = FaceDetector() >>> res = detect(img)
- class kornia.contrib.FaceKeypoint(value)[source]¶
Define the keypoints detected in a face.
The left/right convention is based on the screen viewer.
- EYE_LEFT = 0¶
- EYE_RIGHT = 1¶
- MOUTH_LEFT = 3¶
- MOUTH_RIGHT = 4¶
- NOSE = 2¶
- class kornia.contrib.FaceDetectorResult(data)[source]¶
Encapsulate the results obtained by the
kornia.contrib.FaceDetector.- Parameters:
data (
Tensor) – the encoded results coming from the feature detector with shape \((14,)\).
- property bottom_right: Tensor¶
The [x y] position of the bottom-right coordinate of the bounding box.
- get_keypoint(keypoint)[source]¶
Get the [x y] position of a given facial keypoint.
- Parameters:
keypoint (
FaceKeypoint) – the keypoint type to return the position.- Return type:
Interactive Demo¶
Visit the Kornia face detection demo on the Hugging Face Spaces.
Object Detection¶
- class kornia.contrib.object_detection.BoundingBoxDataFormat(value)[source]¶
Enum class that maps bounding box data format.
- XYWH = 0¶
- XYXY = 1¶
- CXCYWH = 2¶
- CENTER_XYWH = 2¶
- class kornia.contrib.object_detection.BoundingBox(data, data_format)[source]¶
Bounding box data class.
Useful for representing bounding boxes in different formats for object detection.
- Parameters:
- data_format: BoundingBoxDataFormat¶
- class kornia.contrib.object_detection.ObjectDetectorResult(class_id, confidence, bbox)[source]¶
Object detection result.
- Parameters:
class_id (
int) – class id of the detected object.confidence (
float) – confidence score of the detected object.bbox (
BoundingBox) – bounding box of the detected object in xywh format.
- bbox: BoundingBox¶
- class kornia.contrib.object_detection.ObjectDetector(model, pre_processor, post_processor)[source]¶
Wrap an object detection model and perform pre-processing and post-processing.
- compile(*, fullgraph=False, dynamic=False, backend='inductor', mode=None, options=None, disable=False)[source]¶
Compile the internal object detection model with
torch.compile().- Return type:
- forward(images)[source]¶
Detect objects in a given list of images.
- Parameters:
images (
Union[Tensor,list[Tensor]]) – If list of RGB images. Each image is a torch.Tensor with shape \((3, H, W)\). If torch.Tensor, a torch.Tensor with shape \((B, 3, H, W)\).- Return type:
- Returns:
list of detections found in each image. For item in a batch, shape is \((D, 6)\), where \(D\) is the number of detections in the given image, \(6\) represents class id, score, and xywh bounding box.
- static from_config(config)[source]¶
Build ObjectDetector from config.
This is a placeholder to satisfy the abstract method requirement. Use kornia.contrib.object_detection.RTDETRDetectorBuilder.build() or instantiate ObjectDetector directly.
- Parameters:
config (
Any) – Configuration object (not used, kept for interface compatibility).- Return type:
- Returns:
ObjectDetector instance.
- to_onnx(onnx_name=None, image_size=640, include_pre_and_post_processor=True, save=True, additional_metadata=None, **kwargs)[source]¶
Export an RT-DETR object detection model to ONNX format.
Either model_name or config must be provided. If neither is provided, a default pretrained model (rtdetr_r18vd) will be built.
- Parameters:
onnx_name (
Optional[str], optional) – The name of the output ONNX file. If not provided, a default name in the format “Kornia-<ClassName>.onnx” will be used. Default:Noneimage_size (
Optional[int], optional) – The size to which input images will be resized during preprocessing. If None, image_size will be dynamic. For RTDETR, recommended scales include [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]. Default:640include_pre_and_post_processor (
bool, optional) – Whether to include the pre-processor and post-processor in the exported model. Default:Truesave (
bool, optional) – If to save the model or load it. Default:Trueadditional_metadata (
Optional[list[tuple[str,str]]], optional) – Additional metadata to add to the ONNX model. Default:Nonekwargs (
Any) – Additional arguments to convert to onnx.
- Return type:
ModelProto
- class kornia.contrib.object_detection.ResizePreProcessor(height, width, interpolation_mode='bilinear')[source]¶
Resize a list of image tensors to the given size.
Additionally, also returns the original image sizes for further post-processing.
- Parameters:
Example
>>> import torch >>> from kornia.models.processors import ResizePreProcessor >>> processor = ResizePreProcessor(height=224, width=224) >>> imgs = torch.randn(2, 3, 480, 640) >>> resized, sizes = processor(imgs) >>> print(resized.shape, sizes.shape) torch.Size([2, 3, 224, 224]) torch.Size([2, 2])
- forward(imgs)[source]¶
Resize input images to the target size.
- Parameters:
imgs (
Union[Tensor,List[Tensor]]) – Input images, either a tensor of shape \((B, C, H, W)\) or a list of tensors of shape \((C, H, W)\).- Returns:
resized_imgs: Resized images as a tensor of shape \((B, C, H_{\text{new}}, W_{\text{new}})\).
original_sizes: Original image sizes of shape \((B, 2)\) containing (height, width).
- Return type:
Tuple containing
- kornia.contrib.object_detection.results_from_detections(detections, format)[source]¶
Convert a detection torch.Tensor to a list of
ObjectDetectorResult.- Parameters:
detections (
Tensor) – torch.Tensor with shape \((D, 6)\), where \(D\) is the number of detections in the given image, \(6\) represents class id, score, and xywh bounding box.format (
str|BoundingBoxDataFormat) – detection format.
- Return type:
- Returns:
list of
ObjectDetectorResult.
Real-Time Detection Transformer (RT-DETR)¶
- class kornia.models.rt_detr.RTDETRModelType(value)[source]¶
Enum class that maps RT-DETR model type.
- resnet18d = 0¶
- resnet34d = 1¶
- resnet50d = 2¶
- resnet101d = 3¶
- hgnetv2_l = 4¶
- hgnetv2_x = 5¶
- resnet50d_m = 6¶
- class kornia.models.rt_detr.RTDETRConfig(model_type, num_classes, input_size=640, checkpoint=None, neck_hidden_dim=None, neck_dim_feedforward=None, neck_expansion=None, head_hidden_dim=256, head_num_queries=300, head_num_decoder_layers=None, confidence_threshold=0.3)[source]¶
Configuration to construct RT-DETR model.
- Parameters:
model_type (
RTDETRModelType|str|int) –model variant. Available models are
ResNetD-18:
0,'resnet18d'orRTDETRModelType.resnet18dResNetD-34:
1,'resnet34d'orRTDETRModelType.resnet34dResNetD-50:
2,'resnet50d'orRTDETRModelType.resnet50dResNetD-101:
3,'resnet101d'orRTDETRModelType.resnet101dHGNetV2-L:
4,'hgnetv2_l'orRTDETRModelType.hgnetv2_lHGNetV2-X:
5,'hgnetv2_x'orRTDETRModelType.hgnetv2_x
num_classes (
int) – number of classes.checkpoint (
Optional[str], optional) – URL or local path of model weights. Default:Noneneck_hidden_dim (
Optional[int], optional) – hidden dim for neck. Default:Noneneck_dim_feedforward (
Optional[int], optional) – feed-forward network dim for neck. Default:Noneneck_expansion (
Optional[float], optional) – expansion ratio for neck. Default:Nonehead_hidden_dim (
int, optional) – hidden dim for head. Default:256head_num_queries (
int, optional) – number of queries for Deformable DETR transformer decoder. Default:300head_num_decoder_layers (
Optional[int], optional) – number of decoder layers for Deformable DETR transformer decoder. Default:None
- static from_name(model_name, num_classes=80)[source]¶
Load model without pretrained weights.
- Parameters:
- Return type:
- model_type: RTDETRModelType | str | int¶
- class kornia.models.rt_detr.RTDETR(backbone, encoder, decoder)[source]¶
RT-DETR Object Detection model, as described in https://arxiv.org/abs/2304.08069.
- __init__(backbone, encoder, decoder)[source]¶
Construct RT-DETR Object Detection model.
- Parameters:
backbone (
ResNetD|PPHGNetV2) – backbone network for feature extraction.encoder (
HybridEncoder) – neck network for feature fusion.decoder (
RTDETRHead) – head network to decode features into detection results.
- forward(images)[source]¶
Detect objects in an image.
- Parameters:
images (
Tensor) – images to be detected. Shape \((N, C, H, W)\).- Return type:
- Returns:
logits - Tensor of shape \((N, Q, K)\), where \(Q\) is the number of queries, \(K\) is the number of classes.
boxes - Tensor of shape \((N, Q, 4)\), where \(Q\) is the number of queries.
- static from_config(config)[source]¶
Construct RT-DETR Object Detection model from a config object.
- Parameters:
config (
RTDETRConfig) – configuration object for RT-DETR.- Return type:
Note
For
config.neck_hidden_dim,config.neck_dim_feedforward,config.neck_expansion, andconfig.head_num_decoder_layers, if they areNone, their values will be replaced with the default values depending on theconfig.model_type. See the source code for the default values.
- class kornia.models.rt_detr.DETRPostProcessor(confidence_threshold=None, num_classes=80, num_top_queries=300, confidence_filtering=True, filter_as_zero=False)[source]¶
Convert raw DETR model outputs into final bounding box detections.
This module applies the softmax function to scores and transforms normalized bounding box coordinates into the pixel coordinate system of the input image.
- Parameters:
num_classes (
int, optional) – The number of object classes. Default:80confidence_threshold (
Optional[float], optional) – The threshold to filter out low-confidence detections. Default:Nonenum_top_queries (
int, optional) – The number of top queries to consider for each image. Default:300confidence_filtering (
bool, optional) – Whether to apply confidence-based filtering. Default:Truefilter_as_zero (
bool, optional) – If True, boxes below the confidence threshold are set to zero instead of being removed. Default:False
- forward(logits, boxes, original_sizes)[source]¶
Post-process outputs from DETR.
- Parameters:
logits (
Tensor) – tensor with shape \((N, Q, K)\), where \(N\) is the batch size, \(Q\) is the number of queries, \(K\) is the number of classes.boxes (
Tensor) – tensor with shape \((N, Q, 4)\), where \(N\) is the batch size, \(Q\) is the number of queries.original_sizes (
Tensor) – tensor with shape \((N, 2)\), where \(N\) is the batch size and each element represents the image size of (img_height, img_width).
- Return type:
- Returns:
Processed detections. For each image, the detections have shape (D, 6), where D is the number of detections in that image, 6 represent (class_id, confidence_score, x, y, w, h).
Image Segmentation¶
- kornia.contrib.connected_components(image, num_iterations=100)[source]¶
Compute the Connected-component labelling (CCL) algorithm.
The implementation is an adaptation of the following repository:
https://gist.github.com/efirdc/5d8bd66859e574c683a504a4690ae8bc
Warning
This is an experimental API subject to changes and optimization improvements.
Note
See a working example here.
- Parameters:
- Return type:
- Returns:
The labels image with the same shape of the input image.
Example
>>> img = torch.rand(2, 1, 4, 5) >>> img_labels = connected_components(img, num_iterations=100)
Segment Anything (SAM)¶
- class kornia.models.sam.SamModelType(value)[source]¶
Map the SAM model types.
- vit_h = 0¶
- vit_l = 1¶
- vit_b = 2¶
- mobile_sam = 3¶
- class kornia.models.sam.SamConfig(model_type=None, checkpoint=None, pretrained=False, encoder_embed_dim=None, encoder_depth=None, encoder_num_heads=None, encoder_global_attn_indexes=None)[source]¶
Encapsulate the Config to build a SAM model.
- Parameters:
model_type (
Union[str,int,SamModelType,None], optional) –the available models are: Default:
None0, ‘vit_h’ or
kornia.contrib.sam.SamModelType.vit_h()1, ‘vit_l’ or
kornia.contrib.sam.SamModelType.vit_l()2, ‘vit_b’ or
kornia.contrib.sam.SamModelType.vit_b()3, ‘mobile_sam’, or
kornia.contrib.sam.SamModelType.mobile_sam()
checkpoint (
Optional[str], optional) – URL or a path for a file with the weights of the model Default:Noneencoder_embed_dim (
Optional[int], optional) – Patch embedding dimension. Default:Noneencoder_depth (
Optional[int], optional) – Depth of ViT. Default:Noneencoder_num_heads (
Optional[int], optional) – Number of attention heads in each ViT block. Default:Noneencoder_global_attn_indexes (
Optional[tuple[int,...]], optional) – Encoder indexes for blocks using global attention. Default:None
- model_type: str | int | SamModelType | None = None¶
- class kornia.models.sam.Sam(image_encoder, prompt_encoder, mask_decoder)[source]¶
Implement the Segment Anything Model (SAM) wrapper.
This class coordinates the image encoder, prompt encoder, and mask decoder.
- __init__(image_encoder, prompt_encoder, mask_decoder)[source]¶
SAM predicts object masks from an image and input prompts.
- Parameters:
image_encoder (
ImageEncoderViT|TinyViT) – The backbone used to encode the image into image embeddings that allow for efficient mask prediction.prompt_encoder (
PromptEncoder) – Encodes various types of input prompts.mask_decoder (
MaskDecoder) – Predicts masks from the image embeddings and encoded prompts.
- forward(images, batched_prompts, multimask_output)[source]¶
Predicts masks end-to-end from provided images and prompts.
This method expects that the images have already been pre-processed, at least been normalized, resized and padded to be compatible with the self.image_encoder.
Note
For each image \((3, H, W)\), it is possible to input a batch (\(K\)) of \(N\) prompts, the results are batched by the number of prompts batch. So given a prompt with \(K=5\), and \(N=10\), the results will look like \(5xCxHxW\) where \(C\) is determined by multimask_output. And within each of these masks \((5xC)\), it should be possible to find \(N\) instances if the model succeed.
- Parameters:
images (
Tensor) – The image as a torch tensor in \((B, 3, H, W)\) format, already transformed for input to the model.batched_prompts (
list[dict[str,Any]]) –- A list over the batch of images (list length should be \(B\)), each a dictionary with
the following keys. If it does not have the respective prompt, it should not be included in this dictionary. The options are:
”points”: tuple of (torch.Tensor, torch.Tensor) within the coordinate keypoints and their respective labels. The tuple should look like (keypoints, labels), where the keypoints (a tensor) are a batched point prompts for this image, with shape \((K, N, 2)\). Already transformed to the input frame of the model. The labels (a tensor) are a batched labels for point prompts, with shape \((K, N)\). Where 1 indicates a foreground point and 0 indicates a background point.
”boxes”: (torch.Tensor) Batched box inputs, with shape \((K, 4)\). Already transformed to the input frame of the model.
”mask_inputs”: (torch.Tensor) Batched mask inputs to the model, in the form \((K, 1, H, W)\).
multimask_output (
bool) – Whether the model should predict multiple disambiguating masks, or return a single mask.
- Returns:
logits: Low resolution logits with shape \((K, C, H, W)\). Can be passed as mask input to subsequent iterations of prediction. Where \(K\) is the number of input prompts, \(C\) is determined by multimask_output, and \(H=W=256\) are the model output size.
scores: The model’s predictions of mask quality (iou prediction), in shape BxC.
- Return type:
A list over input images, where each element is as SegmentationResults the following
- static from_config(config)[source]¶
Build/load the SAM model based on it’s config.
- Parameters:
config (
SamConfig) – The SamConfig data structure. If the model_type is available, build from it, otherwise will use the parameters set.- Return type:
- Returns:
The respective SAM model
Example
>>> from kornia.models.sam import SamConfig >>> sam_model = Sam.from_config(SamConfig('vit_b'))
Image Patches¶
- kornia.contrib.compute_padding(original_size, window_size, stride=None)[source]¶
Compute required padding to ensure chaining of
extract_tensor_patches()andcombine_tensor_patches()produces expected result.- Parameters:
original_size (
Union[int,Tuple[int,int]]) – the size of the original torch.Tensor.window_size (
Union[int,Tuple[int,int]]) – the size of the sliding window used while extracting patches.stride (
Union[int,Tuple[int,int],None], optional) – The stride of the sliding window. Optional: if not specified, window_size will be used. Default:None
- Returns:
(top, bottom, left, right)
- Return type:
The required padding as a tuple of four ints
Example
>>> image = torch.arange(12).view(1, 1, 4, 3) >>> padding = compute_padding((4,3), (3,3)) >>> out = extract_tensor_patches(image, window_size=(3, 3), stride=(3, 3), padding=padding) >>> combine_tensor_patches(out, original_size=(4, 3), window_size=(3, 3), stride=(3, 3), unpadding=padding) tensor([[[[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]]])
Note
This function will be implicitly used in
extract_tensor_patches()andcombine_tensor_patches()if allow_auto_(un)padding is set to True.
- kornia.contrib.extract_tensor_patches(input, window_size, stride=1, padding=0, allow_auto_padding=False)[source]¶
Extract patches from tensors and stacks them.
See
ExtractTensorPatchesfor details.- Parameters:
input (
Tensor) – torch.Tensor image where to extract the patches with shape \((B, C, H, W)\).window_size (
Union[int,Tuple[int,int]]) – the size of the sliding window and the output patch size.stride (
Union[int,Tuple[int,int]], optional) – stride of the sliding window. Default:1padding (
Union[int,Tuple[int,int],Tuple[int,int,int,int]], optional) – Zero-padding added to both side of the input. Default:0allow_auto_padding (
bool, optional) – whether to allow automatic padding if the window and stride do not fit into the image. Default:False
- Return type:
- Returns:
the torch.Tensor with the extracted patches with shape \((B, N, C, H_{out}, W_{out})\).
Examples
>>> input = torch.arange(9.).view(1, 1, 3, 3) >>> patches = extract_tensor_patches(input, (2, 3)) >>> input tensor([[[[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]]]]) >>> patches[:, -1] tensor([[[[3., 4., 5.], [6., 7., 8.]]]])
- kornia.contrib.combine_tensor_patches(patches, original_size, window_size, stride, allow_auto_unpadding=False, unpadding=0, eps=1e-8)[source]¶
Restore input from patches.
See
CombineTensorPatchesfor details.- Parameters:
patches (
Tensor) – patched torch.Tensor with shape \((B, N, C, H_{out}, W_{out})\).original_size (
Union[int,Tuple[int,int]]) – the size of the original torch.Tensor and the output size.window_size (
Union[int,Tuple[int,int]]) – the size of the sliding window used while extracting patches.stride (
Union[int,Tuple[int,int]]) – stride of the sliding window.unpadding (
Union[int,Tuple[int,int],Tuple[int,int,int,int]], optional) – remove the padding added to both side of the input. Default:0allow_auto_unpadding (
bool, optional) – whether to allow automatic unpadding of the input if the window and stride do not fit into the original_size. Default:Falseeps (
float, optional) – small value used to prevent division by zero. Default:1e-8
- Return type:
- Returns:
The combined patches in an image torch.Tensor with shape \((B, C, H, W)\).
Example
>>> out = extract_tensor_patches(torch.arange(16).view(1, 1, 4, 4), window_size=(2, 2), stride=(2, 2)) >>> combine_tensor_patches(out, original_size=(4, 4), window_size=(2, 2), stride=(2, 2)) tensor([[[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15]]]])
Note
This function is supposed to be used in conjunction with
extract_tensor_patches().
- class kornia.contrib.ExtractTensorPatches(window_size, stride=1, padding=0, allow_auto_padding=False)[source]¶
nn.Module that extract patches from tensors and torch.stack them.
In the simplest case, the output value of the operator with input size \((B, C, H, W)\) is \((B, N, C, H_{out}, W_{out})\).
- where
\(B\) is the batch size.
\(N\) denotes the total number of extracted patches stacked in
\(C\) denotes the number of input channels.
\(H\), \(W\) the input height and width of the input in pixels.
\(H_{out}\), \(W_{out}\) denote to denote to the patch size defined in the function signature. left-right and top-bottom order.
window_sizeis the size of the sliding window and controls the shape of the output torch.Tensor and defines the shape of the output patch.stridecontrols the stride to apply to the sliding window and regulates the overlapping between the extracted patches.paddingcontrols the amount of implicit torch.zeros-paddings on both sizes at each dimension.allow_auto_paddingallows automatic calculation of the padding required to fit the window and stride into the image.
The parameters
window_size,strideandpaddingcan be either:a single
int– in which case the same value is used for the height and width dimension.a
tupleof two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension.
paddingcan also be atupleof four ints – in which case, the first two ints are for the height dimension while the last two ints are for the width dimension.- Parameters:
input – torch.Tensor image where to extract the patches with shape \((B, C, H, W)\).
window_size (
Union[int,Tuple[int,int]]) – the size of the sliding window and the output patch size.stride (
Union[int,Tuple[int,int]], optional) – stride of the sliding window. Default:1padding (
Union[int,Tuple[int,int],Tuple[int,int,int,int]], optional) – Zero-padding added to both side of the input. Default:0allow_auto_adding – whether to allow automatic padding if the window and stride do not fit into the image.
- Shape:
Input: \((B, C, H, W)\)
Output: \((B, N, C, H_{out}, W_{out})\)
- Returns:
the torch.Tensor with the extracted patches.
Examples
>>> input = torch.arange(9.).view(1, 1, 3, 3) >>> patches = extract_tensor_patches(input, (2, 3)) >>> input tensor([[[[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]]]]) >>> patches[:, -1] tensor([[[[3., 4., 5.], [6., 7., 8.]]]])
- class kornia.contrib.CombineTensorPatches(original_size, window_size, stride=None, unpadding=0, allow_auto_unpadding=False)[source]¶
nn.Module that combines patches back into full tensors.
In the simplest case, the output value of the operator with input size \((B, N, C, H_{out}, W_{out})\) is \((B, C, H, W)\).
- where
\(B\) is the batch size.
\(N\) denotes the total number of extracted patches stacked in
\(C\) denotes the number of input channels.
\(H\), \(W\) the input height and width of the input in pixels.
\(H_{out}\), \(W_{out}\) denote to denote to the patch size defined in the function signature. left-right and top-bottom order.
original_sizeis the size of the original image prior to extracting torch.Tensor patches and defines the shape of the output patch.window_sizeis the size of the sliding window used while extracting torch.Tensor patches.stridecontrols the stride to apply to the sliding window and regulates the overlapping between the extracted patches.unpaddingis the amount of padding to be removed. If specified, this value must be the same as padding used while extracting torch.Tensor patches.allow_auto_unpaddingallows automatic calculation of the padding required to fit the window and stride into the image. This must be used if the allow_auto_padding flag was used for extracting the patches.
The parameters
original_size,window_size,stride, andunpaddingcan be either:a single
int– in which case the same value is used for the height and width dimension.a
tupleof two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension.
unpaddingcan also be atupleof four ints – in which case, the first two ints are for the height dimension while the last two ints are for the width dimension.- Parameters:
patches – patched torch.Tensor with shape \((B, N, C, H_{out}, W_{out})\).
original_size (
Tuple[int,int]) – the size of the original torch.Tensor and the output size.window_size (
Union[int,Tuple[int,int]]) – the size of the sliding window used while extracting patches.stride (
Union[int,Tuple[int,int],None], optional) – stride of the sliding window. Default:Noneunpadding (
Union[int,Tuple[int,int],Tuple[int,int,int,int]], optional) – remove the padding added to both side of the input. Default:0allow_auto_unpadding (
bool, optional) – whether to allow automatic unpadding of the input if the window and stride do not fit into the original_size. Default:Falseeps – small value used to prevent division by zero.
- Shape:
Input: \((B, N, C, H_{out}, W_{out})\)
Output: \((B, C, H, W)\)
Example
>>> out = extract_tensor_patches(torch.arange(16).view(1, 1, 4, 4), window_size=(2, 2), stride=(2, 2)) >>> combine_tensor_patches(out, original_size=(4, 4), window_size=(2, 2), stride=(2, 2)) tensor([[[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15]]]])
Note
This function is supposed to be used in conjunction with
ExtractTensorPatches.
Image Classification¶
- class kornia.models.vit.VisionTransformer(image_size=224, patch_size=16, in_channels=3, embed_dim=768, depth=12, num_heads=12, dropout_rate=0.0, dropout_attn=0.0, backbone=None)[source]¶
Vision transformer (ViT) module.
The module is expected to be used as operator for different vision tasks.
The method is inspired from existing implementations of the paper [DBK+21].
Warning
This is an experimental API subject to changes in favor of flexibility.
- Parameters:
image_size (
int, optional) – the size of the input image. Default:224patch_size (
int, optional) – the size of the patch to compute the embedding. Default:16in_channels (
int, optional) – the number of channels for the input. Default:3embed_dim (
int, optional) – the embedding dimension inside the transformer encoder. Default:768depth (
int, optional) – the depth of the transformer. Default:12num_heads (
int, optional) – the number of attention heads. Default:12dropout_rate (
float, optional) – dropout rate. Default:0.0dropout_attn (
float, optional) – attention dropout rate. Default:0.0backbone (
Module|None, optional) – an nn.Module to compute the image patches embeddings. Default:None
Example
>>> img = torch.rand(1, 3, 224, 224) >>> vit = VisionTransformer(image_size=224, patch_size=16) >>> vit(img).shape torch.Size([1, 197, 768])
- property encoder_results: list[Tensor]¶
Return intermediate outputs captured by the transformer encoder.
- Returns:
List of tensors produced by the encoder blocks. Each tensor stores token embeddings for a layer, typically shaped \((B, N, D)\), where \(B\) is batch size, \(N\) is token count, and \(D\) is embedding dimension.
- forward(x)[source]¶
Encode an image batch into Vision Transformer token embeddings.
- Parameters:
x (
Tensor) – Image tensor with shape \((B, C, H, W)\), where \(B\) is batch size, \(C\) must matchself.in_channels, and \(H\) and \(W\) are expected to matchself.image_size.- Return type:
- Returns:
Normalized token embedding tensor produced by patch embedding and the transformer encoder. The output shape follows the encoder layout, usually \((B, N, D)\).
- static from_config(variant, pretrained=False, **kwargs)[source]¶
Build ViT model based on the given config string.
The format is
vit_{size}/{patch_size}. E.g.vit_b/16means ViT-Base, patch size 16x16. Ifpretrained=True, AugReg weights are loaded. The weights are hosted on HuggingFace’s model hub: https://huggingface.co/kornia.Note
The available weights are:
vit_l/16,vit_b/16,vit_s/16,vit_ti/16,vit_b/32,vit_s/32.- Parameters:
variant (
str) – ViT model variant e.g.vit_b/16.pretrained (
bool, optional) – whether to load pre-trained AugReg weights. Default:Falsekwargs (
Any) – other keyword arguments that will be passed tokornia.models.vit.VisionTransformer().
- Return type:
- Returns:
The respective ViT model
Example
>>> from kornia.models.vit import VisionTransformer >>> vit_model = VisionTransformer.from_config("vit_b/16", pretrained=True)
- class kornia.models.vit_mobile.MobileViT(mode='xxs', in_channels=3, patch_size=(2, 2), dropout=0.0)[source]¶
Module MobileViT. Default arguments is for MobileViT XXS.
Paper: https://arxiv.org/abs/2110.02178 Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
- Parameters:
mode (
str, optional) – ‘xxs’, ‘xs’ or ‘s’, defaults to ‘xxs’. Default:"xxs"in_channels (
int, optional) – the number of channels for the input image. Default:3patch_size (
Tuple[int,int], optional) – image_size must be divisible by patch_size. Default:(2, 2)dropout (
float, optional) – dropout ratio in Transformer. Default:0.0
Example
>>> img = torch.rand(1, 3, 256, 256) >>> mvit = MobileViT(mode='xxs') >>> mvit(img).shape torch.Size([1, 320, 8, 8])
- class kornia.contrib.TinyViT(img_size=224, in_chans=3, num_classes=1000, embed_dims=(96, 192, 384, 768), depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), window_sizes=(7, 7, 14, 7), mlp_ratio=4.0, drop_rate=0.0, drop_path_rate=0.0, use_checkpoint=False, mbconv_expand_ratio=4.0, local_conv_size=3, activation=nn.GELU, mobile_sam=False)[source]¶
TinyViT model, as described in https://arxiv.org/abs/2207.10666.
- Parameters:
img_size (
int, optional) – Size of input image. Default:224in_chans (
int, optional) – Number of input image’s channels. Default:3num_classes (
int, optional) – Number of output classes. Default:1000embed_dims (
Sequence[int], optional) – List of embedding dimensions. Default:(96, 192, 384, 768)depths (
Sequence[int], optional) – List of block count for each downsampling stage Default:(2, 2, 6, 2)num_heads (
Sequence[int], optional) – List of attention heads used in self-attention for each downsampling stage. Default:(3, 6, 12, 24)window_sizes (
Sequence[int], optional) – List of self-attention’s window size for each downsampling stage. Default:(7, 7, 14, 7)mlp_ratio (
float, optional) – Ratio of MLP dimension to embedding dimension in self-attention. Default:4.0drop_rate (
float, optional) – Dropout rate. Default:0.0drop_path_rate (
float, optional) – Stochastic depth rate. Default:0.0use_checkpoint (
bool, optional) – Whether to use activation checkpointing to trade compute for memory. Default:Falsembconv_expand_ratio (
float, optional) – Expansion ratio used in MBConv block. Default:4.0local_conv_size (
int, optional) – Kernel size of convolution used in TinyViTBlock Default:3activation (
type[Module], optional) – activation function. Default:nn.GELUmobile_same – Whether to use modifications for MobileSAM.
- forward(x)[source]¶
Classify images if
mobile_sam=False, produce feature maps ifmobile_sam=True.- Return type:
- static from_config(variant, pretrained=False, **kwargs)[source]¶
Create a TinyViT model from pre-defined variants.
- Parameters:
variant (
str) – TinyViT variant. Possible values:'5m','11m','21m'.pretrained (
bool|str, optional) – whether to use pre-trained weights. Possible values:False,True,'in22k','in1k'. For TinyViT-21M (variant='21m'),'in1k_384','in1k_512'are also available. Default:False**kwargs (
Any) – other keyword arguments that will be passed toTinyViT.
- Return type:
Note
When
img_sizeis different from the pre-trained size, bicubic interpolation will be performed on attention biases. When usingpretrained=True, ImageNet-1k checkpoint ('in1k') is used. For feature extraction or fine-tuning, ImageNet-22k checkpoint ('in22k') is preferred.
Image Stitching¶
- class kornia.contrib.ImageStitcher(matcher, estimator='ransac', blending_method='naive')[source]¶
Stitch two images with overlapping fields of view.
- Parameters:
matcher (
Module) – image feature matching module.estimator (
str, optional) – method to compute homography, either “vanilla” or “ransac”. “ransac” is slower with a better accuracy. Default:"ransac"blending_method (
str, optional) – method to blend two images together. Only “naive” is currently supported. Default:"naive"
Note
Current implementation requires strict image ordering from left to right.
IS = ImageStitcher(KF.LoFTR(pretrained='outdoor'), estimator='ransac').cuda() # Compute the stitched result with less GPU memory cost. with torch.inference_mode(): out = IS(img_left, img_right) # Show the result plt.imshow(K.tensor_to_image(out))
Lambda¶
- class kornia.contrib.Lambda(func)[source]¶
Applies user-defined lambda as a transform.
- Parameters:
- Returns:
The output of the user-defined lambda.
Example
>>> import kornia >>> x = torch.rand(1, 3, 5, 5) >>> f = Lambda(lambda x: kornia.color.rgb_to_grayscale(x)) >>> f(x).shape torch.Size([1, 1, 5, 5])
Distance Transform¶
- kornia.contrib.distance_transform(image, kernel_size=3, h=0.35)[source]¶
Approximates the Euclidean distance transform of images/volumes using cascaded convolution operations.
The value at each pixel/voxel represents the distance to the nearest non-zero element. It uses the method described in [PDP20]. The transformation is applied independently across the channel dimension.
- Parameters:
- Return type:
- Returns:
tensor with the same shape as input.
Example
>>> # 2D example: >>> tensor = torch.zeros(1, 1, 5, 5) >>> tensor[:,:, 1, 2] = 1 >>> dt = distance_transform(tensor) >>> # 3D example: >>> volume = torch.zeros(1, 1, 5, 5, 5) >>> volume[:, :, 2, 2, 2] = 1 >>> dt = distance_transform(volume)
- kornia.contrib.diamond_square(output_size, roughness=0.5, random_scale=1.0, random_fn=torch.rand, normalize_range=None, device=None, dtype=None)[source]¶
Generate Plasma Fractal Images using the diamond square algorithm.
See: https://en.wikipedia.org/wiki/Diamond-square_algorithm
- Parameters:
output_size (
Tuple[int,int,int,int]) – a tuple of integers with the BxCxHxW of the image to be generated.roughness (
Union[float,Tensor], optional) – the scale value to apply at each recursion step. Default:0.5random_scale (
Union[float,Tensor], optional) – the initial value of the scale for recursion. Default:1.0random_fn (
Callable[...,Tensor], optional) – the callable function to use to sample a random torch.Tensor. Default:torch.randnormalize_range (
Optional[Tuple[float,float]], optional) – whether to F.normalize using min-max the output map. In case of a range is specified, min-max norm is applied between the provided range. Default:Nonedevice (
Optional[device], optional) – the torch device to place the output map. Default:Nonedtype (
Optional[dtype], optional) – the torch dtype to place the output map. Default:None
- Return type:
- Returns:
A torch.Tensor with shape \((B,C,H,W)\) containing the fractal image.
KMeans¶
- class kornia.contrib.KMeans(num_clusters, cluster_centers, tolerance=10e-4, max_iterations=0, seed=None)[source]¶
Implements the kmeans clustering algorithm with euclidean distance as similarity measure.
- Parameters:
num_clusters (
int) – number of clusters the data has to be assigned tocluster_centers (
Tensor|None) – torch.Tensor of starting cluster centres can be passed instead of num_clusterstolerance (
float, optional) – float value. the algorithm terminates if the shift in centers is less than tolerance Default:10e-4max_iterations (
int, optional) – number of iterations to run the algorithm for Default:0seed (
int|None, optional) – number to set torch manual seed for reproducibility Default:None
Example
>>> kmeans = kornia.contrib.KMeans(3, None, 10e-4, 100, 0) >>> kmeans.fit(torch.rand((1000, 5))) >>> predictions = kmeans.predict(torch.rand((10, 5)))
- property cluster_assignments: Tensor¶
Return cluster labels assigned during the most recent
fitcall.
- property cluster_centers: Tensor¶
Return the current cluster centers.
- Returns:
Cis the number of clusters.Dis the feature dimension of each sample.
If
fit()has already been called, this returns the learned final centers. Otherwise, it returns the initialization provided during construction.- Return type:
A tensor with shape \((C, D)\)
- Raises:
TypeError – If no initial centers were provided and
fithas not been run.