Source code for kornia.augmentation.auto.rand_augment.rand_augment

from typing import Dict, Iterator, List, Optional, Tuple, Union, cast

import torch
from torch.distributions import Categorical

from kornia.augmentation.auto.base import SUBPLOLICY_CONFIG, PolicyAugmentBase
from kornia.augmentation.auto.operations import OperationBase
from kornia.augmentation.auto.operations.policy import PolicySequential
from kornia.augmentation.container.params import ParamItem
from kornia.core import Module, Tensor

from . import ops

default_policy: List[SUBPLOLICY_CONFIG] = [
    [("auto_contrast", 0, 1)],
    [("equalize", 0, 1)],
    [("invert", 0, 1)],
    [("rotate", -30.0, 30.0)],
    [("posterize", 0.0, 4)],
    [("solarize", 0.0, 1.0)],
    [("solarize_add", 0.0, 0.43)],
    [("color", 0.1, 1.9)],
    [("contrast", 0.1, 1.9)],
    [("brightness", 0.1, 1.9)],
    [("sharpness", 0.1, 1.9)],
    [("shear_x", -0.3, 0.3)],
    [("shear_y", -0.3, 0.3)],
    # (CutoutAbs, 0, 40),
    [("translate_x", -0.1, 0.1)],
    [("translate_x", -0.1, 0.1)],
]


[docs]class RandAugment(PolicyAugmentBase): """Apply RandAugment :cite:`cubuk2020randaugment` augmentation strategies. Args: n: the number of augmentations to apply sequentially. m: magnitude for all the augmentations, ranged from [0, 30]. policy: candidate transformations. If None, a default candidate list will be used. transformation_matrix_mode: computation mode for the chained transformation matrix, via `.transform_matrix` attribute. If `silent`, transformation matrix will be computed silently and the non-rigid modules will be ignored as identity transformations. If `rigid`, transformation matrix will be computed silently and the non-rigid modules will trigger errors. If `skip`, transformation matrix will be totally ignored. Examples: >>> import kornia.augmentation as K >>> in_tensor = torch.rand(5, 3, 30, 30) >>> aug = K.AugmentationSequential(RandAugment(n=2, m=10)) >>> aug(in_tensor).shape torch.Size([5, 3, 30, 30]) """ def __init__( self, n: int, m: int, policy: Optional[List[SUBPLOLICY_CONFIG]] = None, transformation_matrix_mode: str = "silent", ) -> None: if m <= 0 or m >= 30: raise ValueError(f"Expect `m` in [0, 30]. Got {m}.") if policy is None: _policy = default_policy else: _policy = policy super().__init__(_policy, transformation_matrix_mode=transformation_matrix_mode) selection_weights = torch.tensor([1.0 / len(self)] * len(self)) self.rand_selector = Categorical(selection_weights) self.n = n self.m = m def compose_subpolicy_sequential(self, subpolicy: SUBPLOLICY_CONFIG) -> PolicySequential: if len(subpolicy) != 1: raise RuntimeError(f"Each policy must have only one operation for RandAugment. Got {len(subpolicy)}.") name, low, high = subpolicy[0] return PolicySequential(*[getattr(ops, name)(low, high)]) def get_forward_sequence(self, params: Optional[List[ParamItem]] = None) -> Iterator[Tuple[str, Module]]: if params is None: idx = self.rand_selector.sample((self.n,)) return self.get_children_by_indices(idx) return self.get_children_by_params(params)
[docs] def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]: named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence() params: List[ParamItem] = [] mod_param: Union[Dict[str, Tensor], List[ParamItem]] m = torch.tensor([self.m / 30] * batch_shape[0]) for name, module in named_modules: # The Input PolicySequential only got one child. op = cast(PolicySequential, module)[0] op = cast(OperationBase, op) mag = None if op.magnitude_range is not None: minval, maxval = op.magnitude_range mag = m * float(maxval - minval) + minval mod_param = op.forward_parameters(batch_shape, mag=mag) # Compose it param = ParamItem(name, [ParamItem(next(iter(module.named_children()))[0], mod_param)]) params.append(param) return params