Source code for kornia.io.io

from __future__ import annotations

try:
    import kornia_rs
except ImportError:  # pragma: no cover
    kornia_rs = None

import os
from enum import Enum
from pathlib import Path

import torch
from torch.utils import dlpack  # TODO: remove this if kornia relies on torch>=1.10

from kornia.color import rgb_to_grayscale, rgba_to_rgb
from kornia.color.gray import grayscale_to_rgb
from kornia.color.rgb import rgb_to_rgba
from kornia.core import Device, Tensor
from kornia.core.check import KORNIA_CHECK


[docs]class ImageLoadType(Enum): r"""Enum to specify the desired image type.""" UNCHANGED = 0 GRAY8 = 1 RGB8 = 2 RGBA8 = 3 GRAY32 = 4 RGB32 = 5
def load_image_to_tensor(path_file: str, device: Device) -> Tensor: # load the file and decodes using kornia_rs. Internally it uses a package that # combines image-rs a self maintained version of the dlpack-rs. After the decoding, # the obtained stream bits are encapusalted to a cv::Tensor data structure without # memory ownership and passed as PyCapsule from rust to python. cv_tensor = kornia_rs.read_image_rs(path_file) # for convenience use the torch dlpack parser to get a zero copy torch.Tensor # TODO: evaluate other potential API so that we can return in numpy, jax, mxnet since # the kornia_rs cv::Tensor has this ability. th_tensor = dlpack.from_dlpack(cv_tensor) # HxWx3 # move the tensor to the desired device, move the data layout to CHW and clone # to return an owned data tensor. dev = device if isinstance(device, torch.device) or device is None else torch.device(device) return th_tensor.to(device=dev).permute(2, 0, 1).clone() # CxHxW def to_float32(image: Tensor) -> Tensor: KORNIA_CHECK(image.dtype == torch.uint8) return image.float() / 255.0 def to_uint8(image: Tensor) -> Tensor: KORNIA_CHECK(image.dtype == torch.float32) return image.mul(255.0).byte()
[docs]def load_image(path_file: str | Path, desired_type: ImageLoadType, device: Device = "cpu") -> Tensor: """Read an image file and decode using the Kornia Rust backend. Args: path_file: Path to a valid image file. desired_type: the desired image type, defined by color space and dtype. device: the device where you want to get your image placed. Return: Image tensor with shape :math:`(3,H,W)`. """ if kornia_rs is None: # pragma: no cover raise ModuleNotFoundError("The io API is not available: `pip install kornia_rs` in a Linux system.") if isinstance(path_file, Path): path_file = str(path_file) KORNIA_CHECK(os.path.isfile(path_file), f"Invalid file: {path_file}") image: Tensor = load_image_to_tensor(path_file, device) # CxHxW if desired_type == ImageLoadType.UNCHANGED: return image elif desired_type == ImageLoadType.GRAY8: if image.shape[0] == 1 and image.dtype == torch.uint8: return image elif image.shape[0] == 3 and image.dtype == torch.uint8: gray8 = rgb_to_grayscale(image) return gray8 elif image.shape[0] == 4 and image.dtype == torch.uint8: gray32 = rgb_to_grayscale(rgba_to_rgb(to_float32(image))) return to_uint8(gray32) elif desired_type == ImageLoadType.RGB8: if image.shape[0] == 3 and image.dtype == torch.uint8: return image elif image.shape[0] == 1 and image.dtype == torch.uint8: rgb8 = grayscale_to_rgb(image) return rgb8 elif desired_type == ImageLoadType.RGBA8: if image.shape[0] == 3 and image.dtype == torch.uint8: rgba32 = rgb_to_rgba(to_float32(image), 0.0) return to_uint8(rgba32) elif desired_type == ImageLoadType.GRAY32: if image.shape[0] == 1 and image.dtype == torch.uint8: return to_float32(image) elif image.shape[0] == 3 and image.dtype == torch.uint8: gray32 = rgb_to_grayscale(to_float32(image)) return gray32 elif image.shape[0] == 4 and image.dtype == torch.uint8: gray32 = rgb_to_grayscale(rgba_to_rgb(to_float32(image))) return gray32 elif desired_type == ImageLoadType.RGB32: if image.shape[0] == 3 and image.dtype == torch.uint8: return to_float32(image) elif image.shape[0] == 1 and image.dtype == torch.uint8: rgb32 = grayscale_to_rgb(to_float32(image)) return rgb32 raise NotImplementedError(f"Unknown type: {desired_type}")
[docs]def write_image(path_file: str | Path, image: Tensor) -> None: """Save an image file using the Kornia Rust backend. For now, we only support the writing of JPEG of the following types: RGB8. Args: path_file: Path to a valid image file. image: Image tensor with shape :math:`(3,H,W)`. Return: None. """ if kornia_rs is None: # pragma: no cover raise ModuleNotFoundError("The io API is not available: `pip install kornia_rs` in a Linux system.") if isinstance(path_file, Path): path_file = str(path_file) KORNIA_CHECK("jpg" in path_file[-3:], f"Invalid file extension: {path_file}") KORNIA_CHECK(image.dim() == 3 and image.shape[0] == 3, f"Invalid image shape: {image.shape}") KORNIA_CHECK(image.dtype == torch.uint8, f"Invalid image dtype: {image.dtype}") # create the image encoder image_encoder = kornia_rs.ImageEncoder() image_encoder.set_quality(100) # move the tensor to the cpu and clone to avoid memory ownership issues. image = image.cpu().clone() # 3xHxW # move the data layout to HWC and convert to numpy image_np = image.permute(1, 2, 0).numpy() # HxWx3 # encode the image using the kornia_rs image_encoded: list[int] = image_encoder.encode(image_np.tobytes(), image_np.shape) # save the image using the kornia_rs.write_image_jpeg(path_file, image_encoded)