kornia.geometry.conversions#

Angles#

kornia.geometry.conversions.rad2deg(tensor)[source]#

Function that converts angles from radians to degrees.

Parameters:

tensor (Tensor) – Tensor of arbitrary shape.

Return type:

Tensor

Returns:

Tensor with same shape as input.

Example

>>> input = tensor(3.1415926535)
>>> rad2deg(input)
tensor(180.)
kornia.geometry.conversions.deg2rad(tensor)[source]#

Function that converts angles from degrees to radians.

Parameters:

tensor (Tensor) – Tensor of arbitrary shape.

Return type:

Tensor

Returns:

tensor with same shape as input.

Examples

>>> input = tensor(180.)
>>> deg2rad(input)
tensor(3.1416)
kornia.geometry.conversions.pol2cart(rho, phi)[source]#

Function that converts polar coordinates to cartesian coordinates.

Parameters:
  • rho (Tensor) – Tensor of arbitrary shape.

  • phi (Tensor) – Tensor of same arbitrary shape.

Returns:

Tensor with same shape as input. - y: Tensor with same shape as input.

Return type:

  • x

Example

>>> rho = torch.rand(1, 3, 3)
>>> phi = torch.rand(1, 3, 3)
>>> x, y = pol2cart(rho, phi)
kornia.geometry.conversions.cart2pol(x, y, eps=1e-08)[source]#

Function that converts cartesian coordinates to polar coordinates.

Parameters:
  • x (Tensor) – Tensor of arbitrary shape.

  • y (Tensor) – Tensor of same arbitrary shape.

  • eps (float, optional) – To avoid division by zero. Default: 1e-08

Returns:

Tensor with same shape as input. - phi: Tensor with same shape as input.

Return type:

  • rho

Example

>>> x = torch.rand(1, 3, 3)
>>> y = torch.rand(1, 3, 3)
>>> rho, phi = cart2pol(x, y)
kornia.geometry.conversions.angle_to_rotation_matrix(angle)[source]#

Create a rotation matrix out of angles in degrees.

Parameters:

angle (Tensor) – tensor of angles in degrees, any shape \((*)\).

Return type:

Tensor

Returns:

tensor of rotation matrices with shape \((*, 2, 2)\).

Example

>>> input = torch.rand(1, 3)  # Nx3
>>> output = angle_to_rotation_matrix(input)  # Nx3x2x2

Coordinates#

kornia.geometry.conversions.convert_points_from_homogeneous(points, eps=1e-08)[source]#

Function that converts points from homogeneous to Euclidean space.

Parameters:
  • points (Tensor) – the points to be transformed of shape \((B, N, D)\).

  • eps (float, optional) – to avoid division by zero. Default: 1e-08

Return type:

Tensor

Returns:

the points in Euclidean space \((B, N, D-1)\).

Examples

>>> input = tensor([[0., 0., 1.]])
>>> convert_points_from_homogeneous(input)
tensor([[0., 0.]])
kornia.geometry.conversions.convert_points_to_homogeneous(points)[source]#

Function that converts points from Euclidean to homogeneous space.

Parameters:

points (Tensor) – the points to be transformed with shape \((*, N, D)\).

Return type:

Tensor

Returns:

the points in homogeneous coordinates \((*, N, D+1)\).

Examples

>>> input = tensor([[0., 0.]])
>>> convert_points_to_homogeneous(input)
tensor([[0., 0., 1.]])
kornia.geometry.conversions.convert_affinematrix_to_homography(A)[source]#

Function that converts batch of affine matrices.

Parameters:

A (Tensor) – the affine matrix with shape \((B,2,3)\).

Return type:

Tensor

Returns:

the homography matrix with shape of \((B,3,3)\).

Examples

>>> A = tensor([[[1., 0., 0.],
...                    [0., 1., 0.]]])
>>> convert_affinematrix_to_homography(A)
tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])
kornia.geometry.conversions.denormalize_pixel_coordinates(pixel_coordinates, height, width, eps=1e-08)[source]#

