kornia.enhance#

The functions in this section perform normalisations and intensity transformations.

Adjustment#

kornia.enhance.add_weighted(src1, alpha, src2, beta, gamma)#

Calculate the weighted sum of two Tensors.

_images/add_weighted.png

The function calculates the weighted sum of two Tensors as follows:

\[out = src1 * alpha + src2 * beta + gamma\]
Parameters:
  • src1 (Tensor) – Tensor with an arbitrary shape, equal to shape of src2.

  • alpha (Union[float, Tensor]) – weight of the src1 elements as Union[float, Tensor].

  • src2 (Tensor) – Tensor with an arbitrary shape, equal to shape of src1.

  • beta (Union[float, Tensor]) – weight of the src2 elements as Union[float, Tensor].

  • gamma (Union[float, Tensor]) – scalar added to each sum as Union[float, Tensor].

Return type:

Tensor

Returns:

Weighted Tensor with shape equal to src1 and src2 shapes.

Example

>>> input1 = torch.rand(1, 1, 5, 5)
>>> input2 = torch.rand(1, 1, 5, 5)
>>> output = add_weighted(input1, 0.5, input2, 0.5, 1.0)
>>> output.shape
torch.Size([1, 1, 5, 5])

Notes

Tensor alpha/beta/gamma have to be with shape broadcastable to src1 and src2 shapes.

kornia.enhance.adjust_brightness(image, factor, clip_output=True)#

Adjust the brightness of an image tensor.

_images/adjust_brightness.png

This implementation follows Szeliski’s book convention, where brightness is defined as an additive operation directly to raw pixel and shift its values according the applied factor and range of the image values. Beware that other framework might use different conventions which can be difficult to reproduce exact results.

The input image and factor is expected to be in the range of [0, 1].

Tip

By applying a large factor might prouce clipping or loss of image detail. We recommenda to apply small factors to avoid the mentioned issues. Ideally one must implement the adjustment of image intensity with other techniques suchs as kornia.enhance.adjust_gamma(). More details in the following link: https://scikit-image.org/docs/dev/auto_examples/color_exposure/plot_log_gamma.html#sphx-glr-auto-examples-color-exposure-plot-log-gamma-py

Parameters:
  • image (Tensor) – Image to be adjusted in the shape of \((*, H, W)\).

  • factor (Union[float, Tensor]) – Brightness adjust factor per element in the batch. It’s recommended to bound the factor by [0, 1]. 0 does not modify the input image while any other number modify the brightness.

Return type:

Tensor

Returns:

Adjusted tensor in the shape of \((*, H, W)\).

Note

See a working example here.

Example

>>> x = torch.ones(1, 1, 2, 2)
>>> adjust_brightness(x, 1.)
tensor([[[[1., 1.],
          [1., 1.]]]])
>>> x = torch.ones(2, 5, 3, 3)
>>> y = torch.tensor([0.25, 0.50])
>>> adjust_brightness(x, y).shape
torch.Size([2, 5, 3, 3])
kornia.enhance.adjust_contrast(image, factor, clip_output=True)#

Adjust the contrast of an image tensor.

_images/adjust_contrast.png

This implementation follows Szeliski’s book convention, where contrast is defined as a multiplicative operation directly to raw pixel values. Beware that other frameworks might use different conventions which can be difficult to reproduce exact results.

The input image and factor is expected to be in the range of [0, 1].

Tip

This is not the preferred way to adjust the contrast of an image. Ideally one must implement kornia.enhance.adjust_gamma(). More details in the following link: https://scikit-image.org/docs/dev/auto_examples/color_exposure/plot_log_gamma.html#sphx-glr-auto-examples-color-exposure-plot-log-gamma-py

Parameters:
  • image (Tensor) – Image to be adjusted in the shape of \((*, H, W)\).

  • factor (Union[float, Tensor]) – Contrast adjust factor per element in the batch. 0 generates a completely black image, 1 does not modify the input image while any other non-negative number modify the brightness by this factor.

  • clip_output (bool, optional) – whether to clip the output image with range of [0, 1]. Default: True

Return type:

Tensor

Returns:

Adjusted image in the shape of \((*, H, W)\).

Note

See a working example here.

Example

>>> import torch
>>> x = torch.ones(1, 1, 2, 2)
>>> adjust_contrast(x, 0.5)
tensor([[[[0.5000, 0.5000],
          [0.5000, 0.5000]]]])
>>> x = torch.ones(2, 5, 3, 3)
>>> y = torch.tensor([0.65, 0.50])
>>> adjust_contrast(x, y).shape
torch.Size([2, 5, 3, 3])
kornia.enhance.adjust_contrast_with_mean_subtraction(image, factor)#

Adjust the contrast of an image tensor by subtracting the mean over channels.

Note

this is just a convenience function to have compatibility with Pil. For exact definition of image contrast adjustment consider using kornia.enhance.adjust_gamma().

Parameters:
  • image (Tensor) – Image to be adjusted in the shape of \((*, H, W)\).

  • factor (Union[float, Tensor]) – Contrast adjust factor per element in the batch. 0 generates a completely black image, 1 does not modify the input image while any other non-negative number modify the brightness by this factor.

Return type:

Tensor

Returns:

Adjusted image in the shape of \((*, H, W)\).

Example

>>> import torch
>>> x = torch.ones(1, 1, 2, 2)
>>> adjust_contrast_with_mean_subtraction(x, 0.5)
tensor([[[[1., 1.],
          [1., 1.]]]])
>>> x = torch.ones(2, 5, 3, 3)
>>> y = torch.tensor([0.65, 0.50])
>>> adjust_contrast_with_mean_subtraction(x, y).shape
torch.Size([2, 5, 3, 3])
kornia.enhance.adjust_gamma(input, gamma, gain=1.0)#

Perform gamma correction on an image.

_images/adjust_contrast.png

The input image is expected to be in the range of [0, 1].

Parameters:
  • input (Tensor) – Image to be adjusted in the shape of \((*, H, W)\).

  • gamma (Union[float, Tensor]) – Non negative real number, same as ygammay in the equation. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter.

  • gain (Union[float, Tensor], optional) – The constant multiplier. Default: 1.0

