Base Classes¶
This is the base class for creating a new transform. The user only needs to override: generate_parameters, apply_transform and optionally, compute_transformation.
-
class
AugmentationBase2D
(return_transform: bool = False, same_on_batch: bool = False, p: float = 0.5, p_batch: float = 1.0, keepdim: bool = False)[source]¶ AugmentationBase2D base class for customized augmentation implementations.
For any augmentation, the implementation of “generate_parameters” and “apply_transform” are required while the “compute_transformation” is only required when passing “return_transform” as True.
- Parameters
p (float) – probability for applying an augmentation. This param controls the augmentation probabilities element-wisely for a batch.
p_batch (float) – probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wisely.
return_transform (bool) – if
True
return the matrix describing the geometric transformation applied to each input tensor. IfFalse
and the input is a tuple the applied transformation wont be concatenated.same_on_batch (bool) – apply the same transformation across the batch. Default: False.
keepdim (bool) – whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Default: False.
-
generate_parameters
(batch_shape: torch.Size) → Dict[str, torch.Tensor]¶
-
compute_transformation
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor¶
-
apply_transform
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor¶
-
class
AugmentationBase3D
(return_transform: bool = False, same_on_batch: bool = False, p: float = 0.5, p_batch: float = 1.0, keepdim: bool = False)[source]¶ AugmentationBase3D base class for customized augmentation implementations.
For any augmentation, the implementation of “generate_parameters” and “apply_transform” are required while the “compute_transformation” is only required when passing “return_transform” as True.
- Parameters
p (float) – probability for applying an augmentation. This param controls the augmentation probabilities element-wisely for a batch.
p_batch (float) – probability for applying an augmentation to a batch. This param controls the augmentation probabilities batch-wisely.
return_transform (bool) – if
True
return the matrix describing the geometric transformation applied to each input tensor. IfFalse
and the input is a tuple the applied transformation wont be concatenated.same_on_batch (bool) – apply the same transformation across the batch. Default: False.
-
generate_parameters
(batch_shape: torch.Size) → Dict[str, torch.Tensor]¶
-
compute_transformation
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor¶
-
apply_transform
(input: torch.Tensor, params: Dict[str, torch.Tensor]) → torch.Tensor¶
Create your own transformation:
import torch
import kornia as K
from kornia.augmentation import AugmentationBase2D
class MyRandomTransform(AugmentationBase2D):
def __init__(self, return_transform: bool = False) -> None:
super(MyRandomTransform, self).__init__(return_transform)
def generate_parameters(self, input_shape: torch.Size):
# generate the random parameters for your use case.
angles_rad torch.Tensor = torch.rand(input_shape[0]) * K.pi
angles_deg = kornia.rad2deg(angles_rad)
return dict(angles=angles_deg)
def compute_transformation(self, input, params):
B, _, H, W = input.shape
# compute transformation
angles: torch.Tensor = params['angles'].type_as(input)
center = torch.tensor([[W / 2, H / 2]] * B).type_as(input)
transform = K.get_rotation_matrix2d(
center, angles, torch.ones_like(angles))
return transform
def apply_transform(self, input, params):
_, _, H, W = input.shape
# compute transformation
transform = self.compute_transformation(input, params)
# apply transformation and return
output = K.warp_affine(input, transform, (H, W))
return (output, transform)