# Source code for kornia.geometry.quaternion

# kornia.geometry.quaternion module inspired by Eigen, Sophus-sympy, and PyQuaternion.
# https://github.com/strasdat/Sophus/blob/master/sympy/sophus/quaternion.py
# https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py
# https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Geometry/Quaternion.h
from math import pi
from typing import Optional, Tuple, Union

from kornia.core import Device, Dtype, Module, Parameter, Tensor, concatenate, rand, stack, tensor, where
from kornia.core.check import KORNIA_CHECK_TYPE
from kornia.geometry.conversions import (
QuaternionCoeffOrder,
angle_axis_to_quaternion,
normalize_quaternion,
quaternion_to_rotation_matrix,
rotation_matrix_to_quaternion,
)
from kornia.geometry.linalg import batched_dot_product

[docs]class Quaternion(Module):
r"""Base class to represent a Quaternion.

A quaternion is a four dimensional vector representation of a rotation transformation in 3d.
See more: https://en.wikipedia.org/wiki/Quaternion

The general definition of a quaternion is given by:

.. math ::

Q = a + b \cdot \mathbf{i} + c \cdot \mathbf{j} + d \cdot \mathbf{k}

Thus, we represent a rotation quaternion as a contiguous tensor structure to
perform rigid bodies transformations:

.. math ::

Q = \begin{bmatrix} q_w & q_x & q_y & q_z \end{bmatrix}

Example:
>>> q = Quaternion.identity(batch_size=4)
>>> q.data
Parameter containing:
tensor([[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.]], requires_grad=True)
>>> q.real
tensor([1., 1., 1., 1.], grad_fn=<SelectBackward0>)
>>> q.vec
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], grad_fn=<SliceBackward0>)
"""

[docs]    def __init__(self, data: Tensor) -> None:
"""Constructor for the base class.

Args:
data: tensor containing the quaternion data with the sape of :math:(B, 4).

Example:
>>> data = torch.rand(2, 4)
>>> q = Quaternion(data)
>>> q.shape
(2, 4)
"""
super().__init__()
# KORNIA_CHECK_SHAPE(data, ["B", "4"])  # FIXME: resolve shape bugs. @edgarriba
self._data = Parameter(data)

[docs]    def __repr__(self) -> str:
return f"{self.data}"

def __getitem__(self, idx: Union[int, slice]) -> 'Quaternion':
return Quaternion(self.data[idx])

[docs]    def __neg__(self) -> 'Quaternion':
"""Inverts the sign of the quaternion data.

Example:
>>> q = Quaternion.identity()
>>> -q.data
tensor([-1., -0., -0., -0.], grad_fn=<NegBackward0>)
"""
return Quaternion(-self.data)

[docs]    def __add__(self, right: 'Quaternion') -> 'Quaternion':
"""Add a given quaternion.

Args:
right: the quaternion to add.

Example:
>>> q1 = Quaternion.identity()
>>> q2 = Quaternion(tensor([2., 0., 1., 1.]))
>>> q3 = q1 + q2
>>> q3.data
Parameter containing:
tensor([3., 0., 1., 1.], requires_grad=True)
"""
KORNIA_CHECK_TYPE(right, Quaternion)
return Quaternion(self.data + right.data)

[docs]    def __sub__(self, right: 'Quaternion') -> 'Quaternion':
"""Subtract a given quaternion.

Args:
right: the quaternion to subtract.

Example:
>>> q1 = Quaternion(tensor([2., 0., 1., 1.]))
>>> q2 = Quaternion.identity()
>>> q3 = q1 - q2
>>> q3.data
Parameter containing:
tensor([1., 0., 1., 1.], requires_grad=True)
"""
KORNIA_CHECK_TYPE(right, Quaternion)
return Quaternion(self.data - right.data)

def __mul__(self, right: 'Quaternion') -> 'Quaternion':
KORNIA_CHECK_TYPE(right, Quaternion)
# NOTE: borrowed from sophus sympy. Produce less multiplications compared to others.
# https://github.com/strasdat/Sophus/blob/785fef35b7d9e0fc67b4964a69124277b7434a44/sympy/sophus/quaternion.py#L19
new_real = self.real * right.real - batched_dot_product(self.vec, right.vec)
new_vec = self.real[..., None] * right.vec + right.real[..., None] * self.vec + self.vec.cross(right.vec)
return Quaternion(concatenate((new_real[..., None], new_vec), -1))

