kornia.geometry.calibration#

Module with useful functionalities for camera calibration.

The pinhole model is an ideal projection model that not considers lens distortion for the projection of a 3D point \((X, Y, Z)\) onto the image plane. To model the distortion of a projected 2D pixel point \((u,v)\) with the linear pinhole model, we need first to estimate the normalized 2D points coordinates \((\bar{u}, \bar{v})\). For that, we can use the calibration matrix \(\mathbf{K}\) with the following expression

\[\begin{split}\begin{align} \begin{bmatrix} \bar{u}\\ \bar{v}\\ 1 \end{bmatrix} = \mathbf{K}^{-1} \begin{bmatrix} u \\ v \\ 1 \end{bmatrix} \enspace, \end{align}\end{split}\]

which is equivalent to directly using the internal parameters: focals \(f_u, f_v\) and principal point \((u_0, v_0)\) to estimated the normalized coordinates

\[\begin{split}\begin{equation} \bar{u} = (u - u_0)/f_u \enspace, \\ \bar{v} = (v - v_0)/f_v \enspace. \end{equation}\end{split}\]

The normalized distorted point \((\bar{u}_d, \bar{v}_d)\) is given by

\[\begin{split}\begin{align} \begin{bmatrix} \bar{u}_d\\ \bar{v}_d \end{bmatrix} = \dfrac{1+k_1r^2+k_2r^4+k_3r^6}{1+k_4r^2+k_5r^4+k_6r^6} \begin{bmatrix} \bar{u}\\ \bar{v} \end{bmatrix} + \begin{bmatrix} 2p_1\bar{u}\bar{v} + p_2(r^2 + 2\bar{u}^2) + s_1r^2 + s_2r^4\\ 2p_2\bar{u}\bar{v} + p_1(r^2 + 2\bar{v}^2) + s_3r^2 + s_4r^4 \end{bmatrix} \enspace, \end{align}\end{split}\]

where \(r = \bar{u}^2 + \bar{v}^2\). With this model we consider radial \((k_1, k_2, k_3, k_4, k_4, k_6)\), tangential \((p_1, p_2)\), and thin prism \((s_1, s_2, s_3, s_4)\) distortion. If we want to consider tilt distortion \((\tau_x, \tau_y)\), we need an additional step where we estimate a point \((\bar{u}'_d, \bar{v}'_d)\)

\[\begin{split}\begin{align} \begin{bmatrix} \bar{u}'_d\\ \bar{v}'_d\\ 1 \end{bmatrix} = \begin{bmatrix} \mathbf{R}_{33}(\tau_x, \tau_y) & 0 & -\mathbf{R}_{13}(\tau_x, \tau_y)\\ 0 & \mathbf{R}_{33}(\tau_x, \tau_y) & -\mathbf{R}_{23}(\tau_x, \tau_y)\\ 0 & 0 & 1 \end{bmatrix} \mathbf{R}(\tau_x, \tau_y) \begin{bmatrix} \bar{u}_d \\ \bar{v}_d \\ 1 \end{bmatrix} \enspace, \end{align}\end{split}\]

where \(\mathbf{R}(\tau_x, \tau_y)\) is a 3D rotation matrix defined by an \(X\) and \(Y\) rotation given by the angles \(\tau_x\) and \(\tau_y\). Furthermore, \(\mathbf{R}_{ij}(\tau_x, \tau_y)\) represent the \(i\)-th row and \(j\)-th column from \(\mathbf{R}(\tau_x, \tau_y)\) matrix.

\[\begin{split}\begin{align} \mathbf{R}(\tau_x, \tau_y) = \begin{bmatrix} \cos \tau_y & 0 & -\sin \tau_y \\ 0 & 1 & 0 \\ \sin \tau_y & 0 & \cos \tau_y \end{bmatrix} \begin{bmatrix} 1 & 0 & 0 \\ 0 & \cos \tau_x & \sin \tau_x \\ 0 & -\sin \tau_x & \cos \tau_x \end{bmatrix} \enspace. \end{align}\end{split}\]

Finally, we just need to come back to the original (unnormalized) pixel space. For that we can use the intrinsic matrix

\[\begin{split}\begin{align} \begin{bmatrix} u_d\\ v_d\\ 1 \end{bmatrix} = \mathbf{K} \begin{bmatrix} \bar{u}'_d\\ \bar{v}'_d\\ 1 \end{bmatrix} \enspace, \end{align}\end{split}\]

