Source code for kornia.color.gray

from __future__ import annotations

import torch

from kornia.color.rgb import bgr_to_rgb
from kornia.core import Module, Tensor, concatenate
from kornia.core.check import KORNIA_CHECK_IS_TENSOR


[docs]def grayscale_to_rgb(image: Tensor) -> Tensor: r"""Convert a grayscale image to RGB version of image. .. image:: _static/img/grayscale_to_rgb.png The image data is assumed to be in the range of (0, 1). Args: image: grayscale image tensor to be converted to RGB with shape :math:`(*,1,H,W)`. Returns: RGB version of the image with shape :math:`(*,3,H,W)`. Example: >>> input = torch.randn(2, 1, 4, 5) >>> gray = grayscale_to_rgb(input) # 2x3x4x5 """ KORNIA_CHECK_IS_TENSOR(image) if len(image.shape) < 3 or image.shape[-3] != 1: raise ValueError(f"Input size must have a shape of (*, 1, H, W). Got {image.shape}.") return concatenate([image, image, image], -3)
[docs]def rgb_to_grayscale(image: Tensor, rgb_weights: Tensor | None = None) -> Tensor: r"""Convert a RGB image to grayscale version of image. .. image:: _static/img/rgb_to_grayscale.png The image data is assumed to be in the range of (0, 1). Args: image: RGB image to be converted to grayscale with shape :math:`(*,3,H,W)`. rgb_weights: Weights that will be applied on each channel (RGB). The sum of the weights should add up to one. Returns: grayscale version of the image with shape :math:`(*,1,H,W)`. .. note:: See a working example `here <https://kornia.github.io/tutorials/nbs/color_conversions.html>`__. Example: >>> input = torch.rand(2, 3, 4, 5) >>> gray = rgb_to_grayscale(input) # 2x1x4x5 """ KORNIA_CHECK_IS_TENSOR(image) if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") if rgb_weights is None: # 8 bit images if image.dtype == torch.uint8: rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8) # floating point images elif image.dtype in (torch.float16, torch.float32, torch.float64): rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype) else: raise TypeError(f"Unknown data type: {image.dtype}") else: # is tensor that we make sure is in the same device/dtype rgb_weights = rgb_weights.to(image) # unpack the color image channels with RGB order r: Tensor = image[..., 0:1, :, :] g: Tensor = image[..., 1:2, :, :] b: Tensor = image[..., 2:3, :, :] w_r, w_g, w_b = rgb_weights.unbind() return w_r * r + w_g * g + w_b * b
[docs]def bgr_to_grayscale(image: Tensor) -> Tensor: r"""Convert a BGR image to grayscale. The image data is assumed to be in the range of (0, 1). First flips to RGB, then converts. Args: image: BGR image to be converted to grayscale with shape :math:`(*,3,H,W)`. Returns: grayscale version of the image with shape :math:`(*,1,H,W)`. Example: >>> input = torch.rand(2, 3, 4, 5) >>> gray = bgr_to_grayscale(input) # 2x1x4x5 """ KORNIA_CHECK_IS_TENSOR(image) if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") image_rgb: Tensor = bgr_to_rgb(image) return rgb_to_grayscale(image_rgb)
[docs]class GrayscaleToRgb(Module): r"""Module to convert a grayscale image to RGB version of image. The image data is assumed to be in the range of (0, 1). Shape: - image: :math:`(*, 1, H, W)` - output: :math:`(*, 3, H, W)` reference: https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html Example: >>> input = torch.rand(2, 1, 4, 5) >>> rgb = GrayscaleToRgb() >>> output = rgb(input) # 2x3x4x5 """ def forward(self, image: Tensor) -> Tensor: return grayscale_to_rgb(image)
[docs]class RgbToGrayscale(Module): r"""Module to convert a RGB image to grayscale version of image. The image data is assumed to be in the range of (0, 1). Shape: - image: :math:`(*, 3, H, W)` - output: :math:`(*, 1, H, W)` reference: https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html Example: >>> input = torch.rand(2, 3, 4, 5) >>> gray = RgbToGrayscale() >>> output = gray(input) # 2x1x4x5 """ def __init__(self, rgb_weights: Tensor | None = None) -> None: super().__init__() if rgb_weights is None: rgb_weights = Tensor([0.299, 0.587, 0.114]) self.rgb_weights = rgb_weights def forward(self, image: Tensor) -> Tensor: return rgb_to_grayscale(image, rgb_weights=self.rgb_weights)
[docs]class BgrToGrayscale(Module): r"""Module to convert a BGR image to grayscale version of image. The image data is assumed to be in the range of (0, 1). First flips to RGB, then converts. Shape: - image: :math:`(*, 3, H, W)` - output: :math:`(*, 1, H, W)` reference: https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html Example: >>> input = torch.rand(2, 3, 4, 5) >>> gray = BgrToGrayscale() >>> output = gray(input) # 2x1x4x5 """ def forward(self, image: Tensor) -> Tensor: return bgr_to_grayscale(image)