from math import pi
from typing import Optional, Union
import torch
import torch.nn as nn
from kornia.color.hsv import hsv_to_rgb, rgb_to_hsv
from kornia.utils.helpers import _torch_histc_cast
from kornia.utils.image import perform_keep_shape_image, perform_keep_shape_video
def adjust_saturation_raw(input: torch.Tensor, saturation_factor: Union[float, torch.Tensor]) -> torch.Tensor:
r"""Adjust color saturation of an image. Expecting input to be in hsv format already."""
if not isinstance(input, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
if not isinstance(saturation_factor, (float, torch.Tensor)):
raise TypeError(
f"The saturation_factor should be a float number or torch.Tensor." f"Got {type(saturation_factor)}"
)
if isinstance(saturation_factor, float):
saturation_factor = torch.as_tensor(saturation_factor)
saturation_factor = saturation_factor.to(input.device).to(input.dtype)
# TODO: find a proper way to check bound values in batched tensors.
# if (saturation_factor < 0).any():
# raise ValueError(f"Saturation factor must be non-negative. Got {saturation_factor}")
for _ in range(len(input.shape) - len(saturation_factor.shape)):
saturation_factor = torch.unsqueeze(saturation_factor, dim=-1)
# unpack the hsv values
h, s, v = torch.chunk(input, chunks=3, dim=-3)
# transform the hue value and appl module
s_out: torch.Tensor = torch.clamp(s * saturation_factor, min=0, max=1)
# pack back back the corrected hue
out: torch.Tensor = torch.cat([h, s_out, v], dim=-3)
return out
[docs]def adjust_saturation(input: torch.Tensor, saturation_factor: Union[float, torch.Tensor]) -> torch.Tensor:
r"""Adjust color saturation of an image.
.. image:: _static/img/adjust_saturation.png
The input image is expected to be an RGB image in the range of [0, 1].
Args:
input: Image/Tensor to be adjusted in the shape of :math:`(*, 3, H, W)`.
saturation_factor: How much to adjust the saturation. 0 will give a black
and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2.
Return:
Adjusted image in the shape of :math:`(*, 3, H, W)`.
.. note::
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
image_enhancement.html>`__.
Example:
>>> x = torch.ones(1, 3, 3, 3)
>>> adjust_saturation(x, 2.).shape
torch.Size([1, 3, 3, 3])
>>> x = torch.ones(2, 3, 3, 3)
>>> y = torch.tensor([1., 2.])
>>> adjust_saturation(x, y).shape
torch.Size([2, 3, 3, 3])
"""
# convert the rgb image to hsv
x_hsv: torch.Tensor = rgb_to_hsv(input)
# perform the conversion
x_adjusted: torch.Tensor = adjust_saturation_raw(x_hsv, saturation_factor)
# convert back to rgb
out: torch.Tensor = hsv_to_rgb(x_adjusted)
return out
def adjust_hue_raw(input: torch.Tensor, hue_factor: Union[float, torch.Tensor]) -> torch.Tensor:
r"""Adjust hue of an image. Expecting input to be in hsv format already."""
if not isinstance(input, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
if not isinstance(hue_factor, (float, torch.Tensor)):
raise TypeError(
f"The hue_factor should be a float number or torch.Tensor in the range between"
f" [-PI, PI]. Got {type(hue_factor)}"
)
if isinstance(hue_factor, float):
hue_factor = torch.as_tensor(hue_factor)
hue_factor = hue_factor.to(input.device, input.dtype)
# TODO: find a proper way to check bound values in batched tensors.
# if ((hue_factor < -pi) | (hue_factor > pi)).any():
# raise ValueError(f"Hue-factor must be in the range [-PI, PI]. Got {hue_factor}")
for _ in range(len(input.shape) - len(hue_factor.shape)):
hue_factor = torch.unsqueeze(hue_factor, dim=-1)
# unpack the hsv values
h, s, v = torch.chunk(input, chunks=3, dim=-3)
# transform the hue value and appl module
divisor: float = 2 * pi
h_out: torch.Tensor = torch.fmod(h + hue_factor, divisor)
# pack back back the corrected hue
out: torch.Tensor = torch.cat([h_out, s, v], dim=-3)
return out
[docs]def adjust_hue(input: torch.Tensor, hue_factor: Union[float, torch.Tensor]) -> torch.Tensor:
r"""Adjust hue of an image.
.. image:: _static/img/adjust_hue.png
The input image is expected to be an RGB image in the range of [0, 1].
Args:
input: Image to be adjusted in the shape of :math:`(*, 3, H, W)`.
hue_factor: How much to shift the hue channel. Should be in [-PI, PI]. PI
and -PI give complete reversal of hue channel in HSV space in positive and negative
direction respectively. 0 means no shift. Therefore, both -PI and PI will give an
image with complementary colors while 0 gives the original image.
Return:
Adjusted image in the shape of :math:`(*, 3, H, W)`.
.. note::
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
image_enhancement.html>`__.
Example:
>>> x = torch.ones(1, 3, 2, 2)
>>> adjust_hue(x, 3.141516).shape
torch.Size([1, 3, 2, 2])
>>> x = torch.ones(2, 3, 3, 3)
>>> y = torch.ones(2) * 3.141516
>>> adjust_hue(x, y).shape
torch.Size([2, 3, 3, 3])
"""
# convert the rgb image to hsv
x_hsv: torch.Tensor = rgb_to_hsv(input)
# perform the conversion
x_adjusted: torch.Tensor = adjust_hue_raw(x_hsv, hue_factor)
# convert back to rgb
out: torch.Tensor = hsv_to_rgb(x_adjusted)
return out
[docs]def adjust_gamma(
input: torch.Tensor, gamma: Union[float, torch.Tensor], gain: Union[float, torch.Tensor] = 1.0
) -> torch.Tensor:
r"""Perform gamma correction on an image.
.. image:: _static/img/adjust_contrast.png
The input image is expected to be in the range of [0, 1].
Args:
input: Image to be adjusted in the shape of :math:`(*, H, W)`.
gamma: Non negative real number, same as γ\gammaγ in the equation.
gamma larger than 1 make the shadows darker, while gamma smaller than 1 make
dark regions lighter.
gain: The constant multiplier.
Return:
Adjusted image in the shape of :math:`(*, H, W)`.
.. note::
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
image_enhancement.html>`__.
Example:
>>> x = torch.ones(1, 1, 2, 2)
>>> adjust_gamma(x, 1.0, 2.0)
tensor([[[[1., 1.],
[1., 1.]]]])
>>> x = torch.ones(2, 5, 3, 3)
>>> y1 = torch.ones(2) * 1.0
>>> y2 = torch.ones(2) * 2.0
>>> adjust_gamma(x, y1, y2).shape
torch.Size([2, 5, 3, 3])
"""
if not isinstance(input, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
if not isinstance(gamma, (float, torch.Tensor)):
raise TypeError(f"The gamma should be a positive float or torch.Tensor. Got {type(gamma)}")
if not isinstance(gain, (float, torch.Tensor)):
raise TypeError(f"The gain should be a positive float or torch.Tensor. Got {type(gain)}")
if isinstance(gamma, float):
gamma = torch.tensor([gamma])
if isinstance(gain, float):
gain = torch.tensor([gain])
gamma = gamma.to(input.device).to(input.dtype)
gain = gain.to(input.device).to(input.dtype)
if (gamma < 0.0).any():
raise ValueError(f"Gamma must be non-negative. Got {gamma}")
if (gain < 0.0).any():
raise ValueError(f"Gain must be non-negative. Got {gain}")
for _ in range(len(input.shape) - len(gamma.shape)):
gamma = torch.unsqueeze(gamma, dim=-1)
for _ in range(len(input.shape) - len(gain.shape)):
gain = torch.unsqueeze(gain, dim=-1)
# Apply the gamma correction
x_adjust: torch.Tensor = gain * torch.pow(input, gamma)
# Truncate between pixel values
out: torch.Tensor = torch.clamp(x_adjust, 0.0, 1.0)
return out
[docs]def adjust_contrast(input: torch.Tensor, contrast_factor: Union[float, torch.Tensor]) -> torch.Tensor:
r"""Adjust Contrast of an image.
.. image:: _static/img/adjust_contrast.png
This implementation aligns OpenCV, not PIL. Hence, the output differs from TorchVision.
The input image is expected to be in the range of [0, 1].
Args:
input: Image to be adjusted in the shape of :math:`(*, H, W)`.
contrast_factor: Contrast adjust factor per element
in the batch. 0 generates a completely black image, 1 does not modify
the input image while any other non-negative number modify the
brightness by this factor.
Return:
Adjusted image in the shape of :math:`(*, H, W)`.
.. note::
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
image_enhancement.html>`__.
Example:
>>> x = torch.ones(1, 1, 2, 2)
>>> adjust_contrast(x, 0.5)
tensor([[[[0.5000, 0.5000],
[0.5000, 0.5000]]]])
>>> x = torch.ones(2, 5, 3, 3)
>>> y = torch.tensor([0.65, 0.50])
>>> adjust_contrast(x, y).shape
torch.Size([2, 5, 3, 3])
"""
if not isinstance(input, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
if not isinstance(contrast_factor, (float, torch.Tensor)):
raise TypeError(f"The factor should be either a float or torch.Tensor. " f"Got {type(contrast_factor)}")
if isinstance(contrast_factor, float):
contrast_factor = torch.tensor([contrast_factor])
contrast_factor = contrast_factor.to(input.device).to(input.dtype)
if (contrast_factor < 0).any():
raise ValueError(f"Contrast factor must be non-negative. Got {contrast_factor}")
for _ in range(len(input.shape) - len(contrast_factor.shape)):
contrast_factor = torch.unsqueeze(contrast_factor, dim=-1)
# Apply contrast factor to each channel
x_adjust: torch.Tensor = input * contrast_factor
# Truncate between pixel values
out: torch.Tensor = torch.clamp(x_adjust, 0.0, 1.0)
return out
[docs]def adjust_brightness(input: torch.Tensor, brightness_factor: Union[float, torch.Tensor]) -> torch.Tensor:
r"""Adjust Brightness of an image.
.. image:: _static/img/adjust_brightness.png
This implementation aligns OpenCV, not PIL. Hence, the output differs from TorchVision.
The input image is expected to be in the range of [0, 1].
Args:
input: image to be adjusted in the shape of :math:`(*, H, W)`.
brightness_factor: Brightness adjust factor per element
in the batch. 0 does not modify the input image while any other number modify the
brightness.
Return:
Adjusted image in the shape of :math:`(*, H, W)`.
.. note::
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
image_enhancement.html>`__.
Example:
>>> x = torch.ones(1, 1, 2, 2)
>>> adjust_brightness(x, 1.)
tensor([[[[1., 1.],
[1., 1.]]]])
>>> x = torch.ones(2, 5, 3, 3)
>>> y = torch.tensor([0.25, 0.50])
>>> adjust_brightness(x, y).shape
torch.Size([2, 5, 3, 3])
"""
if not isinstance(input, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
if not isinstance(brightness_factor, (float, torch.Tensor)):
raise TypeError(f"The factor should be either a float or torch.Tensor. " f"Got {type(brightness_factor)}")
if isinstance(brightness_factor, float):
brightness_factor = torch.tensor([brightness_factor])
brightness_factor = brightness_factor.to(input.device).to(input.dtype)
for _ in range(len(input.shape) - len(brightness_factor.shape)):
brightness_factor = torch.unsqueeze(brightness_factor, dim=-1)
# Apply brightness factor to each channel
x_adjust: torch.Tensor = input + brightness_factor
# Truncate between pixel values
out: torch.Tensor = torch.clamp(x_adjust, 0.0, 1.0)
return out
def _solarize(input: torch.Tensor, thresholds: Union[float, torch.Tensor] = 0.5) -> torch.Tensor:
r"""For each pixel in the image, select the pixel if the value is less than the threshold.
Otherwise, subtract 1.0 from the pixel.
Args:
input (torch.Tensor): image or batched images to solarize.
thresholds (float or torch.Tensor): solarize thresholds.
If int or one element tensor, input will be solarized across the whole batch.
If 1-d tensor, input will be solarized element-wise, len(thresholds) == len(input).
Returns:
torch.Tensor: Solarized images.
"""
if not isinstance(input, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
if not isinstance(thresholds, (float, torch.Tensor)):
raise TypeError(f"The factor should be either a float or torch.Tensor. " f"Got {type(thresholds)}")
if isinstance(thresholds, torch.Tensor) and len(thresholds.shape) != 0:
if not (input.size(0) == len(thresholds) and len(thresholds.shape) == 1):
raise AssertionError(f"thresholds must be a 1-d vector of shape ({input.size(0)},). Got {thresholds}")
# TODO: I am not happy about this line, but no easy to do batch-wise operation
thresholds = thresholds.to(input.device).to(input.dtype)
thresholds = torch.stack([x.expand(*input.shape[-3:]) for x in thresholds])
return torch.where(input < thresholds, input, 1.0 - input)
[docs]def solarize(
input: torch.Tensor,
thresholds: Union[float, torch.Tensor] = 0.5,
additions: Optional[Union[float, torch.Tensor]] = None,
) -> torch.Tensor:
r"""For each pixel in the image less than threshold.
.. image:: _static/img/solarize.png
We add 'addition' amount to it and then clip the pixel value to be between 0 and 1.0.
The value of 'addition' is between -0.5 and 0.5.
Args:
input: image tensor with shapes like :math:`(*, C, H, W)` to solarize.
thresholds: solarize thresholds.
If int or one element tensor, input will be solarized across the whole batch.
If 1-d tensor, input will be solarized element-wise, len(thresholds) == len(input).
additions: between -0.5 and 0.5.
If None, no addition will be performed.
If int or one element tensor, same addition will be added across the whole batch.
If 1-d tensor, additions will be added element-wisely, len(additions) == len(input).
Returns:
The solarized images with shape :math:`(*, C, H, W)`.
Example:
>>> x = torch.rand(1, 4, 3, 3)
>>> out = solarize(x, thresholds=0.5, additions=0.)
>>> out.shape
torch.Size([1, 4, 3, 3])
>>> x = torch.rand(2, 4, 3, 3)
>>> thresholds = torch.tensor([0.8, 0.5])
>>> additions = torch.tensor([-0.25, 0.25])
>>> solarize(x, thresholds, additions).shape
torch.Size([2, 4, 3, 3])
"""
if not isinstance(input, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
if not isinstance(thresholds, (float, torch.Tensor)):
raise TypeError(f"The factor should be either a float or torch.Tensor. " f"Got {type(thresholds)}")
if isinstance(thresholds, float):
thresholds = torch.tensor(thresholds)
if additions is not None:
if not isinstance(additions, (float, torch.Tensor)):
raise TypeError(f"The factor should be either a float or torch.Tensor. " f"Got {type(additions)}")
if isinstance(additions, float):
additions = torch.tensor(additions)
if not torch.all((additions < 0.5) * (additions > -0.5)):
raise AssertionError(f"The value of 'addition' is between -0.5 and 0.5. Got {additions}.")
if isinstance(additions, torch.Tensor) and len(additions.shape) != 0:
if not (input.size(0) == len(additions) and len(additions.shape) == 1):
raise AssertionError(f"additions must be a 1-d vector of shape ({input.size(0)},). Got {additions}")
# TODO: I am not happy about this line, but no easy to do batch-wise operation
additions = additions.to(input.device).to(input.dtype)
additions = torch.stack([x.expand(*input.shape[-3:]) for x in additions])
input = input + additions
input = input.clamp(0.0, 1.0)
return _solarize(input, thresholds)
[docs]@perform_keep_shape_image
def posterize(input: torch.Tensor, bits: Union[int, torch.Tensor]) -> torch.Tensor:
r"""Reduce the number of bits for each color channel.
.. image:: _static/img/posterize.png
Non-differentiable function, ``torch.uint8`` involved.
Args:
input: image tensor with shape :math:`(*, C, H, W)` to posterize.
bits: number of high bits. Must be in range [0, 8].
If int or one element tensor, input will be posterized by this bits.
If 1-d tensor, input will be posterized element-wisely, len(bits) == input.shape[-3].
If n-d tensor, input will be posterized element-channel-wisely, bits.shape == input.shape[:len(bits.shape)]
Returns:
Image with reduced color channels with shape :math:`(*, C, H, W)`.
Example:
>>> x = torch.rand(1, 6, 3, 3)
>>> out = posterize(x, bits=8)
>>> torch.testing.assert_allclose(x, out)
>>> x = torch.rand(2, 6, 3, 3)
>>> bits = torch.tensor([4, 2])
>>> posterize(x, bits).shape
torch.Size([2, 6, 3, 3])
"""
if not isinstance(input, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
if not isinstance(bits, (int, torch.Tensor)):
raise TypeError(f"bits type is not an int or torch.Tensor. Got {type(bits)}")
if isinstance(bits, int):
bits = torch.tensor(bits)
# TODO: find a better way to check boundaries on tensors
# if not torch.all((bits >= 0) * (bits <= 8)) and bits.dtype == torch.int:
# raise ValueError(f"bits must be integers within range [0, 8]. Got {bits}.")
# TODO: Make a differentiable version
# Current version:
# Ref: https://github.com/open-mmlab/mmcv/pull/132/files#diff-309c9320c7f71bedffe89a70ccff7f3bR19
# Ref: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L222
# Potential approach: implementing kornia.LUT with floating points
# https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/functional.py#L472
def _left_shift(input: torch.Tensor, shift: torch.Tensor):
return ((input * 255).to(torch.uint8) * (2 ** shift)).to(input.dtype) / 255.0
def _right_shift(input: torch.Tensor, shift: torch.Tensor):
return (input * 255).to(torch.uint8) / (2 ** shift).to(input.dtype) / 255.0
def _posterize_one(input: torch.Tensor, bits: torch.Tensor):
# Single bits value condition
if bits == 0:
return torch.zeros_like(input)
if bits == 8:
return input.clone()
bits = 8 - bits
return _left_shift(_right_shift(input, bits), bits)
if len(bits.shape) == 0 or (len(bits.shape) == 1 and len(bits) == 1):
return _posterize_one(input, bits)
res = []
if len(bits.shape) == 1:
if bits.shape[0] != input.shape[0]:
raise AssertionError(
f"Batch size must be equal between bits and input. Got {bits.shape[0]}, {input.shape[0]}."
)
for i in range(input.shape[0]):
res.append(_posterize_one(input[i], bits[i]))
return torch.stack(res, dim=0)
if bits.shape != input.shape[: len(bits.shape)]:
raise AssertionError(
"Batch and channel must be equal between bits and input. "
f"Got {bits.shape}, {input.shape[:len(bits.shape)]}."
)
_input = input.view(-1, *input.shape[len(bits.shape):])
_bits = bits.flatten()
for i in range(input.shape[0]):
res.append(_posterize_one(_input[i], _bits[i]))
return torch.stack(res, dim=0).reshape(*input.shape)
[docs]@perform_keep_shape_image
def sharpness(input: torch.Tensor, factor: Union[float, torch.Tensor]) -> torch.Tensor:
r"""Apply sharpness to the input tensor.
.. image:: _static/img/sharpness.png
Implemented Sharpness function from PIL using torch ops. This implementation refers to:
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L326
Args:
input: image tensor with shape :math:`(*, C, H, W)` to sharpen.
factor: factor of sharpness strength. Must be above 0.
If float or one element tensor, input will be sharpened by the same factor across the whole batch.
If 1-d tensor, input will be sharpened element-wisely, len(factor) == len(input).
Returns:
Sharpened image or images with shape :math:`(*, C, H, W)`.
Example:
>>> x = torch.rand(1, 1, 5, 5)
>>> sharpness(x, 0.5).shape
torch.Size([1, 1, 5, 5])
"""
if not isinstance(factor, torch.Tensor):
factor = torch.tensor(factor, device=input.device, dtype=input.dtype)
if len(factor.size()) != 0 and factor.shape != torch.Size([input.size(0)]):
raise AssertionError(
"Input batch size shall match with factor size if factor is not a 0-dim tensor. "
f"Got {input.size(0)} and {factor.shape}"
)
kernel = (
torch.tensor([[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=input.dtype, device=input.device)
.view(1, 1, 3, 3)
.repeat(input.size(1), 1, 1, 1)
/ 13
)
# This shall be equivalent to depthwise conv2d:
# Ref: https://discuss.pytorch.org/t/depthwise-and-separable-convolutions-in-pytorch/7315/2
degenerate = torch.nn.functional.conv2d(input, kernel, bias=None, stride=1, groups=input.size(1))
degenerate = torch.clamp(degenerate, 0.0, 1.0)
# For the borders of the resulting image, fill in the values of the original image.
mask = torch.ones_like(degenerate)
padded_mask = torch.nn.functional.pad(mask, [1, 1, 1, 1])
padded_degenerate = torch.nn.functional.pad(degenerate, [1, 1, 1, 1])
result = torch.where(padded_mask == 1, padded_degenerate, input)
if len(factor.size()) == 0:
return _blend_one(result, input, factor)
return torch.stack([_blend_one(result[i], input[i], factor[i]) for i in range(len(factor))])
def _blend_one(input1: torch.Tensor, input2: torch.Tensor, factor: torch.Tensor) -> torch.Tensor:
r"""Blend two images into one.
Args:
input1: image tensor with shapes like :math:`(H, W)` or :math:`(D, H, W)`.
input2: image tensor with shapes like :math:`(H, W)` or :math:`(D, H, W)`.
factor: factor 0-dim tensor.
Returns:
: image tensor with the batch in the zero position.
"""
if not isinstance(input1, torch.Tensor):
raise AssertionError(f"`input1` must be a tensor. Got {input1}.")
if not isinstance(input2, torch.Tensor):
raise AssertionError(f"`input1` must be a tensor. Got {input2}.")
if isinstance(factor, torch.Tensor) and len(factor.size()) != 0:
raise AssertionError(f"Factor shall be a float or single element tensor. Got {factor}.")
if factor == 0.0:
return input1
if factor == 1.0:
return input2
diff = (input2 - input1) * factor
res = input1 + diff
if factor > 0.0 and factor < 1.0:
return res
return torch.clamp(res, 0, 1)
def _build_lut(histo, step):
# Compute the cumulative sum, shifting by step // 2
# and then normalization by step.
step_trunc = torch.div(step, 2, rounding_mode='trunc')
lut = torch.div(torch.cumsum(histo, 0) + step_trunc, step, rounding_mode='trunc')
# Shift lut, prepending with 0.
lut = torch.cat([torch.zeros(1, device=lut.device, dtype=lut.dtype), lut[:-1]])
# Clip the counts to be in range. This is done
# in the C code for image.point.
return torch.clamp(lut, 0, 255)
# Code taken from: https://github.com/pytorch/vision/pull/796
def _scale_channel(im: torch.Tensor) -> torch.Tensor:
r"""Scale the data in the channel to implement equalize.
Args:
input: image tensor with shapes like :math:`(H, W)` or :math:`(D, H, W)`.
Returns:
image tensor with the batch in the zero position.
"""
min_ = im.min()
max_ = im.max()
if min_.item() < 0.0 and not torch.isclose(min_, torch.tensor(0.0, dtype=min_.dtype)):
raise ValueError(f"Values in the input tensor must greater or equal to 0.0. Found {min_.item()}.")
if max_.item() > 1.0 and not torch.isclose(max_, torch.tensor(1.0, dtype=max_.dtype)):
raise ValueError(f"Values in the input tensor must lower or equal to 1.0. Found {max_.item()}.")
ndims = len(im.shape)
if ndims not in (2, 3):
raise TypeError(f"Input tensor must have 2 or 3 dimensions. Found {ndims}.")
im = im * 255.
# Compute the histogram of the image channel.
histo = _torch_histc_cast(im, bins=256, min=0, max=255)
# For the purposes of computing the step, filter out the nonzeros.
nonzero_histo = torch.reshape(histo[histo != 0], [-1])
step = torch.div(torch.sum(nonzero_histo) - nonzero_histo[-1], 255, rounding_mode='trunc')
# If step is zero, return the original image. Otherwise, build
# lut from the full histogram and step and then index from it.
if step == 0:
result = im
else:
# can't index using 2d index. Have to flatten and then reshape
result = torch.gather(_build_lut(histo, step), 0, im.flatten().long())
result = result.reshape_as(im)
return result / 255.0
[docs]@perform_keep_shape_image
def equalize(input: torch.Tensor) -> torch.Tensor:
r"""Apply equalize on the input tensor.
.. image:: _static/img/equalize.png
Implements Equalize function from PIL using PyTorch ops based on uint8 format:
https://github.com/tensorflow/tpu/blob/5f71c12a020403f863434e96982a840578fdd127/models/official/efficientnet/autoaugment.py#L355
Args:
input: image tensor to equalize with shape :math:`(*, C, H, W)`.
Returns:
Equalized image tensor with shape :math:`(*, C, H, W)`.
Example:
>>> x = torch.rand(1, 2, 3, 3)
>>> equalize(x).shape
torch.Size([1, 2, 3, 3])
"""
res = []
for image in input:
# Assumes RGB for now. Scales each channel independently
# and then stacks the result.
scaled_image = torch.stack([_scale_channel(image[i, :, :]) for i in range(len(image))])
res.append(scaled_image)
return torch.stack(res)
[docs]@perform_keep_shape_video
def equalize3d(input: torch.Tensor) -> torch.Tensor:
r"""Equalize the values for a 3D volumetric tensor.
Implements Equalize function for a sequence of images using PyTorch ops based on uint8 format:
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L352
Args:
input: image tensor with shape :math:`(*, C, D, H, W)` to equalize.
Returns:
Equalized volume with shape :math:`(B, C, D, H, W)`.
"""
res = []
for volume in input:
# Assumes RGB for now. Scales each channel independently
# and then stacks the result.
scaled_input = torch.stack([_scale_channel(volume[i, :, :, :]) for i in range(len(volume))])
res.append(scaled_input)
return torch.stack(res)
[docs]def invert(input: torch.Tensor, max_val: torch.Tensor = torch.tensor(1.0)) -> torch.Tensor:
r"""Invert the values of an input tensor by its maximum value.
.. image:: _static/img/invert.png
Args:
input: The input tensor to invert with an arbitatry shape.
max_val: The expected maximum value in the input tensor. The shape has to
according to the input tensor shape, or at least has to work with broadcasting.
Example:
>>> img = torch.rand(1, 2, 4, 4)
>>> invert(img).shape
torch.Size([1, 2, 4, 4])
>>> img = 255. * torch.rand(1, 2, 3, 4, 4)
>>> invert(img, torch.tensor(255.)).shape
torch.Size([1, 2, 3, 4, 4])
>>> img = torch.rand(1, 3, 4, 4)
>>> invert(img, torch.tensor([[[[1.]]]])).shape
torch.Size([1, 3, 4, 4])
"""
if not isinstance(input, torch.Tensor):
raise AssertionError(f"Input is not a torch.Tensor. Got: {type(input)}")
if not isinstance(max_val, torch.Tensor):
raise AssertionError(f"max_val is not a torch.Tensor. Got: {type(max_val)}")
return max_val.to(input.dtype) - input
[docs]class AdjustSaturation(nn.Module):
r"""Adjust color saturation of an image.
The input image is expected to be an RGB image in the range of [0, 1].
Args:
saturation_factor: How much to adjust the saturation. 0 will give a black
and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2.
Shape:
- Input: Image/Tensor to be adjusted in the shape of :math:`(*, 3, H, W)`.
- Output: Adjusted image in the shape of :math:`(*, 3, H, W)`.
Example:
>>> x = torch.ones(1, 3, 3, 3)
>>> AdjustSaturation(2.)(x)
tensor([[[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
<BLANKLINE>
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
<BLANKLINE>
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]]])
>>> x = torch.ones(2, 3, 3, 3)
>>> y = torch.ones(2)
>>> out = AdjustSaturation(y)(x)
>>> torch.nn.functional.mse_loss(x, out)
tensor(0.)
"""
def __init__(self, saturation_factor: Union[float, torch.Tensor]) -> None:
super().__init__()
self.saturation_factor: Union[float, torch.Tensor] = saturation_factor
def forward(self, input: torch.Tensor) -> torch.Tensor:
return adjust_saturation(input, self.saturation_factor)
[docs]class AdjustHue(nn.Module):
r"""Adjust hue of an image.
The input image is expected to be an RGB image in the range of [0, 1].
Args:
hue_factor: How much to shift the hue channel. Should be in [-PI, PI]. PI
and -PI give complete reversal of hue channel in HSV space in positive and negative
direction respectively. 0 means no shift. Therefore, both -PI and PI will give an
image with complementary colors while 0 gives the original image.
Shape:
- Input: Image/Tensor to be adjusted in the shape of :math:`(*, 3, H, W)`.
- Output: Adjusted image in the shape of :math:`(*, 3, H, W)`.
Example:
>>> x = torch.ones(1, 3, 3, 3)
>>> AdjustHue(3.141516)(x)
tensor([[[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
<BLANKLINE>
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
<BLANKLINE>
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]]])
>>> x = torch.ones(2, 3, 3, 3)
>>> y = torch.ones(2) * 3.141516
>>> AdjustHue(y)(x).shape
torch.Size([2, 3, 3, 3])
"""
def __init__(self, hue_factor: Union[float, torch.Tensor]) -> None:
super().__init__()
self.hue_factor: Union[float, torch.Tensor] = hue_factor
def forward(self, input: torch.Tensor) -> torch.Tensor:
return adjust_hue(input, self.hue_factor)
[docs]class AdjustGamma(nn.Module):
r"""Perform gamma correction on an image.
The input image is expected to be in the range of [0, 1].
Args:
gamma: Non negative real number, same as γ\gammaγ in the equation.
gamma larger than 1 make the shadows darker, while gamma smaller than 1 make
dark regions lighter.
gain: The constant multiplier.
Shape:
- Input: Image to be adjusted in the shape of :math:`(*, N)`.
- Output: Adjusted image in the shape of :math:`(*, N)`.
Example:
>>> x = torch.ones(1, 1, 3, 3)
>>> AdjustGamma(1.0, 2.0)(x)
tensor([[[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]]])
>>> x = torch.ones(2, 5, 3, 3)
>>> y1 = torch.ones(2) * 1.0
>>> y2 = torch.ones(2) * 2.0
>>> AdjustGamma(y1, y2)(x).shape
torch.Size([2, 5, 3, 3])
"""
def __init__(self, gamma: Union[float, torch.Tensor], gain: Union[float, torch.Tensor] = 1.0) -> None:
super().__init__()
self.gamma: Union[float, torch.Tensor] = gamma
self.gain: Union[float, torch.Tensor] = gain
def forward(self, input: torch.Tensor) -> torch.Tensor:
return adjust_gamma(input, self.gamma, self.gain)
[docs]class AdjustContrast(nn.Module):
r"""Adjust Contrast of an image.
This implementation aligns OpenCV, not PIL. Hence, the output differs from TorchVision.
The input image is expected to be in the range of [0, 1].
Args:
contrast_factor: Contrast adjust factor per element
in the batch. 0 generates a completely black image, 1 does not modify
the input image while any other non-negative number modify the
brightness by this factor.
Shape:
- Input: Image/Input to be adjusted in the shape of :math:`(*, N)`.
- Output: Adjusted image in the shape of :math:`(*, N)`.
Example:
>>> x = torch.ones(1, 1, 3, 3)
>>> AdjustContrast(0.5)(x)
tensor([[[[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000]]]])
>>> x = torch.ones(2, 5, 3, 3)
>>> y = torch.ones(2)
>>> AdjustContrast(y)(x).shape
torch.Size([2, 5, 3, 3])
"""
def __init__(self, contrast_factor: Union[float, torch.Tensor]) -> None:
super().__init__()
self.contrast_factor: Union[float, torch.Tensor] = contrast_factor
def forward(self, input: torch.Tensor) -> torch.Tensor:
return adjust_contrast(input, self.contrast_factor)
[docs]class AdjustBrightness(nn.Module):
r"""Adjust Brightness of an image.
This implementation aligns OpenCV, not PIL. Hence, the output differs from TorchVision.
The input image is expected to be in the range of [0, 1].
Args:
brightness_factor: Brightness adjust factor per element
in the batch. 0 does not modify the input image while any other number modify the
brightness.
Shape:
- Input: Image/Input to be adjusted in the shape of :math:`(*, N)`.
- Output: Adjusted image in the shape of :math:`(*, N)`.
Example:
>>> x = torch.ones(1, 1, 3, 3)
>>> AdjustBrightness(1.)(x)
tensor([[[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]]])
>>> x = torch.ones(2, 5, 3, 3)
>>> y = torch.ones(2)
>>> AdjustBrightness(y)(x).shape
torch.Size([2, 5, 3, 3])
"""
def __init__(self, brightness_factor: Union[float, torch.Tensor]) -> None:
super().__init__()
self.brightness_factor: Union[float, torch.Tensor] = brightness_factor
def forward(self, input: torch.Tensor) -> torch.Tensor:
return adjust_brightness(input, self.brightness_factor)
[docs]class Invert(nn.Module):
r"""Invert the values of an input tensor by its maximum value.
Args:
input: The input tensor to invert with an arbitatry shape.
max_val: The expected maximum value in the input tensor. The shape has to
according to the input tensor shape, or at least has to work with broadcasting. Default: 1.0.
Example:
>>> img = torch.rand(1, 2, 4, 4)
>>> Invert()(img).shape
torch.Size([1, 2, 4, 4])
>>> img = 255. * torch.rand(1, 2, 3, 4, 4)
>>> Invert(torch.tensor(255.))(img).shape
torch.Size([1, 2, 3, 4, 4])
>>> img = torch.rand(1, 3, 4, 4)
>>> Invert(torch.tensor([[[[1.]]]]))(img).shape
torch.Size([1, 3, 4, 4])
"""
def __init__(self, max_val: torch.Tensor = torch.tensor(1.0)) -> None:
super().__init__()
if not isinstance(max_val, nn.Parameter):
self.register_buffer("max_val", max_val)
else:
self.max_val = max_val
def forward(self, input: torch.Tensor) -> torch.Tensor:
return invert(input, self.max_val)