Source code for kornia.color.xyz

import torch
import torch.nn as nn


[docs]def rgb_to_xyz(image: torch.Tensor) -> torch.Tensor: r"""Convert a RGB image to XYZ. .. image:: _static/img/rgb_to_xyz.png Args: image: RGB Image to be converted to XYZ with shape :math:`(*, 3, H, W)`. Returns: XYZ version of the image with shape :math:`(*, 3, H, W)`. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = rgb_to_xyz(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") r: torch.Tensor = image[..., 0, :, :] g: torch.Tensor = image[..., 1, :, :] b: torch.Tensor = image[..., 2, :, :] x: torch.Tensor = 0.412453 * r + 0.357580 * g + 0.180423 * b y: torch.Tensor = 0.212671 * r + 0.715160 * g + 0.072169 * b z: torch.Tensor = 0.019334 * r + 0.119193 * g + 0.950227 * b out: torch.Tensor = torch.stack([x, y, z], -3) return out
[docs]def xyz_to_rgb(image: torch.Tensor) -> torch.Tensor: r"""Convert a XYZ image to RGB. Args: image: XYZ Image to be converted to RGB with shape :math:`(*, 3, H, W)`. Returns: RGB version of the image with shape :math:`(*, 3, H, W)`. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = xyz_to_rgb(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") x: torch.Tensor = image[..., 0, :, :] y: torch.Tensor = image[..., 1, :, :] z: torch.Tensor = image[..., 2, :, :] r: torch.Tensor = 3.2404813432005266 * x + -1.5371515162713185 * y + -0.4985363261688878 * z g: torch.Tensor = -0.9692549499965682 * x + 1.8759900014898907 * y + 0.0415559265582928 * z b: torch.Tensor = 0.0556466391351772 * x + -0.2040413383665112 * y + 1.0573110696453443 * z out: torch.Tensor = torch.stack([r, g, b], dim=-3) return out
[docs]class RgbToXyz(nn.Module): r"""Convert an image from RGB to XYZ. The image data is assumed to be in the range of (0, 1). Returns: XYZ version of the image. Shape: - image: :math:`(*, 3, H, W)` - output: :math:`(*, 3, H, W)` Examples: >>> input = torch.rand(2, 3, 4, 5) >>> xyz = RgbToXyz() >>> output = xyz(input) # 2x3x4x5 Reference: [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html """ def forward(self, image: torch.Tensor) -> torch.Tensor: return rgb_to_xyz(image)
[docs]class XyzToRgb(nn.Module): r"""Converts an image from XYZ to RGB. Returns: RGB version of the image. Shape: - image: :math:`(*, 3, H, W)` - output: :math:`(*, 3, H, W)` Examples: >>> input = torch.rand(2, 3, 4, 5) >>> rgb = XyzToRgb() >>> output = rgb(input) # 2x3x4x5 Reference: [1] https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html """ def forward(self, image: torch.Tensor) -> torch.Tensor: return xyz_to_rgb(image)