Denormalize pixel coordinates.

The input is assumed to be -1 if on extreme left, 1 if on extreme right (x = w-1).

Parameters:
  • pixel_coordinates (Tensor) – the normalized grid coordinates. Shape can be \((*, 2)\).

  • width (int) – the maximum width in the x-axis.

  • height (int) – the maximum height in the y-axis.

  • eps (float, optional) – safe division by zero. Default: 1e-08

Return type:

Tensor

Returns:

the denormalized pixel coordinates with shape \((*, 2)\).

Examples

>>> coords = tensor([[-1., -1.]])
>>> denormalize_pixel_coordinates(coords, 100, 50)
tensor([[0., 0.]])
kornia.geometry.conversions.normalize_pixel_coordinates(pixel_coordinates, height, width, eps=1e-08)[source]#

Normalize pixel coordinates between -1 and 1.

Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1).

Parameters:
  • pixel_coordinates (Tensor) – the grid with pixel coordinates. Shape can be \((*, 2)\).

  • width (int) – the maximum width in the x-axis.

  • height (int) – the maximum height in the y-axis.

  • eps (float, optional) – safe division by zero. Default: 1e-08

Return type:

Tensor

Returns:

the normalized pixel coordinates with shape \((*, 2)\).

Examples

>>> coords = tensor([[50., 100.]])
>>> normalize_pixel_coordinates(coords, 100, 50)
tensor([[1.0408, 1.0202]])
kornia.geometry.conversions.denormalize_pixel_coordinates3d(pixel_coordinates, depth, height, width, eps=1e-08)[source]#

Denormalize pixel coordinates.

The input is assumed to be -1 if on extreme left, 1 if on extreme right (x = w-1).

Parameters:
  • pixel_coordinates (Tensor) – the normalized grid coordinates. Shape can be \((*, 3)\).

  • depth (int) – the maximum depth in the x-axis.

  • height (int) – the maximum height in the y-axis.

  • width (int) – the maximum width in the x-axis.

  • eps (float, optional) – safe division by zero. Default: 1e-08

Return type:

Tensor

Returns:

the denormalized pixel coordinates.

kornia.geometry.conversions.normalize_pixel_coordinates3d(pixel_coordinates, depth, height, width, eps=1e-08)[source]#

Normalize pixel coordinates between -1 and 1.

Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1).

Parameters:
  • pixel_coordinates (Tensor) – the grid with pixel coordinates. Shape can be \((*, 3)\).

  • depth (int) – the maximum depth in the z-axis.

  • height (int) – the maximum height in the y-axis.

  • width (int) – the maximum width in the x-axis.

  • eps (float, optional) – safe division by zero. Default: 1e-08

Return type:

Tensor

Returns:

the normalized pixel coordinates.

kornia.geometry.conversions.normalize_points_with_intrinsics(point_2d, camera_matrix)[source]#

Normalizes points with intrinsics. Useful for conversion of keypoints to be used with essential matrix.

Parameters:
  • point_2d (Tensor) – tensor containing the 2d points in the image pixel coordinates. The shape of the tensor can be \((*, 2)\).

  • camera_matrix (Tensor) – tensor containing the intrinsics camera matrix. The tensor shape must be \((*, 3, 3)\).

Return type:

Tensor

Returns:

tensor of (u, v) cam coordinates with shape \((*, 2)\).

Example

>>> _ = torch.manual_seed(0)
>>> X = torch.rand(1, 2)
>>> K = torch.eye(3)[None]
>>> normalize_points_with_intrinsics(X, K)
tensor([[0.4963, 0.7682]])
kornia.geometry.conversions.denormalize_points_with_intrinsics(point_2d_norm, camera_matrix)[source]#

Normalizes points with intrinsics. Useful for conversion of keypoints to be used with essential matrix.

Parameters:
  • point_2d_norm (Tensor) – tensor containing the 2d points in the image pixel coordinates. The shape of the tensor can be \((*, 2)\).

  • camera_matrix (Tensor) – tensor containing the intrinsics camera matrix. The tensor shape must be \((*, 3, 3)\).

