from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor
from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
from kornia.enhance import denormalize
[docs]class Denormalize(IntensityAugmentationBase2D):
r"""Denormalize tensor images with mean and standard deviation.
.. math::
\text{input[channel] = (input[channel] * std[channel]) + mean[channel]}
Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
Args:
mean: Mean for each channel.
std: Standard deviations for each channel.
same_on_batch: apply the same transformation across the batch.
p: probability of applying the transformation.
keepdim: whether to keep the output shape the same as input (True) or broadcast it
to the batch form (False).
Return:
Denormalised tensor with same size as input :math:`(*, C, H, W)`.
.. note::
This function internally uses :func:`kornia.enhance.denormalize`.
Examples:
>>> norm = Denormalize(mean=torch.zeros(1, 4), std=torch.ones(1, 4))
>>> x = torch.rand(1, 4, 3, 3)
>>> out = norm(x)
>>> out.shape
torch.Size([1, 4, 3, 3])
"""
def __init__(
self,
mean: Union[Tensor, Tuple[float], List[float], float],
std: Union[Tensor, Tuple[float], List[float], float],
p: float = 1.0,
keepdim: bool = False,
return_transform: Optional[bool] = None,
) -> None:
super().__init__(p=p, return_transform=return_transform, same_on_batch=True, keepdim=keepdim)
if isinstance(mean, float):
mean = torch.tensor([mean])
if isinstance(std, float):
std = torch.tensor([std])
if isinstance(mean, (tuple, list)):
mean = torch.tensor(mean)
if isinstance(std, (tuple, list)):
std = torch.tensor(std)
self.flags = dict(mean=mean, std=std)
def apply_transform(
self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
) -> Tensor:
return denormalize(input, flags["mean"], flags["std"])