from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
import torch
import kornia.augmentation as K
from kornia.augmentation.base import _AugmentationBase
from kornia.augmentation.utils import override_parameters
from kornia.core import Module, Tensor, as_tensor
from kornia.utils import eye_like
from .base import ImageSequentialBase
from .params import ParamItem
__all__ = ["ImageSequential"]
[docs]class ImageSequential(ImageSequentialBase):
r"""Sequential for creating kornia image processing pipeline.
Args:
*args : a list of kornia augmentation and image operation modules.
same_on_batch: apply the same transformation across the batch.
If None, it will not overwrite the function-wise settings.
keepdim: whether to keep the output shape the same as input (True) or broadcast it
to the batch form (False). If None, it will not overwrite the function-wise settings.
random_apply: randomly select a sublist (order agnostic) of args to
apply transformation. The selection probability aligns to the ``random_apply_weights``.
If int, a fixed number of transformations will be selected.
If (a,), x number of transformations (a <= x <= len(args)) will be selected.
If (a, b), x number of transformations (a <= x <= b) will be selected.
If True, the whole list of args will be processed as a sequence in a random order.
If False, the whole list of args will be processed as a sequence in original order.
random_apply_weights: a list of selection weights for each operation. The length shall be as
same as the number of operations. By default, operations are sampled uniformly.
.. note::
Transformation matrix returned only considers the transformation applied in ``kornia.augmentation`` module.
Those transformations in ``kornia.geometry`` will not be taken into account.
Examples:
>>> _ = torch.manual_seed(77)
>>> import kornia
>>> input = torch.randn(2, 3, 5, 6)
>>> aug_list = ImageSequential(
... kornia.color.BgrToRgb(),
... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
... kornia.filters.MedianBlur((3, 3)),
... kornia.augmentation.RandomAffine(360, p=1.0),
... kornia.enhance.Invert(),
... kornia.augmentation.RandomMixUpV2(p=1.0),
... same_on_batch=True,
... random_apply=10,
... )
>>> out = aug_list(input)
>>> out.shape
torch.Size([2, 3, 5, 6])
Reproduce with provided params.
>>> out2 = aug_list(input, params=aug_list._params)
>>> torch.equal(out, out2)
True
Perform ``OneOf`` transformation with ``random_apply=1`` and ``random_apply_weights`` in ``ImageSequential``.
>>> import kornia
>>> input = torch.randn(2, 3, 5, 6)
>>> aug_list = ImageSequential(
... kornia.color.BgrToRgb(),
... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0),
... kornia.filters.MedianBlur((3, 3)),
... kornia.augmentation.RandomAffine(360, p=1.0),
... random_apply=1,
... random_apply_weights=[0.5, 0.3, 0.2, 0.5]
... )
>>> out= aug_list(input)
>>> out.shape
torch.Size([2, 3, 5, 6])
"""
def __init__(
self,
*args: Module,
same_on_batch: Optional[bool] = None,
keepdim: Optional[bool] = None,
random_apply: Union[int, bool, Tuple[int, int]] = False,
random_apply_weights: Optional[List[float]] = None,
if_unsupported_ops: str = "raise",
) -> None:
super().__init__(*args, same_on_batch=same_on_batch, keepdim=keepdim)
self.random_apply = self._read_random_apply(random_apply, len(args))
if random_apply_weights is not None and len(random_apply_weights) != len(self):
raise ValueError(
"The length of `random_apply_weights` must be as same as the number of operations."
f"Got {len(random_apply_weights)} and {len(self)}."
)
self.random_apply_weights = as_tensor(random_apply_weights or torch.ones((len(self),)))
self.if_unsupported_ops = if_unsupported_ops
def _read_random_apply(
self, random_apply: Union[int, bool, Tuple[int, int]], max_length: int
) -> Union[Tuple[int, int], bool]:
"""Process the scenarios for random apply."""
if isinstance(random_apply, (bool,)) and random_apply is False:
random_apply = False
elif isinstance(random_apply, (bool,)) and random_apply is True:
random_apply = (max_length, max_length + 1)
elif isinstance(random_apply, (int,)):
random_apply = (random_apply, random_apply + 1)
elif (
isinstance(random_apply, (tuple,))
and len(random_apply) == 2
and isinstance(random_apply[0], (int,))
and isinstance(random_apply[1], (int,))
):
random_apply = (random_apply[0], random_apply[1] + 1)
elif isinstance(random_apply, (tuple,)) and len(random_apply) == 1 and isinstance(random_apply[0], (int,)):
random_apply = (random_apply[0], max_length + 1)
else:
raise ValueError(f"Non-readable random_apply. Got {random_apply}.")
if random_apply is not False and not (
isinstance(random_apply, (tuple,))
and len(random_apply) == 2
and isinstance(random_apply[0], (int,))
and isinstance(random_apply[0], (int,))
):
raise AssertionError(f"Expect a tuple of (int, int). Got {random_apply}.")
return random_apply
def get_random_forward_sequence(self, with_mix: bool = True) -> Tuple[Iterator[Tuple[str, Module]], bool]:
"""Get a forward sequence when random apply is in need.
Args:
with_mix: if to require a mix augmentation for the sequence.
Note:
Mix augmentations (e.g. RandomMixUp) will be only applied once even in a random forward.
"""
if isinstance(self.random_apply, tuple):
num_samples = int(torch.randint(*self.random_apply, (1,)).item())
else:
raise TypeError(f"random apply should be a tuple. Gotcha {type(self.random_apply)}")
multinomial_weights = self.random_apply_weights.clone()
# Mix augmentation can only be applied once per forward
mix_indices = self.get_mix_augmentation_indices(self.named_children())
# kick out the mix augmentations
multinomial_weights[mix_indices] = 0
indices = torch.multinomial(
multinomial_weights,
num_samples,
# enable replacement if non-mix augmentation is less than required
replacement=num_samples > multinomial_weights.sum().item(),
)
mix_added = False
if with_mix and len(mix_indices) != 0:
# Make the selection fair.
if (torch.rand(1) < ((len(mix_indices) + len(indices)) / len(self))).item():
indices[-1] = torch.multinomial((~multinomial_weights.bool()).float(), 1)
indices = indices[torch.randperm(len(indices))]
mix_added = True
return self.get_children_by_indices(indices), mix_added
def get_mix_augmentation_indices(self, named_modules: Iterator[Tuple[str, Module]]) -> List[int]:
"""Get all the mix augmentations since they are label-involved.
Special operations needed for label-involved augmentations.
"""
# NOTE: MixV2 will not be a special op in the future.
return [idx for idx, (_, child) in enumerate(named_modules) if isinstance(child, K.MixAugmentationBaseV2)]
def get_forward_sequence(self, params: Optional[List[ParamItem]] = None) -> Iterator[Tuple[str, Module]]:
if params is None:
# Mix augmentation can only be applied once per forward
mix_indices = self.get_mix_augmentation_indices(self.named_children())
if self.random_apply:
return self.get_random_forward_sequence()[0]
if len(mix_indices) > 1:
raise ValueError(
"Multiple mix augmentation is prohibited without enabling random_apply."
f"Detected {len(mix_indices)} mix augmentations."
)
return self.named_children()
return self.get_children_by_params(params)
def forward_parameters(self, batch_shape: torch.Size) -> List[ParamItem]:
named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence()
params: List[ParamItem] = []
mod_param: Union[Dict[str, Tensor], List[ParamItem]]
for name, module in named_modules:
if isinstance(module, (_AugmentationBase, K.MixAugmentationBaseV2, ImageSequentialBase)):
mod_param = module.forward_parameters(batch_shape)
param = ParamItem(name, mod_param)
else:
param = ParamItem(name, None)
batch_shape = _get_new_batch_shape(param, batch_shape)
params.append(param)
return params
def identity_matrix(self, input: Tensor) -> Tensor:
"""Return identity matrix."""
return eye_like(3, input)
def get_transformation_matrix(
self,
input: Tensor,
params: Optional[List[ParamItem]] = None,
recompute: bool = False,
extra_args: Dict[str, Any] = {},
) -> Optional[Tensor]:
"""Compute the transformation matrix according to the provided parameters.
Args:
input: the input tensor.
params: params for the sequence.
recompute: if to recompute the transformation matrix according to the params.
default: False.
"""
if params is None:
raise NotImplementedError("requires params to be provided.")
named_modules: Iterator[Tuple[str, Module]] = self.get_forward_sequence(params)
# Define as 1 for broadcasting
res_mat: Optional[Tensor] = None
for (_, module), param in zip(named_modules, params if params is not None else []):
if isinstance(module, (K.GeometricAugmentationBase2D,)) and isinstance(param.data, dict):
ori_shape = input.shape
try:
input = module.transform_tensor(input)
except ValueError:
# Ignore error for 5-dim video
pass
# Standardize shape
if recompute:
flags = override_parameters(module.flags, extra_args, in_place=False)
mat = module.generate_transformation_matrix(input, param.data, flags)
elif module._transform_matrix is not None:
mat = as_tensor(module._transform_matrix, device=input.device, dtype=input.dtype)
else:
raise RuntimeError(f"{module}._transform_matrix is None while `recompute=False`.")
res_mat = mat if res_mat is None else mat @ res_mat
input = module.transform_output_tensor(input, ori_shape)
if module.keepdim and ori_shape != input.shape:
res_mat = res_mat.squeeze()
elif isinstance(module, (ImageSequentialBase,)):
# If not augmentationSequential
if isinstance(module, (K.AugmentationSequential,)) and not recompute:
mat = as_tensor(module._transform_matrix, device=input.device, dtype=input.dtype)
else:
maybe_param_data = cast(Optional[List[ParamItem]], param.data)
_mat = module.get_transformation_matrix(
input, maybe_param_data, recompute=recompute, extra_args=extra_args
)
mat = module.identity_matrix(input) if _mat is None else _mat
res_mat = mat if res_mat is None else mat @ res_mat
return res_mat
# TODO: Make this as a class property to avoid running every time.
def is_intensity_only(self, strict: bool = True) -> bool:
"""Check if all transformations are intensity-based.
Args:
strict: if strict is False, it will allow non-augmentation Modules to be passed.
e.g. `kornia.enhance.AdjustBrightness` will be recognized as non-intensity module
if strict is set to True.
Note: patch processing would break the continuity of labels (e.g. bbounding boxes, masks).
"""
for arg in self.children():
if isinstance(arg, (ImageSequential,)) and not arg.is_intensity_only(strict):
return False
elif isinstance(arg, (ImageSequential,)):
pass
elif isinstance(arg, K.IntensityAugmentationBase2D):
pass
elif strict:
# disallow non-registered ops if in strict mode
# TODO: add an ops register module
return False
return True
def _get_new_batch_shape(param: ParamItem, batch_shape: torch.Size) -> torch.Size:
"""Get the new batch shape if the augmentation changes the image size.
Note:
Augmentations that change the image size must provide the parameter `output_size`.
"""
if param.data is None:
return batch_shape
if isinstance(param.data, list):
for p in param.data:
batch_shape = _get_new_batch_shape(p, batch_shape)
elif "output_size" in param.data:
if not (param.data["batch_prob"] > 0.5)[0]:
# Augmentations that change the image size must be applied equally to all elements in batch.
# If the augmentation is not applied, return the same batch shape.
return batch_shape
new_batch_shape = list(batch_shape)
new_batch_shape[-2:] = param.data["output_size"][0]
batch_shape = torch.Size(new_batch_shape)
return batch_shape