from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from kornia.augmentation import random_generator as rg
from kornia.augmentation._2d.geometric.base import GeometricAugmentationBase2D
from kornia.constants import Resample
from kornia.core import Tensor, pad, tensor
from kornia.geometry.boxes import Boxes
from kornia.geometry.keypoints import Keypoints
from kornia.geometry.transform import crop_by_indices, crop_by_transform_mat, get_perspective_transform
[docs]class RandomCrop(GeometricAugmentationBase2D):
r"""Crop random patches of a tensor image on a given size.
.. image:: _static/img/RandomCrop.png
Args:
size: Desired output size (out_h, out_w) of the crop.
Must be Tuple[int, int], then out_h = size[0], out_w = size[1].
padding: Optional padding on each border
of the image. Default is None, i.e no padding. If a sequence of length
4 is provided, it is used to pad left, top, right, bottom borders
respectively. If a sequence of length 2 is provided, it is used to
pad left/right, top/bottom borders, respectively.
pad_if_needed: It will pad the image if smaller than the
desired size to avoid raising an exception. Since cropping is done
after padding, the padding seems to be done at a random offset.
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
padding_mode: Type of padding. Should be: constant, reflect, replicate.
resample: the interpolation mode.
same_on_batch: apply the same transformation across the batch.
align_corners: interpolation flag.
p: probability of applying the transformation for the whole batch.
keepdim: whether to keep the output shape the same as input (True) or broadcast it
to the batch form (False).
cropping_mode: The used algorithm to crop. ``slice`` will use advanced slicing to extract the tensor based
on the sampled indices. ``resample`` will use `warp_affine` using the affine transformation
to extract and resize at once. Use `slice` for efficiency, or `resample` for proper
differentiability.
Shape:
- Input: :math:`(C, H, W)` or :math:`(B, C, H, W)`, Optional: :math:`(B, 3, 3)`
- Output: :math:`(B, C, out_h, out_w)`
Note:
Input tensor must be float and normalized into [0, 1] for the best differentiability support.
Additionally, this function accepts another transformation tensor (:math:`(B, 3, 3)`), then the
applied transformation will be merged int to the input transformation tensor and returned.
Examples:
>>> import torch
>>> _ = torch.manual_seed(0)
>>> inputs = torch.arange(1*1*3*3.).view(1, 1, 3, 3)
>>> aug = RandomCrop((2, 2), p=1., cropping_mode="resample")
>>> out = aug(inputs)
>>> out
tensor([[[[3., 4.],
[6., 7.]]]])
>>> aug.inverse(out, padding_mode="replicate")
tensor([[[[3., 4., 4.],
[3., 4., 4.],
[6., 7., 7.]]]])
To apply the exact augmenation again, you may take the advantage of the previous parameter state:
>>> input = torch.randn(1, 3, 32, 32)
>>> aug = RandomCrop((2, 2), p=1., cropping_mode="resample")
>>> (aug(input) == aug(input, params=aug._params)).all()
tensor(True)
"""
def __init__(
self,
size: Tuple[int, int],
padding: Optional[Union[int, Tuple[int, int], Tuple[int, int, int, int]]] = None,
pad_if_needed: Optional[bool] = False,
fill: int = 0,
padding_mode: str = "constant",
resample: Union[str, int, Resample] = Resample.BILINEAR.name,
same_on_batch: bool = False,
align_corners: bool = True,
p: float = 1.0,
keepdim: bool = False,
cropping_mode: str = "slice",
) -> None:
# Since PyTorch does not support ragged tensor. So cropping function happens batch-wisely.
super().__init__(p=1.0, same_on_batch=same_on_batch, p_batch=p, keepdim=keepdim)
self._param_generator = rg.CropGenerator(size)
self.flags = {
"size": size,
"padding": padding,
"pad_if_needed": pad_if_needed,
"fill": fill,
"padding_mode": padding_mode,
"resample": Resample.get(resample),
"align_corners": align_corners,
"cropping_mode": cropping_mode,
}
def compute_padding(self, shape: Tuple[int, ...], flags: Optional[Dict[str, Any]] = None) -> List[int]:
flags = self.flags if flags is None else flags
if len(shape) != 4:
raise AssertionError(f"Expected BCHW. Got {shape}.")
padding = [0, 0, 0, 0] # left, right, top, bottom
if flags["padding"] is not None:
if isinstance(flags["padding"], int):
padding = [flags["padding"]] * 4
elif isinstance(flags["padding"], tuple) and len(flags["padding"]) == 2:
padding = [flags["padding"][0], flags["padding"][0], flags["padding"][1], flags["padding"][1]]
elif isinstance(flags["padding"], tuple) and len(flags["padding"]) == 4:
padding = [flags["padding"][0], flags["padding"][2], flags["padding"][1], flags["padding"][3]]
else:
raise RuntimeError(f"Expect `padding` to be a scalar, or length 2/4 list. Got {flags['padding']}.")
if flags["pad_if_needed"]:
needed_padding: Tuple[int, int] = (flags["size"][0] - shape[-2], flags["size"][1] - shape[-1]) # HW
# If crop width is larger than input width pad equally left and right
if needed_padding[1] > 0:
# Only use the extra padding if actually needed after possible fixed padding
if needed_padding[1] > padding[0]:
padding[0] = needed_padding[1]
if needed_padding[1] > padding[1]:
padding[1] = needed_padding[1]
# If crop height is larger than input height pad equally top and bottom
if needed_padding[0] > 0:
# Only use the extra padding if actually needed after possible fixed padding
if needed_padding[0] > padding[2]:
padding[2] = needed_padding[0]
if needed_padding[0] > padding[3]:
padding[3] = needed_padding[0]
return padding
def precrop_padding(
self, input: Tensor, padding: Optional[List[int]] = None, flags: Optional[Dict[str, Any]] = None
) -> Tensor:
flags = self.flags if flags is None else flags
if padding is None:
padding = self.compute_padding(input.shape)
if any(padding):
input = pad(input, padding, value=flags["fill"], mode=flags["padding_mode"])
return input
def compute_transformation(self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any]) -> Tensor:
if flags["cropping_mode"] in ("resample", "slice"):
transform: Tensor = get_perspective_transform(params["src"].to(input), params["dst"].to(input))
return transform
raise NotImplementedError(f"Not supported type: {flags['cropping_mode']}.")
def apply_transform_keypoint(
self, input: Keypoints, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
) -> Keypoints:
"""Process keypoints corresponding to the inputs that are no transformation applied."""
# For pad the keypoints properly.
padding_size = params["padding_size"].to(device=input.device)
input = input.pad(padding_size)
return super().apply_transform_keypoint(input=input, params=params, flags=flags, transform=transform)
def apply_transform_box(
self, input: Boxes, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
) -> Boxes:
"""Process keypoints corresponding to the inputs that are no transformation applied."""
# For pad the boxes properly.
padding_size = params["padding_size"]
input = input.pad(padding_size)
return super().apply_transform_box(input=input, params=params, flags=flags, transform=transform)
def apply_transform(
self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
) -> Tensor:
padding_size: Optional[List[int]] = None
if "padding_size" in params and isinstance(params["padding_size"], Tensor):
padding_size = params["padding_size"].unique(dim=0).cpu().squeeze().tolist()
input = self.precrop_padding(input, padding_size, flags)
flags = self.flags if flags is None else flags
if flags["cropping_mode"] == "resample": # uses bilinear interpolation to crop
if not isinstance(transform, Tensor):
raise TypeError(f"Expected the `transform` be a Tensor. Got {type(transform)}.")
# Fit the arg to F.pad
if flags["padding_mode"] == "constant":
padding_mode = "zeros"
elif flags["padding_mode"] == "replicate":
padding_mode = "border"
elif flags["padding_mode"] == "reflect":
padding_mode = "reflection"
else:
padding_mode = flags["padding_mode"]
return crop_by_transform_mat(
input,
transform,
flags["size"],
mode=flags["resample"].name.lower(),
padding_mode=padding_mode,
align_corners=flags["align_corners"],
)
if flags["cropping_mode"] == "slice": # uses advanced slicing to crop
return crop_by_indices(input, params["src"], flags["size"])
raise NotImplementedError(f"Not supported type: {flags['cropping_mode']}.")
def inverse_transform(
self,
input: Tensor,
flags: Dict[str, Any],
transform: Optional[Tensor] = None,
size: Optional[Tuple[int, int]] = None,
) -> Tensor:
if flags["cropping_mode"] != "resample":
raise NotImplementedError(
f"`inverse` is only applicable for resample cropping mode. Got {flags['cropping_mode']}."
)
if size is None:
raise RuntimeError("`size` has to be a tuple. Got None.")
if not isinstance(transform, Tensor):
raise TypeError(f"Expected the `transform` be a Tensor. Got {type(transform)}.")
# Fit the arg to F.pad
if flags["padding_mode"] == "constant":
padding_mode = "zeros"
elif flags["padding_mode"] == "replicate":
padding_mode = "border"
elif flags["padding_mode"] == "reflect":
padding_mode = "reflection"
else:
padding_mode = flags["padding_mode"]
return crop_by_transform_mat(
input,
transform[:, :2, :],
size,
flags["resample"].name.lower(),
padding_mode=padding_mode,
align_corners=flags["align_corners"],
)
def inverse_inputs(
self,
input: Tensor,
params: Dict[str, Tensor],
flags: Dict[str, Any],
transform: Optional[Tensor] = None,
**kwargs: Any,
) -> Tensor:
if flags["cropping_mode"] != "resample":
raise NotImplementedError(
f"`inverse` is only applicable for resample cropping mode. Got {flags['cropping_mode']}."
)
out = super().inverse_inputs(input, params, flags, transform, **kwargs)
if not params["batch_prob"].all():
return out
padding_size = params["padding_size"].unique(dim=0).cpu().squeeze().tolist()
padding_size = [-padding_size[0], -padding_size[1], -padding_size[2], -padding_size[3]]
return self.precrop_padding(out, padding_size)
def inverse_boxes(
self,
input: Boxes,
params: Dict[str, Tensor],
flags: Dict[str, Any],
transform: Optional[Tensor] = None,
**kwargs: Any,
) -> Boxes:
if flags["cropping_mode"] != "resample":
raise NotImplementedError(
f"`inverse` is only applicable for resample cropping mode. Got {flags['cropping_mode']}."
)
output = super().inverse_boxes(input, params, flags, transform, **kwargs)
if not params["batch_prob"].all():
return output
return output.unpad(params["padding_size"])
def inverse_keypoints(
self,
input: Keypoints,
params: Dict[str, Tensor],
flags: Dict[str, Any],
transform: Optional[Tensor] = None,
**kwargs: Any,
) -> Keypoints:
if flags["cropping_mode"] != "resample":
raise NotImplementedError(
f"`inverse` is only applicable for resample cropping mode. Got {flags['cropping_mode']}."
)
output = super().inverse_keypoints(input, params, flags, transform, **kwargs)
if not params["batch_prob"].all():
return output
return output.unpad(params["padding_size"].to(device=input.device))
# Override parameters for precrop
def forward_parameters(self, batch_shape: Tuple[int, ...]) -> Dict[str, Tensor]:
input_pad = self.compute_padding(batch_shape)
batch_shape_new = torch.Size(
(
*batch_shape[:2],
batch_shape[2] + input_pad[2] + input_pad[3], # original height + top + bottom padding
batch_shape[3] + input_pad[0] + input_pad[1], # original width + left + right padding
)
)
padding_size = tensor(tuple(input_pad), dtype=torch.long).expand(batch_shape[0], -1)
_params = super().forward_parameters(batch_shape_new)
_params.update({"padding_size": padding_size})
return _params