"""The RGB to Lab color transformations were translated from scikit image's rgb2lab and lab2rgb.
https://github.com/scikit-image/scikit-image/blob/a48bf6774718c64dade4548153ae16065b595ca9/skimage/color/colorconv.py
"""
import torch
from torch import nn
from .rgb import linear_rgb_to_rgb, rgb_to_linear_rgb
from .xyz import rgb_to_xyz, xyz_to_rgb
[docs]def rgb_to_lab(image: torch.Tensor) -> torch.Tensor:
r"""Convert a RGB image to Lab.
.. image:: _static/img/rgb_to_lab.png
The input RGB image is assumed to be in the range of :math:`[0, 1]`. Lab
color is computed using the D65 illuminant and Observer 2.
Args:
image: RGB Image to be converted to Lab with shape :math:`(*, 3, H, W)`.
Returns:
Lab version of the image with shape :math:`(*, 3, H, W)`.
The L channel values are in the range 0..100. a and b are in the range -128..127.
Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> output = rgb_to_lab(input) # 2x3x4x5
"""
if not isinstance(image, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
# Convert from sRGB to Linear RGB
lin_rgb = rgb_to_linear_rgb(image)
xyz_im: torch.Tensor = rgb_to_xyz(lin_rgb)
# normalize for D65 white point
xyz_ref_white = torch.tensor([0.95047, 1.0, 1.08883], device=xyz_im.device, dtype=xyz_im.dtype)[..., :, None, None]
xyz_normalized = torch.div(xyz_im, xyz_ref_white)
threshold = 0.008856
power = torch.pow(xyz_normalized.clamp(min=threshold), 1 / 3.0)
scale = 7.787 * xyz_normalized + 4.0 / 29.0
xyz_int = torch.where(xyz_normalized > threshold, power, scale)
x: torch.Tensor = xyz_int[..., 0, :, :]
y: torch.Tensor = xyz_int[..., 1, :, :]
z: torch.Tensor = xyz_int[..., 2, :, :]
L: torch.Tensor = (116.0 * y) - 16.0
a: torch.Tensor = 500.0 * (x - y)
_b: torch.Tensor = 200.0 * (y - z)
out: torch.Tensor = torch.stack([L, a, _b], dim=-3)
return out
[docs]def lab_to_rgb(image: torch.Tensor, clip: bool = True) -> torch.Tensor:
r"""Convert a Lab image to RGB.
The L channel is assumed to be in the range of :math:`[0, 100]`.
a and b channels are in the range of :math:`[-128, 127]`.
Args:
image: Lab image to be converted to RGB with shape :math:`(*, 3, H, W)`.
clip: Whether to apply clipping to insure output RGB values in range :math:`[0, 1]`.
Returns:
Lab version of the image with shape :math:`(*, 3, H, W)`.
The output RGB image are in the range of :math:`[0, 1]`.
Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> output = lab_to_rgb(input) # 2x3x4x5
"""
if not isinstance(image, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
L: torch.Tensor = image[..., 0, :, :]
a: torch.Tensor = image[..., 1, :, :]
_b: torch.Tensor = image[..., 2, :, :]
fy = (L + 16.0) / 116.0
fx = (a / 500.0) + fy
fz = fy - (_b / 200.0)
# if color data out of range: Z < 0
fz = fz.clamp(min=0.0)
fxyz = torch.stack([fx, fy, fz], dim=-3)
# Convert from Lab to XYZ
power = torch.pow(fxyz, 3.0)
scale = (fxyz - 4.0 / 29.0) / 7.787
xyz = torch.where(fxyz > 0.2068966, power, scale)
# For D65 white point
xyz_ref_white = torch.tensor([0.95047, 1.0, 1.08883], device=xyz.device, dtype=xyz.dtype)[..., :, None, None]
xyz_im = xyz * xyz_ref_white
rgbs_im: torch.Tensor = xyz_to_rgb(xyz_im)
# https://github.com/richzhang/colorization-pytorch/blob/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/util/util.py#L107
# rgbs_im = torch.where(rgbs_im < 0, torch.zeros_like(rgbs_im), rgbs_im)
# Convert from RGB Linear to sRGB
rgb_im = linear_rgb_to_rgb(rgbs_im)
# Clip to 0,1 https://www.w3.org/Graphics/Color/srgb
if clip:
rgb_im = torch.clamp(rgb_im, min=0.0, max=1.0)
return rgb_im
[docs]class RgbToLab(nn.Module):
r"""Convert an image from RGB to Lab.
The image data is assumed to be in the range of :math:`[0, 1]`. Lab
color is computed using the D65 illuminant and Observer 2.
Returns:
Lab version of the image.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> lab = RgbToLab()
>>> output = lab(input) # 2x3x4x5
Reference:
[1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html
[2] https://www.easyrgb.com/en/math.php
[3] https://github.com/torch/image/blob/dc061b98fb7e946e00034a5fc73e883a299edc7f/generic/image.c#L1467
"""
def forward(self, image: torch.Tensor) -> torch.Tensor:
return rgb_to_lab(image)
[docs]class LabToRgb(nn.Module):
r"""Convert an image from Lab to RGB.
Returns:
RGB version of the image. Range may not be in :math:`[0, 1]`.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> rgb = LabToRgb()
>>> output = rgb(input) # 2x3x4x5
References:
[1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html
[2] https://www.easyrgb.com/en/math.php
[3] https://github.com/torch/image/blob/dc061b98fb7e946e00034a5fc73e883a299edc7f/generic/image.c#L1518
"""
def forward(self, image: torch.Tensor, clip: bool = True) -> torch.Tensor:
return lab_to_rgb(image, clip)