Source code for kornia.utils.image

from typing import Union
from PIL import Image

import numpy as np
import torch


[docs]def image_to_tensor(image: Union[np.ndarray, Image.Image], keepdim: bool = True) -> torch.Tensor: """Converts a numpy or PIL image to a PyTorch 4d tensor image. Args: image (numpy.ndarray or PIL Image): image of the form math:`(H, W, C)`, math: `(H, W)` or math:`(B, H, W, C)`. keepdim (bool): If ``False`` unsqueeze the input image to match the shape math: `(B, H, W, C)`. Default: ``True`` Returns: torch.Tensor: tensor of the form math:`(B, C, H, W)` if keepdim is ``False``, math:`(C, H, W)` otherwise. """ if not isinstance(image, (np.ndarray, Image.Image)): raise TypeError("Input type must be a numpy.ndarray or PIL Image. Got {}".format( type(image))) if isinstance(image, Image.Image): image = np.array(image) if len(image.shape) > 4 or len(image.shape) < 2: raise ValueError( "Input size must be a two, three or four dimensional array") input_shape = image.shape tensor: torch.Tensor = torch.from_numpy(image).to(torch.float) if len(input_shape) == 2: # (H, W) -> (1, H, W) tensor = tensor.unsqueeze(0) elif len(input_shape) == 3: # (H, W, C) -> (C, H, W) tensor = tensor.permute(2, 0, 1) elif len(input_shape) == 4: # (B, H, W, C) -> (B, C, H, W) tensor = tensor.permute(0, 3, 1, 2) keepdim = True # no need to unsqueeze else: raise ValueError( "Cannot process image with shape {}".format(input_shape)) return tensor.unsqueeze(0) if not keepdim else tensor
[docs]def tensor_to_image(tensor: torch.Tensor) -> np.array: """Converts a PyTorch tensor image to a numpy image. In case the tensor is in the GPU, it will be copied back to CPU. Args: tensor (torch.Tensor): image of the form :math:`(H, W)`, math:`(C, H, W)`, or math:`(B, C, H, W)`. Returns: numpy.ndarray: image of the form :math:`(H, W)`, math:`(H, W)`, math:`(H, W, C)`, or math:`(B, H, W, C)`. """ if not torch.is_tensor(tensor): raise TypeError("Input type is not a torch.Tensor. Got {}".format( type(tensor))) if len(tensor.shape) > 4 or len(tensor.shape) < 2: raise ValueError( "Input size must be a two, three or four dimensional tensor") input_shape = tensor.shape image: np.array = tensor.cpu().detach().numpy() if len(input_shape) == 2: # (H, W) -> (H, W) image = image elif len(input_shape) == 3: # (C, H, W) -> (H, W, C) if input_shape[0] == 1: # Grayscale for proper plt.imshow needs to be (H,W) image = image.squeeze() else: image = image.transpose(1, 2, 0) elif len(input_shape) == 4: # (B, C, H, W) -> (B, H, W, C) image = image.transpose(0, 2, 3, 1) if input_shape[0] == 1: image = image.squeeze(0) if input_shape[1] == 1: image = image.squeeze(-1) else: raise ValueError( "Cannot process tensor with shape {}".format(input_shape)) return image