Contents

Kabsch Algorithm: NumPy, PyTorch, TensorFlow, and JAX

An algorithm for aligning two sets of points.

The Kabsch Algorithm

In machine learning and molecular dynamics, we often want to compare two structures to understand how similar they are to one another. A vital part of this comparison is aligning the structures so that we can compare them in the same frame of reference, obviating the need to account for the translation and rotation of the structures.

The Kabsch algorithm is such an algorithm for computing the optimal rigid body transformation for aligning two sets of paired points. It enables the computation of the optimal rotation matrix that aligns the two sets of points, and is vital for computing things like the root-mean-square deviation (RMSD) between two structures. While my experience with this algorithm is in the context of molecular dynamics, it has wide-application across computer vision, graphics, and really anytime you have some type of particle-based simulation.

In the case of something like molecular dynamics, clearly the translation and rotation of the structures is irrelevant to its chemical properties and dynamics. For example, if we have a generative model for sampling high-probability conformations of a molecule, we don’t want to penalize the model for sampling the same conformation in different orientations. By incorporating the Kabsch algorithm into the loss function, we can implicitly bake in the model’s invariance to translation and rotation. A recent example can be seen in the following paper:

Although the Kabsch algorithm is implemented elsewhere, I wanted to implement it myself both to better understand the algorithm and provide a reference for others who want to use the Kabsch algorithm in their own work. I wanted to be sure to implement it in libraries that support backpropagation (e.g., PyTorch, TensorFlow, JAX) so that it can be used in machine learning applications.

Below, we will cover the math behind the Kabsch algorithm and its implementation in NumPy, PyTorch, TensorFlow, and JAX. We will demonstrate how to compute the algorithm for singular pairs and for batches of pairs, something vital for its use in machine learning applications.

The Math

Let’s say we have two sets of paired points, $P=[\mathbf{p}_i] \in \mathbb{R}^{N \times D}$ and $Q=[\mathbf{q}_i] \in \mathbb{R}^{N \times D}$, for $i = 1, \dots, N$ (where $D$ is the dimensionality of the points, and $N$ is the number of points). We seek to find a translation vector $\mathbf{t}$ and rotation matrix $R$ to transform $P$ so that it aligns with $Q$. We can write this as the following optimization problem:

$$ \min_{\mathbf{t}^\ast, \ R^\ast} \mathcal{L}(\mathbf{t}, R) = \frac{1}{2} \sum_{i=1}^N || \mathbf{q_i} - R(\mathbf{p}_i + \mathbf{t}) ||^2 $$

where $\mathbf{t}^\ast \in \mathbb{R}^D$ and $R^\ast \in \mathbb{R}^{D \times D}$ are the optimal translation vector and rotation matrix, respectively.

Frequently, we want to solve a weighted version of this problem, where we weight the contribution of each point pair by some weight $w_i$, potentially the mass of the atom at that point in the case of molecular dynamics.

$$ \min_{\mathbf{t}^\ast, \ R^\ast} \mathcal{L}(\mathbf{t}, R) = \frac{1}{2} \sum_{i=1}^N w_i || \mathbf{q_i} - R(\mathbf{p}_i + \mathbf{t}) ||^2 $$

The Translation Vector

We can easily find the translation vector without needing to solve a minimization problem. It simply is the difference between the centroids of the two point sets. Since we’re rotating $\mathbf{p}_i$ to align with $\mathbf{q}_i$, we want to subtract the centroid of $P$ from the centroid of $Q$:

$$ \mathbf{t} = \bar{\mathbf{q}} - \bar{\mathbf{p}} $$

where $\bar{\mathbf{p}} = \frac{1}{N} \sum_{i=1}^N \mathbf{p_i}$ and $\bar{\mathbf{q}} = \frac{1}{N} \sum_{i=1}^N \mathbf{q}_i$.