Return type:

Tensor

Returns:

Adjusted image in the shape of \((*, H, W)\).

Note

See a working example here.

Example

>>> x = torch.ones(1, 1, 2, 2)
>>> adjust_gamma(x, 1.0, 2.0)
tensor([[[[1., 1.],
          [1., 1.]]]])
>>> x = torch.ones(2, 5, 3, 3)
>>> y1 = torch.ones(2) * 1.0
>>> y2 = torch.ones(2) * 2.0
>>> adjust_gamma(x, y1, y2).shape
torch.Size([2, 5, 3, 3])
kornia.enhance.adjust_hue(image, factor)#

Adjust hue of an image.

_images/adjust_hue.png

The image is expected to be an RGB image in the range of [0, 1].

Parameters:
  • image (Tensor) – Image to be adjusted in the shape of \((*, 3, H, W)\).

  • factor (Union[float, Tensor]) – How much to shift the hue channel. Should be in [-PI, PI]. PI and -PI give complete reversal of hue channel in HSV space in positive and negative direction respectively. 0 means no shift. Therefore, both -PI and PI will give an image with complementary colors while 0 gives the original image.

Return type:

Tensor

Returns:

Adjusted image in the shape of \((*, 3, H, W)\).

Note

See a working example here.

Example

>>> x = torch.ones(1, 3, 2, 2)
>>> adjust_hue(x, 3.141516).shape
torch.Size([1, 3, 2, 2])
>>> x = torch.ones(2, 3, 3, 3)
>>> y = torch.ones(2) * 3.141516
>>> adjust_hue(x, y).shape
torch.Size([2, 3, 3, 3])
kornia.enhance.adjust_saturation(image, factor)#

Adjust color saturation of an image.

_images/adjust_saturation.png

The image is expected to be an RGB image in the range of [0, 1].

Parameters:
  • image (Tensor) – Image/Tensor to be adjusted in the shape of \((*, 3, H, W)\).

  • factor (Union[float, Tensor]) – How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2.

  • saturation_mode – The mode to adjust saturation.

Return type:

Tensor

Returns:

Adjusted image in the shape of \((*, 3, H, W)\).

Note

See a working example here.

Example

>>> x = torch.ones(1, 3, 3, 3)
>>> adjust_saturation(x, 2.).shape
torch.Size([1, 3, 3, 3])
>>> x = torch.ones(2, 3, 3, 3)
>>> y = torch.tensor([1., 2.])
>>> adjust_saturation(x, y).shape
torch.Size([2, 3, 3, 3])
kornia.enhance.adjust_sigmoid(image, cutoff=0.5, gain=10, inv=False)#

Adjust sigmoid correction on the input image tensor.

The input image is expected to be in the range of [0, 1].

Reference:
[1]: Gustav J. Braun, “Image Lightness Rescaling Using Sigmoidal Contrast Enhancement Functions”,

http://markfairchild.org/PDFs/PAP07.pdf

Parameters:
  • image (Tensor) – Image to be adjusted in the shape of \((*, H, W)\).

  • cutoff (float, optional) – The cutoff of sigmoid function. Default: 0.5

  • gain (float, optional) – The multiplier of sigmoid function. Default: 10

  • inv (bool, optional) – If is set to True the function will return the inverse sigmoid correction. Default: False

Return type:

Tensor

Returns:

Adjusted tensor in the shape of \((*, H, W)\).

Example

>>> x = torch.ones(1, 1, 2, 2)
>>> adjust_sigmoid(x, gain=0)
tensor([[[[0.5000, 0.5000],
          [0.5000, 0.5000]]]])
kornia.enhance.adjust_log(image, gain=1, inv=False, clip_output=True)#

Adjust log correction on the input image tensor.

The input image is expected to be in the range of [0, 1].

Reference: [1]: http://www.ece.ucsb.edu/Faculty/Manjunath/courses/ece178W03/EnhancePart1.pdf

Parameters:
  • image (Tensor) – Image to be adjusted in the shape of \((*, H, W)\).

  • gain (float, optional) – The multiplier of logarithmic function. Default: 1

  • inv (bool, optional) – If is set to True the function will return the inverse logarithmic correction. Default: False

  • clip_output (bool, optional) – Whether to clip the output image with range of [0, 1]. Default: True

Return type:

Tensor

Returns:

Adjusted tensor in the shape of \((*, H, W)\).

Example

>>> x = torch.zeros(1, 1, 2, 2)
>>> adjust_log(x, inv=True)
tensor([[[[0., 0.],
          [0., 0.]]]])
kornia.enhance.invert(image, max_val=Tensor([1.0]))#

Invert the values of an input image tensor by its maximum value.

_images/invert.png
Parameters:
  • image (Tensor) – The input tensor to invert with an arbitatry shape.

  • max_val (Tensor, optional) – The expected maximum value in the input tensor. The shape has to according to the input tensor shape, or at least has to work with broadcasting. Default: Tensor([1.0])

Return type:

Tensor

Example

>>> img = torch.rand(1, 2, 4, 4)
>>> invert(img).shape
torch.Size([1, 2, 4, 4])
>>> img = 255. * torch.rand(1, 2, 3, 4, 4)
>>> invert(img, torch.as_tensor(255.)).shape
torch.Size([1, 2, 3, 4, 4])
>>> img = torch.rand(1, 3, 4, 4)
>>> invert(img, torch.as_tensor([[[[1.]]]])).shape
torch.Size([1, 3, 4, 4])
kornia.enhance.posterize(input, bits)#

Reduce the number of bits for each color channel.

_images/posterize.png

Non-differentiable function, torch.uint8 involved.

Parameters:
  • input (Tensor) – image tensor with shape \((*, C, H, W)\) to posterize.

  • bits (Union[int, Tensor]) – number of high bits. Must be in range [0, 8]. If int or one element tensor, input will be posterized by this bits. If 1-d tensor, input will be posterized element-wisely, len(bits) == input.shape[-3]. If n-d tensor, input will be posterized element-channel-wisely, bits.shape == input.shape[:len(bits.shape)]

