Source code for kornia.color.lab

"""The RGB to Lab color transformations were translated from scikit image's rgb2lab and lab2rgb.

https://github.com/scikit-image/scikit-image/blob/a48bf6774718c64dade4548153ae16065b595ca9/skimage/color/colorconv.py
"""


import torch
from torch import nn

from .rgb import linear_rgb_to_rgb, rgb_to_linear_rgb
from .xyz import rgb_to_xyz, xyz_to_rgb


[docs]def rgb_to_lab(image: torch.Tensor) -> torch.Tensor: r"""Convert a RGB image to Lab. .. image:: _static/img/rgb_to_lab.png The input RGB image is assumed to be in the range of :math:`[0, 1]`. Lab color is computed using the D65 illuminant and Observer 2. Args: image: RGB Image to be converted to Lab with shape :math:`(*, 3, H, W)`. Returns: Lab version of the image with shape :math:`(*, 3, H, W)`. The L channel values are in the range 0..100. a and b are in the range -128..127. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = rgb_to_lab(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") # Convert from sRGB to Linear RGB lin_rgb = rgb_to_linear_rgb(image) xyz_im: torch.Tensor = rgb_to_xyz(lin_rgb) # normalize for D65 white point xyz_ref_white = torch.tensor([0.95047, 1.0, 1.08883], device=xyz_im.device, dtype=xyz_im.dtype)[..., :, None, None] xyz_normalized = torch.div(xyz_im, xyz_ref_white) threshold = 0.008856 power = torch.pow(xyz_normalized.clamp(min=threshold), 1 / 3.0) scale = 7.787 * xyz_normalized + 4.0 / 29.0 xyz_int = torch.where(xyz_normalized > threshold, power, scale) x: torch.Tensor = xyz_int[..., 0, :, :] y: torch.Tensor = xyz_int[..., 1, :, :] z: torch.Tensor = xyz_int[..., 2, :, :] L: torch.Tensor = (116.0 * y) - 16.0 a: torch.Tensor = 500.0 * (x - y) _b: torch.Tensor = 200.0 * (y - z) out: torch.Tensor = torch.stack([L, a, _b], dim=-3) return out
[docs]def lab_to_rgb(image: torch.Tensor, clip: bool = True) -> torch.Tensor: r"""Convert a Lab image to RGB. The L channel is assumed to be in the range of :math:`[0, 100]`. a and b channels are in the range of :math:`[-128, 127]`. Args: image: Lab image to be converted to RGB with shape :math:`(*, 3, H, W)`. clip: Whether to apply clipping to insure output RGB values in range :math:`[0, 1]`. Returns: Lab version of the image with shape :math:`(*, 3, H, W)`. The output RGB image are in the range of :math:`[0, 1]`. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = lab_to_rgb(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") L: torch.Tensor = image[..., 0, :, :] a: torch.Tensor = image[..., 1, :, :] _b: torch.Tensor = image[..., 2, :, :] fy = (L + 16.0) / 116.0 fx = (a / 500.0) + fy fz = fy - (_b / 200.0) # if color data out of range: Z < 0 fz = fz.clamp(min=0.0) fxyz = torch.stack([fx, fy, fz], dim=-3) # Convert from Lab to XYZ power = torch.pow(fxyz, 3.0) scale = (fxyz - 4.0 / 29.0) / 7.787 xyz = torch.where(fxyz > 0.2068966, power, scale) # For D65 white point xyz_ref_white = torch.tensor([0.95047, 1.0, 1.08883], device=xyz.device, dtype=xyz.dtype)[..., :, None, None] xyz_im = xyz * xyz_ref_white rgbs_im: torch.Tensor = xyz_to_rgb(xyz_im) # https://github.com/richzhang/colorization-pytorch/blob/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/util/util.py#L107 # rgbs_im = torch.where(rgbs_im < 0, torch.zeros_like(rgbs_im), rgbs_im) # Convert from RGB Linear to sRGB rgb_im = linear_rgb_to_rgb(rgbs_im) # Clip to 0,1 https://www.w3.org/Graphics/Color/srgb if clip: rgb_im = torch.clamp(rgb_im, min=0.0, max=1.0) return rgb_im
[docs]class RgbToLab(nn.Module): r"""Convert an image from RGB to Lab. The image data is assumed to be in the range of :math:`[0, 1]`. Lab color is computed using the D65 illuminant and Observer 2. Returns: Lab version of the image. Shape: - image: :math:`(*, 3, H, W)` - output: :math:`(*, 3, H, W)` Examples: >>> input = torch.rand(2, 3, 4, 5) >>> lab = RgbToLab() >>> output = lab(input) # 2x3x4x5 Reference: [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html [2] https://www.easyrgb.com/en/math.php [3] https://github.com/torch/image/blob/dc061b98fb7e946e00034a5fc73e883a299edc7f/generic/image.c#L1467 """ def forward(self, image: torch.Tensor) -> torch.Tensor: return rgb_to_lab(image)
[docs]class LabToRgb(nn.Module): r"""Convert an image from Lab to RGB. Returns: RGB version of the image. Range may not be in :math:`[0, 1]`. Shape: - image: :math:`(*, 3, H, W)` - output: :math:`(*, 3, H, W)` Examples: >>> input = torch.rand(2, 3, 4, 5) >>> rgb = LabToRgb() >>> output = rgb(input) # 2x3x4x5 References: [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html [2] https://www.easyrgb.com/en/math.php [3] https://github.com/torch/image/blob/dc061b98fb7e946e00034a5fc73e883a299edc7f/generic/image.c#L1518 """ def forward(self, image: torch.Tensor, clip: bool = True) -> torch.Tensor: return lab_to_rgb(image, clip)