from __future__ import annotations
from typing import Any, Optional, Sequence, Union
import torch
from kornia.augmentation._2d.mix.base import MixAugmentationBaseV2
from kornia.augmentation.utils import _validate_input_dtype
from kornia.constants import DataKey
from kornia.core import Tensor, tensor
from kornia.core.check import KORNIA_CHECK
__all__ = ["RandomTransplantation"]
[docs]class RandomTransplantation(MixAugmentationBaseV2):
r"""RandomTransplantation augmentation.
.. image:: _static/img/RandomTransplantation.png
Randomly transplant (copy and paste) image features and corresponding segmentation masks between images in a batch.
The transplantation transform works as follows:
1. Based on the parameter `p`, a certain number of images in the batch are selected as acceptor of a
transplantation.
2. For each acceptor, the image below in the batch is selected as donor (via circling: :math:`i - 1 \mod B`).
3. From the donor, a random label is selected and the corresponding image features and segmentation mask are
transplanted to the acceptor.
The augmentation is described in `Semantic segmentation of surgical hyperspectral images under geometric domain
shifts` :cite:`sellner2023semantic`.
Args:
excluded_labels: sequence of labels which should not be transplanted from a donor. This can be useful if only
parts of the image are annotated and the non-annotated regions (with a specific label index) should be
excluded from the augmentation. If no label is left in the donor image, nothing is transplanted.
p: probability for applying an augmentation to an image. This parameter controls how many images in a batch
receive a transplant.
p_batch: probability for applying an augmentation to a batch. This param controls the augmentation
probabilities batch-wise.
data_keys: the input type sequential for applying augmentations. There must be at least one "mask" tensor. If no
data keys are given, the first tensor is assumed to be `DataKey.INPUT` and the second tensor `DataKey.MASK`.
Accepts "input", "mask".
Note:
- This augmentation requires that segmentation masks are available for all images in the batch and that at
least some objects in the image are annotated.
- When using this class directly (`RandomTransplantation()(...)`), it works for arbitrary spatial dimensions
including 2D and 3D images. When wrapping in :class:`kornia.augmentation.AugmentationSequential`, use
:class:`kornia.augmentation.RandomTransplantation` for 2D and
:class:`kornia.augmentation.RandomTransplantation3D` for 3D images.
Inputs:
- Segmentation mask tensor which is used to determine the objects for transplantation: :math:`(B, *)`.
- (optional) Additional image or mask tensors where the features are transplanted based on the first
segmentation mask: :math:`(B, C, *)` (`DataKey.INPUT`) or :math:`(B, *)` (`DataKey.MASK`).
Returns:
Tensor | list[Tensor]:
Tensor:
- Augmented mask tensors: :math:`(B, *)`.
list[Tensor]:
- Augmented mask tensors: :math:`(B, *)`.
- Additional augmented image or mask tensors: :math:`(B, C, *)` (`DataKey.INPUT`) or :math:`(B, *)`
(`DataKey.MASK`).
Examples:
>>> import torch
>>> rng = torch.manual_seed(0)
>>> aug = RandomTransplantation(p=1.)
>>> image = torch.randn(2, 3, 5, 5)
>>> mask = torch.randint(0, 3, (2, 5, 5))
>>> mask
tensor([[[0, 0, 1, 1, 0],
[1, 2, 0, 0, 0],
[1, 2, 1, 1, 0],
[0, 0, 0, 0, 2],
[2, 2, 2, 0, 2]],
<BLANKLINE>
[[2, 0, 0, 2, 1],
[2, 1, 0, 2, 1],
[2, 0, 1, 0, 2],
[2, 2, 2, 0, 2],
[2, 1, 0, 0, 0]]])
>>> image_out, mask_out = aug(image, mask)
>>> image_out.shape
torch.Size([2, 3, 5, 5])
>>> mask_out.shape
torch.Size([2, 5, 5])
>>> mask_out
tensor([[[2, 0, 1, 2, 0],
[2, 2, 0, 2, 0],
[2, 2, 1, 1, 2],
[2, 2, 2, 0, 2],
[2, 2, 2, 0, 2]],
<BLANKLINE>
[[0, 0, 0, 2, 0],
[2, 1, 0, 0, 0],
[2, 0, 1, 0, 0],
[0, 0, 0, 0, 2],
[2, 1, 0, 0, 0]]])
>>> aug._params["selected_labels"] # Image 0 received label 2 from image 1 and image 1 label 0 from image 0
tensor([2, 0])
You can apply the same augmentation again in which case the same objects get transplanted between the images:
>>> aug._params["selection"] # The pixels (objects) which get transplanted
tensor([[[ True, False, False, True, False],
[ True, False, False, True, False],
[ True, False, False, False, True],
[ True, True, True, False, True],
[ True, False, False, False, False]],
<BLANKLINE>
[[ True, True, False, False, True],
[False, False, True, True, True],
[False, False, False, False, True],
[ True, True, True, True, False],
[False, False, False, True, False]]])
>>> image2 = torch.zeros(2, 3, 5, 5)
>>> image2[1] = 1
>>> image2[:, 0]
tensor([[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]],
<BLANKLINE>
[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]])
>>> image_out2, mask_out2 = aug(image2, mask, params=aug._params)
>>> image_out2[:, 0]
tensor([[[1., 0., 0., 1., 0.],
[1., 0., 0., 1., 0.],
[1., 0., 0., 0., 1.],
[1., 1., 1., 0., 1.],
[1., 0., 0., 0., 0.]],
<BLANKLINE>
[[0., 0., 1., 1., 0.],
[1., 1., 0., 0., 0.],
[1., 1., 1., 1., 0.],
[0., 0., 0., 0., 1.],
[1., 1., 1., 0., 1.]]])
"""
def __init__(
self,
excluded_labels: Optional[Union[Sequence[int], Tensor]] = None,
p: float = 0.5,
p_batch: float = 1.0,
data_keys: Optional[list[str | int | DataKey]] = None,
) -> None:
super().__init__(p=p, p_batch=p_batch)
if excluded_labels is None:
excluded_labels = []
if not isinstance(excluded_labels, Tensor):
excluded_labels = tensor(excluded_labels)
self.excluded_labels: Tensor = excluded_labels
KORNIA_CHECK(
self.excluded_labels.ndim == 1,
f"excluded_labels must be a 1-dimensional sequence, but got {self.excluded_labels.ndim} dimensions.",
)
if data_keys is None:
data_keys = [DataKey.INPUT, DataKey.MASK]
self.data_keys = [DataKey.get(inp) for inp in data_keys]
self._channel_dim = 1
def apply_non_transform_mask(self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any]) -> Tensor:
return input
def transform_input(self, acceptor: Tensor, donor: Tensor, selection: Tensor) -> Tensor: # type: ignore[override]
# Expand selection to the channel dimension
selection = selection.unsqueeze(dim=self._channel_dim).expand_as(donor)
acceptor[selection] = donor[selection]
return acceptor
def transform_mask(self, acceptor: Tensor, donor: Tensor, selection: Tensor) -> Tensor: # type: ignore[override]
acceptor[selection] = donor[selection]
return acceptor
def params_from_input(
self,
*input: Tensor,
data_keys: list[DataKey],
params: dict[str, Tensor],
extra_args: Optional[dict[DataKey, dict[str, Any]]] = None,
) -> dict[str, Tensor]:
"""Compute parameters for the transformation which are based on one or more input tensors.
This function is, for example, called by :class:`kornia.augmentation.container.ops.AugmentationSequentialOps`
before the augmentation is applied on the individual input tensors.
Args:
*input: All input tensors passed to the augmentation pipeline.
data_keys: Associated data key for every input tensor.
params: Dictionary of parameters computed so far by the augmentation pipeline (e.g. including the
`batch_prob`).
extra_args: Optional dictionary of extra arguments with specific options for different input types.
Returns:
Updated dictionary of parameters with the necessary information to apply the augmentation on all input
tensors separately.
"""
KORNIA_CHECK(
len(data_keys) == len(input),
f"Length of keys ({len(data_keys)}) does not match number of inputs ({len(input)}).",
)
# The first mask key will be used for the transplantation
mask: Tensor = input[data_keys.index(DataKey.MASK)]
for _input, key in zip(input, data_keys):
if key == DataKey.INPUT:
KORNIA_CHECK(
_input.ndim == mask.ndim + 1,
"Every image input must have one additional dimension (channel dimension) than the segmentation "
f"mask, but got {_input.ndim} for the input image and {mask.ndim} for the segmentation mask.",
)
KORNIA_CHECK(
mask.size() == torch.Size([s for i, s in enumerate(_input.size()) if i != self._channel_dim]),
"The dimensions of the input image and segmentation mask must match except for the channel "
f"dimension, but got {_input.size()} for the input image and {mask.size()} for the segmentation "
"mask.",
)
if "acceptor_indices" not in params:
params["acceptor_indices"] = torch.where(params["batch_prob"] > 0.5)[0]
if "donor_indices" not in params:
params["donor_indices"] = (params["acceptor_indices"] - 1) % len(params["batch_prob"])
if "selected_labels" not in params:
if self.excluded_labels.device != mask.device:
self.excluded_labels = self.excluded_labels.to(mask.device)
donor_labels: list[Tensor] = []
for d in range(len(params["donor_indices"])):
# Select a random label from the donor image
current_mask = mask[params["donor_indices"][d]]
labels = current_mask.unique()
# Remove any label which is part of the excluded labels
labels = labels[(labels.view(1, -1) != self.excluded_labels.view(-1, 1)).all(dim=0)]
if len(labels) > 0:
selected_label = labels[torch.randperm(len(labels))[0]]
donor_labels.append(selected_label)
params["selected_labels"] = torch.stack(donor_labels) if len(donor_labels) > 0 else torch.empty(0)
if "selection" not in params:
selection = torch.zeros(
(len(params["acceptor_indices"]), *mask.shape[1:]), dtype=torch.bool, device=mask.device
)
selected_labels: Tensor = params["selected_labels"]
KORNIA_CHECK(
selected_labels.ndim == 1,
f"selected_labels must be a 1-dimensional tensor, but got {selected_labels.ndim} dimensions.",
)
KORNIA_CHECK(
len(selected_labels) <= len(params["acceptor_indices"]),
f"There cannot be more selected labels ({len(selected_labels)}) than images where this augmentation "
f"should be applied ({len(params['acceptor_indices'])}).",
)
for d, selected_label in zip(range(len(params["donor_indices"])), selected_labels):
current_mask = mask[params["donor_indices"][d]]
selection[d].masked_fill_(current_mask == selected_label, True)
params["selection"] = selection
return params
def forward( # type: ignore[override]
self,
*input: Tensor,
params: Optional[dict[str, Tensor]] = None,
data_keys: Optional[list[str | int | DataKey]] = None,
**kwargs: dict[str, Any],
) -> Tensor | list[Tensor]:
keys: list[DataKey]
if data_keys is None:
keys = self.data_keys
else:
keys = [DataKey.get(inp) for inp in data_keys]
if params is None:
mask: Tensor = input[keys.index(DataKey.MASK)]
self._params = self.forward_parameters(mask.shape)
else:
self._params = params
if any(k not in self._params for k in ["acceptor_indices", "donor_indices", "selection"]):
self._params.update(self.params_from_input(*input, data_keys=keys, params=self._params))
outputs: list[Tensor] = []
for dcate, _input in zip(keys, input):
acceptor = _input[self._params["acceptor_indices"]].clone()
donor = _input[self._params["donor_indices"]]
output: Tensor
if dcate == DataKey.INPUT:
_validate_input_dtype(_input, accepted_dtypes=[torch.float16, torch.float32, torch.float64])
applied = self.transform_input(acceptor, donor, self._params["selection"])
output = self.apply_non_transform(_input, self._params, self.flags)
output = output.index_put(
(self._params["acceptor_indices"],),
self.apply_non_transform_mask(applied, self._params, self.flags),
)
elif dcate == DataKey.MASK:
applied = self.transform_mask(acceptor, donor, self._params["selection"])
output = self.apply_non_transform_mask(_input, self._params, self.flags)
output = output.index_put(
(self._params["acceptor_indices"],),
self.apply_non_transform_mask(applied, self._params, self.flags),
)
else:
raise NotImplementedError
outputs.append(output)
if len(outputs) == 1:
return outputs[0]
else:
return outputs