Source code for kornia.enhance.normalize

"""Module containing functionals for intensity normalisation."""

from typing import List, Tuple, Union

import torch
import torch.nn as nn

__all__ = ["normalize", "normalize_min_max", "denormalize", "Normalize", "Denormalize"]


[docs]class Normalize(nn.Module): r"""Normalize a tensor image with mean and standard deviation. .. math:: \text{input[channel] = (input[channel] - mean[channel]) / std[channel]} Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels, Args: mean: Mean for each channel. std: Standard deviations for each channel. Shape: - Input: Image tensor of size :math:`(*, C, ...)`. - Output: Normalised tensor with same size as input :math:`(*, C, ...)`. Examples: >>> x = torch.rand(1, 4, 3, 3) >>> out = Normalize(0.0, 255.)(x) >>> out.shape torch.Size([1, 4, 3, 3]) >>> x = torch.rand(1, 4, 3, 3) >>> mean = torch.zeros(4) >>> std = 255. * torch.ones(4) >>> out = Normalize(mean, std)(x) >>> out.shape torch.Size([1, 4, 3, 3]) """ def __init__( self, mean: Union[torch.Tensor, Tuple[float], List[float], float], std: Union[torch.Tensor, Tuple[float], List[float], float], ) -> None: super().__init__() if isinstance(mean, float): mean = torch.tensor([mean]) if isinstance(std, float): std = torch.tensor([std]) if isinstance(mean, (tuple, list)): mean = torch.tensor(mean) if isinstance(std, (tuple, list)): std = torch.tensor(std) self.mean = mean self.std = std def forward(self, input: torch.Tensor) -> torch.Tensor: return normalize(input, self.mean, self.std) def __repr__(self): repr = f"(mean={self.mean}, std={self.std})" return self.__class__.__name__ + repr
[docs]def normalize(data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: r"""Normalize an image/video tensor with mean and standard deviation. .. math:: \text{input[channel] = (input[channel] - mean[channel]) / std[channel]} Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels, Args: data: Image tensor of size :math:`(B, C, *)`. mean: Mean for each channel. std: Standard deviations for each channel. Return: Normalised tensor with same size as input :math:`(B, C, *)`. Examples: >>> x = torch.rand(1, 4, 3, 3) >>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.])) >>> out.shape torch.Size([1, 4, 3, 3]) >>> x = torch.rand(1, 4, 3, 3) >>> mean = torch.zeros(4) >>> std = 255. * torch.ones(4) >>> out = normalize(x, mean, std) >>> out.shape torch.Size([1, 4, 3, 3]) """ shape = data.shape if len(mean.shape) == 0 or mean.shape[0] == 1: mean = mean.expand(shape[1]) if len(std.shape) == 0 or std.shape[0] == 1: std = std.expand(shape[1]) # Allow broadcast on channel dimension if mean.shape and mean.shape[0] != 1: if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]: raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.") # Allow broadcast on channel dimension if std.shape and std.shape[0] != 1: if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]: raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.") mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype) std = torch.as_tensor(std, device=data.device, dtype=data.dtype) if mean.shape: mean = mean[..., :, None] if std.shape: std = std[..., :, None] out: torch.Tensor = (data.view(shape[0], shape[1], -1) - mean) / std return out.view(shape)
[docs]class Denormalize(nn.Module): r"""Denormalize a tensor image with mean and standard deviation. .. math:: \text{input[channel] = (input[channel] * std[channel]) + mean[channel]} Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels, Args: mean: Mean for each channel. std: Standard deviations for each channel. Shape: - Input: Image tensor of size :math:`(*, C, ...)`. - Output: Denormalised tensor with same size as input :math:`(*, C, ...)`. Examples: >>> x = torch.rand(1, 4, 3, 3) >>> out = Denormalize(0.0, 255.)(x) >>> out.shape torch.Size([1, 4, 3, 3]) >>> x = torch.rand(1, 4, 3, 3, 3) >>> mean = torch.zeros(1, 4) >>> std = 255. * torch.ones(1, 4) >>> out = Denormalize(mean, std)(x) >>> out.shape torch.Size([1, 4, 3, 3, 3]) """ def __init__(self, mean: Union[torch.Tensor, float], std: Union[torch.Tensor, float]) -> None: super().__init__() self.mean = mean self.std = std def forward(self, input: torch.Tensor) -> torch.Tensor: return denormalize(input, self.mean, self.std) def __repr__(self): repr = f"(mean={self.mean}, std={self.std})" return self.__class__.__name__ + repr
[docs]def denormalize(data: torch.Tensor, mean: Union[torch.Tensor, float], std: Union[torch.Tensor, float]) -> torch.Tensor: r"""Denormalize an image/video tensor with mean and standard deviation. .. math:: \text{input[channel] = (input[channel] * std[channel]) + mean[channel]} Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels, Args: input: Image tensor of size :math:`(B, C, *)`. mean: Mean for each channel. std: Standard deviations for each channel. Return: Denormalised tensor with same size as input :math:`(B, C, *)`. Examples: >>> x = torch.rand(1, 4, 3, 3) >>> out = denormalize(x, 0.0, 255.) >>> out.shape torch.Size([1, 4, 3, 3]) >>> x = torch.rand(1, 4, 3, 3, 3) >>> mean = torch.zeros(1, 4) >>> std = 255. * torch.ones(1, 4) >>> out = denormalize(x, mean, std) >>> out.shape torch.Size([1, 4, 3, 3, 3]) """ shape = data.shape if isinstance(mean, float): mean = torch.tensor([mean] * shape[1], device=data.device, dtype=data.dtype) if isinstance(std, float): std = torch.tensor([std] * shape[1], device=data.device, dtype=data.dtype) if not isinstance(data, torch.Tensor): raise TypeError(f"data should be a tensor. Got {type(data)}") if not isinstance(mean, torch.Tensor): raise TypeError(f"mean should be a tensor or a float. Got {type(mean)}") if not isinstance(std, torch.Tensor): raise TypeError(f"std should be a tensor or float. Got {type(std)}") # Allow broadcast on channel dimension if mean.shape and mean.shape[0] != 1: if mean.shape[0] != data.shape[-3] and mean.shape[:2] != data.shape[:2]: raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.") # Allow broadcast on channel dimension if std.shape and std.shape[0] != 1: if std.shape[0] != data.shape[-3] and std.shape[:2] != data.shape[:2]: raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.") mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype) std = torch.as_tensor(std, device=data.device, dtype=data.dtype) if mean.shape: mean = mean[..., :, None] if std.shape: std = std[..., :, None] out: torch.Tensor = (data.view(shape[0], shape[1], -1) * std) + mean return out.view(shape)
[docs]def normalize_min_max(x: torch.Tensor, min_val: float = 0.0, max_val: float = 1.0, eps: float = 1e-6) -> torch.Tensor: r"""Normalise an image/video tensor by MinMax and re-scales the value between a range. The data is normalised using the following formulation: .. math:: y_i = (b - a) * \frac{x_i - \text{min}(x)}{\text{max}(x) - \text{min}(x)} + a where :math:`a` is :math:`\text{min_val}` and :math:`b` is :math:`\text{max_val}`. Args: x: The image tensor to be normalised with shape :math:`(B, C, *)`. min_val: The minimum value for the new range. max_val: The maximum value for the new range. eps: Float number to avoid zero division. Returns: The normalised image tensor with same shape as input :math:`(B, C, *)`. Example: >>> x = torch.rand(1, 5, 3, 3) >>> x_norm = normalize_min_max(x, min_val=-1., max_val=1.) >>> x_norm.min() tensor(-1.) >>> x_norm.max() tensor(1.0000) """ if not isinstance(x, torch.Tensor): raise TypeError(f"data should be a tensor. Got: {type(x)}.") if not isinstance(min_val, float): raise TypeError(f"'min_val' should be a float. Got: {type(min_val)}.") if not isinstance(max_val, float): raise TypeError(f"'b' should be a float. Got: {type(max_val)}.") if len(x.shape) < 3: raise ValueError(f"Input shape must be at least a 3d tensor. Got: {x.shape}.") shape = x.shape B, C = shape[0], shape[1] x_min: torch.Tensor = x.view(B, C, -1).min(-1)[0].view(B, C, 1) x_max: torch.Tensor = x.view(B, C, -1).max(-1)[0].view(B, C, 1) x_out: torch.Tensor = (max_val - min_val) * (x.view(B, C, -1) - x_min) / (x_max - x_min + eps) + min_val return x_out.view(shape)