kornia.augmentationΒΆ

The classes in this section perform various data augmentation operations.

Kornia provides Torchvision-like augmentation APIs while may not reproduce Torchvision, because Kornia is a library aligns to OpenCV functionalities, not PIL. Besides, pure floating computation is used in Kornia which guarantees a better precision without any float -> uint8 conversions. To be specified, the different functions are:

  • AdjustContrast

  • AdjustBrightness

For detailed comparison, please checkout the Colab: Kornia Playground.

Kornia augmentation implementations can be easily used in a TorchVision style using nn.Sequential.

import kornia.augmentation as K
import torch.nn as nn

transform = nn.Sequential(
   K.RandomAffine(360),
   K.ColorJitter(0.2, 0.3, 0.2, 0.3)
)

Kornia augmentation implementations have two additional parameters compare to TorchVision, return_transform and same_on_batch. The former provides the ability of undoing one geometry transformation while the latter can be used to control the randomness for a batched transformation. To enable those behaviour, you may simply set the flags to True.

import kornia.augmentation as K

class MyAugmentationPipeline(nn.Module):
   def __init__(self) -> None:
      super(MyAugmentationPipeline, self).__init__()
      self.aff = K.RandomAffine(
         360, return_transform=True, same_on_batch=True
      )
      self.jit = K.ColorJitter(0.2, 0.3, 0.2, 0.3, same_on_batch=True)

   def forward(self, input):
      input, transform = self.aff(input)
      input, transform = self.jit((input, transform))
      return input, transform

Example for semantic segmentation using low-level randomness control:

import kornia.augmentation as K

class MyAugmentationPipeline(nn.Module):
   def __init__(self) -> None:
      super(MyAugmentationPipeline, self).__init__()
      self.aff = K.RandomAffine(360)
      self.jit = K.ColorJitter(0.2, 0.3, 0.2, 0.3)

   def forward(self, input, mask):
      assert input.shape == mask.shape,
         f"Input shape should be consistent with mask shape, "
         f"while got {input.shape}, {mask.shape}"

      aff_params = self.aff.forward_parameters(input.shape)
      input = self.aff(input, aff_params)
      mask = self.aff(mask, aff_params)

      jit_params = self.jit.forward_parameters(input.shape)
      input = self.jit(input, jit_params)
      mask = self.jit(mask, jit_params)
      return input, mask