kornia.enhance¶
The functions in this section perform normalisations and intensity transformations.
-
adjust_brightness
(input: torch.Tensor, brightness_factor: Union[float, torch.Tensor]) → torch.Tensor[source]¶ Adjust Brightness of an image.
This implementation aligns OpenCV, not PIL. Hence, the output differs from TorchVision. The input image is expected to be in the range of [0, 1].
- Parameters
input (torch.Tensor) – image to be adjusted in the shape of \((*, N)\).
brightness_factor (Union[float, torch.Tensor]) – Brightness adjust factor per element in the batch. 0 does not modify the input image while any other number modify the brightness.
- Returns
Adjusted image in the shape of \((*, N)\).
- Return type
Example
>>> x = torch.ones(1, 1, 3, 3) >>> adjust_brightness(x, 1.) tensor([[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]])
>>> x = torch.ones(2, 5, 3, 3) >>> y = torch.ones(2) >>> adjust_brightness(x, y).shape torch.Size([2, 5, 3, 3])
-
adjust_contrast
(input: torch.Tensor, contrast_factor: Union[float, torch.Tensor]) → torch.Tensor[source]¶ Adjust Contrast of an image.
This implementation aligns OpenCV, not PIL. Hence, the output differs from TorchVision. The input image is expected to be in the range of [0, 1].
- Parameters
input (torch.Tensor) – Image to be adjusted in the shape of \((*, N)\).
contrast_factor (Union[float, torch.Tensor]) – Contrast adjust factor per element in the batch. 0 generates a completely black image, 1 does not modify the input image while any other non-negative number modify the brightness by this factor.
- Returns
Adjusted image in the shape of \((*, N)\).
- Return type
Example
>>> x = torch.ones(1, 1, 3, 3) >>> adjust_contrast(x, 0.5) tensor([[[[0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000]]]])
>>> x = torch.ones(2, 5, 3, 3) >>> y = torch.ones(2) >>> adjust_contrast(x, y).shape torch.Size([2, 5, 3, 3])
-
adjust_gamma
(input: torch.Tensor, gamma: Union[float, torch.Tensor], gain: Union[float, torch.Tensor] = 1.0) → torch.Tensor[source]¶ Perform gamma correction on an image.
The input image is expected to be in the range of [0, 1].
- Parameters
input (torch.Tensor) – Image to be adjusted in the shape of \((*, N)\).
gamma (Union[float, torch.Tensor]) – Non negative real number, same as γgammaγ in the equation. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter.
gain (Union[float, torch.Tensor], optional) – The constant multiplier. Default 1.
- Returns
Adjusted image in the shape of \((*, N)\).
- Return type
torch.Tenor
Example
>>> x = torch.ones(1, 1, 3, 3) >>> adjust_gamma(x, 1.0, 2.0) tensor([[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]])
>>> x = torch.ones(2, 5, 3, 3) >>> y1 = torch.ones(2) * 1.0 >>> y2 = torch.ones(2) * 2.0 >>> adjust_gamma(x, y1, y2).shape torch.Size([2, 5, 3, 3])
-
adjust_hue
(input: torch.Tensor, hue_factor: Union[float, torch.Tensor]) → torch.Tensor[source]¶ Adjust hue of an image.
The input image is expected to be an RGB image in the range of [0, 1].
- Parameters
input (torch.Tensor) – Image to be adjusted in the shape of \((*, 3, H, W)\).
hue_factor (Union[float, torch.Tensor]) – How much to shift the hue channel. Should be in [-PI, PI]. PI and -PI give complete reversal of hue channel in HSV space in positive and negative direction respectively. 0 means no shift. Therefore, both -PI and PI will give an image with complementary colors while 0 gives the original image.
- Returns
Adjusted image in the shape of \((*, 3, H, W)\).
- Return type
Example
>>> x = torch.ones(1, 3, 3, 3) >>> adjust_hue(x, 3.141516) tensor([[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], <BLANKLINE> [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], <BLANKLINE> [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]])
>>> x = torch.ones(2, 3, 3, 3) >>> y = torch.ones(2) * 3.141516 >>> adjust_hue(x, y).shape torch.Size([2, 3, 3, 3])
-
adjust_saturation
(input: torch.Tensor, saturation_factor: Union[float, torch.Tensor]) → torch.Tensor[source]¶ Adjust color saturation of an image.
The input image is expected to be an RGB image in the range of [0, 1].
- Parameters
input (torch.Tensor) – Image/Tensor to be adjusted in the shape of \((*, 3, H, W)\).
saturation_factor (Union[float, torch.Tensor]) – How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2.
- Returns
Adjusted image in the shape of \((*, 3, H, W)\).
- Return type
Example
>>> x = torch.ones(1, 3, 3, 3) >>> adjust_saturation(x, 2.) tensor([[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], <BLANKLINE> [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], <BLANKLINE> [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]])
>>> x = torch.ones(2, 3, 3, 3) >>> y = torch.ones(2) >>> out = adjust_saturation(x, y) >>> torch.nn.functional.mse_loss(x, out) tensor(0.)
-
add_weighted
(src1: torch.Tensor, alpha: float, src2: torch.Tensor, beta: float, gamma: float) → torch.Tensor[source]¶ Calculates the weighted sum of two Tensors.
The function calculates the weighted sum of two Tensors as follows:
\[out = src1 * alpha + src2 * beta + gamma\]- Parameters
src1 (torch.Tensor) – Tensor of shape \((B, C, H, W)\).
alpha (float) – weight of the src1 elements.
src2 (torch.Tensor) – Tensor of same size and channel number as src1 \((B, C, H, W)\).
beta (float) – weight of the src2 elements.
gamma (float) – scalar added to each sum.
- Returns
Weighted Tensor of shape \((B, C, H, W)\).
- Return type
Example
>>> input1 = torch.rand(1, 1, 5, 5) >>> input2 = torch.rand(1, 1, 5, 5) >>> output = add_weighted(input1, 0.5, input2, 0.5, 1.0) >>> output.shape torch.Size([1, 1, 5, 5])
-
normalize
(data: torch.Tensor, mean: Union[torch.Tensor, float], std: Union[torch.Tensor, float]) → torch.Tensor[source]¶ Normalize a tensor image with mean and standard deviation.
\[\text{input[channel] = (input[channel] - mean[channel]) / std[channel]}\]Where mean is \((M_1, ..., M_n)\) and std \((S_1, ..., S_n)\) for n channels,
- Parameters
data (torch.Tensor) – Image tensor of size \((*, C, ...)\).
mean (Union[torch.Tensor, float]) – Mean for each channel.
std (Union[torch.Tensor, float]) – Standard deviations for each channel.
- Returns
Normalised tensor with same size as input \((*, C, ...)\).
- Return type
Examples
>>> x = torch.rand(1, 4, 3, 3) >>> out = normalize(x, 0.0, 255.) >>> out.shape torch.Size([1, 4, 3, 3])
>>> x = torch.rand(1, 4, 3, 3) >>> mean = torch.zeros(1, 4) >>> std = 255. * torch.ones(1, 4) >>> out = normalize(x, mean, std) >>> out.shape torch.Size([1, 4, 3, 3])
-
normalize_min_max
(x: torch.Tensor, min_val: float = 0.0, max_val: float = 1.0, eps: float = 1e-06) → torch.Tensor[source]¶ Normalise an image tensor by MinMax and re-scales the value between a range.
The data is normalised using the following formulation:
\[y_i = (b - a) * \frac{x_i - \text{min}(x)}{\text{max}(x) - \text{min}(x)} + a\]where \(a\) is \(\text{min_val}\) and \(b\) is \(\text{max_val}\).
- Parameters
x (torch.Tensor) – The image tensor to be normalised with shape \((B, C, ...)\).
min_val (float) – The minimum value for the new range. Default: 0.
max_val (float) – The maximum value for the new range. Default: 1.
eps (float) – Float number to avoid zero division. Default: 1e-6.
- Returns
The normalised image tensor with same shape as input \((B, C, ...)\).
- Return type
Example
>>> x = torch.rand(1, 5, 3, 3) >>> x_norm = normalize_min_max(x, min_val=-1., max_val=1.) >>> x_norm.min() tensor(-1.) >>> x_norm.max() tensor(1.0000)
-
denormalize
(data: torch.Tensor, mean: Union[torch.Tensor, float], std: Union[torch.Tensor, float]) → torch.Tensor[source]¶ Denormalize a tensor image with mean and standard deviation.
\[\text{input[channel] = (input[channel] * mean[channel]) + std[channel]}\]Where mean is \((M_1, ..., M_n)\) and std \((S_1, ..., S_n)\) for n channels,
- Parameters
input (torch.Tensor) – Image tensor of size \((*, C, ...)\).
mean (Union[torch.Tensor, float]) – Mean for each channel.
std (Union[torch.Tensor, float]) – Standard deviations for each channel.
- Returns
Denormalised tensor with same size as input \((*, C, ...)\).
- Return type
Examples
>>> x = torch.rand(1, 4, 3, 3) >>> out = denormalize(x, 0.0, 255.) >>> out.shape torch.Size([1, 4, 3, 3])
>>> x = torch.rand(1, 4, 3, 3, 3) >>> mean = torch.zeros(1, 4) >>> std = 255. * torch.ones(1, 4) >>> out = denormalize(x, mean, std) >>> out.shape torch.Size([1, 4, 3, 3, 3])
-
zca_mean
(inp: torch.Tensor, dim: int = 0, unbiased: bool = True, eps: float = 1e-06, return_inverse: bool = False) → Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]][source]¶ Computes the ZCA whitening matrix and mean vector. The output can be used with
linear_transform()
See
ZCAWhitening
for details.- Parameters
inp (torch.Tensor) – input data tensor
dim (int) – Specifies the dimension that serves as the samples dimension. Default = 0
unbiased (bool) – Whether to use the unbiased estimate of the covariance matrix. Default = True
eps (float) – a small number used for numerical stability. Default = 0
return_inverse (bool) – Whether to return the inverse ZCA transform.
- shapes:
inp: \((D_0,...,D_{\text{dim}},...,D_N)\) is a batch of N-D tensors.
transform_matrix: \((\Pi_{d=0,d\neq \text{dim}}^N D_d, \Pi_{d=0,d\neq \text{dim}}^N D_d)\)
mean_vector: \((1, \Pi_{d=0,d\neq \text{dim}}^N D_d)\)
inv_transform: same shape as the transform matrix
- Returns
A tuple containing the ZCA matrix and the mean vector. If return_inverse is set to True, then it returns the inverse ZCA matrix, otherwise it returns None.
- Return type
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Examples
>>> x = torch.tensor([[0,1],[1,0],[-1,0],[0,-1]], dtype = torch.float32) >>> transform_matrix, mean_vector,_ = zca_mean(x) # Returns transformation matrix and data mean >>> x = torch.rand(3,20,2,2) >>> transform_matrix, mean_vector, inv_transform = zca_mean(x, dim = 1, return_inverse = True) >>> # transform_matrix.size() equals (12,12) and the mean vector.size equal (1,12)
-
zca_whiten
(inp: torch.Tensor, dim: int = 0, unbiased: bool = True, eps: float = 1e-06) → torch.Tensor[source]¶ Applies ZCA whitening transform.
See
ZCAWhitening
for details.- Parameters
inp (torch.Tensor) – input data tensor
dim (int) – Specifies the dimension that serves as the samples dimension. Default = 0
unbiased (bool) – Whether to use the unbiased estimate of the covariance matrix. Default = True
eps (float) – a small number used for numerical stability. Default = 0
- Returns
Whiten Input data
- Return type
Examples
>>> x = torch.tensor([[0,1],[1,0],[-1,0]], dtype = torch.float32) >>> zca_whiten(x) tensor([[ 0.0000, 1.1547], [ 1.0000, -0.5773], [-1.0000, -0.5773]])
-
linear_transform
(inp: torch.Tensor, transform_matrix: torch.Tensor, mean_vector: torch.Tensor, dim: int = 0) → torch.Tensor[source]¶ Given a transformation matrix and a mean vector, this function will flatten the input tensor along the given dimension and subtract the mean vector from it. Then the dot product with the transformation matrix will be computed and then the resulting tensor is reshaped to the original input shape.
\[\mathbf{X}_{T} = (\mathbf{X - \mu})(T)\]- Parameters
inp (torch.Tensor) – Input data \(X\)
transform_matrix (torch.Tensor) – Transform matrix \(T\)
mean_vector (torch.Tensor) – mean vector \(\mu\)
dim (int) – Batch dimension. Default = 0
- shapes:
inp: \((D_0,...,D_{\text{dim}},...,D_N)\) is a batch of N-D tensors.
transform_matrix: \((\Pi_{d=0,d\neq \text{dim}}^N D_d, \Pi_{d=0,d\neq \text{dim}}^N D_d)\)
mean_vector: \((1, \Pi_{d=0,d\neq \text{dim}}^N D_d)\)
- Returns
Transformed data
- Return type
Example
>>> # Example where dim = 3 >>> inp = torch.ones((10,3,4,5)) >>> transform_mat = torch.ones((10*3*4,10*3*4)) >>> mean = 2*torch.ones((1,10*3*4)) >>> out = linear_transform(inp, transform_mat, mean, 3) >>> print(out.shape, out.unique()) # Should a be (10,3,4,5) tensor of -120s torch.Size([10, 3, 4, 5]) tensor([-120.]) >>> # Example where dim = 0 >>> inp = torch.ones((10,2)) >>> transform_mat = torch.ones((2,2)) >>> mean = torch.zeros((1,2)) >>> out = linear_transform(inp, transform_mat, mean) >>> print(out.shape, out.unique()) # Should a be (10,2) tensor of 2s torch.Size([10, 2]) tensor([2.])
-
histogram
(x: torch.Tensor, bins: torch.Tensor, bandwidth: torch.Tensor, epsilon: float = 1e-10) → torch.Tensor[source]¶ Function that estimates the histogram of the input tensor.
The calculation uses kernel density estimation which requires a bandwidth (smoothing) parameter.
- Parameters
x (torch.Tensor) – Input tensor to compute the histogram with shape \((B, D)\).
bins (torch.Tensor) – The number of bins to use the histogram \((N_{bins})\).
bandwidth (torch.Tensor) – Gaussian smoothing factor with shape shape [1].
epsilon (float) – A scalar, for numerical stability. Default: 1e-10.
- Returns
Computed histogram of shape \((B, N_{bins})\).
- Return type
Examples
>>> x = torch.rand(1, 10) >>> bins = torch.torch.linspace(0, 255, 128) >>> hist = histogram(x, bins, bandwidth=torch.tensor(0.9)) >>> hist.shape torch.Size([1, 128])
-
histogram2d
(x1: torch.Tensor, x2: torch.Tensor, bins: torch.Tensor, bandwidth: torch.Tensor, epsilon: float = 1e-10) → torch.Tensor[source]¶ Function that estimates the 2d histogram of the input tensor.
The calculation uses kernel density estimation which requires a bandwidth (smoothing) parameter.
- Parameters
x1 (torch.Tensor) – Input tensor to compute the histogram with shape \((B, D1)\).
x2 (torch.Tensor) – Input tensor to compute the histogram with shape \((B, D2)\).
bins (torch.Tensor) – The number of bins to use the histogram \((N_{bins})\).
bandwidth (torch.Tensor) – Gaussian smoothing factor with shape shape [1].
epsilon (float) – A scalar, for numerical stability. Default: 1e-10.
- Returns
Computed histogram of shape \((B, N_{bins}), N_{bins})\).
- Return type
Examples
>>> x1 = torch.rand(2, 32) >>> x2 = torch.rand(2, 32) >>> bins = torch.torch.linspace(0, 255, 128) >>> hist = histogram2d(x1, x2, bins, bandwidth=torch.tensor(0.9)) >>> hist.shape torch.Size([2, 128, 128])
-
solarize
(input: torch.Tensor, thresholds: Union[float, torch.Tensor] = 0.5, additions: Union[float, torch.Tensor, None] = None) → torch.Tensor[source]¶ For each pixel in the image less than threshold.
We add ‘addition’ amount to it and then clip the pixel value to be between 0 and 1.0. The value of ‘addition’ is between -0.5 and 0.5.
- Parameters
input (torch.Tensor) – image tensor with shapes like \((B, C, H, W)\) to solarize.
thresholds (float or torch.Tensor) – solarize thresholds. If int or one element tensor, input will be solarized across the whole batch. If 1-d tensor, input will be solarized element-wise, len(thresholds) == len(input).
additions (optional, float or torch.Tensor) – between -0.5 and 0.5. Default None. If None, no addition will be performed. If int or one element tensor, same addition will be added across the whole batch. If 1-d tensor, additions will be added element-wisely, len(additions) == len(input).
- Returns
The solarized images with shape \((B, C, H, W)\).
- Return type
Example
>>> x = torch.rand(1, 4, 3, 3) >>> out = solarize(x, thresholds=0.5, additions=0.) >>> out.shape torch.Size([1, 4, 3, 3])
>>> x = torch.rand(2, 4, 3, 3) >>> thresholds = torch.tensor([0.8, 0.7]) >>> out = solarize(x, thresholds) >>> out.shape torch.Size([2, 4, 3, 3])
-
posterize
(input: torch.Tensor, bits: Union[int, torch.Tensor]) → torch.Tensor[source]¶ Reduce the number of bits for each color channel.
Non-differentiable function, torch.uint8 involved.
- Parameters
input (torch.Tensor) – image tensor with shapes like \((B, C, H, W)\) to posterize.
bits (int or torch.Tensor) – number of high bits. Must be in range [0, 8]. If int or one element tensor, input will be posterized by this bits. If 1-d tensor, input will be posterized element-wisely, len(bits) == input.shape[1]. If n-d tensor, input will be posterized element-channel-wisely, bits.shape == input.shape[:len(bits.shape)]
- Returns
Image with reduced color channels with shape \((B, C, H, W)\).
- Return type
Example
>>> x = torch.rand(1, 6, 3, 3) >>> out = posterize(x, bits=8) >>> torch.testing.assert_allclose(x, out)
>>> x = torch.rand(2, 6, 3, 3) >>> bits = torch.tensor([0, 8]) >>> posterize(x, bits).shape torch.Size([2, 6, 3, 3])
-
sharpness
(input: torch.Tensor, factor: Union[float, torch.Tensor]) → torch.Tensor[source]¶ Apply sharpness to the input tensor.
Implemented Sharpness function from PIL using torch ops. This implementation refers to: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py#L326
- Parameters
input (torch.Tensor) – image tensor with shapes like (C, H, W) or (B, C, H, W) to sharpen.
factor (float or torch.Tensor) – factor of sharpness strength. Must be above 0. If float or one element tensor, input will be sharpened by the same factor across the whole batch. If 1-d tensor, input will be sharpened element-wisely, len(factor) == len(input).
- Returns
Sharpened image or images.
- Return type
Example
>>> _ = torch.manual_seed(0) >>> sharpness(torch.randn(1, 1, 5, 5), 0.5) tensor([[[[-1.1258, -1.1524, -0.2506, -0.4339, 0.8487], [ 0.6920, -0.1580, -1.0576, 0.1765, -0.1577], [ 1.4437, 0.1998, 0.1799, 0.6588, -0.1435], [-0.1116, -0.3068, 0.8381, 1.3477, 0.0537], [ 0.6181, -0.4128, -0.8411, -2.3160, -0.1023]]]])
-
equalize
(input: torch.Tensor) → torch.Tensor[source]¶ Apply equalize on the input tensor.
Implements Equalize function from PIL using PyTorch ops based on uint8 format: https://github.com/tensorflow/tpu/blob/5f71c12a020403f863434e96982a840578fdd127/models/official/efficientnet/autoaugment.py#L355
- Parameters
input (torch.Tensor) – image tensor to equalize with shapes like \((C, H, W)\) or \((B, C, H, W)\).
- Returns
Sharpened image or images with shape as the input.
- Return type
Example
>>> _ = torch.manual_seed(0) >>> x = torch.rand(1, 2, 3, 3) >>> equalize(x) tensor([[[[0.4963, 0.7682, 0.0885], [0.1320, 0.3074, 0.6341], [0.4901, 0.8964, 0.4556]], <BLANKLINE> [[0.6323, 0.3489, 0.4017], [0.0223, 0.1689, 0.2939], [0.5185, 0.6977, 0.8000]]]])
Modules¶
-
class
Normalize
(mean: Union[torch.Tensor, float], std: Union[torch.Tensor, float])[source]¶ Normalize a tensor image with mean and standard deviation.
\[\text{input[channel] = (input[channel] - mean[channel]) / std[channel]}\]Where mean is \((M_1, ..., M_n)\) and std \((S_1, ..., S_n)\) for n channels,
- Parameters
mean (Union[torch.Tensor, float]) – Mean for each channel.
std (Union[torch.Tensor, float]) – Standard deviations for each channel.
- Shape:
Input: Image tensor of size \((*, C, ...)\).
Output: Normalised tensor with same size as input \((*, C, ...)\).
Examples
>>> x = torch.rand(1, 4, 3, 3) >>> out = Normalize(0.0, 255.)(x) >>> out.shape torch.Size([1, 4, 3, 3])
>>> x = torch.rand(1, 4, 3, 3) >>> mean = torch.zeros(1, 4) >>> std = 255. * torch.ones(1, 4) >>> out = Normalize(mean, std)(x) >>> out.shape torch.Size([1, 4, 3, 3])
-
class
Denormalize
(mean: Union[torch.Tensor, float], std: Union[torch.Tensor, float])[source]¶ Denormalize a tensor image with mean and standard deviation.
\[\text{input[channel] = (input[channel] * mean[channel]) + std[channel]}\]Where mean is \((M_1, ..., M_n)\) and std \((S_1, ..., S_n)\) for n channels,
- Parameters
mean (Union[torch.Tensor, float]) – Mean for each channel.
std (Union[torch.Tensor, float]) – Standard deviations for each channel.
- Shape:
Input: Image tensor of size \((*, C, ...)\).
Output: Denormalised tensor with same size as input \((*, C, ...)\).
Examples
>>> x = torch.rand(1, 4, 3, 3) >>> out = Denormalize(0.0, 255.)(x) >>> out.shape torch.Size([1, 4, 3, 3])
>>> x = torch.rand(1, 4, 3, 3, 3) >>> mean = torch.zeros(1, 4) >>> std = 255. * torch.ones(1, 4) >>> out = Denormalize(mean, std)(x) >>> out.shape torch.Size([1, 4, 3, 3, 3])
-
class
ZCAWhitening
(dim: int = 0, eps: float = 1e-06, unbiased: bool = True, detach_transforms: bool = True, compute_inv: bool = False)[source]¶ Computes the ZCA whitening matrix transform and the mean vector and applies the transform to the data. The data tensor is flattened, and the mean \(\mathbf{\mu}\) and covariance matrix \(\mathbf{\Sigma}\) are computed from the flattened data \(\mathbf{X} \in \mathbb{R}^{N \times D}\), where \(N\) is the sample size and \(D\) is flattened dimensionality (e.g. for a tensor with size 5x3x2x2 \(N = 5\) and \(D = 12\)). The ZCA whitening transform is given by:
\[\mathbf{X}_{\text{zca}} = (\mathbf{X - \mu})(US^{-\frac{1}{2}}U^T)^T\]where \(U\) are the eigenvectors of \(\Sigma\) and \(S\) contain the corresponding eigenvalues of \(\Sigma\). After the transform is applied, the output is reshaped to same shape.
- Parameters
dim (int) – Determines the dimension that represents the samples axis. Default = 0
eps (float) – a small number used for numerical stability. Default=1e-6
unbiased (bool) – Whether to use the biased estimate of the covariance matrix. Default=False
compute_inv (bool) – Compute the inverse transform matrix. Default=False
detach_transforms (bool) – Detaches gradient from the ZCA fitting. Default=True
- shape:
x: \((D_0,...,D_{\text{dim}},...,D_N)\) is a batch of N-D tensors.
x_whiten: \((D_0,...,D_{\text{dim}},...,D_N)\) same shape as input.
Examples
>>> x = torch.tensor([[0,1],[1,0],[-1,0],[0,-1]], dtype = torch.float32) >>> zca = ZCAWhitening().fit(x) >>> x_whiten = zca(x) >>> zca = ZCAWhitening() >>> x_whiten = zca(x, include_fit = True) # Includes the fitting step >>> x_whiten = zca(x) # Can run now without the fitting set >>> # Enable backprop through ZCA fitting process >>> zca = ZCAWhitening(detach_transforms = False) >>> x_whiten = zca(x, include_fit = True) # Includes the fitting step
Note
This implementation uses
svd()
which yields NaNs in the backwards step if the singular values are not unique. See here for more information.References
[1] Stanford PCA & ZCA whitening tutorial
-
fit
(x: torch.Tensor)[source]¶ Fits ZCA whitening matrices to the data.
- Parameters
x (torch.Tensor) – Input data
- Returns
returns a fitted ZCAWhiten object instance.
- Return type
ZCAWhiten
-
forward
(x: torch.Tensor, include_fit: bool = False) → torch.Tensor[source]¶ Applies the whitening transform to the data
- Parameters
x (torch.Tensor) – Input data
include_fit (bool) – Indicates whether to fit the data as part of the forward pass
- Returns
The transformed data
- Return type
-
inverse_transform
(x: torch.Tensor) → torch.Tensor[source]¶ Applies the inverse transform to the whitened data.
- Parameters
x (torch.Tensor) – Whitened data
- Returns
original data
- Return type
-
class
AdjustBrightness
(brightness_factor: Union[float, torch.Tensor])[source]¶ Adjust Brightness of an image.
This implementation aligns OpenCV, not PIL. Hence, the output differs from TorchVision. The input image is expected to be in the range of [0, 1].
- Parameters
brightness_factor (Union[float, torch.Tensor]) – Brightness adjust factor per element in the batch. 0 does not modify the input image while any other number modify the brightness.
- Shape:
Input: Image/Input to be adjusted in the shape of \((*, N)\).
Output: Adjusted image in the shape of \((*, N)\).
Example
>>> x = torch.ones(1, 1, 3, 3) >>> AdjustBrightness(1.)(x) tensor([[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]])
>>> x = torch.ones(2, 5, 3, 3) >>> y = torch.ones(2) >>> AdjustBrightness(y)(x).shape torch.Size([2, 5, 3, 3])
-
class
AdjustContrast
(contrast_factor: Union[float, torch.Tensor])[source]¶ Adjust Contrast of an image.
This implementation aligns OpenCV, not PIL. Hence, the output differs from TorchVision. The input image is expected to be in the range of [0, 1].
- Parameters
contrast_factor (Union[float, torch.Tensor]) – Contrast adjust factor per element in the batch. 0 generates a completely black image, 1 does not modify the input image while any other non-negative number modify the brightness by this factor.
- Shape:
Input: Image/Input to be adjusted in the shape of \((*, N)\).
Output: Adjusted image in the shape of \((*, N)\).
Example
>>> x = torch.ones(1, 1, 3, 3) >>> AdjustContrast(0.5)(x) tensor([[[[0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000]]]])
>>> x = torch.ones(2, 5, 3, 3) >>> y = torch.ones(2) >>> AdjustContrast(y)(x).shape torch.Size([2, 5, 3, 3])
-
class
AdjustSaturation
(saturation_factor: Union[float, torch.Tensor])[source]¶ Adjust color saturation of an image.
The input image is expected to be an RGB image in the range of [0, 1].
- Parameters
saturation_factor (Union[float, torch.Tensor]) – How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2.
- Shape:
Input: Image/Tensor to be adjusted in the shape of \((*, 3, H, W)\).
Output: Adjusted image in the shape of \((*, 3, H, W)\).
Example
>>> x = torch.ones(1, 3, 3, 3) >>> AdjustSaturation(2.)(x) tensor([[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], <BLANKLINE> [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], <BLANKLINE> [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]])
>>> x = torch.ones(2, 3, 3, 3) >>> y = torch.ones(2) >>> out = AdjustSaturation(y)(x) >>> torch.nn.functional.mse_loss(x, out) tensor(0.)
-
class
AdjustHue
(hue_factor: Union[float, torch.Tensor])[source]¶ Adjust hue of an image.
The input image is expected to be an RGB image in the range of [0, 1].
- Parameters
hue_factor (Union[float, torch.Tensor]) – How much to shift the hue channel. Should be in [-PI, PI]. PI and -PI give complete reversal of hue channel in HSV space in positive and negative direction respectively. 0 means no shift. Therefore, both -PI and PI will give an image with complementary colors while 0 gives the original image.
- Shape:
Input: Image/Tensor to be adjusted in the shape of \((*, 3, H, W)\).
Output: Adjusted image in the shape of \((*, 3, H, W)\).
Example
>>> x = torch.ones(1, 3, 3, 3) >>> AdjustHue(3.141516)(x) tensor([[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], <BLANKLINE> [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], <BLANKLINE> [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]])
>>> x = torch.ones(2, 3, 3, 3) >>> y = torch.ones(2) * 3.141516 >>> AdjustHue(y)(x).shape torch.Size([2, 3, 3, 3])
-
class
AdjustGamma
(gamma: Union[float, torch.Tensor], gain: Union[float, torch.Tensor] = 1.0)[source]¶ Perform gamma correction on an image.
The input image is expected to be in the range of [0, 1].
- Parameters
gamma (Union[float, torch.Tensor]) – Non negative real number, same as γgammaγ in the equation. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter.
gain (Union[float, torch.Tensor], optional) – The constant multiplier. Default 1.
- Shape:
Input: Image to be adjusted in the shape of \((*, N)\).
Output: Adjusted image in the shape of \((*, N)\).
Example
>>> x = torch.ones(1, 1, 3, 3) >>> AdjustGamma(1.0, 2.0)(x) tensor([[[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]])
>>> x = torch.ones(2, 5, 3, 3) >>> y1 = torch.ones(2) * 1.0 >>> y2 = torch.ones(2) * 2.0 >>> AdjustGamma(y1, y2)(x).shape torch.Size([2, 5, 3, 3])
-
class
AddWeighted
(alpha: float, beta: float, gamma: float)[source]¶ Calculates the weighted sum of two Tensors.
The function calculates the weighted sum of two Tensors as follows:
\[out = src1 * alpha + src2 * beta + gamma\]- Parameters
- Shape:
Input1: Tensor of shape \((B, C, H, W)\).
Input2: Tensor of shape \((B, C, H, W)\).
Output: Weighted tensor of shape \((B, C, H, W)\).
Example
>>> input1 = torch.rand(1, 1, 5, 5) >>> input2 = torch.rand(1, 1, 5, 5) >>> output = AddWeighted(0.5, 0.5, 1.0)(input1, input2) >>> output.shape torch.Size([1, 1, 5, 5])