def __div__(self, right: Union[Tensor, 'Quaternion']) -> 'Quaternion':
if isinstance(right, Tensor):
return Quaternion(self.data / right[..., None])
KORNIA_CHECK_TYPE(right, Quaternion)
return self * right.inv()

def __truediv__(self, right: 'Quaternion') -> 'Quaternion':
return self.__div__(right)

[docs]    def __pow__(self, t: Union[int, float]) -> 'Quaternion':
"""Return the power of a quaternion raised to exponent t.

Args:
t: raised exponent.

Example:
>>> q = Quaternion(tensor([1., .5, 0., 0.]))
>>> q_pow = q**2
"""
theta = self.polar_angle[..., None]
vec_norm = self.vec.norm(dim=-1, keepdim=True)
n = where(vec_norm != 0, self.vec / vec_norm, self.vec * 0)
w = (t * theta).cos()
xyz = (t * theta).sin() * n
return Quaternion(concatenate((w, xyz), -1))

@property
def data(self) -> Tensor:
"""Return the underlying data with shape :math:(B, 4)."""
return self._data

@property
def coeffs(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Return a tuple with the underlying coefficients in WXYZ order."""
return self.w, self.x, self.y, self.z

@property
def real(self) -> Tensor:
"""Return the real part with shape :math:(B,).

Alias for :func:~kornia.geometry.quaternion.Quaternion.w
"""
return self.w

@property
def vec(self) -> Tensor:
"""Return the vector with the imaginary part with shape :math:(B, 3)."""
return self.data[..., 1:]

@property
def q(self) -> Tensor:
"""Return the underlying data with shape :math:(B, 4).

Alias for :func:~kornia.geometry.quaternion.Quaternion.data
"""
return self.data

@property
def scalar(self) -> Tensor:
"""Return a scalar with the real with shape :math:(B,).

Alias for :func:~kornia.geometry.quaternion.Quaternion.w
"""
return self.real

@property
def w(self) -> Tensor:
"""Return the :math:q_w with shape :math:(B,)."""
return self.data[..., 0]

@property
def x(self) -> Tensor:
"""Return the :math:q_x with shape :math:(B,)."""
return self.data[..., 1]

@property
def y(self) -> Tensor:
"""Return the :math:q_y with shape :math:(B,)."""
return self.data[..., 2]

@property
def z(self) -> Tensor:
"""Return the :math:q_z with shape :math:(B,)."""
return self.data[..., 3]

@property
def shape(self) -> Tuple[int, ...]:
"""Return the shape of the underlying data with shape :math:(B, 4)."""
return tuple(self.data.shape)

@property
def polar_angle(self) -> Tensor:
"""Return the polar angle with shape :math:(B,1).

Example:
>>> q = Quaternion.identity()
>>> q.polar_angle
"""
return (self.scalar / self.norm()).acos()

[docs]    def matrix(self) -> Tensor:
"""Convert the quaternion to a rotation matrix of shape :math:(B, 3, 3).

Example:
>>> q = Quaternion.identity()
>>> m = q.matrix()
>>> m
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], grad_fn=<SqueezeBackward1>)
"""
return quaternion_to_rotation_matrix(self.data, order=QuaternionCoeffOrder.WXYZ)

[docs]    @classmethod
def from_matrix(cls, matrix: Tensor) -> 'Quaternion':
"""Create a quaternion from a rotation matrix.

Args:
matrix: the rotation matrix to convert of shape :math:(B, 3, 3).

Example:
>>> m = torch.eye(3)[None]
>>> q = Quaternion.from_matrix(m)
>>> q.data
Parameter containing:
tensor([[1., 0., 0., 0.]], requires_grad=True)
"""
return cls(rotation_matrix_to_quaternion(matrix, order=QuaternionCoeffOrder.WXYZ))

[docs]    @classmethod
def from_axis_angle(cls, axis_angle: Tensor) -> 'Quaternion':
"""Create a quaternion from axis-angle representation.