which is equivalent to

\[\begin{split}\begin{equation} u_d = f_u \bar{u}'_d + u_0 \enspace, \\ v_d = f_v \bar{v}'_d + v_0 \enspace. \end{equation}\end{split}\]

Undistortion#

To compensate for lens distortion a set of 2D points, i.e., to estimate the undistorted coordinates for a given set of distorted points, we need to inverse the previously explained distortion model. For the case of undistorting an image, instead of estimating the undistorted location for each pixel, we distort each pixel in the destination image (final undistorted image) to match them with the input image. We finally interpolate the intensity values at each pixel.

kornia.geometry.calibration.undistort_image(image, K, dist)#

Compensate an image for lens distortion.

Radial \((k_1, k_2, k_3, k_4, k_4, k_6)\), tangential \((p_1, p_2)\), thin prism \((s_1, s_2, s_3, s_4)\), and tilt \((\tau_x, \tau_y)\) distortion models are considered in this function.

Parameters:
  • image (Tensor) – Input image with shape \((*, C, H, W)\).

  • K (Tensor) – Intrinsic camera matrix with shape \((*, 3, 3)\).

  • dist (Tensor) – Distortion coefficients \((k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]])\). This is a vector with 4, 5, 8, 12 or 14 elements with shape \((*, n)\).

Return type:

Tensor

Returns:

Undistorted image with shape \((*, C, H, W)\).

Example

>>> img = torch.rand(1, 3, 5, 5)
>>> K = torch.eye(3)[None]
>>> dist_coeff = torch.rand(1, 4)
>>> out = undistort_image(img, K, dist_coeff)
>>> out.shape
torch.Size([1, 3, 5, 5])
kornia.geometry.calibration.undistort_points(points, K, dist, new_K=None, num_iters=5)#

Compensate for lens distortion a set of 2D image points.

Radial \((k_1, k_2, k_3, k_4, k_5, k_6)\), tangential \((p_1, p_2)\), thin prism \((s_1, s_2, s_3, s_4)\), and tilt \((\tau_x, \tau_y)\) distortion models are considered in this function.

Parameters:
  • points (Tensor) – Input image points with shape \((*, N, 2)\).

  • K (Tensor) – Intrinsic camera matrix with shape \((*, 3, 3)\).

  • dist (Tensor) – Distortion coefficients \((k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]])\). This is a vector with 4, 5, 8, 12 or 14 elements with shape \((*, n)\).

  • new_K (Optional[Tensor], optional) – Intrinsic camera matrix of the distorted image. By default, it is the same as K but you may additionally scale and shift the result by using a different matrix. Shape: \((*, 3, 3)\). Default: None.

  • num_iters (int, optional) – Number of undistortion iterations. Default: 5.

Return type:

Tensor

Returns:

Undistorted 2D points with shape \((*, N, 2)\).

Example

>>> _ = torch.manual_seed(0)
>>> x = torch.rand(1, 4, 2)
>>> K = torch.eye(3)[None]
>>> dist = torch.rand(1, 4)
>>> undistort_points(x, K, dist)
tensor([[[-0.1513, -0.1165],
         [ 0.0711,  0.1100],
         [-0.0697,  0.0228],
         [-0.1843, -0.1606]]])
kornia.geometry.calibration.distort_points(points, K, dist, new_K=None)#

Distortion of a set of 2D points based on the lens distortion model.

Radial \((k_1, k_2, k_3, k_4, k_4, k_6)\), tangential \((p_1, p_2)\), thin prism \((s_1, s_2, s_3, s_4)\), and tilt \((\tau_x, \tau_y)\) distortion models are considered in this function.

Parameters:
  • points (Tensor) – Input image points with shape \((*, N, 2)\).

  • K (Tensor) – Intrinsic camera matrix with shape \((*, 3, 3)\).

  • dist (Tensor) – Distortion coefficients \((k_1,k_2,p_1,p_2[,k_3[,k_4,k_5,k_6[,s_1,s_2,s_3,s_4[,\tau_x,\tau_y]]]])\). This is a vector with 4, 5, 8, 12 or 14 elements with shape \((*, n)\).

  • new_K (Optional[Tensor], optional) – Intrinsic camera matrix of the distorted image. By default, it is the same as K but you may additionally scale and shift the result by using a different matrix. Shape: \((*, 3, 3)\). Default: None.

