from typing import Any, Dict, List, Optional, Tuple, Union

import torch

from kornia.augmentation import random_generator as rg
from kornia.augmentation._2d.mix.base import MixAugmentationBaseV2
from kornia.constants import DataKey, Resample
from kornia.core import Tensor, as_tensor, concatenate, pad, zeros
from kornia.geometry.boxes import Boxes
from kornia.geometry.transform import crop_by_indices, crop_by_transform_mat, get_perspective_transform
from kornia.testing import KORNIA_UNWRAP
from kornia.utils import eye_like

__all__ = ["RandomMosaic"]

[docs]class RandomMosaic(MixAugmentationBaseV2): r"""Mosaic augmentation. .. image:: Given a certain number of images, mosaic transform combines them into one output image. The output image is composed of the parts from each sub-image. To mess up each image individually, referring to :class:`kornia.augmentation.RandomJigsaw`. The mosaic transform steps are as follows: 1. Concate selected images into a super-image. 2. Crop out the outcome image according to the top-left corner and crop size. Args: output_size: the output tensor width and height after mosaicing. start_ratio_range: top-left (x, y) position for cropping the mosaic images. mosaic_grid: the number of images and image arrangement. e.g. (2, 2) means each output will mix 4 images in a 2x2 grid. min_bbox_size: minimum area of bounding boxes. Default to 0. data_keys: the input type sequential for applying augmentations. Accepts "input", "mask", "bbox", "bbox_xyxy", "bbox_xywh", "keypoints". p: probability of applying the transformation for the whole batch. keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch form ``False``. padding_mode: Type of padding. Should be: constant, reflect, replicate. resample: the interpolation mode. align_corners: interpolation flag. cropping_mode: The used algorithm to crop. ``slice`` will use advanced slicing to extract the tensor based on the sampled indices. ``resample`` will use `warp_affine` using the affine transformation to extract and resize at once. Use `slice` for efficiency, or `resample` for proper differentiability. Examples: >>> mosaic = RandomMosaic((300, 300), data_keys=["input", "bbox_xyxy"]) >>> boxes = torch.tensor([[ ... [70, 5, 150, 100], ... [60, 180, 175, 220], ... ]]).repeat(8, 1, 1) >>> input = torch.randn(8, 3, 224, 224) >>> out = mosaic(input, boxes) >>> out[0].shape, out[1].shape (torch.Size([8, 3, 300, 300]), torch.Size([8, 8, 4])) """ def __init__( self, output_size: Optional[Tuple[int, int]] = None, mosaic_grid: Tuple[int, int] = (2, 2), start_ratio_range: Tuple[float, float] = (0.3, 0.7), min_bbox_size: float = 0.0, data_keys: List[Union[str, int, DataKey]] = [DataKey.INPUT], p: float = 0.7, keepdim: bool = False, padding_mode: str = "constant", resample: Union[str, int, Resample] =, align_corners: bool = True, cropping_mode: str = "slice", ) -> None: super().__init__(p=p, p_batch=1.0, same_on_batch=False, keepdim=keepdim, data_keys=data_keys) self.start_ratio_range = start_ratio_range self._param_generator = rg.MosaicGenerator(output_size, mosaic_grid, start_ratio_range) self.flags = dict( mosaic_grid=mosaic_grid, output_size=output_size, min_bbox_size=min_bbox_size, padding_mode=padding_mode, resample=Resample.get(resample), align_corners=align_corners, cropping_mode=cropping_mode, ) def apply_transform_mask(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor: raise NotImplementedError @torch.no_grad() def apply_transform_boxes(self, input: Boxes, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Boxes: src_box = as_tensor(params["src"], device=input.device, dtype=input.dtype) dst_box = as_tensor(params["dst"], device=input.device, dtype=input.dtype) # Boxes is BxNx4x2 only. batch_shapes = as_tensor(params["batch_shapes"], device=input.device, dtype=input.dtype) offset = zeros((len(params["batch_prob"]), 2), device=input.device, dtype=input.dtype) # Bx2 # NOTE: not a pretty good line I think. offset_end = dst_box[0, 2].repeat([0], 1) idx = torch.arange(0,[0], device=input.device, dtype=torch.long)[params["batch_prob"]] maybe_out_boxes: Optional[Boxes] = None for i in range(flags['mosaic_grid'][0]): for j in range(flags['mosaic_grid'][1]): _offset = offset.clone() _offset[idx, 0] = batch_shapes[:, -2] * i - src_box[:, 0, 0] _offset[idx, 1] = batch_shapes[:, -1] * j - src_box[:, 0, 1] _box = input.clone() _idx = i * flags['mosaic_grid'][1] + j _box._data[params["permutation"][:, 0]] = _box._data[params["permutation"][:, _idx]] _box.translate(_offset, inplace=True) # zero-out unrelated batch elements. _box._data[~params["batch_prob"]] = 0 if maybe_out_boxes is None: _box._data[~params["batch_prob"]] = input._data[~params["batch_prob"]] maybe_out_boxes = _box else: KORNIA_UNWRAP(maybe_out_boxes, Boxes).merge(_box, inplace=True) out_boxes: Boxes = KORNIA_UNWRAP(maybe_out_boxes, Boxes) out_boxes.clamp(offset, offset_end, inplace=True) out_boxes.filter_boxes_by_area(flags["min_bbox_size"], inplace=True) return out_boxes def apply_transform_keypoint(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor: raise NotImplementedError def apply_transform_class(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor: raise RuntimeError(f"{self.__class__.__name__} does not support `TAG` types.") @torch.no_grad() def _compose_images(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor: out = [] for i in range(flags['mosaic_grid'][0]): out_row = [] for j in range(flags['mosaic_grid'][1]): img_idx = flags['mosaic_grid'][1] * i + j image = input[params["permutation"][:, img_idx]] out_row.append(image) out.append(concatenate(out_row, -2)) return concatenate(out, -1) def compute_transformation(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor: if flags["cropping_mode"] == "resample": transform: Tensor = get_perspective_transform(params["src"].to(input), params["dst"].to(input)) return transform if flags["cropping_mode"] == "slice": # Skip the computation for slicing. return eye_like(3, input) raise NotImplementedError(f"Not supported type: {flags['cropping_mode']}.") def _crop_images( self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None ) -> Tensor: flags = self.flags if flags is None else flags if flags["cropping_mode"] == "resample": # uses bilinear interpolation to crop if not isinstance(transform, Tensor): raise TypeError(f'Expected the transform to be a Tensor. Gotcha {type(transform)}') # Fit the arg to F.pad if flags['padding_mode'] == "constant": padding_mode = "zeros" elif flags['padding_mode'] == "replicate": padding_mode = "border" elif flags['padding_mode'] == "reflect": padding_mode = "reflection" else: padding_mode = flags['padding_mode'] return crop_by_transform_mat( input, transform, flags["output_size"], mode=flags["resample"].name.lower(), padding_mode=padding_mode, align_corners=flags["align_corners"], ) if flags["cropping_mode"] == "slice": # uses advanced slicing to crop return crop_by_indices(input, params["src"], flags["output_size"], shape_compensation="pad") raise NotImplementedError(f"Not supported type: {flags['cropping_mode']}.") def apply_non_transform( self, input: Tensor, params: Dict[str, Tensor], flags: Optional[Dict[str, Any]] = None ) -> Tensor: if flags is not None and flags["output_size"] is not None: output_size = KORNIA_UNWRAP(flags["output_size"], Tuple[int, int]) return pad(input, [0, output_size[1] - input.shape[-1], 0, output_size[0] - input.shape[-2]]) # NOTE: resize is not suitable for being consistent with bounding boxes. # return resize( # input, # size=flags["output_size"], # interpolation=flags["resample"].name.lower(), # align_corners=flags["align_corners"] # ) return input def apply_transform( self, input: Tensor, params: Dict[str, Tensor], maybe_flags: Optional[Dict[str, Any]] = None ) -> Tensor: flags = KORNIA_UNWRAP(maybe_flags, Dict[str, Any]) output = self._compose_images(input, params, flags=flags) transform = self.compute_transformation(output, params, flags=flags) output = self._crop_images(output, params, flags=flags, transform=transform) return output