Source code for kornia.geometry.liegroup.se3

```# kornia.geometry.so3 module inspired by Sophus-sympy.
# https://github.com/strasdat/Sophus/blob/master/sympy/sophus/se3.py
from __future__ import annotations

from kornia.core import (
Device,
Dtype,
Module,
Parameter,
Tensor,
concatenate,
eye,
stack,
tensor,
where,
zeros_like,
)
from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_SAME_DEVICES
from kornia.geometry.liegroup.so3 import So3
from kornia.geometry.linalg import batched_dot_product
from kornia.geometry.quaternion import Quaternion
from kornia.geometry.vector import Vector3

[docs]class Se3(Module):
r"""Base class to represent the Se3 group.

The SE(3) is the group of rigid body transformations about the origin of three-dimensional Euclidean
space :math:`R^3` under the operation of composition.
See more: https://ingmec.ual.es/~jlblanco/papers/jlblanco2010geometry3D_techrep.pdf

Example:
>>> q = Quaternion.identity()
>>> s = Se3(q, torch.ones(3))
>>> s.r
Parameter containing:
>>> s.t
Parameter containing:
"""

[docs]    def __init__(self, rotation: Quaternion | So3, translation: Vector3 | Tensor) -> None:
"""Constructor for the base class.

Internally represented by a unit quaternion `q` and a translation 3-vector.

Args:
rotation: So3 group encompassing a rotation.
translation: Vector3 or translation tensor with the shape of :math:`(B, 3)`.

Example:
>>> from kornia.geometry.quaternion import Quaternion
>>> q = Quaternion.identity(batch_size=1)
>>> s = Se3(q, torch.ones((1, 3)))
>>> s.r
Parameter containing:
>>> s.t
Parameter containing:
"""
super().__init__()
# KORNIA_CHECK_TYPE(rotation, (Quaternion, So3))
if not isinstance(rotation, (Quaternion, So3)):
raise TypeError(f"rotation type is {type(rotation)}")
# KORNIA_CHECK_TYPE(translation, (Vector3, Tensor))
if not isinstance(translation, (Vector3, Tensor)):
raise TypeError(f"translation type is {type(translation)}")
# KORNIA_CHECK_SHAPE(t, ["B", "3"])  # FIXME: resolve shape bugs. @edgarriba
self._translation: Vector3 | Parameter
self._rotation: So3
if isinstance(translation, Tensor):
self._translation = Parameter(translation)
else:
self._translation = translation
if isinstance(rotation, Quaternion):
self._rotation = So3(rotation)
else:
self._rotation = rotation

[docs]    def __repr__(self) -> str:
return f"rotation: {self.r}\ntranslation: {self.t}"

def __getitem__(self, idx: int | slice) -> Se3:
return Se3(self._rotation[idx], self._translation[idx])

[docs]    def __mul__(self, right: Se3) -> Se3 | Vector3 | Tensor:
"""Compose two Se3 transformations.

Args:
right: the other Se3 transformation.

Return:
The resulting Se3 transformation.
"""
so3 = self.so3
t = self.t
if isinstance(right, Se3):
# https://github.com/strasdat/Sophus/blob/master/sympy/sophus/se3.py#L97
_r = so3 * right.so3
_t = t + so3 * right.t
return Se3(_r, _t)
elif isinstance(right, (Vector3, Tensor)):
# KORNIA_CHECK_SHAPE(right, ["B", "N"])  # FIXME: resolve shape bugs. @edgarriba
return so3 * right + t.data
else:
raise TypeError(f"Unsupported type: {type(right)}")

@property
def so3(self) -> So3:
"""Return the underlying rotation(So3)."""
return self._rotation

@property
def quaternion(self) -> Quaternion:
"""Return the underlying rotation(Quaternion)."""
return self._rotation.q

@property
def r(self) -> So3:
"""Return the underlying rotation(So3)."""
return self._rotation

@property
def t(self) -> Vector3 | Tensor:
"""Return the underlying translation vector of shape :math:`(B,3)`."""
return self._translation

@property
def rotation(self) -> So3:
"""Return the underlying rotation(So3)."""
return self._rotation

@property
def translation(self) -> Vector3 | Tensor:
"""Return the underlying translation vector of shape :math:`(B,3)`."""
return self._translation

[docs]    @staticmethod
def exp(v: Tensor) -> Se3:
"""Converts elements of lie algebra to elements of lie group.

Args:
v: vector of shape :math:`(B, 6)`.

Example:
>>> v = torch.zeros((1, 6))
>>> s = Se3.exp(v)
>>> s.r
Parameter containing:
>>> s.t
Parameter containing:
"""
# KORNIA_CHECK_SHAPE(v, ["B", "6"])  # FIXME: resolve shape bugs. @edgarriba
upsilon = v[..., :3]
omega = v[..., 3:]
omega_hat = So3.hat(omega)
omega_hat_sq = omega_hat @ omega_hat
theta = batched_dot_product(omega, omega).sqrt()
R = So3.exp(omega)
V = (
eye(3, device=v.device, dtype=v.dtype)
+ ((1 - theta.cos()) / (theta**2))[..., None, None] * omega_hat
+ ((theta - theta.sin()) / (theta**3))[..., None, None] * omega_hat_sq
)
U = where(theta[..., None] != 0.0, (upsilon[..., None, :] * V).sum(-1), upsilon)
return Se3(R, U)

[docs]    def log(self) -> Tensor:
"""Converts elements of lie group  to elements of lie algebra.

Example:
>>> from kornia.geometry.quaternion import Quaternion
>>> q = Quaternion.identity()
>>> Se3(q, torch.zeros(3)).log()
tensor([0., 0., 0., 0., 0., 0.], grad_fn=<CatBackward0>)
"""
omega = self.r.log()
theta = batched_dot_product(omega, omega).sqrt()
t = self.t.data
omega_hat = So3.hat(omega)
omega_hat_sq = omega_hat @ omega_hat
V_inv = (
eye(3, device=omega.device, dtype=omega.dtype)
- 0.5 * omega_hat
+ ((1 - theta * (theta / 2).cos() / (2 * (theta / 2).sin())) / theta.pow(2))[..., None, None] * omega_hat_sq
)
t = where(theta[..., None] != 0.0, (t[..., None, :] * V_inv).sum(-1), t)
return concatenate((t, omega), -1)

[docs]    @staticmethod
def hat(v: Tensor) -> Tensor:
"""Converts elements from vector space to lie algebra.

Args:
v: vector of shape :math:`(B, 6)`.

Returns:
matrix of shape :math:`(B, 4, 4)`.

Example:
>>> v = torch.ones((1, 6))
>>> m = Se3.hat(v)
>>> m
tensor([[[ 0., -1.,  1.,  1.],
[ 1.,  0., -1.,  1.],
[-1.,  1.,  0.,  1.],
[ 0.,  0.,  0.,  0.]]])
"""
# KORNIA_CHECK_SHAPE(v, ["B", "6"])  # FIXME: resolve shape bugs. @edgarriba
upsilon, omega = v[..., :3], v[..., 3:]
rt = concatenate((So3.hat(omega), upsilon[..., None]), -1)

[docs]    @staticmethod
def vee(omega: Tensor) -> Tensor:
"""Converts elements from lie algebra to vector space.

Args:
omega: 4x4-matrix representing lie algebra of shape :math:`(B,4,4)`.

Returns:
vector of shape :math:`(B,6)`.

Example:
>>> v = torch.ones((1, 6))
>>> omega_hat = Se3.hat(v)
>>> Se3.vee(omega_hat)
tensor([[1., 1., 1., 1., 1., 1.]])
"""
# KORNIA_CHECK_SHAPE(omega, ["B", "4", "4"])  # FIXME: resolve shape bugs. @edgarriba
tail = So3.vee(omega[..., :3, :3])

[docs]    @classmethod
def identity(cls, batch_size: int | None = None, device: Device | None = None, dtype: Dtype = None) -> Se3:
"""Create a Se3 group representing an identity rotation and zero translation.

Args:
batch_size: the batch size of the underlying data.

Example:
>>> s = Se3.identity()
>>> s.r
Parameter containing:
>>> s.t
x: 0.0
y: 0.0
z: 0.0
"""
t = tensor([0.0, 0.0, 0.0], device=device, dtype=dtype)
if batch_size is not None:
t = t.repeat(batch_size, 1)

return cls(So3.identity(batch_size, device, dtype), Vector3(t))

[docs]    def matrix(self) -> Tensor:
"""Returns the matrix representation of shape :math:`(B, 4, 4)`.

Example:
>>> s = Se3(So3.identity(), torch.ones(3))
>>> s.matrix()
tensor([[1., 0., 0., 1.],
[0., 1., 0., 1.],
[0., 0., 1., 1.],
"""
rt = concatenate((self.r.matrix(), self.t.data[..., None]), -1)
rt_4x4 = pad(rt, (0, 0, 0, 1))  # add last row zeros
rt_4x4[..., -1, -1] = 1.0
return rt_4x4

[docs]    @classmethod
def from_matrix(cls, matrix: Tensor) -> Se3:
"""Create a Se3 group from a matrix.

Args:
matrix: tensor of shape :math:`(B, 4, 4)`.

Example:
>>> s = Se3.from_matrix(torch.eye(4))
>>> s.r
Parameter containing:
>>> s.t
Parameter containing:
"""
# KORNIA_CHECK_SHAPE(matrix, ["B", "4", "4"])  # FIXME: resolve shape bugs. @edgarriba
r = So3.from_matrix(matrix[..., :3, :3])
t = matrix[..., :3, -1]
return cls(r, t)

[docs]    @classmethod
def from_qxyz(cls, qxyz: Tensor) -> Se3:
"""Create a Se3 group a quaternion and translation vector.

Args:
qxyz: tensor of shape :math:`(B, 7)`.

Example:
>>> qxyz = torch.tensor([1., 2., 3., 0., 0., 0., 1.])
>>> s = Se3.from_qxyz(qxyz)
>>> s.r
Parameter containing:
>>> s.t
x: 0.0
y: 0.0
z: 1.0
"""
# KORNIA_CHECK_SHAPE(qxyz, ["B", "7"])  # FIXME: resolve shape bugs. @edgarriba
q, xyz = qxyz[..., :4], qxyz[..., 4:]
return cls(So3.from_wxyz(q), Vector3(xyz))

[docs]    def inverse(self) -> Se3:
"""Returns the inverse transformation.

Example:
>>> s = Se3(So3.identity(), torch.ones(3))
>>> s_inv = s.inverse()
>>> s_inv.r
Parameter containing:
>>> s_inv.t
Parameter containing:
"""
r_inv = self.r.inverse()
_t = -1 * self.t
if isinstance(_t, int):
raise TypeError('Unexpected integer from `-1 * translation`')

return Se3(r_inv, r_inv * _t)

[docs]    @classmethod
def random(cls, batch_size: int | None = None, device: Device | None = None, dtype: Dtype = None) -> Se3:
"""Create a Se3 group representing a random transformation.

Args:
batch_size: the batch size of the underlying data.

Example:
>>> s = Se3.random()
>>> s = Se3.random(batch_size=3)
"""
shape: tuple[int, ...]
if batch_size is None:
shape = ()
else:
KORNIA_CHECK(batch_size >= 1, msg="batch_size must be positive")
shape = (batch_size,)
r = So3.random(batch_size, device, dtype)
t = Vector3.random(shape, device, dtype)
return cls(r, t)

[docs]    @classmethod
def rot_x(cls, x: Tensor) -> Se3:
"""Construct a x-axis rotation.

Args:
x: the x-axis rotation angle.
"""
zs = zeros_like(x)
return cls(So3.rot_x(x), stack((zs, zs, zs), -1))

[docs]    @classmethod
def rot_y(cls, y: Tensor) -> Se3:
"""Construct a y-axis rotation.

Args:
y: the y-axis rotation angle.
"""
zs = zeros_like(y)
return cls(So3.rot_y(y), stack((zs, zs, zs), -1))

[docs]    @classmethod
def rot_z(cls, z: Tensor) -> Se3:
"""Construct a z-axis rotation.

Args:
z: the z-axis rotation angle.
"""
zs = zeros_like(z)
return cls(So3.rot_z(z), stack((zs, zs, zs), -1))

[docs]    @classmethod
def trans(cls, x: Tensor, y: Tensor, z: Tensor) -> Se3:
"""Construct a translation only Se3 instance.

Args:
x: the x-axis translation.
y: the y-axis translation.
z: the z-axis translation.
"""
KORNIA_CHECK(x.shape == y.shape)
KORNIA_CHECK(y.shape == z.shape)
KORNIA_CHECK_SAME_DEVICES([x, y, z])
batch_size = x.shape[0] if len(x.shape) > 0 else None
rotation = So3.identity(batch_size, x.device, x.dtype)
return cls(rotation, stack((x, y, z), -1))

[docs]    @classmethod
def trans_x(cls, x: Tensor) -> Se3:
"""Construct a x-axis translation.

Args:
x: the x-axis translation.
"""
zs = zeros_like(x)
return cls.trans(x, zs, zs)

[docs]    @classmethod
def trans_y(cls, y: Tensor) -> Se3:
"""Construct a y-axis translation.

Args:
y: the y-axis translation.
"""
zs = zeros_like(y)
return cls.trans(zs, y, zs)

[docs]    @classmethod
def trans_z(cls, z: Tensor) -> Se3:
"""Construct a z-axis translation.

Args:
z: the z-axis translation.
"""
zs = zeros_like(z)
return cls.trans(zs, zs, z)

"""Returns the adjoint matrix of shape :math:`(B, 6, 6)`.

Example:
>>> s = Se3.identity()