As a sanity check, note that $$ \mathbf{p}_i = \bar{\mathbf{p}} + \mathbf{p}_i^\prime $$ where $\mathbf{p}_i^\prime = \mathbf{p}_i - \bar{\mathbf{p}}$, the residual after centering $\mathbf{p}_i$.

So, $$ \mathbf{p}_i + \mathbf{t} = \bar{\mathbf{p}} + \mathbf{p}_i^\prime + (\bar{\mathbf{q}} - \bar{\mathbf{p}}) = \bar{\mathbf{q}} + \mathbf{p}_i^\prime $$ In other words, we’re shifting the centroid of $P$ to the centroid of $Q$.

The Rotation Matrix

The rotation matrix is a bit trickier to compute, but only marginally so.

First, we need to center our point sets by subtracting their centroids from each point. $$ \mathbf{p}_i^\prime = \mathbf{p}_i - \bar{\mathbf{p}} \quad \text{and} \quad \mathbf{q}_i^\prime = \mathbf{q}_i - \bar{\mathbf{q}} $$

Once we’ve centered our point sets, we can compute the cross-covariance matrix $C \in \mathbb{R}^{D \times D}$ between the two point sets. $$ C = P^{\prime T} Q^\prime = \sum_{i=1}^N \mathbf{p}_i^{\prime T} \mathbf{q}_i^{\prime} \in \mathbb{R}^{D \times D} $$

We can then compute the singular value decomposition (SVD) of $C$. $$ C = U \Sigma V^T $$

Next, we may need to correct for improper rotations (i.e., reflections) by flipping the sign of the last column of $V$. We can check if this is necessary by computing the determinant of $VU^T$. $$ d = \text{sign}(\det(V U^T)) $$ If $d = -1$, we need to flip the sign of the last column of $V$. Otherwise, we can leave it as is. Let $B = \text{diag}(1, 1, d)$ be a diagonal matrix with $d$ on the last diagonal element.

Finally, we can compute the optimal rotation matrix $R^\ast$. $$ R^\ast = V B U^T $$

Summary

Now we can summarize the Kabsch algorithm as follows:

  • Compute the centroids of $P$ and $Q$.
  • Center $P$ and $Q$ by subtracting their centroids from each point.
  • Compute the cross-covariance matrix $C = P^{\prime T} Q^\prime$.
  • Compute the SVD of $C = U \Sigma V^T$.
  • Compute $d = \text{sign}(\det(V U^T))$.
  • Compute $B = \text{diag}(1, 1, d)$.
  • Compute $R^\ast = V B U^T$.
  • Compute $\mathbf{t}^\ast = \bar{\mathbf{p}} - \bar{\mathbf{q}}$.
  • Return $R^\ast$ and $\mathbf{t}^\ast$.

And to compare the two point sets, we can compute the RMSD between them: $$ \text{RMSD} = \sqrt{\frac{1}{N} \sum_{i=1}^N || \mathbf{q}_i - R^\ast(\mathbf{p}_i + \mathbf{t}^\ast) ||^2} $$

The Implementation

Without further ado, let’s write some code!

NumPy

import numpy as np


def kabsch_numpy(P, Q):
    """
    Computes the optimal rotation and translation to align two sets of points (P -> Q),
    and their RMSD.

    :param P: A Nx3 matrix of points
    :param Q: A Nx3 matrix of points
    :return: A tuple containing the optimal rotation matrix, the optimal
             translation vector, and the RMSD.
    """
    assert P.shape == Q.shape, "Matrix dimensions must match"

    # Compute centroids
    centroid_P = np.mean(P, axis=0)
    centroid_Q = np.mean(Q, axis=0)

    # Optimal translation
    t = centroid_Q - centroid_P

    # Center the points
    p = P - centroid_P
    q = Q - centroid_Q

    # Compute the covariance matrix
    H = np.dot(p.T, q)

    # SVD
    U, S, Vt = np.linalg.svd(H)

    # Validate right-handed coordinate system
    if np.linalg.det(np.dot(Vt.T, U.T)) < 0.0:
        Vt[-1, :] *= -1.0

    # Optimal rotation
    R = np.dot(Vt.T, U.T)

    # RMSD
    rmsd = np.sqrt(np.sum(np.square(np.dot(p, R.T) - q)) / P.shape[0])

    return R, t, rmsd