Args:
axis_angle: rotation vector of shape :math:(B, 3).

Example:
>>> axis_angle = torch.tensor([[1., 0., 0.]])
>>> q = Quaternion.from_axis_angle(axis_angle)
>>> q.data
Parameter containing:
tensor([[0.8776, 0.4794, 0.0000, 0.0000]], requires_grad=True)
"""
return cls(angle_axis_to_quaternion(axis_angle, order=QuaternionCoeffOrder.WXYZ))

[docs]    @classmethod
def identity(
cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Dtype = None
) -> 'Quaternion':
"""Create a quaternion representing an identity rotation.

Args:
batch_size: the batch size of the underlying data.

Example:
>>> q = Quaternion.identity()
>>> q.data
Parameter containing:
tensor([1., 0., 0., 0.], requires_grad=True)
"""
data = tensor([1.0, 0.0, 0.0, 0.0], device=device, dtype=dtype)
if batch_size is not None:
data = data.repeat(batch_size, 1)
return cls(data)

[docs]    @classmethod
def from_coeffs(cls, w: float, x: float, y: float, z: float) -> 'Quaternion':
"""Create a quaternion from the data coefficients.

Args:
w: a float representing the :math:q_w component.
x: a float representing the :math:q_x component.
y: a float representing the :math:q_y component.
z: a float representing the :math:q_z component.

Example:
>>> q = Quaternion.from_coeffs(1., 0., 0., 0.)
>>> q.data
Parameter containing:
tensor([1., 0., 0., 0.], requires_grad=True)
"""
return cls(tensor([w, x, y, z]))

# TODO: update signature
# def random(cls, shape: Optional[List] = None, device = None, dtype = None) -> 'Quaternion':
[docs]    @classmethod
def random(
cls, batch_size: Optional[int] = None, device: Optional[Device] = None, dtype: Dtype = None
) -> 'Quaternion':
"""Create a random unit quaternion of shape :math:(B, 4).

Uniformly distributed across the rotation space as per: http://planning.cs.uiuc.edu/node198.html

Args:
batch_size: the batch size of the underlying data.

Example:
>>> q = Quaternion.random()
>>> q = Quaternion.random(batch_size=2)
"""
rand_shape = (batch_size,) if batch_size is not None else ()

r1, r2, r3 = rand((3,) + rand_shape, device=device, dtype=dtype)
q1 = (1.0 - r1).sqrt() * ((2 * pi * r2).sin())
q2 = (1.0 - r1).sqrt() * ((2 * pi * r2).cos())
q3 = r1.sqrt() * (2 * pi * r3).sin()
q4 = r1.sqrt() * (2 * pi * r3).cos()
return cls(stack((q1, q2, q3, q4), -1))

[docs]    def slerp(self, q1: 'Quaternion', t: float) -> 'Quaternion':
"""Returns a unit quaternion spherically interpolated between quaternions self.q and q1.

See more: https://en.wikipedia.org/wiki/Slerp

Args:
q1: second quaternion to be interpolated between.
t: interpolation ratio, range [0-1]

Example:
>>> q0 = Quaternion.identity()
>>> q1 = Quaternion(torch.tensor([1., .5, 0., 0.]))
>>> q2 = q0.slerp(q1, .3)
"""
KORNIA_CHECK_TYPE(q1, Quaternion)
q0 = self.normalize()
q1 = q1.normalize()
return q0 * (q0.inv() * q1) ** t

# TODO: add docs
def norm(self, keepdim: bool = False) -> Tensor:
# p==2, dim|axis==-1, keepdim
return self.data.norm(2, -1, keepdim)

# TODO: add docs
def normalize(self) -> 'Quaternion':
return Quaternion(normalize_quaternion(self.data))

# TODO: add docs
def conj(self) -> 'Quaternion':
return Quaternion(concatenate((self.real[..., None], -self.vec), -1))

# TODO: add docs
def inv(self) -> 'Quaternion':
return self.conj() / self.squared_norm()

# TODO: add docs
def squared_norm(self) -> Tensor:
return batched_dot_product(self.vec, self.vec) + self.real**2