Source code for kornia.io.io
try:
import kornia_rs
except ImportError:
kornia_rs = None
import os
from enum import Enum
import torch
from kornia.color import rgb_to_grayscale, rgba_to_rgb
from kornia.color.gray import grayscale_to_rgb
from kornia.color.rgb import rgb_to_rgba
from kornia.core import Tensor
from kornia.testing import KORNIA_CHECK
[docs]class ImageLoadType(Enum):
r"""Enum to specify the desired image type."""
UNCHANGED = 0
GRAY8 = 1
RGB8 = 2
RGBA8 = 3
GRAY32 = 4
RGB32 = 5
def load_image_to_tensor(path_file: str, device: str) -> Tensor:
# load the file and decodes using kornia_rs. Internally it uses a package that
# combines image-rs a self maintained version of the dlpack-rs. After the decoding,
# the obtained stream bits are encapusalted to a cv::Tensor data structure without
# memory ownership and passed as PyCapsule from rust to python.
cv_tensor = kornia_rs.read_image_rs(path_file)
# for convenience use the torch dlpack parser to get a zero copy torch.Tensor
# TODO: evaluate other potential API so that we can return in numpy, jax, mxnet since
# the kornia_rs cv::Tensor has this ability.
th_tensor = torch.utils.dlpack.from_dlpack(cv_tensor) # type: ignore # HxWx3
# move the tensor to the desired device, move the data layout to CHW and clone
# to return an owned data tensor.
return th_tensor.to(torch.device(device)).permute(2, 0, 1).clone() # CxHxW
def to_float32(image: Tensor) -> Tensor:
KORNIA_CHECK(image.dtype == torch.uint8)
return image.float() / 255.0
def to_uint8(image: Tensor) -> Tensor:
KORNIA_CHECK(image.dtype == torch.float32)
return image.mul(255.0).byte()
[docs]def load_image(path_file: str, desired_type: ImageLoadType, device: str = "cpu") -> Tensor:
"""Read an image file and decode using the Kornia Rust backend.
Args:
path_file: Path to a valid image file.
desired_type: the desired image type, defined by color space and dtype.
device: the device where you want to get your image placed.
Return:
Image tensor with shape :math:`(3,H,W)`.
"""
if kornia_rs is None:
raise ModuleNotFoundError("The io API is not available: `pip install kornia_rs` in a Linux system.")
KORNIA_CHECK(os.path.isfile(path_file), f"Invalid file: {path_file}")
image: Tensor = load_image_to_tensor(path_file, device) # CxHxW
if desired_type == ImageLoadType.UNCHANGED:
return image
elif desired_type == ImageLoadType.GRAY8:
if image.shape[0] == 1 and image.dtype == torch.uint8:
return image
elif image.shape[0] == 3 and image.dtype == torch.uint8:
gray8 = rgb_to_grayscale(image)
return gray8
elif image.shape[0] == 4 and image.dtype == torch.uint8:
gray32 = rgb_to_grayscale(rgba_to_rgb(to_float32(image)))
return to_uint8(gray32)
elif desired_type == ImageLoadType.RGB8:
if image.shape[0] == 3 and image.dtype == torch.uint8:
return image
elif image.shape[0] == 1 and image.dtype == torch.uint8:
rgb8 = grayscale_to_rgb(image)
return rgb8
elif desired_type == ImageLoadType.RGBA8:
if image.shape[0] == 3 and image.dtype == torch.uint8:
rgba32 = rgb_to_rgba(to_float32(image), 0.0)
return to_uint8(rgba32)
elif desired_type == ImageLoadType.GRAY32:
if image.shape[0] == 1 and image.dtype == torch.uint8:
return to_float32(image)
elif image.shape[0] == 3 and image.dtype == torch.uint8:
gray32 = rgb_to_grayscale(to_float32(image))
return gray32
elif image.shape[0] == 4 and image.dtype == torch.uint8:
gray32 = rgb_to_grayscale(rgba_to_rgb(to_float32(image)))
return gray32
elif desired_type == ImageLoadType.RGB32:
if image.shape[0] == 3 and image.dtype == torch.uint8:
return to_float32(image)
elif image.shape[0] == 1 and image.dtype == torch.uint8:
rgb32 = grayscale_to_rgb(to_float32(image))
return rgb32
else:
raise NotImplementedError(f"Unknown type: {desired_type}")
return Tensor([])