Source code for kornia.utils.grid

from typing import Optional

import torch


[docs]def create_meshgrid( height: int, width: int, normalized_coordinates: Optional[bool] = True) -> torch.Tensor: """Generates a coordinate grid for an image. When the flag `normalized_coordinates` is set to True, the grid is normalized to be in the range [-1,1] to be consistent with the pytorch function grid_sample. http://pytorch.org/docs/master/nn.html#torch.nn.functional.grid_sample Args: height (int): the image height (rows). width (int): the image width (cols). normalized_coordinates (Optional[bool]): whether to normalize coordinates in the range [-1, 1] in order to be consistent with the PyTorch function grid_sample. Return: torch.Tensor: returns a grid tensor with shape :math:`(1, H, W, 2)`. """ # generate coordinates xs: Optional[torch.Tensor] = None ys: Optional[torch.Tensor] = None if normalized_coordinates: xs = torch.linspace(-1, 1, width) ys = torch.linspace(-1, 1, height) else: xs = torch.linspace(0, width - 1, width) ys = torch.linspace(0, height - 1, height) # generate grid by stacking coordinates base_grid: torch.Tensor = torch.stack( torch.meshgrid([xs, ys])).transpose(1, 2) # 2xHxW return torch.unsqueeze(base_grid, dim=0).permute(0, 2, 3, 1) # 1xHxWx2
def create_meshgrid3d( depth: int, height: int, width: int, normalized_coordinates: Optional[bool] = True) -> torch.Tensor: """Generates a coordinate grid for an image. When the flag `normalized_coordinates` is set to True, the grid is normalized to be in the range [-1,1] to be consistent with the pytorch function grid_sample. http://pytorch.org/docs/master/nn.html#torch.nn.functional.grid_sample Args: depth (int): the image depth (channels). height (int): the image height (rows). width (int): the image width (cols). normalized_coordinates (Optional[bool]): wether to normalize coordinates in the range [-1, 1] in order to be consistent with the PyTorch function grid_sample. Return: torch.Tensor: returns a grid tensor with shape :math:`(1, D, H, W, 3)`. """ grid2d = create_meshgrid(height, width, normalized_coordinates) if normalized_coordinates: z = torch.linspace(-1, 1, depth) else: z = torch.linspace(0, depth - 1, depth) z = z.view(depth, 1, 1, 1) grid3d = torch.cat([z.repeat(1, height, width, 1).contiguous(), grid2d.repeat(depth, 1, 1, 1)], dim=3) return grid3d.unsqueeze(0) # 1xDxHxWx3