Return type:

Tensor

Returns:

Undistorted 2D points with shape \((*, N, 2)\).

Example

>>> points = torch.rand(1, 1, 2)
>>> K = torch.eye(3)[None]
>>> dist_coeff = torch.rand(1, 4)
>>> points_dist = distort_points(points, K, dist_coeff)
kornia.geometry.calibration.tilt_projection(taux, tauy, return_inverse=False)#

Estimate the tilt projection matrix or the inverse tilt projection matrix.

Parameters:
  • taux (Tensor) – Rotation angle in radians around the \(x\)-axis with shape \((*, 1)\).

  • tauy (Tensor) – Rotation angle in radians around the \(y\)-axis with shape \((*, 1)\).

  • return_inverse (bool, optional) – False to obtain the the tilt projection matrix. True for the inverse matrix. Default: False

Returns:

Inverse tilt projection matrix with shape \((*, 3, 3)\).

Return type:

torch.Tensor

Perspective-n-Point (PnP)#

kornia.geometry.calibration.solve_pnp_dlt(world_points, img_points, intrinsics, weights=None, svd_eps=1e-4)#

This function attempts to solve the Perspective-n-Point (PnP) problem using Direct Linear Transform (DLT).

Given a batch (where batch size is \(B\)) of \(N\) 3D points (where \(N \geq 6\)) in the world space, a batch of \(N\) corresponding 2D points in the image space and a batch of intrinsic matrices, this function tries to estimate a batch of world to camera transformation matrices.

This implementation needs at least 6 points (i.e. \(N \geq 6\)) to provide solutions.

This function cannot be used if all the 3D world points (of any element of the batch) lie on a line or if all the 3D world points (of any element of the batch) lie on a plane. This function attempts to check for these conditions and throws an AssertionError if found. Do note that this check is sensitive to the value of the svd_eps parameter.

Another bad condition occurs when the camera and the points lie on a twisted cubic. However, this function does not check for this condition.

Parameters:
  • world_points (Tensor) – A tensor with shape \((B, N, 3)\) representing the points in the world space.

  • img_points (Tensor) – A tensor with shape \((B, N, 2)\) representing the points in the image space.

  • intrinsics (Tensor) – A tensor with shape \((B, 3, 3)\) representing the intrinsic matrices.

  • weights (Optional[Tensor], optional) – This parameter is not used currently and is just a placeholder for API consistency. Default: None

  • svd_eps (float, optional) – A small float value to avoid numerical precision issues. Default: 1e-4

Return type:

Tensor

Returns:

A tensor with shape \((B, 3, 4)\) representing the estimated world to camera transformation matrices (also known as the extrinsic matrices).

Example

>>> world_points = torch.tensor([[
...     [ 5. , -5. ,  0. ], [ 0. ,  0. ,  1.5],
...     [ 2.5,  3. ,  6. ], [ 9. , -2. ,  3. ],
...     [-4. ,  5. ,  2. ], [-5. ,  5. ,  1. ],
... ]], dtype=torch.float64)
>>>
>>> img_points = torch.tensor([[
...     [1409.1504, -800.936 ], [ 407.0207, -182.1229],
...     [ 392.7021,  177.9428], [1016.838 ,   -2.9416],
...     [ -63.1116,  142.9204], [-219.3874,   99.666 ],
... ]], dtype=torch.float64)
>>>
>>> intrinsics = torch.tensor([[
...     [ 500.,    0.,  250.],
...     [   0.,  500.,  250.],
...     [   0.,    0.,    1.],
... ]], dtype=torch.float64)
>>>
>>> print(world_points.shape, img_points.shape, intrinsics.shape)
torch.Size([1, 6, 3]) torch.Size([1, 6, 2]) torch.Size([1, 3, 3])
>>>
>>> pred_world_to_cam = kornia.geometry.solve_pnp_dlt(world_points, img_points, intrinsics)
>>>
>>> print(pred_world_to_cam.shape)
torch.Size([1, 3, 4])
>>>
>>> pred_world_to_cam
tensor([[[ 0.9392, -0.3432, -0.0130,  1.6734],
         [ 0.3390,  0.9324, -0.1254, -4.3634],
         [ 0.0552,  0.1134,  0.9920,  3.7785]]], dtype=torch.float64)