Base Classes

This is the base class for creating a new transform using kornia.augmentation. The user only needs to override: generate_parameters, apply_transform and optionally, compute_transformation.

Create your own transformations with the following snippet:

import torch
import kornia as K

from kornia.augmentation import AugmentationBase2D

class MyRandomTransform(AugmentationBase2D):
   def __init__(self, return_transform: bool = False) -> None:
      super(MyRandomTransform, self).__init__(return_transform)

   def generate_parameters(self, input_shape: torch.Size):
      # generate the random parameters for your use case.
      angles_rad torch.Tensor = torch.rand(input_shape[0]) * K.pi
      angles_deg = kornia.rad2deg(angles_rad)
      return dict(angles=angles_deg)

   def compute_transformation(self, input, params):

      B, _, H, W = input.shape

      # compute transformation
      angles: torch.Tensor = params['angles'].type_as(input)
      center = torch.tensor([[W / 2, H / 2]] * B).type_as(input)
      transform = K.get_rotation_matrix2d(
         center, angles, torch.ones_like(angles))
      return transform

   def apply_transform(self, input, params):

      _, _, H, W = input.shape
      # compute transformation
      transform = self.compute_transformation(input, params)

      # apply transformation and return
      output = K.warp_affine(input, transform, (H, W))
      return (output, transform)
class kornia.augmentation.base.AugmentationBase2D(return_transform=None, same_on_batch=False, p=0.5, p_batch=1.0, keepdim=False)[source]

AugmentationBase2D base class for customized augmentation implementations.

For any augmentation, the implementation of “generate_parameters” and “apply_transform” are required while the “compute_transformation” is only required when passing “return_transform” as True.

Parameters
  • p (float, optional) – probability for applying an augmentation. This param controls the augmentation probabilities element-wise for a batch. Default: 0.5

  • p_batch (float, optional) – probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise. Default: 1.0

  • return_transform (Optional[bool], optional) – if True return the matrix describing the geometric transformation applied to each input tensor. If False and the input is a tuple the applied transformation won’t be concatenated. Default: None

  • same_on_batch (bool, optional) – apply the same transformation across the batch. Default: False

  • keepdim (bool, optional) – whether to keep the output shape the same as input True or broadcast it to the batch form False. Default: False

generate_parameters(batch_shape)
Return type

Dict[str, Tensor]

compute_transformation(input, params)
Return type

Tensor

apply_transform(input, params, transform=None)
Return type

Tensor

class kornia.augmentation.base.AugmentationBase3D(return_transform=None, same_on_batch=False, p=0.5, p_batch=1.0, keepdim=False)[source]

AugmentationBase3D base class for customized augmentation implementations.

For any augmentation, the implementation of “generate_parameters” and “apply_transform” are required while the “compute_transformation” is only required when passing “return_transform” as True.

Parameters
  • p (float, optional) – probability for applying an augmentation. This param controls the augmentation probabilities element-wise for a batch. Default: 0.5

  • p_batch (float, optional) – probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wise. Default: 1.0

  • return_transform (Optional[bool], optional) – if True return the matrix describing the geometric transformation applied to each input tensor. If False and the input is a tuple the applied transformation won’t be concatenated. Default: None

  • same_on_batch (bool, optional) – apply the same transformation across the batch. Default: False

generate_parameters(batch_shape)
Return type

Dict[str, Tensor]

compute_transformation(input, params)
Return type

Tensor

apply_transform(input, params, transform=None)
Return type

Tensor