Source code for kornia.geometry.quaternion

# kornia.geometry.quaternion module inspired by Eigen, Sophus-sympy, and PyQuaternion.
# https://github.com/strasdat/Sophus/blob/master/sympy/sophus/quaternion.py
# https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py
# https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/Quaternion.h
from math import pi
from typing import Optional, Tuple, Union

from kornia.core import Device, Dtype, Module, Parameter, Tensor, concatenate, rand, stack, tensor, where
from kornia.core.check import KORNIA_CHECK_TYPE
from kornia.geometry.conversions import (
    QuaternionCoeffOrder,
    angle_axis_to_quaternion,
    normalize_quaternion,
    quaternion_to_rotation_matrix,
    rotation_matrix_to_quaternion,
)
from kornia.geometry.linalg import batched_dot_product


[docs]class Quaternion(Module): r"""Base class to represent a Quaternion. A quaternion is a four dimensional vector representation of a rotation transformation in 3d. See more: https://en.wikipedia.org/wiki/Quaternion The general definition of a quaternion is given by: .. math :: Q = a + b \cdot \mathbf{i} + c \cdot \mathbf{j} + d \cdot \mathbf{k} Thus, we represent a rotation quaternion as a contiguous tensor structure to perform rigid bodies transformations: .. math :: Q = \begin{bmatrix} q_w & q_x & q_y & q_z \end{bmatrix} Example: >>> q = Quaternion.identity(batch_size=4) >>> q.data Parameter containing: tensor([[1., 0., 0., 0.], [1., 0., 0., 0.], [1., 0., 0., 0.], [1., 0., 0., 0.]], requires_grad=True) >>> q.real tensor([1., 1., 1., 1.], grad_fn=<SelectBackward0>) >>> q.vec tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], grad_fn=<SliceBackward0>) """
[docs] def __init__(self, data: Tensor) -> None: """Constructor for the base class. Args: data: tensor containing the quaternion data with the sape of :math:`(B, 4)`. Example: >>> data = torch.rand(2, 4) >>> q = Quaternion(data) >>> q.shape (2, 4) """ super().__init__() # KORNIA_CHECK_SHAPE(data, ["B", "4"]) # FIXME: resolve shape bugs. @edgarriba self._data = Parameter(data)
[docs] def __repr__(self) -> str: return f"{self.data}"
def __getitem__(self, idx: Union[int, slice]) -> 'Quaternion': return Quaternion(self.data[idx])
[docs] def __neg__(self) -> 'Quaternion': """Inverts the sign of the quaternion data. Example: >>> q = Quaternion.identity() >>> -q.data tensor([-1., -0., -0., -0.], grad_fn=<NegBackward0>) """ return Quaternion(-self.data)
[docs] def __add__(self, right: 'Quaternion') -> 'Quaternion': """Add a given quaternion. Args: right: the quaternion to add. Example: >>> q1 = Quaternion.identity() >>> q2 = Quaternion(tensor([2., 0., 1., 1.])) >>> q3 = q1 + q2 >>> q3.data Parameter containing: tensor([3., 0., 1., 1.], requires_grad=True) """ KORNIA_CHECK_TYPE(right, Quaternion) return Quaternion(self.data + right.data)
[docs] def __sub__(self, right: 'Quaternion') -> 'Quaternion': """Subtract a given quaternion. Args: right: the quaternion to subtract. Example: >>> q1 = Quaternion(tensor([2., 0., 1., 1.])) >>> q2 = Quaternion.identity() >>> q3 = q1 - q2 >>> q3.data Parameter containing: tensor([1., 0., 1., 1.], requires_grad=True) """ KORNIA_CHECK_TYPE(right, Quaternion) return Quaternion(self.data - right.data)
def __mul__(self, right: 'Quaternion') -> 'Quaternion': KORNIA_CHECK_TYPE(right, Quaternion) # NOTE: borrowed from sophus sympy. Produce less multiplications compared to others. # https://github.com/strasdat/Sophus/blob/785fef35b7d9e0fc67b4964a69124277b7434a44/sympy/sophus/quaternion.py#L19 new_real = self.real * right.real - batched_dot_product(self.vec, right.vec) new_vec = self.real[..., None] * right.vec + right.real[..., None] * self.vec + self.vec.cross(right.vec) return Quaternion(concatenate((new_real[..., None], new_vec), -1)) def __div__(self, right: Union[Tensor, 'Quaternion']) -> 'Quaternion': if isinstance(right, Tensor): return Quaternion(self.data / right[..., None]) KORNIA_CHECK_TYPE(right, Quaternion) return self * right.inv() def __truediv__(self, right: 'Quaternion') -> 'Quaternion': return self.__div__(right)
[docs] def __pow__(self, t: Union[int, float]) -> 'Quaternion': """Return the power of a quaternion raised to exponent t. Args: t: raised exponent. Example: >>> q = Quaternion(tensor([1., .5, 0., 0.])) >>> q_pow = q**2 """ theta = self.polar_angle[..., None] vec_norm = self.vec.norm(dim=-1, keepdim=True) n = where(vec_norm != 0, self.vec / vec_norm, self.vec * 0) w = (t * theta).cos() xyz = (t * theta).sin() * n return Quaternion(concatenate((w, xyz), -1))
@property def data(self) -> Tensor: """Return the underlying data with shape :math:`(B, 4).`""" return self._data @property def coeffs(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Return a tuple with the underlying coefficients in WXYZ order.""" return self.w, self.x, self.y, self.z @property def real(self) -> Tensor: """Return the real part with shape :math:`(B,)`. Alias for :func:`~kornia.geometry.quaternion.Quaternion.w` """ return self.w @property def vec(self) -> Tensor: """Return the vector with the imaginary part with shape :math:`(B, 3)`.""" return self.data[..., 1:] @property def q(self) -> Tensor: """Return the underlying data with shape :math:`(B, 4)`. Alias for :func:`~kornia.geometry.quaternion.Quaternion.data` """ return self.data @property def scalar(self) -> Tensor: """Return a scalar with the real with shape :math:`(B,)`. Alias for :func:`~kornia.geometry.quaternion.Quaternion.w` """ return self.real @property def w(self) -> Tensor: """Return the :math:`q_w` with shape :math:`(B,)`.""" return self.data[..., 0] @property def x(self) -> Tensor: """Return the :math:`q_x` with shape :math:`(B,)`.""" return self.data[..., 1] @property def y(self) -> Tensor: """Return the :math:`q_y` with shape :math:`(B,)`.""" return self.data[..., 2] @property def z(self) -> Tensor: """Return the :math:`q_z` with shape :math:`(B,)`.""" return self.data[..., 3] @property def shape(self) -> Tuple[int, ...]: """Return the shape of the underlying data with shape :math:`(B, 4)`.""" return tuple(self.data.shape) @property def polar_angle(self) -> Tensor: """Return the polar angle with shape :math:`(B,1)`. Example: >>> q = Quaternion.identity() >>> q.polar_angle tensor(0., grad_fn=<AcosBackward0>) """ return (self.scalar / self.norm()).acos()
[docs] def matrix(self) -> Tensor: """Convert the quaternion to a rotation matrix of shape :math:`(B, 3, 3)`. Example: >>> q = Quaternion.identity() >>> m = q.matrix() >>> m tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], grad_fn=<SqueezeBackward1>) """ return quaternion_to_rotation_matrix(self.data, order=QuaternionCoeffOrder.WXYZ)
[docs] @classmethod def from_matrix(cls, matrix: Tensor) -> 'Quaternion': """Create a quaternion from a rotation matrix. Args: matrix: the rotation matrix to convert of shape :math:`(B, 3, 3)`. Example: >>> m = torch.eye(3)[None] >>> q = Quaternion.from_matrix(m) >>> q.data Parameter containing: tensor([[1., 0., 0., 0.]], requires_grad=True) """ return cls(rotation_matrix_to_quaternion(matrix, order=QuaternionCoeffOrder.WXYZ))
[docs] @classmethod def from_axis_angle(cls, axis_angle: Tensor) -> 'Quaternion': """Create a quaternion from axis-angle representation. Args: axis_angle: rotation vector of shape :math:`(B, 3)`. Example: >>> axis_angle = torch.tensor([[1., 0., 0.]]) >>> q = Quaternion.from_axis_angle(axis_angle) >>> q.data Parameter containing: tensor([[0.8776, 0.4794, 0.0000, 0.0000]], requires_grad=True) """ return cls(angle_axis_to_quaternion(axis_angle, order=QuaternionCoeffOrder.WXYZ))
[docs] @classmethod def identity( cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Dtype = None ) -> 'Quaternion': """Create a quaternion representing an identity rotation. Args: batch_size: the batch size of the underlying data. Example: >>> q = Quaternion.identity() >>> q.data Parameter containing: tensor([1., 0., 0., 0.], requires_grad=True) """ data = tensor([1.0, 0.0, 0.0, 0.0], device=device, dtype=dtype) if batch_size is not None: data = data.repeat(batch_size, 1) return cls(data)
[docs] @classmethod def from_coeffs(cls, w: float, x: float, y: float, z: float) -> 'Quaternion': """Create a quaternion from the data coefficients. Args: w: a float representing the :math:`q_w` component. x: a float representing the :math:`q_x` component. y: a float representing the :math:`q_y` component. z: a float representing the :math:`q_z` component. Example: >>> q = Quaternion.from_coeffs(1., 0., 0., 0.) >>> q.data Parameter containing: tensor([1., 0., 0., 0.], requires_grad=True) """ return cls(tensor([w, x, y, z]))
# TODO: update signature # def random(cls, shape: Optional[List] = None, device = None, dtype = None) -> 'Quaternion':
[docs] @classmethod def random( cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Dtype = None ) -> 'Quaternion': """Create a random unit quaternion of shape :math:`(B, 4)`. Uniformly distributed across the rotation space as per: http://planning.cs.uiuc.edu/node198.html Args: batch_size: the batch size of the underlying data. Example: >>> q = Quaternion.random() >>> q = Quaternion.random(batch_size=2) """ rand_shape = (batch_size,) if batch_size is not None else () r1, r2, r3 = rand((3,) + rand_shape, device=device, dtype=dtype) q1 = (1.0 - r1).sqrt() * ((2 * pi * r2).sin()) q2 = (1.0 - r1).sqrt() * ((2 * pi * r2).cos()) q3 = r1.sqrt() * (2 * pi * r3).sin() q4 = r1.sqrt() * (2 * pi * r3).cos() return cls(stack((q1, q2, q3, q4), -1))
[docs] def slerp(self, q1: 'Quaternion', t: float) -> 'Quaternion': """Returns a unit quaternion spherically interpolated between quaternions self.q and q1. See more: https://en.wikipedia.org/wiki/Slerp Args: q1: second quaternion to be interpolated between. t: interpolation ratio, range [0-1] Example: >>> q0 = Quaternion.identity() >>> q1 = Quaternion(torch.tensor([1., .5, 0., 0.])) >>> q2 = q0.slerp(q1, .3) """ KORNIA_CHECK_TYPE(q1, Quaternion) q0 = self.normalize() q1 = q1.normalize() return q0 * (q0.inv() * q1) ** t
# TODO: add docs def norm(self, keepdim: bool = False) -> Tensor: # p==2, dim|axis==-1, keepdim return self.data.norm(2, -1, keepdim) # TODO: add docs def normalize(self) -> 'Quaternion': return Quaternion(normalize_quaternion(self.data)) # TODO: add docs def conj(self) -> 'Quaternion': return Quaternion(concatenate((self.real[..., None], -self.vec), -1)) # TODO: add docs def inv(self) -> 'Quaternion': return self.conj() / self.squared_norm() # TODO: add docs def squared_norm(self) -> Tensor: return batched_dot_product(self.vec, self.vec) + self.real**2