Return type:

Tensor

Returns:

Image with reduced color channels with shape \((*, C, H, W)\).

Example

>>> x = torch.rand(1, 6, 3, 3)
>>> out = posterize(x, bits=8)
>>> torch.testing.assert_close(x, out)
>>> x = torch.rand(2, 6, 3, 3)
>>> bits = torch.tensor([4, 2])
>>> posterize(x, bits).shape
torch.Size([2, 6, 3, 3])
kornia.enhance.sharpness(input, factor)#

Apply sharpness to the input tensor.

_images/sharpness.png

Implemented Sharpness function from PIL using torch ops. This implementation refers to: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L326

Parameters:
  • input (Tensor) – image tensor with shape \((*, C, H, W)\) to sharpen.

  • factor (Union[float, Tensor]) – factor of sharpness strength. Must be above 0. If float or one element tensor, input will be sharpened by the same factor across the whole batch. If 1-d tensor, input will be sharpened element-wisely, len(factor) == len(input).

Return type:

Tensor

Returns:

Sharpened image or images with shape \((*, C, H, W)\).

Example

>>> x = torch.rand(1, 1, 5, 5)
>>> sharpness(x, 0.5).shape
torch.Size([1, 1, 5, 5])
kornia.enhance.solarize(input, thresholds=0.5, additions=None)#

For each pixel in the image less than threshold.

_images/solarize.png

We add ‘addition’ amount to it and then clip the pixel value to be between 0 and 1.0. The value of ‘addition’ is between -0.5 and 0.5.

Parameters:
  • input (Tensor) – image tensor with shapes like \((*, C, H, W)\) to solarize.

  • thresholds (Union[float, Tensor], optional) – solarize thresholds. If int or one element tensor, input will be solarized across the whole batch. If 1-d tensor, input will be solarized element-wise, len(thresholds) == len(input). Default: 0.5

  • additions (Union[Tensor, float, None], optional) – between -0.5 and 0.5. If None, no addition will be performed. If int or one element tensor, same addition will be added across the whole batch. If 1-d tensor, additions will be added element-wisely, len(additions) == len(input). Default: None

Return type:

Tensor

Returns:

The solarized images with shape \((*, C, H, W)\).

Example

>>> x = torch.rand(1, 4, 3, 3)
>>> out = solarize(x, thresholds=0.5, additions=0.)
>>> out.shape
torch.Size([1, 4, 3, 3])
>>> x = torch.rand(2, 4, 3, 3)
>>> thresholds = torch.tensor([0.8, 0.5])
>>> additions = torch.tensor([-0.25, 0.25])
>>> solarize(x, thresholds, additions).shape
torch.Size([2, 4, 3, 3])

Interactive Demo#

Visit the demo on Hugging Face Spaces.

Equalization#

kornia.enhance.equalize(input)#

Apply equalize on the input tensor.

_images/equalize.png

Implements Equalize function from PIL using PyTorch ops based on uint8 format: https://github.com/tensorflow/tpu/blob/5f71c12a020403f863434e96982a840578fdd127/models/official/efficientnet/autoaugment.py#L355

Parameters:

input (Tensor) – image tensor to equalize with shape \((*, C, H, W)\).

Return type:

Tensor

Returns:

Equalized image tensor with shape \((*, C, H, W)\).

Example

>>> x = torch.rand(1, 2, 3, 3)
>>> equalize(x).shape
torch.Size([1, 2, 3, 3])
kornia.enhance.equalize_clahe(input, clip_limit=40.0, grid_size=(8, 8), slow_and_differentiable=False)#

Apply clahe equalization on the input tensor.

_images/equalize_clahe.png

NOTE: Lut computation uses the same approach as in OpenCV, in next versions this can change.

Parameters:
  • input (Tensor) – images tensor to equalize with values in the range [0, 1] and shape \((*, C, H, W)\).

  • clip_limit (float, optional) – threshold value for contrast limiting. If 0 clipping is disabled. Default: 40.0

  • grid_size (Tuple[int, int], optional) – number of tiles to be cropped in each direction (GH, GW). Default: (8, 8)

  • slow_and_differentiable (bool, optional) – flag to select implementation Default: False

Return type:

Tensor

Returns:

Equalized image or images with shape as the input.

Examples

>>> img = torch.rand(1, 10, 20)
>>> res = equalize_clahe(img)
>>> res.shape
torch.Size([1, 10, 20])
>>> img = torch.rand(2, 3, 10, 20)
>>> res = equalize_clahe(img)
>>> res.shape
torch.Size([2, 3, 10, 20])
kornia.enhance.equalize3d(input)#

Equalize the values for a 3D volumetric tensor.

Implements Equalize function for a sequence of images using PyTorch ops based on uint8 format: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L352

Parameters:

input (Tensor) – image tensor with shape \((*, C, D, H, W)\) to equalize.

Return type:

Tensor

Returns:

Equalized volume with shape \((B, C, D, H, W)\).

kornia.enhance.histogram(x, bins, bandwidth, epsilon=1e-10)#

Estimate the histogram of the input tensor.

The calculation uses kernel density estimation which requires a bandwidth (smoothing) parameter.

Parameters:
  • x (Tensor) – Input tensor to compute the histogram with shape \((B, D)\).

  • bins (Tensor) – The number of bins to use the histogram \((N_{bins})\).

  • bandwidth (Tensor) – Gaussian smoothing factor with shape shape [1].

  • epsilon (float, optional) – A scalar, for numerical stability. Default: 1e-10

Return type:

Tensor

Returns:

Computed histogram of shape \((B, N_{bins})\).

Examples

>>> x = torch.rand(1, 10)
>>> bins = torch.torch.linspace(0, 255, 128)
>>> hist = histogram(x, bins, bandwidth=torch.tensor(0.9))
>>> hist.shape
torch.Size([1, 128])
kornia.enhance.histogram2d(x1, x2, bins, bandwidth, epsilon=1e-10)#

Estimate the 2d histogram of the input tensor.

The calculation uses kernel density estimation which requires a bandwidth (smoothing) parameter.

