import torch
from torch import Tensor, nn
def _rgb_to_y(r: Tensor, g: Tensor, b: Tensor) -> Tensor:
y: Tensor = 0.299 * r + 0.587 * g + 0.114 * b
return y
[docs]def rgb_to_ycbcr(image: Tensor) -> Tensor:
r"""Convert an RGB image to YCbCr.
.. image:: _static/img/rgb_to_ycbcr.png
Args:
image: RGB Image to be converted to YCbCr with shape :math:`(*, 3, H, W)`.
Returns:
YCbCr version of the image with shape :math:`(*, 3, H, W)`.
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> output = rgb_to_ycbcr(input) # 2x3x4x5
"""
if not isinstance(image, Tensor):
raise TypeError(f"Input type is not a Tensor. Got {type(image)}")
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
r: Tensor = image[..., 0, :, :]
g: Tensor = image[..., 1, :, :]
b: Tensor = image[..., 2, :, :]
delta: float = 0.5
y: Tensor = _rgb_to_y(r, g, b)
cb: Tensor = (b - y) * 0.564 + delta
cr: Tensor = (r - y) * 0.713 + delta
return torch.stack([y, cb, cr], -3)
def rgb_to_y(image: Tensor) -> Tensor:
r"""Convert an RGB image to Y.
Args:
image: RGB Image to be converted to Y with shape :math:`(*, 3, H, W)`.
Returns:
Y version of the image with shape :math:`(*, 1, H, W)`.
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> output = rgb_to_y(input) # 2x1x4x5
"""
if not isinstance(image, Tensor):
raise TypeError(f"Input type is not a Tensor. Got {type(image)}")
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
r: Tensor = image[..., 0:1, :, :]
g: Tensor = image[..., 1:2, :, :]
b: Tensor = image[..., 2:3, :, :]
y: Tensor = _rgb_to_y(r, g, b)
return y
[docs]def ycbcr_to_rgb(image: Tensor) -> Tensor:
r"""Convert an YCbCr image to RGB.
The image data is assumed to be in the range of (0, 1).
Args:
image: YCbCr Image to be converted to RGB with shape :math:`(*, 3, H, W)`.
Returns:
RGB version of the image with shape :math:`(*, 3, H, W)`.
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> output = ycbcr_to_rgb(input) # 2x3x4x5
"""
if not isinstance(image, Tensor):
raise TypeError(f"Input type is not a Tensor. Got {type(image)}")
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
y: Tensor = image[..., 0, :, :]
cb: Tensor = image[..., 1, :, :]
cr: Tensor = image[..., 2, :, :]
delta: float = 0.5
cb_shifted: Tensor = cb - delta
cr_shifted: Tensor = cr - delta
r: Tensor = y + 1.403 * cr_shifted
g: Tensor = y - 0.714 * cr_shifted - 0.344 * cb_shifted
b: Tensor = y + 1.773 * cb_shifted
return torch.stack([r, g, b], -3)
[docs]class RgbToYcbcr(nn.Module):
r"""Convert an image from RGB to YCbCr.
The image data is assumed to be in the range of (0, 1).
Returns:
YCbCr version of the image.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> ycbcr = RgbToYcbcr()
>>> output = ycbcr(input) # 2x3x4x5
"""
def forward(self, image: Tensor) -> Tensor:
return rgb_to_ycbcr(image)
[docs]class YcbcrToRgb(nn.Module):
r"""Convert an image from YCbCr to Rgb.
The image data is assumed to be in the range of (0, 1).
Returns:
RGB version of the image.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Examples:
>>> input = torch.rand(2, 3, 4, 5)
>>> rgb = YcbcrToRgb()
>>> output = rgb(input) # 2x3x4x5
"""
def forward(self, image: Tensor) -> Tensor:
return ycbcr_to_rgb(image)