from __future__ import annotations
import torch
from kornia.color.rgb import bgr_to_rgb
from kornia.core import Module, Tensor, concatenate
from kornia.core.check import KORNIA_CHECK_IS_TENSOR
[docs]def grayscale_to_rgb(image: Tensor) -> Tensor:
r"""Convert a grayscale image to RGB version of image.
.. image:: _static/img/grayscale_to_rgb.png
The image data is assumed to be in the range of (0, 1).
Args:
image: grayscale image tensor to be converted to RGB with shape :math:`(*,1,H,W)`.
Returns:
RGB version of the image with shape :math:`(*,3,H,W)`.
Example:
>>> input = torch.randn(2, 1, 4, 5)
>>> gray = grayscale_to_rgb(input) # 2x3x4x5
"""
KORNIA_CHECK_IS_TENSOR(image)
if len(image.shape) < 3 or image.shape[-3] != 1:
raise ValueError(f"Input size must have a shape of (*, 1, H, W). Got {image.shape}.")
return concatenate([image, image, image], -3)
[docs]def rgb_to_grayscale(image: Tensor, rgb_weights: Tensor | None = None) -> Tensor:
r"""Convert a RGB image to grayscale version of image.
.. image:: _static/img/rgb_to_grayscale.png
The image data is assumed to be in the range of (0, 1).
Args:
image: RGB image to be converted to grayscale with shape :math:`(*,3,H,W)`.
rgb_weights: Weights that will be applied on each channel (RGB).
The sum of the weights should add up to one.
Returns:
grayscale version of the image with shape :math:`(*,1,H,W)`.
.. note::
See a working example `here <https://kornia.github.io/tutorials/nbs/color_conversions.html>`__.
Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> gray = rgb_to_grayscale(input) # 2x1x4x5
"""
KORNIA_CHECK_IS_TENSOR(image)
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
if rgb_weights is None:
# 8 bit images
if image.dtype == torch.uint8:
rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8)
# floating point images
elif image.dtype in (torch.float16, torch.float32, torch.float64):
rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype)
else:
raise TypeError(f"Unknown data type: {image.dtype}")
else:
# is tensor that we make sure is in the same device/dtype
rgb_weights = rgb_weights.to(image)
# unpack the color image channels with RGB order
r: Tensor = image[..., 0:1, :, :]
g: Tensor = image[..., 1:2, :, :]
b: Tensor = image[..., 2:3, :, :]
w_r, w_g, w_b = rgb_weights.unbind()
return w_r * r + w_g * g + w_b * b
[docs]def bgr_to_grayscale(image: Tensor) -> Tensor:
r"""Convert a BGR image to grayscale.
The image data is assumed to be in the range of (0, 1). First flips to RGB, then converts.
Args:
image: BGR image to be converted to grayscale with shape :math:`(*,3,H,W)`.
Returns:
grayscale version of the image with shape :math:`(*,1,H,W)`.
Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> gray = bgr_to_grayscale(input) # 2x1x4x5
"""
KORNIA_CHECK_IS_TENSOR(image)
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
image_rgb: Tensor = bgr_to_rgb(image)
return rgb_to_grayscale(image_rgb)
[docs]class GrayscaleToRgb(Module):
r"""Module to convert a grayscale image to RGB version of image.
The image data is assumed to be in the range of (0, 1).
Shape:
- image: :math:`(*, 1, H, W)`
- output: :math:`(*, 3, H, W)`
reference:
https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html
Example:
>>> input = torch.rand(2, 1, 4, 5)
>>> rgb = GrayscaleToRgb()
>>> output = rgb(input) # 2x3x4x5
"""
def forward(self, image: Tensor) -> Tensor:
return grayscale_to_rgb(image)
[docs]class RgbToGrayscale(Module):
r"""Module to convert a RGB image to grayscale version of image.
The image data is assumed to be in the range of (0, 1).
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 1, H, W)`
reference:
https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html
Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> gray = RgbToGrayscale()
>>> output = gray(input) # 2x1x4x5
"""
def __init__(self, rgb_weights: Tensor | None = None) -> None:
super().__init__()
if rgb_weights is None:
rgb_weights = Tensor([0.299, 0.587, 0.114])
self.rgb_weights = rgb_weights
def forward(self, image: Tensor) -> Tensor:
return rgb_to_grayscale(image, rgb_weights=self.rgb_weights)
[docs]class BgrToGrayscale(Module):
r"""Module to convert a BGR image to grayscale version of image.
The image data is assumed to be in the range of (0, 1). First flips to RGB, then converts.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 1, H, W)`
reference:
https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html
Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> gray = BgrToGrayscale()
>>> output = gray(input) # 2x1x4x5
"""
def forward(self, image: Tensor) -> Tensor:
return bgr_to_grayscale(image)