Parameters:
  • x1 (Tensor) – Input tensor to compute the histogram with shape \((B, D1)\).

  • x2 (Tensor) – Input tensor to compute the histogram with shape \((B, D2)\).

  • bins (Tensor) – The number of bins to use the histogram \((N_{bins})\).

  • bandwidth (Tensor) – Gaussian smoothing factor with shape shape [1].

  • epsilon (float, optional) – A scalar, for numerical stability. Default: 1e-10.

Return type:

Tensor

Returns:

Computed histogram of shape \((B, N_{bins}), N_{bins})\).

Examples

>>> x1 = torch.rand(2, 32)
>>> x2 = torch.rand(2, 32)
>>> bins = torch.torch.linspace(0, 255, 128)
>>> hist = histogram2d(x1, x2, bins, bandwidth=torch.tensor(0.9))
>>> hist.shape
torch.Size([2, 128, 128])
kornia.enhance.image_histogram2d(image, min=0.0, max=255.0, n_bins=256, bandwidth=None, centers=None, return_pdf=False, kernel='triangular', eps=1e-10)#

Estimate the histogram of the input image(s).

The calculation uses triangular kernel density estimation.

Parameters:
  • image (Tensor) – Input tensor to compute the histogram with shape \((H, W)\), \((C, H, W)\) or \((B, C, H, W)\).

  • min (float, optional) – Lower end of the interval (inclusive). Default: 0.0

  • max (float, optional) – Upper end of the interval (inclusive). Ignored when centers is specified. Default: 255.0

  • n_bins (int, optional) – The number of histogram bins. Ignored when centers is specified. Default: 256

  • bandwidth (Optional[float], optional) – Smoothing factor. If not specified or equal to -1, \((bandwidth = (max - min) / n_bins)\). Default: None

  • centers (Optional[Tensor], optional) – Centers of the bins with shape \((n_bins,)\). If not specified or empty, it is calculated as centers of equal width bins of [min, max] range. Default: None

  • return_pdf (bool, optional) – If True, also return probability densities for each bin. Default: False

  • kernel (str, optional) – kernel to perform kernel density estimation (`triangular`, `gaussian`, `uniform`, `epanechnikov`). Default: "triangular"

Return type:

Tuple[Tensor, Tensor]

Returns:

Computed histogram of shape \((bins)\), \((C, bins)\),

\((B, C, bins)\).

Computed probability densities of shape \((bins)\), \((C, bins)\),

\((B, C, bins)\), if return_pdf is True. Tensor of zeros with shape of the histogram otherwise.

Normalizations#

kornia.enhance.normalize(data, mean, std)#

Normalize an image/video tensor with mean and standard deviation.

\[\text{input[channel] = (input[channel] - mean[channel]) / std[channel]}\]

Where mean is \((M_1, ..., M_n)\) and std \((S_1, ..., S_n)\) for n channels,

Parameters:
  • data (Tensor) – Image tensor of size \((B, C, *)\).

  • mean (Tensor) – Mean for each channel.

  • std (Tensor) – Standard deviations for each channel.

Return type:

Tensor

Returns:

Normalised tensor with same size as input \((B, C, *)\).

Examples

>>> x = torch.rand(1, 4, 3, 3)
>>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.]))
>>> out.shape
torch.Size([1, 4, 3, 3])
>>> x = torch.rand(1, 4, 3, 3)
>>> mean = torch.zeros(4)
>>> std = 255. * torch.ones(4)
>>> out = normalize(x, mean, std)
>>> out.shape
torch.Size([1, 4, 3, 3])
kornia.enhance.normalize_min_max(x, min_val=0.0, max_val=1.0, eps=1e-6)#

Normalise an image/video tensor by MinMax and re-scales the value between a range.

The data is normalised using the following formulation:

\[y_i = (b - a) * \frac{x_i - \text{min}(x)}{\text{max}(x) - \text{min}(x)} + a\]

where \(a\) is \(\text{min_val}\) and \(b\) is \(\text{max_val}\).

Parameters:
  • x (Tensor) – The image tensor to be normalised with shape \((B, C, *)\).

  • min_val (float, optional) – The minimum value for the new range. Default: 0.0

  • max_val (float, optional) – The maximum value for the new range. Default: 1.0

  • eps (float, optional) – Float number to avoid zero division. Default: 1e-6

Return type:

Tensor

Returns:

The normalised image tensor with same shape as input \((B, C, *)\).

Example

>>> x = torch.rand(1, 5, 3, 3)
>>> x_norm = normalize_min_max(x, min_val=-1., max_val=1.)
>>> x_norm.min()
tensor(-1.)
>>> x_norm.max()
tensor(1.0000)
kornia.enhance.denormalize(data, mean, std)#

Denormalize an image/video tensor with mean and standard deviation.

\[\text{input[channel] = (input[channel] * std[channel]) + mean[channel]}\]

Where mean is \((M_1, ..., M_n)\) and std \((S_1, ..., S_n)\) for n channels,

Parameters:
  • input – Image tensor of size \((B, C, *)\).

  • mean (Union[Tensor, float]) – Mean for each channel.

  • std (Union[Tensor, float]) – Standard deviations for each channel.

Return type:

Tensor

Returns:

Denormalised tensor with same size as input \((B, C, *)\).

Examples

>>> x = torch.rand(1, 4, 3, 3)
>>> out = denormalize(x, 0.0, 255.)
>>> out.shape
torch.Size([1, 4, 3, 3])
>>> x = torch.rand(1, 4, 3, 3, 3)
>>> mean = torch.zeros(1, 4)
>>> std = 255. * torch.ones(1, 4)
>>> out = denormalize(x, mean, std)
>>> out.shape
torch.Size([1, 4, 3, 3, 3])
kornia.enhance.zca_mean(inp, dim=0, unbiased=True, eps=1e-6, return_inverse=False)#

Compute the ZCA whitening matrix and mean vector.

The output can be used with linear_transform(). See ZCAWhitening for details.

Parameters:
  • inp (Tensor) – input data tensor.

  • dim (int, optional) – Specifies the dimension that serves as the samples dimension. Default: 0

  • unbiased (bool, optional) – Whether to use the unbiased estimate of the covariance matrix. Default: True

  • eps (float, optional) – a small number used for numerical stability. Default: 1e-6

  • return_inverse (bool, optional) – Whether to return the inverse ZCA transform. Default: False