Furthermore, we can write a small sanity check to make sure that our implementation is correct.

def test_numpy():
    np.random.seed(12345)

    P = np.random.randn(100, 3)

    alpha = np.random.rand() * 2 * np.pi
    R = np.array([[np.cos(alpha), -np.sin(alpha), 0],
                    [np.sin(alpha), np.cos(alpha), 0],
                    [0, 0, 1]])
    t = np.random.randn(3) * 10

    Q = np.dot(P, R.T) + t

    R_opt, t_opt, rmsd = kabsch_numpy(P, Q)

    print('RMSD: {}'.format(rmsd))
    print('R:\n{}'.format(R))
    print('R_opt:\n{}'.format(R_opt))
    print('t:\n{}'.format(t))
    print('t_opt:\n{}'.format(t_opt))

    l2_t = np.linalg.norm(t - t_opt)
    l2_R = np.linalg.norm(R - R_opt)
    print('l2_t: {}'.format(l2_t))
    print('l2_R: {}'.format(l2_R))

Executing this test, I get the following output:

RMSD: 3.176703044042434e-15
R:
[[-0.8475392 -0.5307328  0.       ]
 [ 0.5307328 -0.8475392  0.       ]
 [ 0.         0.         1.       ]]
R_opt:
[[-8.47539198e-01 -5.30732803e-01 -8.26804079e-17]
 [ 5.30732803e-01 -8.47539198e-01 -8.32570352e-19]
 [ 8.20168388e-17  9.60990451e-18  1.00000000e+00]]
t:
[ 5.99726796  1.50078468 -3.34633977]
t_opt:
[ 6.10180512  1.50760274 -3.34633977]
l2_t: 0.10475926000353594
l2_R: 7.538724554724993e-16

where we can see that the RMSD is very small, and the optimal rotation and translation are very close to the ground truth.

How about a batched version?

def kabsch_numpy_batched(P, Q):
    """
    Computes the optimal rotation and translation to align two sets of points (P -> Q),
    and their RMSD.

    :param P: A BxNx3 matrix of points
    :param Q: A BxNx3 matrix of points
    :return: A tuple containing the optimal rotation matrix, the optimal
             translation vector, and the RMSD.
    """
    assert P.shape == Q.shape, "Matrix dimensions must match"

    # Compute centroids
    centroid_P = np.mean(P, axis=1, keepdims=True)  # Bx1x3
    centroid_Q = np.mean(Q, axis=1, keepdims=True)  # Bx1x3

    # Optimal translation
    t = centroid_Q - centroid_P  # Bx1x3
    t = t.squeeze(1)  # Bx3

    # Center the points
    p = P - centroid_P  # BxNx3
    q = Q - centroid_Q  # BxNx3

    # Compute the covariance matrix
    H = np.matmul(p.transpose(0, 2, 1), q)  # Bx3x3

    # SVD
    U, S, Vt = np.linalg.svd(H)  # Bx3x3

    # Validate right-handed coordinate system
    d = np.linalg.det(np.matmul(Vt.transpose(0, 2, 1), U.transpose(0, 2, 1)))
    flip = d < 0.0
    if flip.any():
        Vt[flip, -1, :] *= -1.0

    # Optimal rotation
    R = np.matmul(Vt.transpose(0, 2, 1), U.transpose(0, 2, 1))  # Bx3x3

    # RMSD
    rmsd = np.sqrt(np.sum(np.square(np.matmul(p, R.transpose(0, 2, 1)) - q), axis=(1, 2)) / P.shape[1])

    return R, t, rmsd

