Image Classification

Image Classification is a fundamental task that attempts to comprehend an entire image as a whole. The goal is to classify the image by assigning it to a specific label. Typically, Image Classification refers to images in which only one object appears and is analyzed. In contrast, object detection involves both classification and localization tasks, and is used to analyze more realistic cases in which multiple objects may exist in an image.

Learn more:


Kornia provides a couple of backbones based on transformers to perform image classification. Checkout the following apis VisionTransformer, ClassificationHead and combine as follows to customize your own classifier:

import torch.nn as nn
import kornia.contrib as K

classifier = nn.Sequential(
   K.VisionTransformer(image_size=224, patch_size=16),

img = torch.rand(1, 3, 224, 224)
out = classifier(img)     # BxN
scores = out.argmax(-1)   # B


Read more about our Vision Transformer (ViT)


In order to customize your model with your own data you can use our Training API (experimental) to perform the fine-tuning of your model.

We provide ImageClassifierTrainer with a default training structure to train basic image classification problems. However, one can leverage this is API using the models provided by Kornia or use existing libraries from the PyTorch ecosystem such as torchvision or timm.

    # create the model
    model = nn.Sequential(
        K.contrib.VisionTransformer(image_size=32, patch_size=16, embed_dim=128, num_heads=3),
        K.contrib.ClassificationHead(embed_size=128, num_classes=10),

    # create the dataset
    train_dataset = torchvision.datasets.CIFAR10(
        root=to_absolute_path(config.data_path), train=True, download=True, transform=T.ToTensor())

    valid_dataset = torchvision.datasets.CIFAR10(
        root=to_absolute_path(config.data_path), train=False, download=True, transform=T.ToTensor())

    # create the dataloaders
    train_dataloader =
        train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True)

    valid_daloader =
        valid_dataset, batch_size=config.batch_size, shuffle=True, num_workers=8, pin_memory=True)

    # create the loss function
    criterion = nn.CrossEntropyLoss()

    # instantiate the optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(),
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, config.num_epochs * len(train_dataloader))

Define your augmentations and callbacks:

    _augmentations = nn.Sequential(
            K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.8),
            grid_size=(2, 2),  # cifar-10 is 32x32 and vit is patch 16

    def augmentations(self, sample: dict) -> dict:
        out = _augmentations(sample["input"])
        return {"input": out, "target": sample["target"]}

    model_checkpoint = ModelCheckpoint(
        filepath="./outputs", monitor="top5",

Finally, instantiate the ImageClassifierTrainer and execute your training pipeline.

    trainer = ImageClassifierTrainer(
        model, train_dataloader, valid_daloader, criterion, optimizer, scheduler, config,
            "augmentations": augmentations, "on_checkpoint": model_checkpoint,

See also

Play with the full example here