from typing import Optional

import torch

[docs]def create_meshgrid( height: int, width: int, normalized_coordinates: Optional[bool] = True) -> torch.Tensor: """Generates a coordinate grid for an image. When the flag `normalized_coordinates` is set to True, the grid is normalized to be in the range [-1,1] to be consistent with the pytorch function grid_sample. Args: height (int): the image height (rows). width (int): the image width (cols). normalized_coordinates (Optional[bool]): wether to normalize coordinates in the range [-1, 1] in order to be consistent with the PyTorch function grid_sample. Return: torch.Tensor: returns a grid tensor with shape :math:`(1, H, W, 2)`. """ # generate coordinates xs: Optional[torch.Tensor] = None ys: Optional[torch.Tensor] = None if normalized_coordinates: xs = torch.linspace(-1, 1, width) ys = torch.linspace(-1, 1, height) else: xs = torch.linspace(0, width - 1, width) ys = torch.linspace(0, height - 1, height) # generate grid by stacking coordinates base_grid: torch.Tensor = torch.stack( torch.meshgrid([xs, ys])).transpose(1, 2) # 2xHxW return torch.unsqueeze(base_grid, dim=0).permute(0, 2, 3, 1) # 1xHxWx2