Shapes:
  • inp: \((D_0,...,D_{\text{dim}},...,D_N)\) is a batch of N-D tensors.

  • transform_matrix: \((\Pi_{d=0,d\neq \text{dim}}^N D_d, \Pi_{d=0,d\neq \text{dim}}^N D_d)\)

  • mean_vector: \((1, \Pi_{d=0,d\neq \text{dim}}^N D_d)\)

  • inv_transform: same shape as the transform matrix

Return type:

Tuple[Tensor, Tensor, Optional[Tensor]]

Returns:

A tuple containing the ZCA matrix and the mean vector. If return_inverse is set to True, then it returns the inverse ZCA matrix, otherwise it returns None.

Note

See a working example here.

Examples

>>> x = torch.tensor([[0,1],[1,0],[-1,0],[0,-1]], dtype = torch.float32)
>>> transform_matrix, mean_vector,_ = zca_mean(x) # Returns transformation matrix and data mean
>>> x = torch.rand(3,20,2,2)
>>> transform_matrix, mean_vector, inv_transform = zca_mean(x, dim = 1, return_inverse = True)
>>> # transform_matrix.size() equals (12,12) and the mean vector.size equal (1,12)
kornia.enhance.zca_whiten(inp, dim=0, unbiased=True, eps=1e-6)#

Apply ZCA whitening transform.

See ZCAWhitening for details.

Parameters:
  • inp (Tensor) – input data tensor.

  • dim (int, optional) – Specifies the dimension that serves as the samples dimension. Default: 0

  • unbiased (bool, optional) – Whether to use the unbiased estimate of the covariance matrix. Default: True

  • eps (float, optional) – a small number used for numerical stability. Default: 1e-6

Return type:

Tensor

Returns:

Whiten Input data.

Note

See a working example here.

Examples

>>> x = torch.tensor([[0,1],[1,0],[-1,0]], dtype = torch.float32)
>>> zca_whiten(x)
tensor([[ 0.0000,  1.1547],
        [ 1.0000, -0.5773],
        [-1.0000, -0.5773]])
kornia.enhance.linear_transform(inp, transform_matrix, mean_vector, dim=0)#

Given a transformation matrix and a mean vector, this function will flatten the input tensor along the given dimension and subtract the mean vector from it. Then the dot product with the transformation matrix will be computed and then the resulting tensor is reshaped to the original input shape.

\[\mathbf{X}_{T} = (\mathbf{X - \mu})(T)\]
Parameters:
  • inp (Tensor) – Input data \(X\).

  • transform_matrix (Tensor) – Transform matrix \(T\).

  • mean_vector (Tensor) – mean vector \(\mu\).

  • dim (int, optional) – Batch dimension. Default: 0

Shapes:
  • inp: \((D_0,...,D_{\text{dim}},...,D_N)\) is a batch of N-D tensors.

  • transform_matrix: \((\Pi_{d=0,d\neq \text{dim}}^N D_d, \Pi_{d=0,d\neq \text{dim}}^N D_d)\)

  • mean_vector: \((1, \Pi_{d=0,d\neq \text{dim}}^N D_d)\)

Return type:

Tensor

Returns:

Transformed data.

Example

>>> # Example where dim = 3
>>> inp = torch.ones((10,3,4,5))
>>> transform_mat = torch.ones((10*3*4,10*3*4))
>>> mean = 2*torch.ones((1,10*3*4))
>>> out = linear_transform(inp, transform_mat, mean, 3)
>>> print(out.shape, out.unique())  # Should a be (10,3,4,5) tensor of -120s
torch.Size([10, 3, 4, 5]) tensor([-120.])
>>> # Example where dim = 0
>>> inp = torch.ones((10,2))
>>> transform_mat = torch.ones((2,2))
>>> mean = torch.zeros((1,2))
>>> out = linear_transform(inp, transform_mat, mean)
>>> print(out.shape, out.unique()) # Should a be (10,2) tensor of 2s
torch.Size([10, 2]) tensor([2.])

Codec#

kornia.enhance.jpeg_codec_differentiable(image_rgb, jpeg_quality, quantization_table_y=None, quantization_table_c=None)#

Differentiable JPEG encoding-decoding module.

Based on [RDPC24] [SS17], we perform differentiable JPEG encoding-decoding as follows:

_images/jpeg_codec_differentiable.png
\[\text{JPEG}_{\text{diff}}(I, q, QT_{y}, QT_{c}) = \hat{I}\]
Where:
  • \(I\) is the original image to be coded.

  • \(q\) is the JPEG quality controlling the compression strength.

  • \(QT_{y}\) is the luma quantization table.

  • \(QT_{c}\) is the chroma quantization table.

  • \(\hat{I}\) is the resulting JPEG encoded-decoded image.

Parameters:
  • image_rgb (Tensor) – the RGB image to be coded.

  • jpeg_quality (Tensor) – JPEG quality in the range \([0, 100]\) controlling the compression strength.

  • quantization_table_y (Tensor | None, optional) – quantization table for Y channel. Default: None, which will load the standard quantization table.

  • quantization_table_c (Tensor | None, optional) – quantization table for C channels. Default: None, which will load the standard quantization table.

Shape:
  • image_rgb: \((*, 3, H, W)\).

  • jpeg_quality: \((1)\) or \((B)\) (if used batch dim. needs to match w/ image_rgb).

  • quantization_table_y: \((8, 8)\) or \((B, 8, 8)\) (if used batch dim. needs to match w/ image_rgb).

  • quantization_table_c: \((8, 8)\) or \((B, 8, 8)\) (if used batch dim. needs to match w/ image_rgb).

Return type:

Tensor

Returns:

JPEG coded image of the shape \((B, 3, H, W)\)

Example

To perform JPEG coding with the standard quantization tables just provide a JPEG quality

