Training API (experimental)#

Kornia provides a Training API with the specific purpose to train and fine-tune the supported deep learning algorithms within the library.


In order to use our Training API you must: pip install kornia[x]

Why a Training API ?#

Kornia includes deep learning models that eventually need to be updated through fine-tuning. Our aim is to have an API flexible enough to be used across our vision models and enable us to override methods or dynamically pass callbacks to ease the process of debugging and experimentations.


We do not pretend to be a general purpose training library but instead we allow Kornia users to experiment with the training of our models.

Design Principles#

  • kornia golden rule is to not have heavy dependencies.

  • Our models are simple enough so that a light training API can fulfill our needs.

  • Flexible and full control to the training/validation loops and customize the pipeline.

  • Decouple the model definition from the training pipeline.

  • Use plane PyTorch abstractions and recipes to write your own routines.

  • Implement accelerate library to scale the problem.

Trainer Usage#

The entry point to start training with Kornia is through the Trainer class.

The main API is a self contained module that heavily relies on accelerate to easily scale the training over multi-GPUs/TPU/fp16 (see more) by following standard PyTorch recipes. Our API expects to consume standard PyTorch components and you decide if kornia makes the magic for you.

  1. Define your model

model = nn.Sequential(
  kornia.contrib.VisionTransformer(image_size=32, patch_size=16),
  1. Create the datasets and dataloaders for training and validation

# datasets
train_dataset = torchvision.datasets.CIFAR10(
  root=config.data_path, train=True, download=True, transform=T.ToTensor())

valid_dataset = torchvision.datasets.CIFAR10(
  root=config.data_path, train=False, download=True, transform=T.ToTensor())

# dataloaders
train_dataloader =
  train_dataset, batch_size=config.batch_size, shuffle=True)

valid_daloader =
  valid_dataset, batch_size=config.batch_size, shuffle=True)
  1. Create your loss function, optimizer and scheduler

# loss function
criterion = nn.CrossEntropyLoss()

# optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(),
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
  optimizer, config.num_epochs * len(train_dataloader)
  1. Create the Trainer and execute the training pipeline

trainer = kornia.train.Trainer(
  model, train_dataloader, valid_daloader, criterion, optimizer, scheduler, config,
)  # execute your training !

Customize [callbacks]#

At this point you might think - Is this API generic enough ?

Of course not ! What is next ? Let’s have fun and customize.

The Trainer internals are clearly defined such in a way so that e.g you can subclass and just override the evaluate() method and adjust according to your needs. We provide predefined classes for generic problems such as ImageClassifierTrainer, SemanticSegmentationTrainer.


More trainers will come as soon as we include more models.

You can easily customize by creating your own class, or even through callbacks as follows:

def my_evaluate(self) -> dict:
  for sample_id, sample in enumerate(self.valid_dataloader):
    source, target = sample  # this might change with new pytorch ataset structure

    # perform the preprocess and augmentations in batch
    img = self.preprocess(source)
    # Forward
    out = self.model(img)
    # Loss computation
    val_loss = self.criterion(out, target)

    # measure accuracy and record loss
    acc1, acc5 = accuracy(out.detach(), target, topk=(1, 5))

# create the trainer and pass the evaluate method as follows
trainer = K.train.Trainer(..., callbacks={"evaluate", my_evaluate})

Still not convinced ?

You can even override the whole fit() method and implement your custom for loops and the trainer will setup for you using the Accelerator all the data to the device and the rest of the story is just PyTorch :)

def my_fit(self, ):  # this is a custom pytorch training loop
  for epoch in range(self.num_epochs):
    for source, targets in self.train_dataloader:

      output = self.model(source)
      loss = self.criterion(output, targets)


      stats = self.evaluate()  # do whatever you want with validation

# create the trainer and pass the evaluate method as follows
trainer = K.train.Trainer(..., callbacks={"fit", my_fit})


The following hooks are available to override: preprocess, augmentations, evaluate, fit, on_checkpoint, on_epoch_end, on_before_model

Preprocess and augmentations#

Taking a pre-trained model from an external source and assume that fine-tuning with your data by just changing few things in your model is usually a bad assumption in practice.

Fine-tuning a model need a lot tricks which usually means designing a good augmentation or preprocess strategy before you execute the training pipeline. For this reason, we enable through callbacks to pass pointers to the proprocess and augmentation functions to make easy the debugging and experimentation experience.