Return type:

Tensor

Returns:

tensor of (u, v) cam coordinates with shape \((*, 2)\).

Example

>>> _ = torch.manual_seed(0)
>>> X = torch.rand(1, 2)
>>> K = torch.eye(3)[None]
>>> denormalize_points_with_intrinsics(X, K)
tensor([[0.4963, 0.7682]])

Homography#

kornia.geometry.conversions.normalize_homography(dst_pix_trans_src_pix, dsize_src, dsize_dst)[source]#

Normalize a given homography in pixels to [-1, 1].

Parameters:
  • dst_pix_trans_src_pix (Tensor) – homography/ies from source to destination to be normalized. \((B, 3, 3)\)

  • dsize_src (Tuple[int, int]) – size of the source image (height, width).

  • dsize_dst (Tuple[int, int]) – size of the destination image (height, width).

Return type:

Tensor

Returns:

the normalized homography of shape \((B, 3, 3)\).

kornia.geometry.conversions.denormalize_homography(dst_pix_trans_src_pix, dsize_src, dsize_dst)[source]#

De-normalize a given homography in pixels from [-1, 1] to actual height and width.

Parameters:
  • dst_pix_trans_src_pix (Tensor) – homography/ies from source to destination to be denormalized. \((B, 3, 3)\)

  • dsize_src (Tuple[int, int]) – size of the source image (height, width).

  • dsize_dst (Tuple[int, int]) – size of the destination image (height, width).

Return type:

Tensor

Returns:

the denormalized homography of shape \((B, 3, 3)\).

kornia.geometry.conversions.normalize_homography3d(dst_pix_trans_src_pix, dsize_src, dsize_dst)[source]#

Normalize a given homography in pixels to [-1, 1].

Parameters:
  • dst_pix_trans_src_pix (Tensor) – homography/ies from source to destination to be normalized. \((B, 4, 4)\)

  • dsize_src (Tuple[int, int, int]) – size of the source image (depth, height, width).

  • dsize_src – size of the destination image (depth, height, width).

Return type:

Tensor

Returns:

the normalized homography.

Shape:

Output: \((B, 4, 4)\)

Quaternion#

kornia.geometry.conversions.quaternion_to_angle_axis(quaternion, order=QuaternionCoeffOrder.XYZW)[source]#

Convert quaternion vector to angle axis of rotation in radians.

The quaternion should be in (x, y, z, w) or (w, x, y, z) format.

Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h

Parameters:
  • quaternion (Tensor) – tensor with quaternions.

  • order (QuaternionCoeffOrder, optional) – quaternion coefficient order. Note: ‘xyzw’ will be deprecated in favor of ‘wxyz’. Default: QuaternionCoeffOrder.XYZW

Return type:

Tensor

Returns:

tensor with angle axis of rotation.

Shape:
  • Input: \((*, 4)\) where * means, any number of dimensions

  • Output: \((*, 3)\)

Example

>>> quaternion = tensor((1., 0., 0., 0.))
>>> quaternion_to_angle_axis(quaternion)
tensor([3.1416, 0.0000, 0.0000])
kornia.geometry.conversions.quaternion_to_rotation_matrix(quaternion, order=QuaternionCoeffOrder.XYZW)[source]#

Convert a quaternion to a rotation matrix.

The quaternion should be in (x, y, z, w) or (w, x, y, z) format.

Parameters:
  • quaternion (Tensor) – a tensor containing a quaternion to be converted. The tensor can be of shape \((*, 4)\).

  • order (QuaternionCoeffOrder, optional) – quaternion coefficient order. Note: ‘xyzw’ will be deprecated in favor of ‘wxyz’. Default: QuaternionCoeffOrder.XYZW

Return type:

Tensor

Returns:

the rotation matrix of shape \((*, 3, 3)\).

Example