>>> img = torch.rand(3, 3, 64, 64, requires_grad=True, dtype=torch.float)
>>> jpeg_quality = torch.tensor((99.0, 25.0, 1.0), requires_grad=True)
>>> img_jpeg = jpeg_codec_differentiable(img, jpeg_quality)
>>> img_jpeg.sum().backward()

You also have the option to provide custom quantization tables

>>> img = torch.rand(3, 3, 64, 64, requires_grad=True, dtype=torch.float)
>>> jpeg_quality = torch.tensor((99.0, 25.0, 1.0), requires_grad=True)
>>> quantization_table_y = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
>>> quantization_table_c = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
>>> img_jpeg = jpeg_codec_differentiable(img, jpeg_quality, quantization_table_y, quantization_table_c)
>>> img_jpeg.sum().backward()

In case you want to control the quantization purly base on the quantization tables use a JPEG quality of 99.5. Setting the JPEG quality to 99.5 leads to a QT scaling of 1, see Eq. 2 of [RDPC24] for details.

>>> img = torch.rand(3, 3, 64, 64, requires_grad=True, dtype=torch.float)
>>> jpeg_quality = torch.ones(3) * 99.5
>>> quantization_table_y = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
>>> quantization_table_c = torch.randint(1, 256, size=(3, 8, 8), dtype=torch.float)
>>> img_jpeg = jpeg_codec_differentiable(img, jpeg_quality, quantization_table_y, quantization_table_c)
>>> img_jpeg.sum().backward()

Modules#

class kornia.enhance.Normalize(mean, std)#

Normalize a tensor image with mean and standard deviation.

\[\text{input[channel] = (input[channel] - mean[channel]) / std[channel]}\]

Where mean is \((M_1, ..., M_n)\) and std \((S_1, ..., S_n)\) for n channels,

Parameters:
Shape:
  • Input: Image tensor of size \((*, C, ...)\).

  • Output: Normalised tensor with same size as input \((*, C, ...)\).

Examples

>>> x = torch.rand(1, 4, 3, 3)
>>> out = Normalize(0.0, 255.)(x)
>>> out.shape
torch.Size([1, 4, 3, 3])
>>> x = torch.rand(1, 4, 3, 3)
>>> mean = torch.zeros(4)
>>> std = 255. * torch.ones(4)
>>> out = Normalize(mean, std)(x)
>>> out.shape
torch.Size([1, 4, 3, 3])
class kornia.enhance.Denormalize(mean, std)#

Denormalize a tensor image with mean and standard deviation.

\[\text{input[channel] = (input[channel] * std[channel]) + mean[channel]}\]

Where mean is \((M_1, ..., M_n)\) and std \((S_1, ..., S_n)\) for n channels,

Parameters:
Shape:
  • Input: Image tensor of size \((*, C, ...)\).

  • Output: Denormalised tensor with same size as input \((*, C, ...)\).

Examples

>>> x = torch.rand(1, 4, 3, 3)
>>> out = Denormalize(0.0, 255.)(x)
>>> out.shape
torch.Size([1, 4, 3, 3])
>>> x = torch.rand(1, 4, 3, 3, 3)
>>> mean = torch.zeros(1, 4)
>>> std = 255. * torch.ones(1, 4)
>>> out = Denormalize(mean, std)(x)
>>> out.shape
torch.Size([1, 4, 3, 3, 3])
class kornia.enhance.ZCAWhitening(dim=0, eps=1e-6, unbiased=True, detach_transforms=True, compute_inv=False)#

Compute the ZCA whitening matrix transform and the mean vector and applies the transform to the data.

The data tensor is flattened, and the mean \(\mathbf{\mu}\) and covariance matrix \(\mathbf{\Sigma}\) are computed from the flattened data \(\mathbf{X} \in \mathbb{R}^{N \times D}\), where \(N\) is the sample size and \(D\) is flattened dimensionality (e.g. for a tensor with size 5x3x2x2 \(N = 5\) and \(D = 12\)). The ZCA whitening transform is given by:

\[\mathbf{X}_{\text{zca}} = (\mathbf{X - \mu})(US^{-\frac{1}{2}}U^T)^T\]

where \(U\) are the eigenvectors of \(\Sigma\) and \(S\) contain the corresponding eigenvalues of \(\Sigma\). After the transform is applied, the output is reshaped to same shape.

Parameters:
  • dim (int, optional) – Determines the dimension that represents the samples axis. Default: 0

  • eps (float, optional) – a small number used for numerical stability. Default: 1e-6

  • unbiased (bool, optional) – Whether to use the biased estimate of the covariance matrix. Default: True

  • compute_inv (bool, optional) – Compute the inverse transform matrix. Default: False

  • detach_transforms (bool, optional) – Detaches gradient from the ZCA fitting. Default: True

shape:
  • x: \((D_0,...,D_{\text{dim}},...,D_N)\) is a batch of N-D tensors.

  • x_whiten: \((D_0,...,D_{\text{dim}},...,D_N)\) same shape as input.

Note

See a working example here.

Examples

>>> x = torch.tensor([[0,1],[1,0],[-1,0],[0,-1]], dtype = torch.float32)
>>> zca = ZCAWhitening().fit(x)
>>> x_whiten = zca(x)
>>> zca = ZCAWhitening()
>>> x_whiten = zca(x, include_fit = True) # Includes the fitting step
>>> x_whiten = zca(x) # Can run now without the fitting set
>>> # Enable backprop through ZCA fitting process
>>> zca = ZCAWhitening(detach_transforms = False)
>>> x_whiten = zca(x, include_fit = True) # Includes the fitting step

Note

This implementation uses svd() which yields NaNs in the backwards step if the singular values are not unique. See here for more information.

References

[1] Stanford PCA & ZCA whitening tutorial

fit(x)#

Fit ZCA whitening matrices to the data.

Parameters:

x (Tensor) – Input data.

Return type:

ZCAWhitening

Returns:

Returns a fitted ZCAWhiten object instance.

forward(x, include_fit=False)#

Apply the whitening transform to the data.

Parameters:
  • x (Tensor) – Input data.

  • include_fit (bool, optional) – Indicates whether to fit the data as part of the forward pass. Default: False

Return type:

Tensor

Returns:

The transformed data.

