# Source code for kornia.augmentation._2d.intensity.normalize

from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor

from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
from kornia.enhance import normalize

[docs]class Normalize(IntensityAugmentationBase2D):
r"""Normalize tensor images with mean and standard deviation.

.. math::
\text{input[channel] = (input[channel] - mean[channel]) / std[channel]}

Where mean is :math:(M_1, ..., M_n) and std :math:(S_1, ..., S_n) for n channels,

Args:
mean: Mean for each channel.
std: Standard deviations for each channel.
p: probability of applying the transformation.
keepdim: whether to keep the output shape the same as input (True) or broadcast it
to the batch form (False).

Return:
Normalised tensor with same size as input :math:(*, C, H, W).

.. note::
This function internally uses :func:kornia.enhance.normalize.

Examples:

>>> norm = Normalize(mean=torch.zeros(4), std=torch.ones(4))
>>> x = torch.rand(1, 4, 3, 3)
>>> out = norm(x)
>>> out.shape
torch.Size([1, 4, 3, 3])
"""

def __init__(
self,
mean: Union[Tensor, Tuple[float], List[float], float],
std: Union[Tensor, Tuple[float], List[float], float],
p: float = 1.0,
keepdim: bool = False,
return_transform: Optional[bool] = None,
) -> None:
super().__init__(p=p, return_transform=return_transform, same_on_batch=True, keepdim=keepdim)
if isinstance(mean, float):
mean = torch.tensor([mean])

if isinstance(std, float):
std = torch.tensor([std])

if isinstance(mean, (tuple, list)):
mean = torch.tensor(mean)

if isinstance(std, (tuple, list)):
std = torch.tensor(std)

self.flags = dict(mean=mean, std=std)

def apply_transform(
self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any], transform: Optional[Tensor] = None
) -> Tensor:
return normalize(input, flags["mean"], flags["std"])