>>> quaternion = tensor((0., 0., 0., 1.))
>>> quaternion_to_rotation_matrix(quaternion, order=QuaternionCoeffOrder.WXYZ)
tensor([[-1.,  0.,  0.],
        [ 0., -1.,  0.],
        [ 0.,  0.,  1.]])
kornia.geometry.conversions.quaternion_log_to_exp(quaternion, eps=1e-08, order=QuaternionCoeffOrder.XYZW)[source]#

Apply exponential map to log quaternion.

The quaternion should be in (x, y, z, w) or (w, x, y, z) format.

Parameters:
  • quaternion (Tensor) – a tensor containing a quaternion to be converted. The tensor can be of shape \((*, 3)\).

  • eps (float, optional) – a small number for clamping. Default: 1e-08

  • order (QuaternionCoeffOrder, optional) – quaternion coefficient order. Note: ‘xyzw’ will be deprecated in favor of ‘wxyz’. Default: QuaternionCoeffOrder.XYZW

Return type:

Tensor

Returns:

the quaternion exponential map of shape \((*, 4)\).

Example

>>> quaternion = tensor((0., 0., 0.))
>>> quaternion_log_to_exp(quaternion, eps=torch.finfo(quaternion.dtype).eps,
...                       order=QuaternionCoeffOrder.WXYZ)
tensor([1., 0., 0., 0.])
kornia.geometry.conversions.quaternion_exp_to_log(quaternion, eps=1e-08, order=QuaternionCoeffOrder.XYZW)[source]#

Apply the log map to a quaternion.

The quaternion should be in (x, y, z, w) format.

Parameters:
  • quaternion (Tensor) – a tensor containing a quaternion to be converted. The tensor can be of shape \((*, 4)\).

  • eps (float, optional) – a small number for clamping. Default: 1e-08

  • order (QuaternionCoeffOrder, optional) – quaternion coefficient order. Note: ‘xyzw’ will be deprecated in favor of ‘wxyz’. Default: QuaternionCoeffOrder.XYZW

Return type:

Tensor

Returns:

the quaternion log map of shape \((*, 3)\).

Example

>>> quaternion = tensor((1., 0., 0., 0.))
>>> quaternion_exp_to_log(quaternion, eps=torch.finfo(quaternion.dtype).eps,
...                       order=QuaternionCoeffOrder.WXYZ)
tensor([0., 0., 0.])
kornia.geometry.conversions.normalize_quaternion(quaternion, eps=1e-12)[source]#

Normalize a quaternion.

The quaternion should be in (x, y, z, w) or (w, x, y, z) format.

Parameters:
  • quaternion (Tensor) – a tensor containing a quaternion to be normalized. The tensor can be of shape \((*, 4)\).

  • eps (float, optional) – small value to avoid division by zero. Default: 1e-12

Return type:

Tensor

Returns:

the normalized quaternion of shape \((*, 4)\).

Example

>>> quaternion = tensor((1., 0., 1., 0.))
>>> normalize_quaternion(quaternion)
tensor([0.7071, 0.0000, 0.7071, 0.0000])

Rotation Matrix#

kornia.geometry.conversions.rotation_matrix_to_angle_axis(rotation_matrix)[source]#

Convert 3x3 rotation matrix to Rodrigues vector in radians.

Parameters:

rotation_matrix (Tensor) – rotation matrix of shape \((N, 3, 3)\).

Return type:

Tensor

Returns:

Rodrigues vector transformation of shape \((N, 3)\).

Example

>>> input = tensor([[1., 0., 0.],
...                       [0., 1., 0.],
...                       [0., 0., 1.]])
>>> rotation_matrix_to_angle_axis(input)
tensor([0., 0., 0.])
>>> input = tensor([[1., 0., 0.],
...                       [0., 0., -1.],
...                       [0., 1., 0.]])
>>> rotation_matrix_to_angle_axis(input)
tensor([1.5708, 0.0000, 0.0000])
kornia.geometry.conversions.rotation_matrix_to_quaternion(rotation_matrix, eps=1e-08, order=QuaternionCoeffOrder.XYZW)[source]#

Convert 3x3 rotation matrix to 4d quaternion vector.