inverse_transform(x)#

Apply the inverse transform to the whitened data.

Parameters:

x (Tensor) – Whitened data.

Return type:

Tensor

Returns:

Original data.

class kornia.enhance.AdjustBrightness(brightness_factor)#

Adjust Brightness of an image.

This implementation aligns OpenCV, not PIL. Hence, the output differs from TorchVision. The input image is expected to be in the range of [0, 1].

Parameters:

brightness_factor (Union[float, Tensor]) – Brightness adjust factor per element in the batch. 0 does not modify the input image while any other number modify the brightness.

Shape:
  • Input: Image/Input to be adjusted in the shape of \((*, N)\).

  • Output: Adjusted image in the shape of \((*, N)\).

Example

>>> x = torch.ones(1, 1, 3, 3)
>>> AdjustBrightness(1.)(x)
tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]])
>>> x = torch.ones(2, 5, 3, 3)
>>> y = torch.ones(2)
>>> AdjustBrightness(y)(x).shape
torch.Size([2, 5, 3, 3])
class kornia.enhance.AdjustContrast(contrast_factor)#

Adjust Contrast of an image.

This implementation aligns OpenCV, not PIL. Hence, the output differs from TorchVision. The input image is expected to be in the range of [0, 1].

Parameters:

contrast_factor (Union[float, Tensor]) – Contrast adjust factor per element in the batch. 0 generates a completely black image, 1 does not modify the input image while any other non-negative number modify the brightness by this factor.

Shape:
  • Input: Image/Input to be adjusted in the shape of \((*, N)\).

  • Output: Adjusted image in the shape of \((*, N)\).

Example

>>> x = torch.ones(1, 1, 3, 3)
>>> AdjustContrast(0.5)(x)
tensor([[[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]]])
>>> x = torch.ones(2, 5, 3, 3)
>>> y = torch.ones(2)
>>> AdjustContrast(y)(x).shape
torch.Size([2, 5, 3, 3])
class kornia.enhance.AdjustSaturation(saturation_factor)#

Adjust color saturation of an image.

The input image is expected to be an RGB image in the range of [0, 1].

Parameters:
  • saturation_factor (Union[float, Tensor]) – How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2.

  • saturation_mode – The mode to adjust saturation.

Shape:
  • Input: Image/Tensor to be adjusted in the shape of \((*, 3, H, W)\).

  • Output: Adjusted image in the shape of \((*, 3, H, W)\).

Example

>>> x = torch.ones(1, 3, 3, 3)
>>> AdjustSaturation(2.)(x)
tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]])
>>> x = torch.ones(2, 3, 3, 3)
>>> y = torch.ones(2)
>>> out = AdjustSaturation(y)(x)
>>> torch.nn.functional.mse_loss(x, out)
tensor(0.)
class kornia.enhance.AdjustHue(hue_factor)#

Adjust hue of an image.

This implementation aligns PIL. Hence, the output is close to TorchVision. The input image is expected to be in the range of [0, 1].

The input image is expected to be an RGB image in the range of [0, 1].

Parameters:

hue_factor (Union[float, Tensor]) – How much to shift the hue channel. Should be in [-PI, PI]. PI and -PI give complete reversal of hue channel in HSV space in positive and negative direction respectively. 0 means no shift. Therefore, both -PI and PI will give an image with complementary colors while 0 gives the original image.

Shape:
  • Input: Image/Tensor to be adjusted in the shape of \((*, 3, H, W)\).

  • Output: Adjusted image in the shape of \((*, 3, H, W)\).

Example

>>> x = torch.ones(1, 3, 3, 3)
>>> AdjustHue(3.141516)(x)
tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]])
>>> x = torch.ones(2, 3, 3, 3)
>>> y = torch.ones(2) * 3.141516
>>> AdjustHue(y)(x).shape
torch.Size([2, 3, 3, 3])
class kornia.enhance.AdjustGamma(gamma, gain=1.0)#

Perform gamma correction on an image.

The input image is expected to be in the range of [0, 1].

Parameters:
  • gamma (Union[float, Tensor]) – Non negative real number, same as ygammay in the equation. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter.

  • gain (Union[float, Tensor], optional) – The constant multiplier. Default: 1.0

Shape:
  • Input: Image to be adjusted in the shape of \((*, N)\).

  • Output: Adjusted image in the shape of \((*, N)\).

Example

>>> x = torch.ones(1, 1, 3, 3)
>>> AdjustGamma(1.0, 2.0)(x)
tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]])
>>> x = torch.ones(2, 5, 3, 3)
>>> y1 = torch.ones(2) * 1.0
>>> y2 = torch.ones(2) * 2.0
>>> AdjustGamma(y1, y2)(x).shape
torch.Size([2, 5, 3, 3])
class kornia.enhance.AdjustSigmoid(cutoff=0.5, gain=10, inv=False)#

Adjust the contrast of an image tensor or performs sigmoid correction on the input image tensor.

The input image is expected to be in the range of [0, 1].

Reference:
[1]: Gustav J. Braun, “Image Lightness Rescaling Using Sigmoidal Contrast Enhancement Functions”,

http://markfairchild.org/PDFs/PAP07.pdf

Parameters:
  • image – Image to be adjusted in the shape of \((*, H, W)\).

  • cutoff (float, optional) – The cutoff of sigmoid function. Default: 0.5

  • gain (float, optional) – The multiplier of sigmoid function. Default: 10

  • inv (bool, optional) – If is set to True the function will return the negative sigmoid correction. Default: False

Example

>>> x = torch.ones(1, 1, 2, 2)
>>> AdjustSigmoid(gain=0)(x)
tensor([[[[0.5000, 0.5000],
          [0.5000, 0.5000]]]])
class kornia.enhance.AdjustLog(gain=1, inv=False, clip_output=True)#

Adjust log correction on the input image tensor.

The input image is expected to be in the range of [0, 1].

Reference: [1]: http://www.ece.ucsb.edu/Faculty/Manjunath/courses/ece178W03/EnhancePart1.pdf