And as before, a simple sanity check:

def test_numpy_batched():
    np.random.seed(12345)

    P = np.random.randn(10, 100, 3)

    alpha = np.random.rand(10) * 2 * np.pi
    R = np.stack([np.array([[np.cos(a), -np.sin(a), 0],
                            [np.sin(a), np.cos(a), 0],
                            [0, 0, 1]]) for a in alpha], axis=0)
    t = np.random.randn(10, 3) * 10

    Q = np.matmul(P, R.transpose(0, 2, 1)) + t[:, None, :]

    R_opt, t_opt, rmsd = kabsch_numpy_batched(P, Q)

    print('RMSD: {}'.format(rmsd.mean()))

    l2_t = np.linalg.norm(t - t_opt, axis=1)
    l2_R = np.linalg.norm(R - R_opt, axis=(1, 2))
    print('l2_t: {}'.format(l2_t.mean()))
    print('l2_R: {}'.format(l2_R.mean()))

Executing this test, I get the following output:

RMSD: 3.751746246898761e-15
l2_t: 0.1473711613814949
l2_R: 7.667528292719723e-16

Looks good!

PyTorch

For the remaining implementations, since they will be similar, without testing (for now), I will just show the code.

import torch


def kabsch_torch(P, Q):
    """
    Computes the optimal rotation and translation to align two sets of points (P -> Q),
    and their RMSD.
    :param P: A Nx3 matrix of points
    :param Q: A Nx3 matrix of points
    :return: A tuple containing the optimal rotation matrix, the optimal
             translation vector, and the RMSD.
    """
    assert P.shape == Q.shape, "Matrix dimensions must match"

    # Compute centroids
    centroid_P = torch.mean(P, dim=0)
    centroid_Q = torch.mean(Q, dim=0)

    # Optimal translation
    t = centroid_Q - centroid_P

    # Center the points
    p = P - centroid_P
    q = Q - centroid_Q

    # Compute the covariance matrix
    H = torch.matmul(p.transpose(0, 1), q)

    # SVD
    U, S, Vt = torch.linalg.svd(H)

    # Validate right-handed coordinate system
    if torch.det(torch.matmul(Vt.transpose(0, 1), U.transpose(0, 1))) < 0.0:
        Vt[:, -1] *= -1.0

    # Optimal rotation
    R = torch.matmul(Vt.transpose(0, 1), U.transpose(0, 1))

    # RMSD
    rmsd = torch.sqrt(torch.sum(torch.square(torch.matmul(p, R.transpose(0, 1)) - q)) / P.shape[0])

    return R, t, rmsd

And our batched version:

def kabsch_torch_batched(P, Q):
    """
    Computes the optimal rotation and translation to align two sets of points (P -> Q),
    and their RMSD, in a batched manner.
    :param P: A BxNx3 matrix of points
    :param Q: A BxNx3 matrix of points
    :return: A tuple containing the optimal rotation matrix, the optimal
             translation vector, and the RMSD.
    """
    assert P.shape == Q.shape, "Matrix dimensions must match"

    # Compute centroids
    centroid_P = torch.mean(P, dim=1, keepdims=True)  # Bx1x3
    centroid_Q = torch.mean(Q, dim=1, keepdims=True)  #

    # Optimal translation
    t = centroid_Q - centroid_P  # Bx1x3
    t = t.squeeze(1)  # Bx3

    # Center the points
    p = P - centroid_P  # BxNx3
    q = Q - centroid_Q  # BxNx3

    # Compute the covariance matrix
    H = torch.matmul(p.transpose(1, 2), q)  # Bx3x3

    # SVD
    U, S, Vt = torch.linalg.svd(H)  # Bx3x3

    # Validate right-handed coordinate system
    d = torch.det(torch.matmul(Vt.transpose(1, 2), U.transpose(1, 2)))  # B
    flip = d < 0.0
    if flip.any().item():
        Vt[flip, -1] *= -1.0

    # Optimal rotation
    R = torch.matmul(Vt.transpose(1, 2), U.transpose(1, 2))

    # RMSD
    rmsd = torch.sqrt(torch.sum(torch.square(torch.matmul(p, R.transpose(1, 2)) - q), dim=(1, 2)) / P.shape[1])

    return R, t, rmsd