def preprocess(x):
  return x.float() / 255.

augmentations = nn.Sequential(
        K.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=0.8),
        grid_size=(2, 2),  # cifar-10 is 32x32 and vit is patch 16

# create the trainer and pass the augmentation or preprocess
trainer = K.train.ImageClassifierTrainer(...,
  callbacks={"preprocess", preprocess, "augmentations": augmentations})

Callbacks utilities#

We also provide utilities to save checkpoints of the model or early stop the training. You can use as follows passing as callbacks the classes ModelCheckpoint and EarlyStopping.

def my_evaluate(self):
  # stats = StatsTracker()
  # loss = nn.CrossEntropyLoss()
  # rest of your evaluation loop
  prediction = self.on_model(self.model, sample)
  val_loss = self.compute_loss(out, sample["mask"])
  stats.update("loss", val_loss.item(), batch_size)
  return stats.as_dict()

    model_checkpoint = ModelCheckpoint(
            filepath="./outputs", monitor="loss",

    early_stop = EarlyStopping(
            monitor="loss", patience=2

    trainer = SemanticSegmentationTrainer(...,
            callbacks={"on_epoch_end": early_stop, "on_checkpoint", model_checkpoint}


It’ll look something like this considering code above | best_loss = inf

epoch 1:
avg_loss = 1.0
checkpoint = ./output/
epoch 2:
avg_loss = 0.9
best_loss = 0.9
checkpoint = ./output/
epoch 3
avg_loss = 1.1
early_stop.counter += 1
epoch 4
avg_loss = 1.2
early_stop.counter += 1
early_stop.patience == early_stop.counter
training ends here

Hyperparameter sweeps#

Use hydra to implement an easy search strategy for your hyper-parameters as follows:


Checkout the toy example in here

python ./train/image_classifier/ num_epochs=50 batch_size=32
python ./train/image_classifier/ --multirun lr=1e-3,1e-4

Distributed Training#

Kornia Trainer heavily relies on accelerate to decouple the process of running your training scripts in a distributed environment.


We haven’t tested yet all the possibilities for distributed training. Expect some adventures or join us and help to iterate :)

The below recipes are taken from the accelerate library in here:

  • single CPU:

    • from a server without GPU

      python ./train/image_classifier/
    • from any server by passing cpu=True to the Accelerator.

      python ./train/image_classifier/ --data_path path_to_data --cpu
    • from any server with Accelerate launcher

      accelerate launch --cpu ./train/image_classifier/ --data_path path_to_data
  • single GPU:

    python ./train/image_classifier/  # from a server with a GPU
  • with fp16 (mixed-precision)

    • from any server by passing fp16=True to the Accelerator.

      python ./train/image_classifier/ --data_path path_to_data --fp16
    • from any server with Accelerate launcher

      accelerate launch --fp16 ./train/image_classifier/ --data_path path_to_data
  • multi GPUs (using PyTorch distributed mode)

    • With Accelerate config and launcher

      accelerate config  # This will create a config file on your server
      accelerate launch ./train/image_classifier/ --data_path path_to_data  # This will run the script on your server
    • With traditional PyTorch launcher

      python -m torch.distributed.launch --nproc_per_node 2 --use_env ./train/image_classifier/ --data_path path_to_data
  • multi GPUs, multi node (several machines, using PyTorch distributed mode)

    • With Accelerate config and launcher, on each machine:

      accelerate config  # This will create a config file on each server
      accelerate launch ./train/image_classifier/ --data_path path_to_data  # This will run the script on each server
    • With PyTorch launcher only

      python -m torch.distributed.launch --nproc_per_node 2 \
        --use_env \
        --node_rank 0 \
        --master_addr master_node_ip_address \
        ./train/image_classifier/ --data_path path_to_data  # On the first server
      python -m torch.distributed.launch --nproc_per_node 2 \
        --use_env \
        --node_rank 1 \
        --master_addr master_node_ip_address \
        ./train/image_classifier/ --data_path path_to_data  # On the second server
  • (multi) TPUs

    • With Accelerate config and launcher

      accelerate config  # This will create a config file on your TPU server
      accelerate launch ./train/image_classifier/ --data_path path_to_data  # This will run the script on each server
    • In PyTorch: Add an xmp.spawn line in your script as you usually do.