Parameters:
  • image – Image to be adjusted in the shape of \((*, H, W)\).

  • gain (float, optional) – The multiplier of logarithmic function. Default: 1

  • inv (bool, optional) – If is set to True the function will return the inverse logarithmic correction. Default: False

  • clip_output (bool, optional) – Whether to clip the output image with range of [0, 1]. Default: True

Example

>>> x = torch.zeros(1, 1, 2, 2)
>>> AdjustLog(inv=True)(x)
tensor([[[[0., 0.],
          [0., 0.]]]])
class kornia.enhance.AddWeighted(alpha, beta, gamma)#

Calculate the weighted sum of two Tensors.

The function calculates the weighted sum of two Tensors as follows:

\[out = src1 * alpha + src2 * beta + gamma\]
Parameters:
  • alpha (Union[float, Tensor]) – weight of the src1 elements as Union[float, Tensor].

  • beta (Union[float, Tensor]) – weight of the src2 elements as Union[float, Tensor].

  • gamma (Union[float, Tensor]) – scalar added to each sum as Union[float, Tensor].

Shape:
  • Input1: Tensor with an arbitrary shape, equal to shape of Input2.

  • Input2: Tensor with an arbitrary shape, equal to shape of Input1.

  • Output: Weighted tensor with shape equal to src1 and src2 shapes.

Example

>>> input1 = torch.rand(1, 1, 5, 5)
>>> input2 = torch.rand(1, 1, 5, 5)
>>> output = AddWeighted(0.5, 0.5, 1.0)(input1, input2)
>>> output.shape
torch.Size([1, 1, 5, 5])

Notes

Tensor alpha/beta/gamma have to be with shape broadcastable to src1 and src2 shapes.

class kornia.enhance.Invert(max_val=torch.tensor(1.0))#

Invert the values of an input tensor by its maximum value.

Parameters:
  • input – The input tensor to invert with an arbitatry shape.

  • max_val (Tensor, optional) – The expected maximum value in the input tensor. The shape has to according to the input tensor shape, or at least has to work with broadcasting. Default: 1.0.

Example

>>> img = torch.rand(1, 2, 4, 4)
>>> Invert()(img).shape
torch.Size([1, 2, 4, 4])
>>> img = 255. * torch.rand(1, 2, 3, 4, 4)
>>> Invert(torch.as_tensor(255.))(img).shape
torch.Size([1, 2, 3, 4, 4])
>>> img = torch.rand(1, 3, 4, 4)
>>> Invert(torch.as_tensor([[[[1.]]]]))(img).shape
torch.Size([1, 3, 4, 4])
class kornia.enhance.JPEGCodecDifferentiable(quantization_table_y=None, quantization_table_c=None)#

Differentiable JPEG encoding-decoding module.

Based on [RDPC24] [SS17], we perform differentiable JPEG encoding-decoding as follows:

\[\text{JPEG}_{\text{diff}}(I, q, QT_{y}, QT_{c}) = \hat{I}\]
Where:
  • \(I\) is the original image to be coded.

  • \(q\) is the JPEG quality controlling the compression strength.

  • \(QT_{y}\) is the luma quantization table.

  • \(QT_{c}\) is the chroma quantization table.

  • \(\hat{I}\) is the resulting JPEG encoded-decoded image.

_images/jpeg_codec_differentiable.png

Note

The input (and output) pixel range is \([0, 1]\). In case you want to handle normalized images you are required to first perform denormalization followed by normalizing the output images again.

Note, that this implementation models the encoding-decoding mapping of JPEG in a differentiable setting, however, does not allow the excess of the JPEG-coded byte file itself. For more details please refer to [RDPC24].

This implementation is not meant for data loading. For loading JPEG images please refer to kornia.io. There we provide an optimized Rust implementation for fast JPEG loading.

Parameters:
  • quantization_table_y (Tensor | Parameter | None, optional) – quantization table for Y channel. Default: None, which will load the standard quantization table.

  • quantization_table_c (Tensor | Parameter | None, optional) – quantization table for C channels. Default: None, which will load the standard quantization table.

Shape:
  • quantization_table_y: \((8, 8)\) or \((B, 8, 8)\) (if used batch dim. needs to match w/ image_rgb).

  • quantization_table_c: \((8, 8)\) or \((B, 8, 8)\) (if used batch dim. needs to match w/ image_rgb).

  • image_rgb: \((*, 3, H, W)\).

  • jpeg_quality: \((1)\) or \((B)\) (if used batch dim. needs to match w/ image_rgb).

Example

You can use the differentiable JPEG module with standard quantization tables by

>>> diff_jpeg_module = JPEGCodecDifferentiable()
>>> img = torch.rand(2, 3, 32, 32, requires_grad=True, dtype=torch.float)
>>> jpeg_quality = torch.tensor((99.0, 1.0), requires_grad=True)
>>> img_jpeg = diff_jpeg_module(img, jpeg_quality)
>>> img_jpeg.sum().backward()

You can also specify custom quantization tables to be used by

>>> quantization_table_y = torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float)
>>> quantization_table_c = torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float)
>>> diff_jpeg_module = JPEGCodecDifferentiable(quantization_table_y, quantization_table_c)
>>> img = torch.rand(2, 3, 32, 32, requires_grad=True, dtype=torch.float)
>>> jpeg_quality = torch.tensor((99.0, 1.0), requires_grad=True)
>>> img_jpeg = diff_jpeg_module(img, jpeg_quality)
>>> img_jpeg.sum().backward()

In case you want to learn the quantization tables just pass parameters nn.Parameter

>>> quantization_table_y = torch.nn.Parameter(torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float))
>>> quantization_table_c = torch.nn.Parameter(torch.randint(1, 256, size=(2, 8, 8), dtype=torch.float))
>>> diff_jpeg_module = JPEGCodecDifferentiable(quantization_table_y, quantization_table_c)
>>> img = torch.rand(2, 3, 32, 32, requires_grad=True, dtype=torch.float)
>>> jpeg_quality = torch.tensor((99.0, 1.0), requires_grad=True)
>>> img_jpeg = diff_jpeg_module(img, jpeg_quality)
>>> img_jpeg.sum().backward()

ZCA Whitening Interactive Demo#