import math
import torch
from torch import nn
[docs]def rgb_to_hsv(image: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
r"""Convert an image from RGB to HSV.
.. image:: _static/img/rgb_to_hsv.png
The image data is assumed to be in the range of (0, 1).
Args:
image: RGB Image to be converted to HSV with shape of :math:`(*, 3, H, W)`.
eps: scalar to enforce numarical stability.
Returns:
HSV version of the image with shape of :math:`(*, 3, H, W)`.
The H channel values are in the range 0..2pi. S and V are in the range 0..1.
.. note::
See a working example `here <https://kornia.github.io/tutorials/nbs/color_conversions.html>`__.
Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> output = rgb_to_hsv(input) # 2x3x4x5
"""
if not isinstance(image, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
max_rgb, argmax_rgb = image.max(-3)
min_rgb, argmin_rgb = image.min(-3)
deltac = max_rgb - min_rgb
v = max_rgb
s = deltac / (max_rgb + eps)
deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac)
rc, gc, bc = torch.unbind((max_rgb.unsqueeze(-3) - image), dim=-3)
h1 = bc - gc
h2 = (rc - bc) + 2.0 * deltac
h3 = (gc - rc) + 4.0 * deltac
h = torch.stack((h1, h2, h3), dim=-3) / deltac.unsqueeze(-3)
h = torch.gather(h, dim=-3, index=argmax_rgb.unsqueeze(-3)).squeeze(-3)
h = (h / 6.0) % 1.0
h = 2.0 * math.pi * h # we return 0/2pi output
return torch.stack((h, s, v), dim=-3)
[docs]def hsv_to_rgb(image: torch.Tensor) -> torch.Tensor:
r"""Convert an image from HSV to RGB.
The H channel values are assumed to be in the range 0..2pi. S and V are in the range 0..1.
Args:
image: HSV Image to be converted to HSV with shape of :math:`(*, 3, H, W)`.
Returns:
RGB version of the image with shape of :math:`(*, 3, H, W)`.
Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> output = hsv_to_rgb(input) # 2x3x4x5
"""
if not isinstance(image, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
h: torch.Tensor = image[..., 0, :, :] / (2 * math.pi)
s: torch.Tensor = image[..., 1, :, :]
v: torch.Tensor = image[..., 2, :, :]
hi: torch.Tensor = torch.floor(h * 6) % 6
f: torch.Tensor = ((h * 6) % 6) - hi
one: torch.Tensor = torch.tensor(1.0, device=image.device, dtype=image.dtype)
p: torch.Tensor = v * (one - s)
q: torch.Tensor = v * (one - f * s)
t: torch.Tensor = v * (one - (one - f) * s)
hi = hi.long()
indices: torch.Tensor = torch.stack([hi, hi + 6, hi + 12], dim=-3)
out = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3)
out = torch.gather(out, -3, indices)
return out
[docs]class RgbToHsv(nn.Module):
r"""Convert an image from RGB to HSV.
The image data is assumed to be in the range of (0, 1).
Args:
eps: scalar to enforce numarical stability.
Returns:
HSV version of the image.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> hsv = RgbToHsv()
>>> output = hsv(input) # 2x3x4x5
"""
def __init__(self, eps: float = 1e-6) -> None:
super().__init__()
self.eps = eps
def forward(self, image: torch.Tensor) -> torch.Tensor:
return rgb_to_hsv(image, self.eps)
[docs]class HsvToRgb(nn.Module):
r"""Convert an image from HSV to RGB.
H channel values are assumed to be in the range 0..2pi. S and V are in the range 0..1.
Returns:
RGB version of the image.
Shape:
- image: :math:`(*, 3, H, W)`
- output: :math:`(*, 3, H, W)`
Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> rgb = HsvToRgb()
>>> output = rgb(input) # 2x3x4x5
"""
def forward(self, image: torch.Tensor) -> torch.Tensor:
return hsv_to_rgb(image)