Source code for kornia.color.gray

import torch
import torch.nn as nn

from kornia.color.rgb import bgr_to_rgb


[docs]def grayscale_to_rgb(image: torch.Tensor) -> torch.Tensor: r"""Convert a grayscale image to RGB version of image. .. image:: _static/img/grayscale_to_rgb.png The image data is assumed to be in the range of (0, 1). Args: image: grayscale image to be converted to RGB with shape :math:`(*,1,H,W)`. Returns: RGB version of the image with shape :math:`(*,3,H,W)`. Example: >>> input = torch.randn(2, 1, 4, 5) >>> gray = grayscale_to_rgb(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. " f"Got {type(image)}") if image.dim() < 3 or image.size(-3) != 1: raise ValueError(f"Input size must have a shape of (*, 1, H, W). " f"Got {image.shape}.") rgb: torch.Tensor = torch.cat([image, image, image], dim=-3) # TODO: we should find a better way to raise this kind of warnings # if not torch.is_floating_point(image): # warnings.warn(f"Input image is not of float dtype. Got {image.dtype}") return rgb
[docs]def rgb_to_grayscale( image: torch.Tensor, rgb_weights: torch.Tensor = torch.tensor([0.299, 0.587, 0.114]) ) -> torch.Tensor: r"""Convert a RGB image to grayscale version of image. .. image:: _static/img/rgb_to_grayscale.png The image data is assumed to be in the range of (0, 1). Args: image: RGB image to be converted to grayscale with shape :math:`(*,3,H,W)`. rgb_weights: Weights that will be applied on each channel (RGB). The sum of the weights should add up to one. Returns: grayscale version of the image with shape :math:`(*,1,H,W)`. .. note:: See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/ color_conversions.html>`__. Example: >>> input = torch.rand(2, 3, 4, 5) >>> gray = rgb_to_grayscale(input) # 2x1x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") if not isinstance(rgb_weights, torch.Tensor): raise TypeError(f"rgb_weights is not a torch.Tensor. Got {type(rgb_weights)}") if rgb_weights.shape[-1] != 3: raise ValueError(f"rgb_weights must have a shape of (*, 3). Got {rgb_weights.shape}") r: torch.Tensor = image[..., 0:1, :, :] g: torch.Tensor = image[..., 1:2, :, :] b: torch.Tensor = image[..., 2:3, :, :] if not torch.is_floating_point(image) and (image.dtype != rgb_weights.dtype): raise TypeError( f"Input image and rgb_weights should be of same dtype. Got {image.dtype} and {rgb_weights.dtype}" ) w_r, w_g, w_b = rgb_weights.to(image).unbind() return w_r * r + w_g * g + w_b * b
[docs]def bgr_to_grayscale(image: torch.Tensor) -> torch.Tensor: r"""Convert a BGR image to grayscale. The image data is assumed to be in the range of (0, 1). First flips to RGB, then converts. Args: image: BGR image to be converted to grayscale with shape :math:`(*,3,H,W)`. Returns: grayscale version of the image with shape :math:`(*,1,H,W)`. Example: >>> input = torch.rand(2, 3, 4, 5) >>> gray = bgr_to_grayscale(input) # 2x1x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") image_rgb = bgr_to_rgb(image) return rgb_to_grayscale(image_rgb)
[docs]class GrayscaleToRgb(nn.Module): r"""Module to convert a grayscale image to RGB version of image. The image data is assumed to be in the range of (0, 1). Shape: - image: :math:`(*, 1, H, W)` - output: :math:`(*, 3, H, W)` reference: https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html Example: >>> input = torch.rand(2, 1, 4, 5) >>> rgb = GrayscaleToRgb() >>> output = rgb(input) # 2x3x4x5 """ def forward(self, image: torch.Tensor) -> torch.Tensor: # type: ignore return grayscale_to_rgb(image)
[docs]class RgbToGrayscale(nn.Module): r"""Module to convert a RGB image to grayscale version of image. The image data is assumed to be in the range of (0, 1). Shape: - image: :math:`(*, 3, H, W)` - output: :math:`(*, 1, H, W)` reference: https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html Example: >>> input = torch.rand(2, 3, 4, 5) >>> gray = RgbToGrayscale() >>> output = gray(input) # 2x1x4x5 """ def __init__(self, rgb_weights: torch.Tensor = torch.tensor([0.299, 0.587, 0.114])) -> None: super().__init__() self.rgb_weights = rgb_weights def forward(self, image: torch.Tensor) -> torch.Tensor: # type: ignore return rgb_to_grayscale(image, rgb_weights=self.rgb_weights)
[docs]class BgrToGrayscale(nn.Module): r"""Module to convert a BGR image to grayscale version of image. The image data is assumed to be in the range of (0, 1). First flips to RGB, then converts. Shape: - image: :math:`(*, 3, H, W)` - output: :math:`(*, 1, H, W)` reference: https://docs.opencv.org/4.0.1/de/d25/imgproc_color_conversions.html Example: >>> input = torch.rand(2, 3, 4, 5) >>> gray = BgrToGrayscale() >>> output = gray(input) # 2x1x4x5 """ def forward(self, image: torch.Tensor) -> torch.Tensor: # type: ignore return bgr_to_grayscale(image)