The quaternion vector has components in (w, x, y, z) or (x, y, z, w) format.

Note

The (x, y, z, w) order is going to be deprecated in favor of efficiency.

Parameters:
  • rotation_matrix (Tensor) – the rotation matrix to convert with shape \((*, 3, 3)\).

  • eps (float, optional) – small value to avoid zero division. Default: 1e-08

  • order (QuaternionCoeffOrder, optional) – quaternion coefficient order. Note: ‘xyzw’ will be deprecated in favor of ‘wxyz’. Default: QuaternionCoeffOrder.XYZW

Return type:

Tensor

Returns:

the rotation in quaternion with shape \((*, 4)\).

Example

>>> input = tensor([[1., 0., 0.],
...                       [0., 1., 0.],
...                       [0., 0., 1.]])
>>> rotation_matrix_to_quaternion(input, eps=torch.finfo(input.dtype).eps,
...                               order=QuaternionCoeffOrder.WXYZ)
tensor([1., 0., 0., 0.])

Angle Axis#

kornia.geometry.conversions.angle_axis_to_quaternion(angle_axis, order=QuaternionCoeffOrder.XYZW)[source]#

Convert an angle axis to a quaternion.

The quaternion vector has components in (x, y, z, w) or (w, x, y, z) format.

Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h

Parameters:
  • angle_axis (Tensor) – tensor with angle axis in radians.

  • order (QuaternionCoeffOrder, optional) – quaternion coefficient order. Note: ‘xyzw’ will be deprecated in favor of ‘wxyz’. Default: QuaternionCoeffOrder.XYZW

Return type:

Tensor

Returns:

tensor with quaternion.

Shape:
  • Input: \((*, 3)\) where * means, any number of dimensions

  • Output: \((*, 4)\)

Example

>>> angle_axis = tensor((0., 1., 0.))
>>> angle_axis_to_quaternion(angle_axis, order=QuaternionCoeffOrder.WXYZ)
tensor([0.8776, 0.0000, 0.4794, 0.0000])
kornia.geometry.conversions.angle_axis_to_rotation_matrix(angle_axis)[source]#

Convert 3d vector of axis-angle rotation to 3x3 rotation matrix.

Parameters:

angle_axis (Tensor) – tensor of 3d vector of axis-angle rotations in radians with shape \((N, 3)\).

Return type:

Tensor

Returns:

tensor of rotation matrices of shape \((N, 3, 3)\).

Example

>>> input = tensor([[0., 0., 0.]])
>>> angle_axis_to_rotation_matrix(input)
tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])
>>> input = tensor([[1.5708, 0., 0.]])
>>> angle_axis_to_rotation_matrix(input)
tensor([[[ 1.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -3.6200e-06, -1.0000e+00],
         [ 0.0000e+00,  1.0000e+00, -3.6200e-06]]])

Euler Angles#

kornia.geometry.conversions.quaternion_from_euler(roll, pitch, yaw)[source]#

Convert Euler angles to quaternion coefficients.

Euler angles are assumed to be in radians in XYZ convention.

Parameters:
  • roll (Tensor) – the roll euler angle.

  • pitch (Tensor) – the pitch euler angle.

  • yaw (Tensor) – the yaw euler angle.

Return type:

Tuple[Tensor, Tensor, Tensor, Tensor]

Returns:

A tuple with quaternion coefficients in order of wxyz.

kornia.geometry.conversions.euler_from_quaternion(w, x, y, z)[source]#

Convert a quaternion coefficients to Euler angles.

Returned angles are in radians in XYZ convention.

Parameters:
  • w (Tensor) – quaternion \(q_w\) coefficient.

  • x (Tensor) – quaternion \(q_x\) coefficient.

  • y (Tensor) – quaternion \(q_y\) coefficient.

  • z (Tensor) – quaternion \(q_z\) coefficient.

Return type:

Tuple[Tensor, Tensor, Tensor]

Returns:

A tuple with euler angles`roll`, pitch, yaw.

