Source code for torchgeometry.losses.one_hot

from typing import Optional

import torch

[docs]def one_hot(labels: torch.Tensor, num_classes: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, eps: Optional[float] = 1e-6) -> torch.Tensor: r"""Converts an integer label 2D tensor to a one-hot 3D tensor. Args: labels (torch.Tensor) : tensor with labels of shape :math:`(N, H, W)`, where N is batch siz. Each value is an integer representing correct classification. num_classes (int): number of classes in labels. device (Optional[torch.device]): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type (see torch.set_default_tensor_type()). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. dtype (Optional[torch.dtype]): the desired data type of returned tensor. Default: if None, infers data type from values. Returns: torch.Tensor: the labels in one hot tensor. Examples:: >>> labels = torch.LongTensor([[[0, 1], [2, 0]]]) >>> tgm.losses.one_hot(labels, num_classes=3) tensor([[[[1., 0.], [0., 1.]], [[0., 1.], [0., 0.]], [[0., 0.], [1., 0.]]]] """ if not torch.is_tensor(labels): raise TypeError("Input labels type is not a torch.Tensor. Got {}" .format(type(labels))) if not len(labels.shape) == 3: raise ValueError("Invalid depth shape, we expect BxHxW. Got: {}" .format(labels.shape)) if not labels.dtype == torch.int64: raise ValueError( "labels must be of the same dtype torch.int64. Got: {}" .format( labels.dtype)) if num_classes < 1: raise ValueError("The number of classes must be bigger than one." " Got: {}".format(num_classes)) batch_size, height, width = labels.shape one_hot = torch.zeros(batch_size, num_classes, height, width, device=device, dtype=dtype) return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps