The Kabsch Algorithm
When working with molecular dynamics or computer vision, we often need to compare structures or point clouds. The challenge is that identical structures can appear different due to translation and rotation. The Kabsch algorithm solves this by finding the optimal rigid body transformation to align two sets of paired points.
This algorithm computes the rotation matrix that best aligns two point sets, which is essential for calculating metrics like root-mean-square deviation (RMSD). I’ve used it primarily in molecular dynamics, but it’s also valuable in computer vision, graphics, and particle simulations.
In molecular applications, a molecule’s translation and rotation don’t affect its chemical properties. For instance, when training generative models for molecular conformations, we shouldn’t penalize the model for producing the same structure in a different orientation. The Kabsch algorithm helps by aligning structures before comparison, making models inherently invariant to these transformations.
Recent examples include the Direct Molecular Conformation Generation work, which uses Kabsch alignment in their loss function alongside permutation-invariant terms.
While implementations exist in various libraries, implementing the algorithm myself provided deeper understanding. More importantly, I needed versions that work with automatic differentiation frameworks (PyTorch, TensorFlow, JAX) for machine learning applications.
Below, we’ll cover the math behind the Kabsch algorithm and its implementation in NumPy, PyTorch, TensorFlow, and JAX, demonstrating both single-pair and batched computations for machine learning applications.
The Math
Let’s say we have two sets of paired points, $P=[\mathbf{p}_i] \in \mathbb{R}^{N \times D}$ and $Q=[\mathbf{q}_i] \in \mathbb{R}^{N \times D}$, for $i = 1, \dots, N$ (where $D$ is the dimensionality and $N$ is the number of points). We want to find a translation vector $\mathbf{t}$ and rotation matrix $R$ to transform $P$ to align with $Q$.
The optimization problem is:
$$ \min_{\mathbf{t}^\ast, \ R^\ast} \mathcal{L}(\mathbf{t}, R) = \frac{1}{2} \sum_{i=1}^N || \mathbf{q_i} - R(\mathbf{p}_i + \mathbf{t}) ||^2 $$
where $\mathbf{t}^\ast \in \mathbb{R}^D$ and $R^\ast \in \mathbb{R}^{D \times D}$ are the optimal translation and rotation.
Often we use a weighted version with weights $w_i$ (e.g., atomic masses in molecular dynamics):
$$ \min_{\mathbf{t}^\ast, \ R^\ast} \mathcal{L}(\mathbf{t}, R) = \frac{1}{2} \sum_{i=1}^N w_i || \mathbf{q_i} - R(\mathbf{p}_i + \mathbf{t}) ||^2 $$
The Translation Vector
The translation vector is straightforward to find—it’s simply the difference between the centroids:
$$ \mathbf{t} = \bar{\mathbf{q}} - \bar{\mathbf{p}} $$
where $\bar{\mathbf{p}} = \frac{1}{N} \sum_{i=1}^N \mathbf{p_i}$ and $\bar{\mathbf{q}} = \frac{1}{N} \sum_{i=1}^N \mathbf{q}_i$.
This shifts the centroid of $P$ to match the centroid of $Q$. After translation, we have:
$$ \mathbf{p}_i + \mathbf{t} = \bar{\mathbf{q}} + (\mathbf{p}_i - \bar{\mathbf{p}}) $$
The Rotation Matrix
Finding the optimal rotation requires a few more steps.
First, center both point sets by subtracting their centroids:
$$ \mathbf{p}_i^\prime = \mathbf{p}_i - \bar{\mathbf{p}} \quad \text{and} \quad \mathbf{q}_i^\prime = \mathbf{q}_i - \bar{\mathbf{q}} $$
Next, compute the cross-covariance matrix between the centered sets:
$$ C = P^{\prime T} Q^\prime = \sum_{i=1}^N \mathbf{p}_i^{\prime T} \mathbf{q}_i^{\prime} \in \mathbb{R}^{D \times D} $$
Compute the singular value decomposition (SVD) of $C$:
$$ C = U \Sigma V^T $$
Check for improper rotations (reflections) and correct if necessary:
$$ d = \text{sign}(\det(V U^T)) $$
If $d = -1$, we need to flip the last column of $V$. Let $B = \text{diag}(1, 1, d)$.
Finally, compute the optimal rotation matrix:
$$ R^\ast = V B U^T $$
Summary
The Kabsch algorithm:
- Compute centroids of $P$ and $Q$
- Center both point sets by subtracting centroids
- Compute cross-covariance matrix $C = P^{\prime T} Q^\prime$
- Compute SVD: $C = U \Sigma V^T$
- Compute $d = \text{sign}(\det(V U^T))$ and $B = \text{diag}(1, 1, d)$
- Optimal rotation: $R^\ast = V B U^T$
- Optimal translation: $\mathbf{t}^\ast = \bar{\mathbf{q}} - \bar{\mathbf{p}}$
The RMSD between aligned point sets is:
$$ \text{RMSD} = \sqrt{\frac{1}{N} \sum_{i=1}^N || \mathbf{q}_i - R^\ast(\mathbf{p}_i + \mathbf{t}^\ast) ||^2} $$
Implementation
Let’s implement the algorithm in different frameworks.
NumPy
import numpy as np
def kabsch_numpy(P, Q):
"""
Computes the optimal rotation and translation to align two sets of points (P -> Q),
and their RMSD.
:param P: A Nx3 matrix of points
:param Q: A Nx3 matrix of points
:return: A tuple containing the optimal rotation matrix, the optimal
translation vector, and the RMSD.
"""
assert P.shape == Q.shape, "Matrix dimensions must match"
# Compute centroids
centroid_P = np.mean(P, axis=0)
centroid_Q = np.mean(Q, axis=0)
# Optimal translation
t = centroid_Q - centroid_P
# Center the points
p = P - centroid_P
q = Q - centroid_Q
# Compute the covariance matrix
H = np.dot(p.T, q)
# SVD
U, S, Vt = np.linalg.svd(H)
# Validate right-handed coordinate system
if np.linalg.det(np.dot(Vt.T, U.T)) < 0.0:
Vt[-1, :] *= -1.0
# Optimal rotation
R = np.dot(Vt.T, U.T)
# RMSD
rmsd = np.sqrt(np.sum(np.square(np.dot(p, R.T) - q)) / P.shape[0])
return R, t, rmsd
Here’s a quick test to verify correctness:
def test_numpy():
np.random.seed(12345)
P = np.random.randn(100, 3)
alpha = np.random.rand() * 2 * np.pi
R = np.array([[np.cos(alpha), -np.sin(alpha), 0],
[np.sin(alpha), np.cos(alpha), 0],
[0, 0, 1]])
t = np.random.randn(3) * 10
Q = np.dot(P, R.T) + t
R_opt, t_opt, rmsd = kabsch_numpy(P, Q)
print('RMSD: {}'.format(rmsd))
print('R:\n{}'.format(R))
print('R_opt:\n{}'.format(R_opt))
print('t:\n{}'.format(t))
print('t_opt:\n{}'.format(t_opt))
l2_t = np.linalg.norm(t - t_opt)
l2_R = np.linalg.norm(R - R_opt)
print('l2_t: {}'.format(l2_t))
print('l2_R: {}'.format(l2_R))
Running this test shows the algorithm correctly recovers the rotation and translation:
RMSD: 3.176703044042434e-15
R:
[[-0.8475392 -0.5307328 0. ]
[ 0.5307328 -0.8475392 0. ]
[ 0. 0. 1. ]]
R_opt:
[[-8.47539198e-01 -5.30732803e-01 -8.26804079e-17]
[ 5.30732803e-01 -8.47539198e-01 -8.32570352e-19]
[ 8.20168388e-17 9.60990451e-18 1.00000000e+00]]
t:
[ 5.99726796 1.50078468 -3.34633977]
t_opt:
[ 6.10180512 1.50760274 -3.34633977]
l2_t: 0.10475926000353594
l2_R: 7.538724554724993e-16
The small differences are due to floating-point precision.
For batch processing:
def kabsch_numpy_batched(P, Q):
"""
Computes the optimal rotation and translation to align two sets of points (P -> Q),
and their RMSD.
:param P: A BxNx3 matrix of points
:param Q: A BxNx3 matrix of points
:return: A tuple containing the optimal rotation matrix, the optimal
translation vector, and the RMSD.
"""
assert P.shape == Q.shape, "Matrix dimensions must match"
# Compute centroids
centroid_P = np.mean(P, axis=1, keepdims=True) # Bx1x3
centroid_Q = np.mean(Q, axis=1, keepdims=True) # Bx1x3
# Optimal translation
t = centroid_Q - centroid_P # Bx1x3
t = t.squeeze(1) # Bx3
# Center the points
p = P - centroid_P # BxNx3
q = Q - centroid_Q # BxNx3
# Compute the covariance matrix
H = np.matmul(p.transpose(0, 2, 1), q) # Bx3x3
# SVD
U, S, Vt = np.linalg.svd(H) # Bx3x3
# Validate right-handed coordinate system
d = np.linalg.det(np.matmul(Vt.transpose(0, 2, 1), U.transpose(0, 2, 1)))
flip = d < 0.0
if flip.any():
Vt[flip, -1, :] *= -1.0
# Optimal rotation
R = np.matmul(Vt.transpose(0, 2, 1), U.transpose(0, 2, 1)) # Bx3x3
# RMSD
rmsd = np.sqrt(np.sum(np.square(np.matmul(p, R.transpose(0, 2, 1)) - q), axis=(1, 2)) / P.shape[1])
return R, t, rmsd
PyTorch
📝 Important Update (August 10, 2025)
Bug Fix Notice: A critical indexing error in the PyTorch implementation has been corrected. The original code incorrectly used Vt[:, -1] *= -1.0
and Vt[flip, -1] *= -1.0
when it should have been Vt[-1, :] *= -1.0
and Vt[flip, -1, :] *= -1.0
respectively. This fix ensures the algorithm properly flips the last row (corresponding to the smallest singular vector) rather than the last column when correcting for improper rotations. Thanks to Jakub for reporting this issue!
The PyTorch implementation follows the same structure:
import torch
def kabsch_torch(P, Q):
"""
Computes the optimal rotation and translation to align two sets of points (P -> Q),
and their RMSD.
:param P: A Nx3 matrix of points
:param Q: A Nx3 matrix of points
:return: A tuple containing the optimal rotation matrix, the optimal
translation vector, and the RMSD.
"""
assert P.shape == Q.shape, "Matrix dimensions must match"
# Compute centroids
centroid_P = torch.mean(P, dim=0)
centroid_Q = torch.mean(Q, dim=0)
# Optimal translation
t = centroid_Q - centroid_P
# Center the points
p = P - centroid_P
q = Q - centroid_Q
# Compute the covariance matrix
H = torch.matmul(p.transpose(0, 1), q)
# SVD
U, S, Vt = torch.linalg.svd(H)
# Validate right-handed coordinate system
if torch.det(torch.matmul(Vt.transpose(0, 1), U.transpose(0, 1))) < 0.0:
Vt[-1, :] *= -1.0
# Optimal rotation
R = torch.matmul(Vt.transpose(0, 1), U.transpose(0, 1))
# RMSD
rmsd = torch.sqrt(torch.sum(torch.square(torch.matmul(p, R.transpose(0, 1)) - q)) / P.shape[0])
return R, t, rmsd
And our batched version:
📝 Important Update (August 10, 2025)
Bug Fix Notice: The batched PyTorch implementation also had the same indexing error. The original code incorrectly used Vt[flip, -1] *= -1.0
when it should have been Vt[flip, -1, :] *= -1.0
to properly flip the last row for selected batches.
def kabsch_torch_batched(P, Q):
"""
Computes the optimal rotation and translation to align two sets of points (P -> Q),
and their RMSD, in a batched manner.
:param P: A BxNx3 matrix of points
:param Q: A BxNx3 matrix of points
:return: A tuple containing the optimal rotation matrix, the optimal
translation vector, and the RMSD.
"""
assert P.shape == Q.shape, "Matrix dimensions must match"
# Compute centroids
centroid_P = torch.mean(P, dim=1, keepdims=True) # Bx1x3
centroid_Q = torch.mean(Q, dim=1, keepdims=True) #
# Optimal translation
t = centroid_Q - centroid_P # Bx1x3
t = t.squeeze(1) # Bx3
# Center the points
p = P - centroid_P # BxNx3
q = Q - centroid_Q # BxNx3
# Compute the covariance matrix
H = torch.matmul(p.transpose(1, 2), q) # Bx3x3
# SVD
U, S, Vt = torch.linalg.svd(H) # Bx3x3
# Validate right-handed coordinate system
d = torch.det(torch.matmul(Vt.transpose(1, 2), U.transpose(1, 2))) # B
flip = d < 0.0
if flip.any().item():
Vt[flip, -1, :] *= -1.0
# Optimal rotation
R = torch.matmul(Vt.transpose(1, 2), U.transpose(1, 2))
# RMSD
rmsd = torch.sqrt(torch.sum(torch.square(torch.matmul(p, R.transpose(1, 2)) - q), dim=(1, 2)) / P.shape[1])
return R, t, rmsd
TensorFlow
The TensorFlow implementation requires attention to different SVD output ordering (S is returned first, and V is not transposed):
import tensorflow as tf
def kabsch_tensorflow(P, Q):
"""
Computes the optimal rotation and translation to align two sets of points (P -> Q),
and their RMSD.
:param P: A Nx3 matrix of points
:param Q: A Nx3 matrix of points
:return: A tuple containing the optimal rotation matrix, the optimal
translation vector, and the RMSD.
"""
assert P.shape == Q.shape, "Matrix dimensions must match"
# Compute centroids
centroid_P = tf.reduce_mean(P, axis=0)
centroid_Q = tf.reduce_mean(Q, axis=0)
# Optimal translation
t = centroid_Q - centroid_P
# Center the points
p = P - centroid_P
q = Q - centroid_Q
# Compute the covariance matrix
H = tf.matmul(tf.transpose(p), q)
# SVD
S, U, V = tf.linalg.svd(H)
# Validate right-handed coordinate system
d = tf.linalg.det(tf.matmul(V, tf.transpose(U)))
if d < 0.0:
V[-1, :] *= -1.0
# Optimal rotation
R = tf.matmul(V, tf.transpose(U))
# RMSD
rmsd = tf.sqrt(tf.reduce_sum(tf.square(tf.matmul(p, tf.transpose(R)) - q)) / P.shape[0])
return R, t, rmsd
and a batched version:
def kabsch_tensorflow_batched(P, Q):
"""
Computes the optimal rotation and translation to align two sets of points (P -> Q),
and their RMSD.
:param P: A Nx3 matrix of points
:param Q: A Nx3 matrix of points
:return: A tuple containing the optimal rotation matrix, the optimal
translation vector, and the RMSD.
"""
assert P.shape == Q.shape, "Matrix dimensions must match"
# Compute centroids
centroid_P = tf.reduce_mean(P, axis=1, keepdims=True)
centroid_Q = tf.reduce_mean(Q, axis=1, keepdims=True)
# Optimal translation
t = centroid_Q - centroid_P
t = tf.squeeze(t, axis=1)
# Center the points
p = P - centroid_P
q = Q - centroid_Q
# Compute the covariance matrix
H = tf.matmul(tf.transpose(p, perm=[0, 2, 1]), q)
# SVD
S, U, V = tf.linalg.svd(H)
# Validate right-handed coordinate system
d = tf.linalg.det(tf.matmul(V, tf.transpose(U, perm=[0, 2, 1])))
flip = tf.squeeze(d < 0.0)
if tf.reduce_any(flip):
V = tf.where(flip[:, None, None], -V, V)
# Optimal rotation
R = tf.matmul(V, tf.transpose(U, perm=[0, 2, 1]))
# RMSD
rmsd = tf.sqrt(tf.reduce_sum(tf.square(tf.matmul(p, tf.transpose(R, perm=[0, 2, 1])) - q), axis=(1, 2)) / P.shape[1])
return R, t, rmsd
JAX
The JAX implementation closely mirrors NumPy, replacing np
with jnp
:
import jax.numpy as jnp
def kabsch_jax(P, Q):
"""
Computes the optimal rotation and translation to align two sets of points (P -> Q),
and their RMSD.
:param P: A Nx3 matrix of points
:param Q: A Nx3 matrix of points
:return: A tuple containing the optimal rotation matrix, the optimal
translation vector, and the RMSD.
"""
assert P.shape == Q.shape, "Matrix dimensions must match"
# Compute centroids
centroid_P = jnp.mean(P, axis=0)
centroid_Q = jnp.mean(Q, axis=0)
# Optimal translation
t = centroid_Q - centroid_P
# Center the points
p = P - centroid_P
q = Q - centroid_Q
# Compute the covariance matrix
H = jnp.dot(p.T, q)
# SVD
U, S, Vt = jnp.linalg.svd(H)
# Validate right-handed coordinate system
if jnp.linalg.det(jnp.dot(Vt.T, U.T)) < 0.0:
Vt[-1, :] *= -1.0
# Optimal rotation
R = jnp.dot(Vt.T, U.T)
# RMSD
rmsd = jnp.sqrt(jnp.sum(jnp.square(jnp.dot(p, R.T) - q)) / P.shape[0])
return R, t, rmsd
and batched:
def kabsch_jax_batched(P, Q):
"""
Computes the optimal rotation and translation to align two sets of points (P -> Q),
and their RMSD.
:param P: A BxNx3 matrix of points
:param Q: A BxNx3 matrix of points
:return: A tuple containing the optimal rotation matrix, the optimal
translation vector, and the RMSD.
"""
assert P.shape == Q.shape, "Matrix dimensions must match"
# Compute centroids
centroid_P = jnp.mean(P, axis=1, keepdims=True) # Bx1x3
centroid_Q = jnp.mean(Q, axis=1, keepdims=True) # Bx1x3
# Optimal translation
t = centroid_Q - centroid_P # Bx1x3
t = t.squeeze(1) # Bx3
# Center the points
p = P - centroid_P # BxNx3
q = Q - centroid_Q # BxNx3
# Compute the covariance matrix
H = jnp.matmul(p.transpose(0, 2, 1), q) # Bx3x3
# SVD
U, S, Vt = jnp.linalg.svd(H) # Bx3x3
# Validate right-handed coordinate system
d = jnp.linalg.det(jnp.matmul(Vt.transpose(0, 2, 1), U.transpose(0, 2, 1)))
flip = d < 0.0
if flip.any():
Vt[flip, -1, :] *= -1.0
# Optimal rotation
R = jnp.matmul(Vt.transpose(0, 2, 1), U.transpose(0, 2, 1)) # Bx3x3
# RMSD
rmsd = jnp.sqrt(jnp.sum(jnp.square(jnp.matmul(p, R.transpose(0, 2, 1)) - q), axis=(1, 2)) / P.shape[1])
return R, t, rmsd
Extensions
The Kabsch algorithm has several important extensions:
Quaternion Form: The algorithm can be reformulated using quaternions for better numerical stability, particularly useful in applications requiring high precision.
Iterative Versions: More robust variants that handle noise better and have improved scaling properties for large point sets.
Weighted Kabsch: Extensions that incorporate point weights (e.g., atomic masses in molecular dynamics). While SciPy provides a weighted version, it lacks batch processing capabilities.
These extensions address specific computational needs while maintaining the core mathematical principles of the algorithm.
References
- Wikipedia, Kabsch Algorithm
- [Kabsch, 1976] Kabsch, W. (1976). A solution for the best rotation to relate two sets of vectors. Acta. Crystal, 32A:922-923.
- [Kabsch, 1978] Kabsch, W. (1978). A discussion of the solution for the best rotation to related two sets of vectors. Acta. Crystal, 34A:827-828.
- Zalo on Kabsch: An interactive shape matching demo.