Pose#

kornia.geometry.conversions.Rt_to_matrix4x4(R, t)[source]#

Combines 3x3 rotation matrix R and 1x3 translation vector t into 4x4 extrinsics.

Parameters:
  • R (Tensor) – Rotation matrix, \((B, 3, 3).\)

  • t (Tensor) – Translation matrix \((B, 3, 1)\).

Return type:

Tensor

Returns:

the extrinsics \((B, 4, 4)\).

Example

>>> R, t = torch.eye(3)[None], torch.ones(3).reshape(1, 3, 1)
>>> Rt_to_matrix4x4(R, t)
tensor([[[1., 0., 0., 1.],
         [0., 1., 0., 1.],
         [0., 0., 1., 1.],
         [0., 0., 0., 1.]]])
kornia.geometry.conversions.matrix4x4_to_Rt(extrinsics)[source]#

Converts 4x4 extrinsics into 3x3 rotation matrix R and 1x3 translation vector ts.

Parameters:

extrinsics (Tensor) – pose matrix \((B, 4, 4)\).

Returns:

Rotation matrix, \((B, 3, 3).\) t: Translation matrix \((B, 3, 1)\).

Return type:

R

Example

>>> ext = torch.eye(4)[None]
>>> matrix4x4_to_Rt(ext)
(tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]]), tensor([[[0.],
         [0.],
         [0.]]]))
kornia.geometry.conversions.worldtocam_to_camtoworld_Rt(R, t)[source]#

Converts worldtocam frame i.e. projection from world to the camera coordinate system (used in Colmap) to camtoworld, i.e. projection from camera coordinate system to world coordinate system.

Parameters:
  • R (Tensor) – Rotation matrix, \((B, 3, 3).\)

  • t (Tensor) – Translation matrix \((B, 3, 1)\).

Returns:

Rotation matrix, \((B, 3, 3).\) tinv: Translation matrix \((B, 3, 1)\).

Return type:

Rinv

Example

>>> R, t = torch.eye(3)[None], torch.ones(3).reshape(1, 3, 1)
>>> worldtocam_to_camtoworld_Rt(R, t)
(tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]]), tensor([[[-1.],
         [-1.],
         [-1.]]]))
kornia.geometry.conversions.camtoworld_to_worldtocam_Rt(R, t)[source]#

Converts camtoworld, i.e. projection from camera coordinate system to world coordinate system, to worldtocam frame i.e. projection from world to the camera coordinate system (used in Colmap). See long-url: https://colmap.github.io/format.html#output-format

Parameters:
  • R (Tensor) – Rotation matrix, \((B, 3, 3).\)

  • t (Tensor) – Translation matrix \((B, 3, 1)\).

Returns:

Rotation matrix, \((B, 3, 3).\) tinv: Translation matrix \((B, 3, 1)\).

Return type:

Rinv

Example

>>> R, t = torch.eye(3)[None], torch.ones(3).reshape(1, 3, 1)
>>> camtoworld_to_worldtocam_Rt(R, t)
(tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]]), tensor([[[-1.],
         [-1.],
         [-1.]]]))
kornia.geometry.conversions.camtoworld_graphics_to_vision_4x4(extrinsics_graphics)[source]#

Converts graphics coordinate frame (e.g. OpenGL) to vision coordinate frame (e.g. OpenCV.), , i.e. flips y and z axis. Graphics convention: [+x, +y, +z] == [right, up, backwards]. Vision convention: [+x, +y, +z] ==

[right, down, forwards]

Parameters:

extrinsics – pose matrix \((B, 4, 4)\).

Returns:

pose matrix \((B, 4, 4)\).

Return type:

extrinsics

Example

>>> ext = torch.eye(4)[None]
>>> camtoworld_graphics_to_vision_4x4(ext)
tensor([[[ 1.,  0.,  0.,  0.],
         [ 0., -1.,  0.,  0.],
         [ 0.,  0., -1.,  0.],
         [ 0.,  0.,  0.,  1.]]])
