Source code for kornia.augmentation.container.augment

import warnings
from typing import Any, Dict, List, Optional, Tuple, Union, cast

from kornia.augmentation._2d.base import RigidAffineAugmentationBase2D
from kornia.augmentation._3d.base import AugmentationBase3D, RigidAffineAugmentationBase3D
from kornia.augmentation.base import _AugmentationBase
from kornia.constants import DataKey, Resample
from kornia.core import Module, Tensor
from kornia.geometry.boxes import Boxes, VideoBoxes
from kornia.geometry.keypoints import Keypoints, VideoKeypoints
from kornia.utils import eye_like, is_autocast_enabled

from .base import TransformMatrixMinIn
from .image import ImageSequential
from .ops import AugmentationSequentialOps, DataType
from .params import ParamItem
from .patch import PatchSequential
from .video import VideoSequential

__all__ = ["AugmentationSequential"]

_BOXES_OPTIONS = {DataKey.BBOX, DataKey.BBOX_XYXY, DataKey.BBOX_XYWH}
_KEYPOINTS_OPTIONS = {DataKey.KEYPOINTS}
_IMG_MSK_OPTIONS = {DataKey.INPUT, DataKey.MASK}


[docs]class AugmentationSequential(TransformMatrixMinIn, ImageSequential): r"""AugmentationSequential for handling multiple input types like inputs, masks, keypoints at once. .. image:: _static/img/AugmentationSequential.png Args: *args: a list of kornia augmentation modules. data_keys: the input type sequential for applying augmentations. Accepts "input", "image", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". same_on_batch: apply the same transformation across the batch. If None, it will not overwrite the function-wise settings. keepdim: whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). If None, it will not overwrite the function-wise settings. random_apply: randomly select a sublist (order agnostic) of args to apply transformation. If int, a fixed number of transformations will be selected. If (a,), x number of transformations (a <= x <= len(args)) will be selected. If (a, b), x number of transformations (a <= x <= b) will be selected. If True, the whole list of args will be processed as a sequence in a random order. If False, the whole list of args will be processed as a sequence in original order. transformation_matrix_mode: computation mode for the chained transformation matrix, via `.transform_matrix` attribute. If `silent`, transformation matrix will be computed silently and the non-rigid modules will be ignored as identity transformations. If `rigid`, transformation matrix will be computed silently and the non-rigid modules will trigger errors. If `skip`, transformation matrix will be totally ignored. extra_args: to control the behaviour for each datakeys. By default, masks are handled by nearest interpolation strategies. .. note:: Mix augmentations (e.g. RandomMixUp, RandomCutMix) can only be working with "input"/"image" data key. It is not clear how to deal with the conversions of masks, bounding boxes and keypoints. .. note:: See a working example `here <https://kornia.github.io/tutorials/nbs/data_augmentation_sequential.html>`__. Examples: >>> import kornia >>> input = torch.randn(2, 3, 5, 6) >>> mask = torch.ones(2, 3, 5, 6) >>> bbox = torch.tensor([[ ... [1., 1.], ... [2., 1.], ... [2., 2.], ... [1., 2.], ... ]]).expand(2, 1, -1, -1) >>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1) >>> aug_list = AugmentationSequential( ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), ... kornia.augmentation.RandomAffine(360, p=1.0), ... data_keys=["input", "mask", "bbox", "keypoints"], ... same_on_batch=False, ... random_apply=10, ... ) >>> out = aug_list(input, mask, bbox, points) >>> [o.shape for o in out] [torch.Size([2, 3, 5, 6]), torch.Size([2, 3, 5, 6]), torch.Size([2, 1, 4, 2]), torch.Size([2, 1, 2])] >>> # apply the exact augmentation again. >>> out_rep = aug_list(input, mask, bbox, points, params=aug_list._params) >>> [(o == o_rep).all() for o, o_rep in zip(out, out_rep)] [tensor(True), tensor(True), tensor(True), tensor(True)] >>> # inverse the augmentations >>> out_inv = aug_list.inverse(*out) >>> [o.shape for o in out_inv] [torch.Size([2, 3, 5, 6]), torch.Size([2, 3, 5, 6]), torch.Size([2, 1, 4, 2]), torch.Size([2, 1, 2])] This example demonstrates the integration of VideoSequential and AugmentationSequential. >>> import kornia >>> input = torch.randn(2, 3, 5, 6)[None] >>> mask = torch.ones(2, 3, 5, 6)[None] >>> bbox = torch.tensor([[ ... [1., 1.], ... [2., 1.], ... [2., 2.], ... [1., 2.], ... ]]).expand(2, 1, -1, -1)[None] >>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1)[None] >>> aug_list = AugmentationSequential( ... VideoSequential( ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), ... kornia.augmentation.RandomAffine(360, p=1.0), ... ), ... data_keys=["input", "mask", "bbox", "keypoints"] ... ) >>> out = aug_list(input, mask, bbox, points) >>> [o.shape for o in out] # doctest: +ELLIPSIS [torch.Size([1, 2, 3, 5, 6]), torch.Size([1, 2, 3, 5, 6]), ...([1, 2, 1, 4, 2]), torch.Size([1, 2, 1, 2])] Perform ``OneOf`` transformation with ``random_apply=1`` and ``random_apply_weights`` in ``AugmentationSequential``. >>> import kornia >>> input = torch.randn(2, 3, 5, 6)[None] >>> mask = torch.ones(2, 3, 5, 6)[None] >>> bbox = torch.tensor([[ ... [1., 1.], ... [2., 1.], ... [2., 2.], ... [1., 2.], ... ]]).expand(2, 1, -1, -1)[None] >>> points = torch.tensor([[[1., 1.]]]).expand(2, -1, -1)[None] >>> aug_list = AugmentationSequential( ... VideoSequential( ... kornia.augmentation.RandomAffine(360, p=1.0), ... ), ... VideoSequential( ... kornia.augmentation.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), ... ), ... data_keys=["input", "mask", "bbox", "keypoints"], ... random_apply=1, ... random_apply_weights=[0.5, 0.3] ... ) >>> out = aug_list(input, mask, bbox, points) >>> [o.shape for o in out] # doctest: +ELLIPSIS [torch.Size([1, 2, 3, 5, 6]), torch.Size([1, 2, 3, 5, 6]), ...([1, 2, 1, 4, 2]), torch.Size([1, 2, 1, 2])] """ def __init__( self, *args: Union[_AugmentationBase, ImageSequential], data_keys: Union[List[str], List[int], List[DataKey]] = [DataKey.INPUT], same_on_batch: Optional[bool] = None, keepdim: Optional[bool] = None, random_apply: Union[int, bool, Tuple[int, int]] = False, random_apply_weights: Optional[List[float]] = None, transformation_matrix_mode: str = "silent", extra_args: Dict[DataKey, Dict[str, Any]] = { DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} }, ) -> None: self._transform_matrix: Optional[Tensor] self._transform_matrices: List[Optional[Tensor]] = [] super().__init__( *args, same_on_batch=same_on_batch, keepdim=keepdim, random_apply=random_apply, random_apply_weights=random_apply_weights, ) self._parse_transformation_matrix_mode(transformation_matrix_mode) self._valid_ops_for_transform_computation: Tuple[Any, ...] = ( RigidAffineAugmentationBase2D, RigidAffineAugmentationBase3D, AugmentationSequential, ) self.data_keys = [DataKey.get(inp) for inp in data_keys] if not all(in_type in DataKey for in_type in self.data_keys): raise AssertionError(f"`data_keys` must be in {DataKey}. Got {self.data_keys}.") if self.data_keys[0] != DataKey.INPUT: raise NotImplementedError(f"The first input must be {DataKey.INPUT}.") self.transform_op = AugmentationSequentialOps(self.data_keys) self.contains_video_sequential: bool = False self.contains_3d_augmentation: bool = False for arg in args: if isinstance(arg, PatchSequential) and not arg.is_intensity_only(): warnings.warn("Geometric transformation detected in PatchSeqeuntial, which would break bbox, mask.") if isinstance(arg, VideoSequential): self.contains_video_sequential = True # NOTE: only for images are supported for 3D. if isinstance(arg, AugmentationBase3D): self.contains_3d_augmentation = True self._transform_matrix = None self.extra_args = extra_args def clear_state(self) -> None: self._reset_transform_matrix_state() return super().clear_state() def _update_transform_matrix_for_valid_op(self, module: Module) -> None: self._transform_matrices.append(module.transform_matrix) def identity_matrix(self, input: Tensor) -> Tensor: """Return identity matrix.""" if self.contains_3d_augmentation: return eye_like(4, input) else: return eye_like(3, input)
[docs] def inverse( # type: ignore[override] self, *args: DataType, params: Optional[List[ParamItem]] = None, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, ) -> Union[DataType, List[DataType]]: """Reverse the transformation applied. Number of input tensors must align with the number of``data_keys``. If ``data_keys`` is not set, use ``self.data_keys`` by default. """ self.transform_op.data_keys = self.transform_op.preproc_datakeys(data_keys) self._validate_args_datakeys(*args, data_keys=self.transform_op.data_keys) in_args = self._arguments_preproc(*args, data_keys=self.transform_op.data_keys) if params is None: if self._params is None: raise ValueError( "No parameters available for inversing, please run a forward pass first " "or passing valid params into this function." ) params = self._params outputs: List[DataType] = in_args for param in params[::-1]: module = self.get_submodule(param.name) outputs = self.transform_op.inverse( # type: ignore *outputs, module=module, param=param, extra_args=self.extra_args ) if not isinstance(outputs, (list, tuple)): # Make sure we are unpacking a list whilst post-proc outputs = [outputs] outputs = self._arguments_postproc(args, outputs, data_keys=self.transform_op.data_keys) # type: ignore if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] return outputs
def _validate_args_datakeys(self, *args: DataType, data_keys: List[DataKey]) -> None: if len(args) != len(data_keys): raise AssertionError( f"The number of inputs must align with the number of data_keys. Got {len(args)} and {len(data_keys)}." ) # TODO: validate args batching, and its consistency def _arguments_preproc(self, *args: DataType, data_keys: List[DataKey]) -> List[DataType]: inp: List[DataType] = [] for arg, dcate in zip(args, data_keys): if DataKey.get(dcate) in _IMG_MSK_OPTIONS: inp.append(arg) elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS: inp.append(self._preproc_keypoints(arg, dcate)) elif DataKey.get(dcate) in _BOXES_OPTIONS: inp.append(self._preproc_boxes(arg, dcate)) elif DataKey.get(dcate) is DataKey.CLASS: inp.append(arg) else: raise NotImplementedError(f"input type of {dcate} is not implemented.") return inp def _arguments_postproc( self, in_args: List[DataType], out_args: List[DataType], data_keys: List[DataKey] ) -> List[DataType]: out: List[DataType] = [] for in_arg, out_arg, dcate in zip(in_args, out_args, data_keys): if DataKey.get(dcate) in _IMG_MSK_OPTIONS: # It is tensor type already. out.append(out_arg) # TODO: may add the float to integer (for masks), etc. elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS: _out_k = self._postproc_keypoint(in_arg, cast(Keypoints, out_arg), dcate) if is_autocast_enabled() and isinstance(in_arg, (Tensor, Keypoints)): if isinstance(_out_k, list): _out_k = [i.type(in_arg.dtype) for i in _out_k] else: _out_k = _out_k.type(in_arg.dtype) out.append(_out_k) elif DataKey.get(dcate) in _BOXES_OPTIONS: _out_b = self._postproc_boxes(in_arg, cast(Boxes, out_arg), dcate) if is_autocast_enabled() and isinstance(in_arg, (Tensor, Boxes)): if isinstance(_out_b, list): _out_b = [i.type(in_arg.dtype) for i in _out_b] else: _out_b = _out_b.type(in_arg.dtype) out.append(_out_b) elif DataKey.get(dcate) is DataKey.CLASS: out.append(out_arg) else: raise NotImplementedError(f"input type of {dcate} is not implemented.") return out
[docs] def forward( # type: ignore[override] self, *args: DataType, params: Optional[List[ParamItem]] = None, data_keys: Optional[Union[List[str], List[int], List[DataKey]]] = None, ) -> Union[DataType, List[DataType]]: """Compute multiple tensors simultaneously according to ``self.data_keys``.""" self.clear_state() self.transform_op.data_keys = self.transform_op.preproc_datakeys(data_keys) self._validate_args_datakeys(*args, data_keys=self.transform_op.data_keys) in_args = self._arguments_preproc(*args, data_keys=self.transform_op.data_keys) if params is None: # image data must exist if params is not provided. if DataKey.INPUT in self.transform_op.data_keys: inp = in_args[self.transform_op.data_keys.index(DataKey.INPUT)] if not isinstance(inp, (Tensor,)): raise ValueError(f"`INPUT` should be a tensor but `{type(inp)}` received.") # A video input shall be BCDHW while an image input shall be BCHW if self.contains_video_sequential or self.contains_3d_augmentation: _, out_shape = self.autofill_dim(inp, dim_range=(3, 5)) else: _, out_shape = self.autofill_dim(inp, dim_range=(2, 4)) params = self.forward_parameters(out_shape) else: raise ValueError("`params` must be provided whilst INPUT is not in data_keys.") outputs: Union[Tensor, List[DataType]] = in_args for param in params: module = self.get_submodule(param.name) outputs = self.transform_op.transform( # type: ignore *outputs, module=module, param=param, extra_args=self.extra_args ) if not isinstance(outputs, (list, tuple)): # Make sure we are unpacking a list whilst post-proc outputs = [outputs] self._update_transform_matrix_by_module(module) outputs = self._arguments_postproc(args, outputs, data_keys=self.transform_op.data_keys) # type: ignore # Restore it back self.transform_op.data_keys = self.data_keys self._params = params if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] return outputs
def _preproc_boxes(self, arg: DataType, dcate: DataKey) -> Boxes: if DataKey.get(dcate) in [DataKey.BBOX]: mode = "vertices_plus" elif DataKey.get(dcate) in [DataKey.BBOX_XYXY]: mode = "xyxy_plus" elif DataKey.get(dcate) in [DataKey.BBOX_XYWH]: mode = "xywh" else: raise ValueError(f"Unsupported mode `{DataKey.get(dcate).name}`.") if isinstance(arg, (Boxes,)): return arg elif self.contains_video_sequential: arg = cast(Tensor, arg) return VideoBoxes.from_tensor(arg) elif self.contains_3d_augmentation: raise NotImplementedError("3D box handlers are not yet supported.") else: arg = cast(Tensor, arg) return Boxes.from_tensor(arg, mode=mode) def _postproc_boxes(self, in_arg: DataType, out_arg: Boxes, dcate: DataKey) -> Union[Tensor, List[Tensor], Boxes]: if DataKey.get(dcate) in [DataKey.BBOX]: mode = "vertices_plus" elif DataKey.get(dcate) in [DataKey.BBOX_XYXY]: mode = "xyxy_plus" elif DataKey.get(dcate) in [DataKey.BBOX_XYWH]: mode = "xywh" else: raise ValueError(f"Unsupported mode `{DataKey.get(dcate).name}`.") # TODO: handle 3d scenarios if isinstance(in_arg, (Boxes,)): return out_arg else: return out_arg.to_tensor(mode=mode) def _preproc_keypoints(self, arg: DataType, dcate: DataKey) -> Keypoints: if self.contains_video_sequential: arg = cast(Union[Tensor, List[Tensor]], arg) return VideoKeypoints.from_tensor(arg) elif self.contains_3d_augmentation: raise NotImplementedError("3D keypoint handlers are not yet supported.") elif isinstance(arg, (Keypoints,)): return arg else: arg = cast(Tensor, arg) # TODO: Add List[Tensor] in the future. return Keypoints.from_tensor(arg) def _postproc_keypoint( self, in_arg: DataType, out_arg: Keypoints, dcate: DataKey ) -> Union[Tensor, List[Tensor], Keypoints]: if isinstance(in_arg, (Keypoints,)): return out_arg else: return out_arg.to_tensor()