Source code for kornia.color.hsv

import math

import torch
import torch.nn as nn


[docs]def rgb_to_hsv(image: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: r"""Convert an image from RGB to HSV. .. image:: _static/img/rgb_to_hsv.png The image data is assumed to be in the range of (0, 1). Args: image: RGB Image to be converted to HSV with shape of :math:`(*, 3, H, W)`. eps: scalar to enforce numarical stability. Returns: HSV version of the image with shape of :math:`(*, 3, H, W)`. The H channel values are in the range 0..2pi. S and V are in the range 0..1. .. note:: See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/ color_conversions.html>`__. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = rgb_to_hsv(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") # The first or last occurrence is not guaranteed before 1.6.0 # https://github.com/pytorch/pytorch/issues/20414 maxc, _ = image.max(-3) maxc_mask = image == maxc.unsqueeze(-3) _, max_indices = ((maxc_mask.cumsum(-3) == 1) & maxc_mask).max(-3) minc: torch.Tensor = image.min(-3)[0] v: torch.Tensor = maxc # brightness deltac: torch.Tensor = maxc - minc s: torch.Tensor = deltac / (v + eps) # avoid division by zero deltac = torch.where(deltac == 0, torch.ones_like(deltac, device=deltac.device, dtype=deltac.dtype), deltac) maxc_tmp = maxc.unsqueeze(-3) - image rc: torch.Tensor = maxc_tmp[..., 0, :, :] gc: torch.Tensor = maxc_tmp[..., 1, :, :] bc: torch.Tensor = maxc_tmp[..., 2, :, :] h = torch.stack([bc - gc, 2.0 * deltac + rc - bc, 4.0 * deltac + gc - rc], dim=-3) h = torch.gather(h, dim=-3, index=max_indices[..., None, :, :]) h = h.squeeze(-3) h = h / deltac h = (h / 6.0) % 1.0 h = 2 * math.pi * h return torch.stack([h, s, v], dim=-3)
[docs]def hsv_to_rgb(image: torch.Tensor) -> torch.Tensor: r"""Convert an image from HSV to RGB. The H channel values are assumed to be in the range 0..2pi. S and V are in the range 0..1. Args: image: HSV Image to be converted to HSV with shape of :math:`(*, 3, H, W)`. Returns: RGB version of the image with shape of :math:`(*, 3, H, W)`. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = hsv_to_rgb(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") h: torch.Tensor = image[..., 0, :, :] / (2 * math.pi) s: torch.Tensor = image[..., 1, :, :] v: torch.Tensor = image[..., 2, :, :] hi: torch.Tensor = torch.floor(h * 6) % 6 f: torch.Tensor = ((h * 6) % 6) - hi one: torch.Tensor = torch.tensor(1.0).to(image.device) p: torch.Tensor = v * (one - s) q: torch.Tensor = v * (one - f * s) t: torch.Tensor = v * (one - (one - f) * s) hi = hi.long() indices: torch.Tensor = torch.stack([hi, hi + 6, hi + 12], dim=-3) out = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3) out = torch.gather(out, -3, indices) return out
[docs]class RgbToHsv(nn.Module): r"""Convert an image from RGB to HSV. The image data is assumed to be in the range of (0, 1). Args: eps: scalar to enforce numarical stability. Returns: HSV version of the image. Shape: - image: :math:`(*, 3, H, W)` - output: :math:`(*, 3, H, W)` Example: >>> input = torch.rand(2, 3, 4, 5) >>> hsv = RgbToHsv() >>> output = hsv(input) # 2x3x4x5 """ def __init__(self, eps: float = 1e-6) -> None: super().__init__() self.eps = eps def forward(self, image: torch.Tensor) -> torch.Tensor: return rgb_to_hsv(image, self.eps)
[docs]class HsvToRgb(nn.Module): r"""Convert an image from HSV to RGB. H channel values are assumed to be in the range 0..2pi. S and V are in the range 0..1. Returns: RGB version of the image. Shape: - image: :math:`(*, 3, H, W)` - output: :math:`(*, 3, H, W)` Example: >>> input = torch.rand(2, 3, 4, 5) >>> rgb = HsvToRgb() >>> output = rgb(input) # 2x3x4x5 """ def forward(self, image: torch.Tensor) -> torch.Tensor: return hsv_to_rgb(image)