kornia.geometry.conversions.camtoworld_vision_to_graphics_4x4(extrinsics_vision)[source]#

Converts vision coordinate frame (e.g. OpenCV) to graphics coordinate frame (e.g. OpenGK.), i.e. flips y and z axis Graphics convention: [+x, +y, +z] == [right, up, backwards]. Vision convention: [+x, +y, +z] == [right, down, forwards]

Parameters:

extrinsics – pose matrix \((B, 4, 4)\).

Returns:

pose matrix \((B, 4, 4)\).

Return type:

extrinsics

Example

>>> ext = torch.eye(4)[None]
>>> camtoworld_vision_to_graphics_4x4(ext)
tensor([[[ 1.,  0.,  0.,  0.],
         [ 0., -1.,  0.,  0.],
         [ 0.,  0., -1.,  0.],
         [ 0.,  0.,  0.,  1.]]])
kornia.geometry.conversions.camtoworld_graphics_to_vision_Rt(R, t)[source]#

Converts graphics coordinate frame (e.g. OpenGL) to vision coordinate frame (e.g. OpenCV.), , i.e. flips y and z axis. Graphics convention: [+x, +y, +z] == [right, up, backwards]. Vision convention: [+x, +y, +z] ==

[right, down, forwards]

Parameters:
  • R (Tensor) – Rotation matrix, \((B, 3, 3).\)

  • t (Tensor) – Translation matrix \((B, 3, 1)\).

Returns:

Rotation matrix, \((B, 3, 3).\) t: Translation matrix \((B, 3, 1)\).

Return type:

R

Example

>>> R, t = torch.eye(3)[None], torch.ones(3).reshape(1, 3, 1)
>>> camtoworld_graphics_to_vision_Rt(R, t)
(tensor([[[ 1.,  0.,  0.],
         [ 0., -1.,  0.],
         [ 0.,  0., -1.]]]), tensor([[[1.],
         [1.],
         [1.]]]))
kornia.geometry.conversions.camtoworld_vision_to_graphics_Rt(R, t)[source]#

Converts graphics coordinate frame (e.g. OpenGL) to vision coordinate frame (e.g. OpenCV.), , i.e. flips y and z axis. Graphics convention: [+x, +y, +z] == [right, up, backwards]. Vision convention: [+x, +y, +z] ==

[right, down, forwards]

Parameters:
  • R (Tensor) – Rotation matrix, \((B, 3, 3).\)

  • t (Tensor) – Translation matrix \((B, 3, 1)\).

Returns:

Rotation matrix, \((B, 3, 3).\) t: Translation matrix \((B, 3, 1)\).

Return type:

R

Example

>>> R, t = torch.eye(3)[None], torch.ones(3).reshape(1, 3, 1)
>>> camtoworld_vision_to_graphics_Rt(R, t)
(tensor([[[ 1.,  0.,  0.],
         [ 0., -1.,  0.],
         [ 0.,  0., -1.]]]), tensor([[[1.],
         [1.],
         [1.]]]))
kornia.geometry.conversions.ARKitQTVecs_to_ColmapQTVecs(qvec, tvec)[source]#

Converts output of Apple ARKit screen pose (in quaternion representation) to the camera-to-world transformation, expected by Colmap, also in quaternion representation.

Parameters:
  • qvec (Tensor) – ARKit rotation quaternion \((B, 4)\), [x, y, z, w] format.

  • tvec (Tensor) – translation vector \((B, 3, 1)\), [x, y, z]

Returns:

Colmap rotation quaternion \((B, 4)\), [w, x, y, z] format. tvec: translation vector \((B, 3, 1)\), [x, y, z]

Return type:

qvec

Example

>>> q, t = tensor([0, 1, 0, 1.])[None], torch.ones(3).reshape(1, 3, 1)
>>> ARKitQTVecs_to_ColmapQTVecs(q, t)
(tensor([[0.7071, 0.0000, 0.7071, 0.0000]]), tensor([[[-1.0000],
         [-1.0000],
         [ 1.0000]]]))