Source code for kornia.color.xyz

from typing import Union

import torch
import torch.nn as nn


[docs]class RgbToXyz(nn.Module): r"""Converts an image from RGB to XYZ The image data is assumed to be in the range of (0, 1). args: image (torch.Tensor): RGB image to be converted to XYZ. returns: torch.Tensor: XYZ version of the image. shape: - image: :math:`(*, 3, H, W)` - output: :math:`(*, 3, H, W)` Examples: >>> input = torch.rand(2, 3, 4, 5) >>> xyz = kornia.color.RgbToXyz() >>> output = xyz(input) # 2x3x4x5 Reference: [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html """ def __init__(self) -> None: super(RgbToXyz, self).__init__() def forward(self, image: torch.Tensor) -> torch.Tensor: # type: ignore return rgb_to_xyz(image)
[docs]class XyzToRgb(nn.Module): r"""Converts an image from XYZ to RGB args: image (torch.Tensor): XYZ image to be converted to RGB. returns: torch.Tensor: RGB version of the image. shape: - image: :math:`(*, 3, H, W)` - output: :math:`(*, 3, H, W)` Examples: >>> input = torch.rand(2, 3, 4, 5) >>> rgb = kornia.color.XyzToRgb() >>> output = rgb(input) # 2x3x4x5 Reference: [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html """ def __init__(self) -> None: super(XyzToRgb, self).__init__() def forward(self, image: torch.Tensor) -> torch.Tensor: # type: ignore return xyz_to_rgb(image)
[docs]def rgb_to_xyz(image: torch.Tensor) -> torch.Tensor: r"""Converts a RGB image to XYZ. See :class:`~kornia.color.RgbToXyz` for details. Args: image (torch.Tensor): RGB Image to be converted to XYZ. Returns: torch.Tensor: XYZ version of the image. """ if not torch.is_tensor(image): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(image))) if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}" .format(image.shape)) r: torch.Tensor = image[..., 0, :, :] g: torch.Tensor = image[..., 1, :, :] b: torch.Tensor = image[..., 2, :, :] x: torch.Tensor = 0.412453 * r + 0.357580 * g + 0.180423 * b y: torch.Tensor = 0.212671 * r + 0.715160 * g + 0.072169 * b z: torch.Tensor = 0.019334 * r + 0.119193 * g + 0.950227 * b out: torch.Tensor = torch.stack((x, y, z), -3) return out
[docs]def xyz_to_rgb(image: torch.Tensor) -> torch.Tensor: r"""Converts a XYZ image to RGB. See :class:`~kornia.color.XyzToRgb` for details. Args: image (torch.Tensor): XYZ Image to be converted to RGB. Returns: torch.Tensor: RGB version of the image. """ if not torch.is_tensor(image): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(image))) if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError("Input size must have a shape of (*, 3, H, W). Got {}" .format(image.shape)) x: torch.Tensor = image[..., 0, :, :] y: torch.Tensor = image[..., 1, :, :] z: torch.Tensor = image[..., 2, :, :] r: torch.Tensor = 3.2404813432005266 * x + -1.5371515162713185 * y + -0.4985363261688878 * z g: torch.Tensor = -0.9692549499965682 * x + 1.8759900014898907 * y + 0.0415559265582928 * z b: torch.Tensor = 0.0556466391351772 * x + -0.2040413383665112 * y + 1.0573110696453443 * z out: torch.Tensor = torch.stack((r, g, b), dim=-3) return out