Source code for kornia.utils.grid

from typing import Optional

import torch
from torch import Tensor, stack

from kornia.utils._compat import torch_meshgrid


[docs]def create_meshgrid( height: int, width: int, normalized_coordinates: bool = True, device: Optional[torch.device] = torch.device('cpu'), dtype: torch.dtype = torch.float32, ) -> Tensor: """Generate a coordinate grid for an image. When the flag ``normalized_coordinates`` is set to True, the grid is normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch function :py:func:`torch.nn.functional.grid_sample`. Args: height: the image height (rows). width: the image width (cols). normalized_coordinates: whether to normalize coordinates in the range :math:`[-1,1]` in order to be consistent with the PyTorch function :py:func:`torch.nn.functional.grid_sample`. device: the device on which the grid will be generated. dtype: the data type of the generated grid. Return: grid tensor with shape :math:`(1, H, W, 2)`. Example: >>> create_meshgrid(2, 2) tensor([[[[-1., -1.], [ 1., -1.]], <BLANKLINE> [[-1., 1.], [ 1., 1.]]]]) >>> create_meshgrid(2, 2, normalized_coordinates=False) tensor([[[[0., 0.], [1., 0.]], <BLANKLINE> [[0., 1.], [1., 1.]]]]) """ xs: Tensor = torch.linspace(0, width - 1, width, device=device, dtype=dtype) ys: Tensor = torch.linspace(0, height - 1, height, device=device, dtype=dtype) # Fix TracerWarning # Note: normalize_pixel_coordinates still gots TracerWarning since new width and height # tensors will be generated. # Below is the code using normalize_pixel_coordinates: # base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys]), dim=2) # if normalized_coordinates: # base_grid = K.geometry.normalize_pixel_coordinates(base_grid, height, width) # return torch.unsqueeze(base_grid.transpose(0, 1), dim=0) if normalized_coordinates: xs = (xs / (width - 1) - 0.5) * 2 ys = (ys / (height - 1) - 0.5) * 2 # generate grid by stacking coordinates base_grid: Tensor = stack(torch_meshgrid([xs, ys], indexing="ij"), dim=-1) # WxHx2 return base_grid.permute(1, 0, 2).unsqueeze(0) # 1xHxWx2
[docs]def create_meshgrid3d( depth: int, height: int, width: int, normalized_coordinates: bool = True, device: Optional[torch.device] = torch.device('cpu'), dtype: torch.dtype = torch.float32, ) -> Tensor: """Generate a coordinate grid for an image. When the flag ``normalized_coordinates`` is set to True, the grid is normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch function :py:func:`torch.nn.functional.grid_sample`. Args: depth: the image depth (channels). height: the image height (rows). width: the image width (cols). normalized_coordinates: whether to normalize coordinates in the range :math:`[-1,1]` in order to be consistent with the PyTorch function :py:func:`torch.nn.functional.grid_sample`. device: the device on which the grid will be generated. dtype: the data type of the generated grid. Return: grid tensor with shape :math:`(1, D, H, W, 3)`. """ xs: Tensor = torch.linspace(0, width - 1, width, device=device, dtype=dtype) ys: Tensor = torch.linspace(0, height - 1, height, device=device, dtype=dtype) zs: Tensor = torch.linspace(0, depth - 1, depth, device=device, dtype=dtype) # Fix TracerWarning if normalized_coordinates: xs = (xs / (width - 1) - 0.5) * 2 ys = (ys / (height - 1) - 0.5) * 2 zs = (zs / (depth - 1) - 0.5) * 2 # generate grid by stacking coordinates base_grid = stack(torch_meshgrid([zs, xs, ys], indexing="ij"), dim=-1) # DxWxHx3 return base_grid.permute(0, 2, 1, 3).unsqueeze(0) # 1xDxHxWx3