TensorFlow

With TensorFlow, we have to be a bit careful since its SVD implementation’s output is not the same as the other two. Specifically, the V matrix is not transposed and S is returned first. See my implementation below:

import tensorflow as tf


def kabsch_tensorflow(P, Q):
    """
    Computes the optimal rotation and translation to align two sets of points (P -> Q),
    and their RMSD.

    :param P: A Nx3 matrix of points
    :param Q: A Nx3 matrix of points
    :return: A tuple containing the optimal rotation matrix, the optimal
             translation vector, and the RMSD.
    """
    assert P.shape == Q.shape, "Matrix dimensions must match"

    # Compute centroids
    centroid_P = tf.reduce_mean(P, axis=0)
    centroid_Q = tf.reduce_mean(Q, axis=0)

    # Optimal translation
    t = centroid_Q - centroid_P

    # Center the points
    p = P - centroid_P
    q = Q - centroid_Q

    # Compute the covariance matrix
    H = tf.matmul(tf.transpose(p), q)

    # SVD
    S, U, V = tf.linalg.svd(H)

    # Validate right-handed coordinate system
    d = tf.linalg.det(tf.matmul(V, tf.transpose(U)))
    if d < 0.0:
        V[-1, :] *= -1.0

    # Optimal rotation
    R = tf.matmul(V, tf.transpose(U))

    # RMSD
    rmsd = tf.sqrt(tf.reduce_sum(tf.square(tf.matmul(p, tf.transpose(R)) - q)) / P.shape[0])

    return R, t, rmsd

and a batched version:

def kabsch_tensorflow_batched(P, Q):
    """
    Computes the optimal rotation and translation to align two sets of points (P -> Q),
    and their RMSD.

    :param P: A Nx3 matrix of points
    :param Q: A Nx3 matrix of points
    :return: A tuple containing the optimal rotation matrix, the optimal
             translation vector, and the RMSD.
    """
    assert P.shape == Q.shape, "Matrix dimensions must match"

    # Compute centroids
    centroid_P = tf.reduce_mean(P, axis=1, keepdims=True)
    centroid_Q = tf.reduce_mean(Q, axis=1, keepdims=True)

    # Optimal translation
    t = centroid_Q - centroid_P
    t = tf.squeeze(t, axis=1)

    # Center the points
    p = P - centroid_P
    q = Q - centroid_Q

    # Compute the covariance matrix
    H = tf.matmul(tf.transpose(p, perm=[0, 2, 1]), q)

    # SVD
    S, U, V = tf.linalg.svd(H)

    # Validate right-handed coordinate system
    d = tf.linalg.det(tf.matmul(V, tf.transpose(U, perm=[0, 2, 1])))
    flip = tf.squeeze(d < 0.0)
    if tf.reduce_any(flip):
        V = tf.where(flip[:, None, None], -V, V)

    # Optimal rotation
    R = tf.matmul(V, tf.transpose(U, perm=[0, 2, 1]))

    # RMSD
    rmsd = tf.sqrt(tf.reduce_sum(tf.square(tf.matmul(p, tf.transpose(R, perm=[0, 2, 1])) - q), axis=(1, 2)) / P.shape[1])

    return R, t, rmsd

JAX

The JAX implementation largely mirrors the numpy one, replacing np with jnp and np.linalg with jnp.linalg. where np is numpy and jnp is jax.numpy.

import jax.numpy as jnp


def kabsch_jax(P, Q):
    """
    Computes the optimal rotation and translation to align two sets of points (P -> Q),
    and their RMSD.

    :param P: A Nx3 matrix of points
    :param Q: A Nx3 matrix of points
    :return: A tuple containing the optimal rotation matrix, the optimal
             translation vector, and the RMSD.
    """
    assert P.shape == Q.shape, "Matrix dimensions must match"

    # Compute centroids
    centroid_P = jnp.mean(P, axis=0)
    centroid_Q = jnp.mean(Q, axis=0)

    # Optimal translation
    t = centroid_Q - centroid_P

    # Center the points
    p = P - centroid_P
    q = Q - centroid_Q

    # Compute the covariance matrix
    H = jnp.dot(p.T, q)

    # SVD
    U, S, Vt = jnp.linalg.svd(H)

    # Validate right-handed coordinate system
    if jnp.linalg.det(jnp.dot(Vt.T, U.T)) < 0.0:
        Vt[-1, :] *= -1.0

    # Optimal rotation
    R = jnp.dot(Vt.T, U.T)

    # RMSD
    rmsd = jnp.sqrt(jnp.sum(jnp.square(jnp.dot(p, R.T) - q)) / P.shape[0])

    return R, t, rmsd

and batched:

def kabsch_jax_batched(P, Q):
    """
    Computes the optimal rotation and translation to align two sets of points (P -> Q),
    and their RMSD.

    :param P: A BxNx3 matrix of points
    :param Q: A BxNx3 matrix of points
    :return: A tuple containing the optimal rotation matrix, the optimal
             translation vector, and the RMSD.
    """
    assert P.shape == Q.shape, "Matrix dimensions must match"

    # Compute centroids
    centroid_P = jnp.mean(P, axis=1, keepdims=True)  # Bx1x3
    centroid_Q = jnp.mean(Q, axis=1, keepdims=True)  # Bx1x3

    # Optimal translation
    t = centroid_Q - centroid_P  # Bx1x3
    t = t.squeeze(1)  # Bx3

    # Center the points
    p = P - centroid_P  # BxNx3
    q = Q - centroid_Q  # BxNx3

    # Compute the covariance matrix
    H = jnp.matmul(p.transpose(0, 2, 1), q)  # Bx3x3

    # SVD
    U, S, Vt = jnp.linalg.svd(H)  # Bx3x3

    # Validate right-handed coordinate system
    d = jnp.linalg.det(jnp.matmul(Vt.transpose(0, 2, 1), U.transpose(0, 2, 1)))
    flip = d < 0.0
    if flip.any():
        Vt[flip, -1, :] *= -1.0

    # Optimal rotation
    R = jnp.matmul(Vt.transpose(0, 2, 1), U.transpose(0, 2, 1))  # Bx3x3

    # RMSD
    rmsd = jnp.sqrt(jnp.sum(jnp.square(jnp.matmul(p, R.transpose(0, 2, 1)) - q), axis=(1, 2)) / P.shape[1])

    return R, t, rmsd

What’s next?

So that’s the Kabsch algorithm in a very straightforward way. However, the story does not end there! For example, the Kabsch algorithm is frequently used in its quaternion form, which is more numerically stable. We also may want to consider iterative versions of the algorithm, which are more robust to noise or have better scaling properties. Additionally, we may need a weighted version of the algorithm, which is useful for aligning point clouds with different densities. (Scipy has a version of the weighted Kabsch algorithm, but it is not batched. Scipy, weighted Kabsch)

But for the time being, I hope this post was useful to you!

References

A short list of references that I found useful:

  • Wikipedia, Kabsch Algorithm
  • [Kabsch, 1976] Kabsch, W. (1976). A solution for the best rotation to relate two sets of vectors. Acta. Crystal, 32A:922-923.
  • [Kabsch, 1978] Kabsch, W. (1978). A discussion of the solution for the best rotation to related two sets of vectors. Acta. Crystal, 34A:827-828.
  • Zalo on Kabsch: A really cool blog post with an interactive shape matching demo.