<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:content="http://purl.org/rss/1.0/modules/content/"><channel><title>Tutorial on Hunter Heidenreich | Senior AI Research Scientist</title><link>https://hunterheidenreich.com/tags/tutorial/</link><description>Recent content in Tutorial on Hunter Heidenreich | Senior AI Research Scientist</description><image><title>Hunter Heidenreich | Senior AI Research Scientist</title><url>https://hunterheidenreich.com/img/avatar.webp</url><link>https://hunterheidenreich.com/img/avatar.webp</link></image><generator>Hugo -- 0.147.7</generator><language>en-US</language><copyright>2026 Hunter Heidenreich</copyright><lastBuildDate>Sat, 30 May 2026 00:00:00 +0000</lastBuildDate><atom:link href="https://hunterheidenreich.com/tags/tutorial/index.xml" rel="self" type="application/rss+xml"/><item><title>Embedded-Atom Method User Guide: Voter's 1994 Chapter</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/classical-methods/embedded-atom-method-voter-1994/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/classical-methods/embedded-atom-method-voter-1994/</guid><description>Comprehensive user guide for the Embedded-Atom Method (EAM), covering theory, potential fitting, and applications to intermetallics.</description><content:encoded><![CDATA[<h2 id="contribution-systematizing-the-embedded-atom-method">Contribution: Systematizing the Embedded-Atom Method</h2>
<p>This is a <strong>Systematization</strong> paper (specifically a handbook chapter) with a strong secondary <strong>Method</strong> projection.</p>
<p>Its primary goal is to serve as a &ldquo;users&rsquo; guide&rdquo; to the Embedded-Atom Method (EAM). The text organizes existing knowledge:</p>
<ul>
<li>It traces the physical origins of EAM from Density Functional Theory (DFT) and Effective Medium Theory.</li>
<li>It synthesizes &ldquo;closely related methods&rdquo; (Second Moment Approximation, Glue Model), showing they are mathematically equivalent or very similar to EAM.</li>
<li>It provides a pedagogical, step-by-step methodology for fitting potentials to experimental data.</li>
</ul>
<h2 id="motivation-bridging-the-gap-between-dft-and-pair-potentials">Motivation: Bridging the Gap Between DFT and Pair Potentials</h2>
<p>The primary motivation is to bridge the gap between accurate, expensive electronic structure calculations and fast, inaccurate pair potentials.</p>
<ul>
<li><strong>Computational Efficiency</strong>: First-principles methods scale as $O(N^3)$ or worse, limiting simulations to $&lt;100$ atoms (in 1994). Pair potentials scale as $O(N)$ and fail to capture essential many-body physics of metals.</li>
<li><strong>Physical Accuracy</strong>: Simple pair potentials cannot accurately model metallic defects; they predict zero Cauchy pressure ($C_{12} - C_{44} = 0$) and equate vacancy formation energy to cohesive energy, both of which are incorrect for transition metals.</li>
<li><strong>Practical Utility</strong>: There was a need for a clear guide on how to construct and apply these potentials for large-scale simulations ($10^6+$ atoms) of fracture and defects.</li>
</ul>
<h2 id="novelty-a-unified-framework-and-robust-fitting-recipe">Novelty: A Unified Framework and Robust Fitting Recipe</h2>
<p>As a review chapter, the novelty lies in the synthesis and the specific, reproducible recipe for potential construction. Central to this synthesis is the core EAM energy functional:</p>
<p>$$E_{\text{tot}} = \sum_i \left( F(\bar{\rho}_i) + \frac{1}{2} \sum_{j \neq i} \phi(r_{ij}) \right)$$</p>
<p>where the total energy $E_{\text{tot}}$ depends on embedding an atom $i$ into a local background electron density $\bar{\rho}_i = \sum_{j \neq i} \rho(r_{ij})$, plus a repulsive pair interaction $\phi(r_{ij})$.</p>
<ul>
<li><strong>Unified Framework</strong>: It explicitly maps the &ldquo;Second Moment Approximation&rdquo; (Tight Binding) and the &ldquo;Glue Model&rdquo; onto the fundamental EAM framework above, clarifying that they differ primarily in terminology or specific functional choices (e.g., square root embedding functions).</li>
<li><strong>Cross-Potential Fitting Recipe</strong>: It details a robust method for fitting alloy potentials (specifically Ni-Al-B) by using &ldquo;transformation invariance&rdquo;, scaling the density and shifting the embedding function to fit alloy properties without disturbing pure element fits.</li>
<li><strong>Specific Parameters</strong>: It publishes optimized potential parameters for Ni, Al, and B that accurately reproduce properties like the Boron interstitial preference in $\text{Ni}_3\text{Al}$.</li>
</ul>
<h2 id="validation-computational-benchmarks-and-simulations">Validation: Computational Benchmarks and Simulations</h2>
<p>The &ldquo;experiments&rdquo; described are computational validations and simulations using the fitted Ni-Al-B potential:</p>
<ol>
<li>
<p><strong>Potential Fitting</strong>:</p>
<ul>
<li>Pure elements (Ni, Al) were fitted to elastic constants, vacancy formation energies, and diatomic data. The Ni fit achieved $\chi_{\text{rms}} = 0.75%$ and Al achieved $\chi_{\text{rms}} = 3.85%$.</li>
<li>Boron was fitted using hypothetical crystal structures (fcc, bcc) calculated via LMTO (Linear Muffin-Tin Orbital) since experimental data for fcc B does not exist.</li>
</ul>
</li>
<li>
<p><strong>Molecular Statics (Validation)</strong>:</p>
<ul>
<li><strong>Surface Relaxation</strong>: Demonstrated that EAM captures the oscillatory relaxation of atomic layers near a free surface, a many-body effect that pair potentials fail to capture.</li>
<li><strong>Defect Energetics</strong>: Calculated formation energies for Boron interstitials in $\text{Ni}_3\text{Al}$. Found the 6Ni-octahedral site is most stable ($-4.59$ eV relative to an isolated B atom and unperturbed crystal), followed by the 4Ni-2Al octahedral site ($-3.65$ eV) and the 3Ni-1Al tetrahedral site ($-2.99$ eV), consistent with channeling experiments.</li>
</ul>
</li>
<li>
<p><strong>Molecular Dynamics (Application)</strong>:</p>
<ul>
<li><strong>Grain Boundary (GB) Cleavage</strong>: Simulated the fracture of a (210) tilt grain boundary in $\text{Ni}_3\text{Al}$ at a strain rate of $5 \times 10^{10}$ s$^{-1}$.</li>
<li><strong>Comparison</strong>: Compared pure $\text{Ni}_3\text{Al}$ boundaries vs. those doped with Boron and substitutional Nickel.</li>
</ul>
</li>
</ol>
<h2 id="key-outcomes-eam-efficiency-and-boron-strengthening">Key Outcomes: EAM Efficiency and Boron Strengthening</h2>
<ul>
<li><strong>EAM Efficiency</strong>: Confirmed that EAM scales linearly with atom count ($N$), requiring only 2-5 times the computational work of pair potentials.</li>
<li><strong>Boron Strengthening Mechanism</strong>: The simulations suggested that Boron segregates to grain boundaries and, specifically when co-segregated with Ni, significantly increases cohesion.
<ul>
<li>The maximum stress for the enriched boundary was approximately 22 GPa, compared to approximately 19 GPa for the clean boundary.</li>
<li>The B-doped boundary required approximately 44% more work to cleave than the undoped boundary.</li>
<li>The fracture mode shifted from cleaving along the GB to failure in the bulk.</li>
</ul>
</li>
<li><strong>Grain Boundary Segregation</strong>: Molecular statics calculations found B interstitial energies at the GB as low as $-6.9$ eV, compared to $-4.59$ eV in the bulk, consistent with experimental observations of boron segregation to grain boundaries.</li>
<li><strong>Limitations</strong>: The author concludes that while EAM is excellent for metals, it lacks the angular dependence required for strongly covalent materials (like $\text{MoSi}_2$) or directional bonding.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>The chapter provides nearly all details required to implement the described potential from scratch.</p>
<h3 id="data">Data</h3>
<ul>
<li><strong>Experimental/Reference Data</strong>: Used for fitting the cost function $\chi_{\text{rms}}$.
<ul>
<li><strong>Pure Elements</strong>: Lattice constants ($a_0$), cohesive energy ($E_{\text{coh}}$), bulk modulus ($B$), elastic constants ($C_{11}, C_{12}, C_{44}$), vacancy formation energy ($E_{\text{vac}}^f$), and diatomic bond length/strength ($R_e, D_e$).</li>
<li><strong>Alloys</strong>: Heat of solution and defect energies (APB, SISF) for $\text{Ni}_3\text{Al}$.</li>
<li><strong>Hypothetical Data</strong>: LMTO first-principles data used for unobserved phases (e.g., fcc Boron, B2 NiB) to constrain the fit.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Component Functions</strong>:
<ul>
<li><strong>Pair Potential $\phi(r)$</strong>: Morse potential form:
$$\phi(r) = D_M {1 - \exp[-\alpha_M(r - R_M)]}^2 - D_M$$</li>
<li><strong>Density Function $\rho(r)$</strong>: Modified hydrogenic 4s orbital:
$$\rho(r) = r^6(e^{-\beta r} + 2^9 e^{-2\beta r})$$</li>
<li><strong>Embedding Function $F(\bar{\rho})$</strong>: Derived numerically to force the crystal energy to match the &ldquo;Universal Energy Relation&rdquo; (Rose et al.) as a function of lattice constant.</li>
</ul>
</li>
<li><strong>Fitting Strategy</strong>:
<ul>
<li><strong>Smooth Cutoff</strong>: A polynomial smoothing function ($h_{\text{smooth}}$) applied at $r_{\text{cut}}$ to ensure continuous derivatives.</li>
<li><strong>Simplex Algorithm</strong>: Used to optimize parameters ($D_M, R_M, \alpha_M, \beta, r_{\text{cut}}$).</li>
<li><strong>Alloy Invariance</strong>: Used transformations $F&rsquo;(\rho) = F(\rho) + g\rho$ and $\rho&rsquo;(r) = s\rho(r)$ to fit cross-potentials without altering pure-element properties.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Parameters</strong>: The text provides the exact optimized parameters for the Ni-Al-B potential in <strong>Table 2</strong> (Pure elements) and <strong>Table 5</strong> (Cross-potentials).
<ul>
<li>Example Ni parameters: $D_M=1.5335$ eV, $\alpha_M=1.7728$ Å$^{-1}$, $r_{\text{cut}}=4.7895$ Å.</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>1994 Context</strong>: Mentions that simulations of $10^6$ atoms were possible on the &ldquo;fastest computers available&rdquo;.</li>
<li><strong>Scaling</strong>: Explicitly notes computational work scales as $O(N)$, roughly 2-5x slower than pair potentials.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Voter, A. F. (1994). Chapter 4: The Embedded-Atom Method. In <em>Intermetallic Compounds: Vol. 1, Principles</em>, edited by J. H. Westbrook and R. L. Fleischer. John Wiley &amp; Sons Ltd.</p>
<p><strong>Publication</strong>: Intermetallic Compounds: Vol. 1, Principles (1994)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@incollection</span>{voterEmbeddedAtomMethod1994,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{The Embedded-Atom Method}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Voter, Arthur F.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Intermetallic Compounds: Vol. 1, Principles}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">editor</span> = <span style="color:#e6db74">{Westbrook, J. H. and Fleischer, R. L.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{1994}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{John Wiley &amp; Sons Ltd}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{77--90}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">chapter</span> = <span style="color:#e6db74">{4}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://www.ctcms.nist.gov/potentials/">NIST Interatomic Potentials Repository</a> (Modern repository often hosting EAM files)</li>
<li><a href="/notes/chemistry/molecular-simulation/classical-methods/embedded-atom-method/">Original EAM Paper (1984)</a></li>
<li><a href="/notes/chemistry/molecular-simulation/classical-methods/embedded-atom-method-review-1993/">EAM Review (1993)</a></li>
</ul>
]]></content:encoded></item><item><title>Importance Weighted Autoencoders: Beyond the Standard VAE</title><link>https://hunterheidenreich.com/posts/importance-weighted-autoencoders/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/importance-weighted-autoencoders/</guid><description>The key difference between multi-sample VAEs and IWAEs: how log-of-averages creates a tighter bound on log-likelihood.</description><content:encoded><![CDATA[<p>If you&rsquo;ve worked with Variational Autoencoders (VAEs), you&rsquo;ve almost certainly used the standard $\mathcal{L}_1$ objective, or ELBO. It&rsquo;s trained by taking <em>one</em> sample ($k=1$) from the recognition network to calculate the loss.</p>
<p>A natural question follows: &ldquo;What if I use more samples? Won&rsquo;t that make it better?&rdquo;</p>
<p>Using more samples improves performance when paired with the correct objective function. Averaging the loss over $k$ samples yields minimal gains. Changing the objective itself is where the real gain comes from. This post explores the difference between a &ldquo;multi-sample VAE&rdquo; and the <strong>Importance Weighted Autoencoder (IWAE)</strong>, a model that uses the <em>same architecture</em> as a VAE but is trained with a different objective that optimizes a tighter bound on the log-likelihood.</p>
<p>All ideas here are based on the fantastic paper: <a href="https://arxiv.org/abs/1509.00519">&ldquo;Importance Weighted Autoencoders&rdquo;</a> by Burda, Grosse, and Salakhutdinov.</p>
<h2 id="the-two-ways-to-use-k-samples">The Two Ways to Use $k$ Samples</h2>
<p>Let&rsquo;s say we have our encoder $q(h|x)$ and decoder $p(x,h)$. We decide to use $k=5$ samples instead of $k=1$. We have two main options for how to calculate our loss.</p>
<h3 id="option-1-the-multi-sample-vae-the-naive-way">Option 1: The &ldquo;Multi-Sample VAE&rdquo; (The Naive Way)</h3>
<p>This is the most straightforward idea. For each input $x$ in our batch:</p>
<ol>
<li>Draw 5 samples ($h_1, &hellip;, h_5$) from $q(h|x)$.</li>
<li>Calculate the standard VAE $\mathcal{L}_1$ loss for <em>each</em> sample.</li>
<li>Average these 5 losses together.</li>
</ol>
<p>This is an <strong>average of logs</strong>. As the IWAE paper shows experimentally, this approach gives you a more stable gradient, but the final performance (in terms of log-likelihood) is &ldquo;only slightly&rdquo; better. You&rsquo;re paying a 5x computational cost for a marginal gain because you&rsquo;re still optimizing the same &ldquo;loose&rdquo; $\mathcal{L}_1$ bound.</p>
<h3 id="option-2-the-importance-weighted-autoencoder-iwae-the-right-way">Option 2: The Importance Weighted Autoencoder (IWAE) (The Right Way)</h3>
<p>The IWAE takes a different approach. For each input $x$:</p>
<ol>
<li>Draw 5 samples ($h_1, &hellip;, h_5$) from $q(h|x)$.</li>
<li>Calculate an &ldquo;importance weight&rdquo; $w_i$ for each sample.</li>
<li>Average these 5 <em>weights</em> together.</li>
<li>Take the <em>logarithm</em> of that average.</li>
</ol>
<p>This is a <strong>log of an average</strong>, and the difference matters.</p>
<h2 id="the-math-average-of-logs-vs-log-of-averages">The Math: Average-of-Logs vs. Log-of-Averages</h2>
<p>Let&rsquo;s make this concrete. The standard VAE $\mathcal{L}_1$ objective is:</p>
<p>$$
\mathcal{L}_1(x) = \mathbb{E} _{h\sim q(h|x)} \left[ \log \frac{p(x,h)}{q(h|x)} \right]
$$</p>
<p>A <strong>multi-sample VAE</strong> simply gets a better estimate of this same value:</p>
<p>$$
\mathcal{L} _{\text{VAE}, k}(x) \approx  \frac{1}{k} \sum _{i=1}^{k} \log w_i \quad \text{where} \quad w_i = \frac{p(x,h_i)}{q(h_i|x)}
$$</p>
<p>The <strong>IWAE</strong> objective, $\mathcal{L}_k$, is fundamentally different:</p>
<p>$$
\mathcal{L} _k (x) = \mathbb{E} _{h_1..h_k \sim q(h|x)} \left[ \log \left( \frac{1}{k} \sum _{i=1}^{k} \frac{p(x,h_i)}{q(h_i|x)} \right) \right]
$$</p>
<p>In practice, we estimate this with a single Monte Carlo sample (of $k$ latents):</p>
<p>$$
\mathcal{L} _k (x) \approx \log \left( \frac{1}{k} \sum _{i=1}^{k} w_i \right)
$$</p>
<p>Because the logarithm is a concave function, Jensen&rsquo;s Inequality tells us that the &ldquo;log of an average&rdquo; is <em>always</em> greater than or equal to the &ldquo;average of logs.&rdquo;</p>
<p>$$
\mathcal{L}_k(x) \ge \mathcal{L}_1(x)
$$</p>
<p>This means the IWAE is optimizing a <strong>strictly tighter lower bound</strong> on the true log-likelihood of the data.</p>
<h2 id="why-does-this-log-of-average-matter">Why Does This &ldquo;Log-of-Average&rdquo; Matter?</h2>
<p>This mathematical property provides two practical benefits.</p>
<h3 id="1-better-density-estimation">1. Better Density Estimation</h3>
<p>Because $\mathcal{L}_k$ is a tighter bound on the true $p(x)$, optimizing it pushes the model to learn a much better generative distribution. The paper shows that IWAEs achieve &ldquo;significantly higher log-likelihoods&rdquo; than VAEs.</p>
<h3 id="2-richer-latent-representations">2. Richer Latent Representations</h3>
<p>This is the most interesting part. The standard VAE $\mathcal{L}_1$ objective &ldquo;harshly penalizes&rdquo; the model if its <em>one</em> sample $h$ is a poor explanation for $x$. This pressure forces the recognition network $q(h|x)$ to be &ldquo;overly simplified&rdquo; to avoid bad samples, which can lead to many latent dimensions becoming inactive (the paper reports the number of &ldquo;active units&rdquo; per model).</p>
<p>The IWAE objective is more flexible. It only needs <em>one</em> of the $k$ samples to be good. This &ldquo;increased flexibility&rdquo; allows the model to learn far more complex posterior distributions and &ldquo;richer latent space representations.&rdquo; The paper&rsquo;s experiments confirm this, showing IWAEs learn to use many more &ldquo;active units&rdquo; in their latent space.</p>
<h2 id="what-this-looks-like-in-code-pytorch">What This Looks Like in Code (PyTorch)</h2>
<p>The implementation difference makes this crystal clear.</p>
<p>First, the &ldquo;k-sample&rdquo; trick: for a batch <code>x</code> of shape <code>[B, D]</code> and <code>k=5</code> samples, we repeat <code>x</code> to get <code>x_repeated</code> of shape <code>[B*k, D]</code>. We do all our forward passes on this large tensor.</p>
<h3 id="vae-multi-sample-k--1-loss">VAE (Multi-Sample, k &gt; 1) Loss</h3>
<p>Here, we can still use the analytical KL divergence, which is a big simplification.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># x_repeated has shape [B*k, 784]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># mu, logvar have shape [B*k, latent_dim]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># recon_x has shape [B*k, 784]</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># recon_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span>recon_loss_all <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>binary_cross_entropy(recon_x, x_repeated, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># kl_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># We use the simple, analytical KL term!</span>
</span></span><span style="display:flex;"><span>kl_loss_all <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(<span style="color:#ae81ff">1</span> <span style="color:#f92672">+</span> logvar <span style="color:#f92672">-</span> mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> logvar<span style="color:#f92672">.</span>exp(), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># total_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span>total_loss_all <span style="color:#f92672">=</span> recon_loss_all <span style="color:#f92672">+</span> kl_loss_all
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- The Key Step ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Just average all B*k losses. This is the &#34;average of logs&#34;.</span>
</span></span><span style="display:flex;"><span>loss <span style="color:#f92672">=</span> total_loss_all<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="iwae-k--1-loss">IWAE (k &gt; 1) Loss</h3>
<p>Here, we must compute the exact log-probabilities of the <em>specific samples</em> we drew.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Helper function to compute log-prob of a sample from a Gaussian</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">log_prob_gaussian</span>(sample, mu, logvar):
</span></span><span style="display:flex;"><span>    const <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sample<span style="color:#f92672">.</span>shape[<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>] <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>log(<span style="color:#ae81ff">2</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>tensor(math<span style="color:#f92672">.</span>pi))
</span></span><span style="display:flex;"><span>    log_det <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(logvar, dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    log_exp <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum((sample <span style="color:#f92672">-</span> mu)<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">/</span> torch<span style="color:#f92672">.</span>exp(logvar), dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> const <span style="color:#f92672">+</span> log_det <span style="color:#f92672">+</span> log_exp
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Get the 3 log-prob components ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># x_repeated, recon_x, z_samples, mu_repeated, logvar_repeated</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># all have a first dimension of [B*k]</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 1. log p(x|h_i): Log-Reconstruction Probability</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_p_x_given_h shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_p_x_given_h <span style="color:#f92672">=</span> <span style="color:#f92672">-</span>F<span style="color:#f92672">.</span>binary_cross_entropy(recon_x, x_repeated, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 2. log p(h_i): Log-Prior Probability (under N(0, I))</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_p_h shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_p_h <span style="color:#f92672">=</span> log_prob_gaussian(z_samples, <span style="color:#ae81ff">0.0</span>, <span style="color:#ae81ff">0.0</span>) <span style="color:#75715e"># mu=0, logvar=0</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 3. log q(h_i|x): Log-Encoder Probability</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_q_h_given_x shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_q_h_given_x <span style="color:#f92672">=</span> log_prob_gaussian(z_samples, mu_repeated, logvar_repeated)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- The Key Step ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Combine to get the log-importance-weight</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_w shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_w <span style="color:#f92672">=</span> log_p_x_given_h <span style="color:#f92672">+</span> log_p_h <span style="color:#f92672">-</span> log_q_h_given_x
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Reshape to [B, k] to group samples by their original input</span>
</span></span><span style="display:flex;"><span>log_w_matrix <span style="color:#f92672">=</span> log_w<span style="color:#f92672">.</span>view(B, k) <span style="color:#75715e"># B is original batch size</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Apply the IWAE Objective (Log-Sum-Exp Trick) ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># This is the &#34;log of the average&#34;</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log( (1/k) * sum(exp(log_w)) ) = logsumexp(log_w) - log(k)</span>
</span></span><span style="display:flex;"><span>log_iwae_bound_per_x <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>logsumexp(log_w_matrix, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> math<span style="color:#f92672">.</span>log(k)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># The objective is to MAXIMIZE this bound, so the loss is its negative</span>
</span></span><span style="display:flex;"><span>loss <span style="color:#f92672">=</span> <span style="color:#f92672">-</span>log_iwae_bound_per_x<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="the-critical-implementation-detail">The Critical Implementation Detail</h3>
<p>Notice the key difference in the final step:</p>
<ul>
<li><strong>VAE</strong>: <code>loss = total_loss_all.mean()</code> average of individual losses</li>
<li><strong>IWAE</strong>: <code>loss = -torch.logsumexp(log_w_matrix, dim=1).mean()</code> log of averaged weights</li>
</ul>
<p>This seemingly small change implements the fundamental mathematical difference between optimizing an &ldquo;average of logs&rdquo; versus a &ldquo;log of averages.&rdquo;</p>
<h2 id="when-to-use-each-approach">When to Use Each Approach</h2>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>When to Use</th>
          <th>Key Benefit</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>VAE ($k=1$)</strong></td>
          <td>Your <strong>default baseline</strong>. It&rsquo;s fast, simple, and often &ldquo;good enough&rdquo; for many tasks.</td>
          <td>Speed and simplicity.</td>
      </tr>
      <tr>
          <td><strong>Multi-Sample VAE ($k&gt;1$)</strong></td>
          <td>When you want slightly more stable gradients but aren&rsquo;t ready for the full IWAE complexity.</td>
          <td>Marginal improvement with minimal code changes.</td>
      </tr>
      <tr>
          <td><strong>IWAE ($k&gt;1$)</strong></td>
          <td>When your baseline VAE is <strong>insufficient</strong>. Specifically, if you need:<br>1. The best possible log-likelihood.<br>2. To activate more latent dimensions or learn richer representations.</td>
          <td>Better performance and richer latents, at the cost of compute (scales linearly with $k$).</td>
      </tr>
  </tbody>
</table>
<h2 id="the-computational-trade-off">The Computational Trade-off</h2>
<p>Both approaches scale linearly with $k$. If you use $k=5$ samples, you&rsquo;re doing roughly 5x the computation. The question is whether you get 5x the benefit.</p>
<p>For multi-sample VAEs, the answer is usually &ldquo;no&rdquo;. You get more stable gradients but only marginal performance improvements.</p>
<p>For IWAEs, the answer is often &ldquo;yes&rdquo;. You get meaningfully better log-likelihoods and richer latent representations that can be worth the computational cost.</p>
<h2 id="conclusion">Conclusion</h2>
<p>The next time you use more samples with your VAE, switch to the IWAE objective to get the full benefit of the computational cost of $k &gt; 1$.</p>
<p>The mathematical insight is simple but powerful: Jensen&rsquo;s Inequality tells us that the &ldquo;log of an average&rdquo; is always greater than or equal to the &ldquo;average of logs.&rdquo; By optimizing this tighter bound, IWAEs achieve better density estimation and learn richer latent representations than standard VAEs.</p>
<p>The implementation requires computing exact log-probabilities to evaluate the specific samples. The result is a fundamentally more powerful model using the exact same architecture.</p>
<p><strong>Want to dive deeper?</strong> Check out the <a href="https://arxiv.org/abs/1509.00519">original IWAE paper</a> for experimental results and theoretical analysis, or explore my <a href="/posts/modern-variational-autoencoder-in-pytorch/">VAE tutorial</a> for hands-on implementation details.</p>
]]></content:encoded></item><item><title>What is Optical Chemical Structure Recognition (OCSR)?</title><link>https://hunterheidenreich.com/posts/what-is-ocsr/</link><pubDate>Sat, 11 Oct 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/what-is-ocsr/</guid><description>A micro-review of Optical Chemical Structure Recognition (OCSR), covering rule-based systems to modern deep learning models.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>Decades of chemical research, breakthroughs in medicine, and novel materials are archived in journals, patents, and textbooks.
A huge portion of this knowledge is stored as images, a format inaccessible to standard computational tools.
This imposes challenges for both data retrieval and leveraging modern computational tools to analyze and predict chemical properties, inefficiencies that compound across the literature: knowledge locked in image form is invisible to search, mining, and downstream model training.</p>
<p>This is the central challenge that <strong>Optical Chemical Structure Recognition (OCSR)</strong> aims to solve. At its heart, OCSR is to chemistry what OCR (Optical Character Recognition) is to text: a technology that teaches computers to extract chemical information directly from 2D diagrams of molecules. It&rsquo;s the bridge between a picture of a molecule and a machine-readable format like <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> (Simplified Molecular Input Line Entry System) that can be stored, searched, and used to power new discoveries.</p>















<figure class="post-figure center ">
    <img src="/img/ocsr/img2smiles.webp"
         alt="The transformation from a 2D chemical structure image to a SMILES representation."
         title="The transformation from a 2D chemical structure image to a SMILES representation."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The transformation from a 2D chemical structure image to a SMILES representation.</figcaption>
    
</figure>

<p>Teaching a computer to read a chemical structure requires specialized techniques.</p>
<h2 id="the-complexity-of-chemical-graphs">The Complexity of Chemical Graphs</h2>
<p>Recognizing a molecule requires specialized techniques that extend standard Optical Character Recognition (OCR). A molecule is a <em>graph</em>: a collection of atoms (nodes) connected by bonds (edges).</p>
<blockquote>
<p>(While this simplified view excludes complex structures like coordination compounds and polymers, it provides a highly effective starting point for this discussion.)</p></blockquote>
<p>An OCSR system must overcome several hurdles:</p>
<ul>
<li><strong>Varying Styles:</strong> Chemical drawings vary widely across publications. Bond lengths, angles, and fonts can differ dramatically from one document to another.</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/ocsr/acs.orglett.2c02187_1.webp"
         alt="An example from the Colored Background OSCR Benchmark, showing a complex and colorful chemical structure."
         title="An example from the Colored Background OSCR Benchmark, showing a complex and colorful chemical structure."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">An example from the <a href="https://huggingface.co/datasets/hheiden/Colored_Background_OCSR_benchmark">Colored Background OSCR Benchmark</a>, showing a complex and colorful chemical structure.</figcaption>
    
</figure>

<ul>
<li><strong>Image Quality:</strong> Older documents might be scanned at low resolutions, containing noise, blur, or other artifacts that make interpretation difficult.</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/ocsr/2008239616_449_chem.webp"
         alt="A challenging chemical structure image from the JPO benchmark, difficult due to its low quality."
         title="A challenging chemical structure image from the JPO benchmark, difficult due to its low quality."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A challenging chemical structure image from the <a href="https://huggingface.co/datasets/hheiden/JPO_OCSR_benchmark">JPO benchmark</a>, difficult due to its low quality.</figcaption>
    
</figure>

<ul>
<li><strong>Structural Complexity:</strong> From simple rings to sprawling polymers and complex <strong>Markush structures</strong> (common in patents to represent a whole family of related compounds), the variety is immense.</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/ocsr/markush.webp"
         alt="An example of a Markush structure, illustrating the complexity and variety of chemical compounds."
         title="An example of a Markush structure, illustrating the complexity and variety of chemical compounds."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">An example of a Markush structure, illustrating the complexity and variety of chemical compounds.</figcaption>
    
</figure>

<h2 id="the-evolution-of-ocsr">The Evolution of OCSR</h2>
<p>The quest to automate this process has evolved significantly, moving from brittle, hand-coded systems to sophisticated AI that can learn from data.</p>
<h3 id="act-1-the-rule-based-pioneers-ocr-10">Act 1: The Rule-Based Pioneers (OCR-1.0)</h3>
<p>The first OCSR systems, developed in the early 1990s, represent what we can now call the <strong>&ldquo;OCR-1.0&rdquo; era</strong>. Tools like <a href="https://pubs.acs.org/doi/10.1021/ci00008a018">Kekulé</a>, and later open-source solutions like <a href="/notes/chemistry/optical-structure-recognition/rule-based/osra/">OSRA</a> and <a href="https://github.com/ncats/molvec">MolVec</a>, operated like meticulous draftsmen. Their approach was methodical:</p>
<ol>
<li><strong>Vectorize the Image:</strong> Convert the pixel-based image into a collection of lines and shapes</li>
<li><strong>Identify Components:</strong> Use a set of hard-coded rules to classify these components. &ldquo;This thick line is a wedge bond.&rdquo; &ldquo;This group of pixels is the letter &lsquo;O&rsquo;.&rdquo;</li>
<li><strong>Reconstruct the Graph:</strong> Piece together the identified atoms and bonds into a coherent molecular graph</li>
</ol>
<p>This rule-based approach was a real first step but brittle. It struggled with the messiness of real-world documents and was expensive to maintain because each new style or error required new rules.</p>
<p>Additionally, they were designed as interactive tools to assist human experts in digitizing chemical structures.
There was always the assumption that a human would review and correct the output.</p>
<p>As a concrete case-study, consider the (reproduced) results from <a href="https://arxiv.org/abs/2411.11098">MolParser</a>:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Method</th>
          <th style="text-align: center">USPTO</th>
          <th style="text-align: center">UoB</th>
          <th style="text-align: center">CLEF</th>
          <th style="text-align: center">JPO</th>
          <th style="text-align: center">ColoredBG</th>
          <th style="text-align: center">USPTO-10K</th>
          <th style="text-align: center">WildMol-10K</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Rule-based methods</strong></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
      </tr>
      <tr>
          <td style="text-align: left">OSRA 2.1 *</td>
          <td style="text-align: center">89.3</td>
          <td style="text-align: center">86.3</td>
          <td style="text-align: center"><strong>93.4</strong></td>
          <td style="text-align: center">56.3</td>
          <td style="text-align: center">5.5</td>
          <td style="text-align: center">89.7</td>
          <td style="text-align: center">26.3</td>
      </tr>
      <tr>
          <td style="text-align: left">MolVec 0.9.7 *</td>
          <td style="text-align: center">91.6</td>
          <td style="text-align: center">79.7</td>
          <td style="text-align: center">81.2</td>
          <td style="text-align: center">66.8</td>
          <td style="text-align: center">8.0</td>
          <td style="text-align: center">92.4</td>
          <td style="text-align: center">26.4</td>
      </tr>
      <tr>
          <td style="text-align: left">Imago 2.0 *</td>
          <td style="text-align: center">89.4</td>
          <td style="text-align: center">63.9</td>
          <td style="text-align: center">68.2</td>
          <td style="text-align: center">41.0</td>
          <td style="text-align: center">2.0</td>
          <td style="text-align: center">89.9</td>
          <td style="text-align: center">6.9</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Only synthetic training</strong></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
      </tr>
      <tr>
          <td style="text-align: left">Img2Mol *</td>
          <td style="text-align: center">30.0</td>
          <td style="text-align: center">68.1</td>
          <td style="text-align: center">17.9</td>
          <td style="text-align: center">16.1</td>
          <td style="text-align: center">3.5</td>
          <td style="text-align: center">33.7</td>
          <td style="text-align: center">24.4</td>
      </tr>
      <tr>
          <td style="text-align: left">MolGrapher †*</td>
          <td style="text-align: center">91.5</td>
          <td style="text-align: center"><strong>94.9</strong></td>
          <td style="text-align: center">90.5</td>
          <td style="text-align: center">67.5</td>
          <td style="text-align: center">7.5</td>
          <td style="text-align: center">93.3</td>
          <td style="text-align: center">45.5</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Real data finetuning</strong></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
      </tr>
      <tr>
          <td style="text-align: left">DECIMER 2.7 *</td>
          <td style="text-align: center">59.9</td>
          <td style="text-align: center">88.3</td>
          <td style="text-align: center">72.0</td>
          <td style="text-align: center">64.0</td>
          <td style="text-align: center">14.5</td>
          <td style="text-align: center">82.4</td>
          <td style="text-align: center">56.0</td>
      </tr>
      <tr>
          <td style="text-align: left">MolScribe *</td>
          <td style="text-align: center"><u>93.1</u></td>
          <td style="text-align: center">87.4</td>
          <td style="text-align: center">88.9</td>
          <td style="text-align: center">76.2</td>
          <td style="text-align: center">21.0</td>
          <td style="text-align: center"><strong>96.0</strong></td>
          <td style="text-align: center">66.4</td>
      </tr>
      <tr>
          <td style="text-align: left">MolParser-Tiny (Ours)</td>
          <td style="text-align: center">93.0</td>
          <td style="text-align: center">91.6</td>
          <td style="text-align: center"><u>91.0</u></td>
          <td style="text-align: center">75.6</td>
          <td style="text-align: center"><strong>58.5</strong></td>
          <td style="text-align: center">89.5</td>
          <td style="text-align: center">73.1</td>
      </tr>
      <tr>
          <td style="text-align: left">MolParser-Small (Ours)</td>
          <td style="text-align: center"><strong>93.1</strong></td>
          <td style="text-align: center">91.1</td>
          <td style="text-align: center">90.8</td>
          <td style="text-align: center">76.2</td>
          <td style="text-align: center">57.0</td>
          <td style="text-align: center"><u>94.8</u></td>
          <td style="text-align: center">76.3</td>
      </tr>
      <tr>
          <td style="text-align: left">MolParser-Base (Ours)</td>
          <td style="text-align: center">93.0</td>
          <td style="text-align: center"><u>91.8</u></td>
          <td style="text-align: center">90.7</td>
          <td style="text-align: center"><strong>78.9</strong></td>
          <td style="text-align: center">57.0</td>
          <td style="text-align: center">94.5</td>
          <td style="text-align: center"><strong>76.9</strong></td>
      </tr>
  </tbody>
</table>
<blockquote>
<p><strong>Table 2. Comparison of our method with existing OCSR models.</strong> We report the accuracy. We use <strong>bold</strong> to indicate the best performance and <u>underline</u> to denote the second-best performance. *: re-implemented results. †: results from original publications.</p></blockquote>
<p>In this table, we see that the rule-based methods (OSRA, MolVec, Imago) perform reasonably well on cleaner datasets like USPTO and UoB but falter on more challenging ones like JPO and ColoredBG. Modern AI-based methods (MolGrapher, DECIMER, MolScribe, MolParser) improve most on the hardest benchmarks (like JPO and ColoredBG), especially when fine-tuned on real data, while the rule-based tools still do reasonably well on cleaner sets like USPTO and UoB.</p>
<h3 id="act-2-the-ai-fork-in-the-road-2010s-2020s">Act 2: The AI Fork in the Road (2010s-2020s)</h3>
<p>The rise of deep learning in the 2010s brought new paradigms that could learn from data. Here, the field split into two distinct paths.</p>
<h4 id="path-a-the-rise-of-the-specialists-graph-based-ai">Path A: The Rise of the Specialists (Graph-Based AI)</h4>
<p>Some models replaced the hard-coded rules with AI components. Systems like <a href="https://github.com/DS4SD/MolGrapher">MolGrapher</a> and <a href="https://github.com/thomas0809/MolScribe">MolScribe</a> use a two-stage process:</p>
<ul>
<li><strong>Atom Detection:</strong> A neural network first identifies all the atoms in the image</li>
<li><strong>Bond Prediction:</strong> A second process then predicts the connections (bonds) between those atoms to form the final graph</li>
</ul>
<p>These are highly specialized tools, trained specifically for the task of building a molecular graph.</p>
<h4 id="path-b-the-rise-of-the-generalists-lvlms">Path B: The Rise of the Generalists (LVLMs)</h4>
<p>Another, more direct method treats OCSR as an image captioning task. This approach aligns with the broader trend of <strong>Large Vision-Language Models (LVLMs)</strong>: massive, general-purpose AIs like GPT-4V. Models like <a href="https://github.com/Kohulan/DECIMER-Image_Transformer">DECIMER</a> and <a href="/notes/chemistry/optical-structure-recognition/vision-language/mol-parser/">MolParser</a> look at a molecular image and directly generate its textual representation, most commonly a <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES string</a>. This direct, end-to-end approach is powerful, though it requires enormous datasets to train effectively.</p>
<h2 id="the-next-frontier-the-ocr-20-vision-2024">The Next Frontier: The OCR-2.0 Vision (2024+)</h2>
<p>Recently, a proposal has emerged that charts a third path forward: <strong>OCR-2.0</strong>. This vision, proposed by <a href="https://arxiv.org/abs/2409.01704">Wei et al.</a> in 2024, argues for a new class of models that combine the best of both worlds. An OCR-2.0 model should be:</p>
<ol>
<li><strong>End-to-End:</strong> A single, unified model that simplifies maintenance</li>
<li><strong>Efficient &amp; Low-Cost:</strong> A specialized, highly efficient perception engine. The paper argues that using a giant LVLM for a pure recognition task is often inefficient</li>
<li><strong>Versatile:</strong> Capable of handling diverse artificial optical signals</li>
</ol>
<p>The flagship model for this theory is <a href="https://huggingface.co/stepfun-ai/GOT-OCR2_0">GOT (General OCR Theory)</a>. It&rsquo;s a single, unified model that can read an image and output structured text for a wide variety of inputs. It can translate a molecular diagram into a SMILES string, transcribe sheet music into musical notation, parse a bar chart into a data table, and describe a geometric shape using code.</p>
<p>This demonstrates that OCSR can be integrated into broader systems for processing human visual information. The same OCR-2.0 philosophy extends beyond chemistry: <a href="/research/gutenocr-grounded-vision-language-frontend/">GutenOCR</a>, for instance, applies grounded vision-language modeling to general document OCR, producing both text transcriptions and bounding-box outputs from a single model.</p>
<h2 id="pushing-the-boundaries-of-recognition">Pushing the Boundaries of Recognition</h2>
<p>OCR-2.0 models like GOT push for <em>breadth</em>, and other state-of-the-art research deepens the <em>depth</em> of understanding for the uniquely complex task of chemical recognition.</p>
<h3 id="deepening-reasoning-with-a-visual-chain-of-thought">Deepening Reasoning with a &ldquo;Visual Chain of Thought&rdquo;</h3>
<p>The <a href="https://arxiv.org/abs/2506.07553">GTR-Mol-VLM</a> model makes recognition more intelligent by mimicking how a person might analyze a complex diagram. The model traverses the molecule step-by-step, predicting an atom, then its bond, then the next atom, and so on. This &ldquo;Visual Chain of Thought&rdquo; improves accuracy, especially for complex molecules. It also faithfully recognizes abbreviations like &ldquo;Ph&rdquo; as single units, better representing the source image.</p>
<h3 id="deepening-application-with-visual-fingerprinting">Deepening Application with &ldquo;Visual Fingerprinting&rdquo;</h3>
<p><a href="https://link.springer.com/article/10.1186/s13321-025-01091-4">Subgrapher</a> rethinks the end goal. Many applications (like searching a patent database) require only the identification of specific molecular features. Subgrapher detects key functional groups and backbones directly from the image and creates a visual fingerprint. This approach mirrors identifying a person by key features (&ldquo;has glasses, a mustache&rdquo;), making it well-suited to finding matches in a large set.</p>
<h2 id="why-it-matters">Why It Matters</h2>
<p>The evolution of OCSR directly enables practical scientific advancements. This technology is a critical enabler for the future of science.</p>
<h3 id="searching-past-knowledge">Searching Past Knowledge</h3>
<p>OCSR digitizes decades of research from patents and journals, making it searchable and accessible for data mining. Imagine being able to search through every molecule ever published with a simple query. Or consider the practical impact: pharmaceutical companies can now automatically scan thousands of patent documents to ensure their new drug candidates don&rsquo;t infringe existing intellectual property, a process that previously required substantial manual review by patent analysts.</p>
<h3 id="accelerating-drug-discovery">Accelerating Drug Discovery</h3>
<p>By extracting vast datasets of molecules, scientists can train AI models to predict drug efficacy and toxicity, speeding up the discovery pipeline. The more molecular data we can digitize, the better our predictive models become.</p>
<h3 id="building-universal-document-intelligence">Building Universal Document Intelligence</h3>
<p>OCSR contributes to building AI systems capable of processing complex human documents. A scientific paper is a mix of text, equations, charts, tables, and molecular diagrams. Unified OCR-2.0 models are the key to making all of this knowledge searchable holistically.</p>
<h2 id="looking-forward">Looking Forward</h2>
<p>The goal is a loop where scientific knowledge, regardless of how it is stored, can be fed back into systems that read, search, and reason over it.</p>
<p>From the rule-based systems of the 1990s to today&rsquo;s models that read many printed diagrams reliably (though hard cases like low-quality scans and Markush structures remain open), OCSR has improved a great deal. As accuracy, efficiency, and breadth improve, more of the chemical literature becomes machine-readable.</p>
<p>This entire process begins with teaching a computer how to read a picture.</p>
]]></content:encoded></item><item><title>Converting SMILES and SELFIES to 2D Molecular Images</title><link>https://hunterheidenreich.com/posts/visualizing-smiles-and-selfies-strings/</link><pubDate>Fri, 12 Sep 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/visualizing-smiles-and-selfies-strings/</guid><description>A guide to generating 2D molecular structure images from SMILES and SELFIES strings using Python, RDKit, and PIL.</description><content:encoded><![CDATA[<p>Lately, I&rsquo;ve spent a lot of time staring at datasets full of 1D molecular strings.
With time, I find I get better at recognizing functional groups and substructures like <code>C(=O)O</code> (carboxylic acid) or <code>c1ccccc1</code> (benzene ring) in SMILES.
However, anything really complex is beyond my personal visualization capabilities.</p>
<p>I ran into this recently while debugging a generative model.
Sometimes the grammar of the string provides the clue as to what is going wrong.
Other times, actually <em>seeing</em> the molecule is what helps.
I had a terminal full of generated strings and needed to verify their structures visually.
I needed a streamlined way to generate these images locally.
A lightweight script turns that text into a properly formatted image directly from the terminal.</p>
<h2 id="smiles-vs-selfies">SMILES vs. SELFIES</h2>
<p>There are two primary string representations you will encounter in modern cheminformatics:</p>
<ol>
<li><strong><a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a></strong>: The industry standard. It uses simple rules (<code>C</code> for carbon, <code>=</code> for double bonds, parentheses for branches). It is compact and machine-parseable. However, random SMILES strings are often invalid (e.g., unclosed rings or invalid valences).</li>
<li><strong><a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a></strong>: Designed specifically for machine learning. It is a robust representation where <em>every</em> string corresponds to a valid molecular graph. This makes it ideal for generative models. Note that it is more verbose than SMILES.</li>
</ol>
<p>I often need to visualize both formats. Let&rsquo;s build a single, robust Python tool to handle them.</p>
<h2 id="the-quick-win-native-rdkit">The Quick Win: Native RDKit</h2>
<p>If you just need a quick image from a SMILES string and don&rsquo;t care about the image dimensions or adding a legend, RDKit can do this in three lines:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> rdkit <span style="color:#f92672">import</span> Chem
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> rdkit.Chem <span style="color:#f92672">import</span> Draw
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>mol <span style="color:#f92672">=</span> Chem<span style="color:#f92672">.</span>MolFromSmiles(<span style="color:#e6db74">&#34;CCO&#34;</span>)
</span></span><span style="display:flex;"><span>Draw<span style="color:#f92672">.</span>MolToFile(mol, <span style="color:#e6db74">&#34;ethanol.png&#34;</span>)
</span></span></code></pre></div><p>The native RDKit method is fast for quick checks. However, custom rendering provides necessary control over image dimensions, formula subscripts, and handling multiple input formats like SELFIES.</p>
<h2 id="building-a-custom-renderer-for-precise-control">Building a Custom Renderer for Precise Control</h2>
<p>Let&rsquo;s build a fuller tool using RDKit for chemical processing, the <code>selfies</code> library for decoding, and PIL for image manipulation.</p>
<h3 id="core-dependencies">Core Dependencies</h3>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> selfies <span style="color:#66d9ef">as</span> sf
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> rdkit <span style="color:#f92672">import</span> Chem
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> rdkit.Chem <span style="color:#f92672">import</span> Draw, rdDepictor, rdMolDescriptors
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> PIL <span style="color:#f92672">import</span> Image, ImageDraw, ImageFont
</span></span></code></pre></div><p>RDKit handles the chemical logic, <code>selfies</code> translates SELFIES to SMILES, and PIL gives us fine control over the final image appearance.</p>
<h3 id="the-main-conversion-function">The Main Conversion Function</h3>
<p>Here is the core conversion logic. Notice the Python type hints on the signature.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">string_to_png</span>(mol_string: str, output_file: str, size: int <span style="color:#f92672">=</span> <span style="color:#ae81ff">500</span>, is_selfies: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>) <span style="color:#f92672">-&gt;</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Generates a 2D molecule image with a chemical formula legend from SMILES or SELFIES.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Decode SELFIES to SMILES if necessary</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> is_selfies:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">try</span>:
</span></span><span style="display:flex;"><span>            smiles <span style="color:#f92672">=</span> sf<span style="color:#f92672">.</span>decoder(mol_string)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">except</span> <span style="color:#a6e22e">Exception</span> <span style="color:#66d9ef">as</span> e:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">raise</span> <span style="color:#a6e22e">ValueError</span>(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Invalid SELFIES string: </span><span style="color:#e6db74">{</span>mol_string<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>) <span style="color:#f92672">from</span> e
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>        smiles <span style="color:#f92672">=</span> mol_string
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    mol <span style="color:#f92672">=</span> Chem<span style="color:#f92672">.</span>MolFromSmiles(smiles)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> <span style="color:#f92672">not</span> mol:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">raise</span> <span style="color:#a6e22e">ValueError</span>(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Could not generate molecule from SMILES: </span><span style="color:#e6db74">{</span>smiles<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Generate 2D coordinates and formula</span>
</span></span><span style="display:flex;"><span>    rdDepictor<span style="color:#f92672">.</span>Compute2DCoords(mol)
</span></span><span style="display:flex;"><span>    formula <span style="color:#f92672">=</span> rdMolDescriptors<span style="color:#f92672">.</span>CalcMolFormula(mol)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Render the molecule</span>
</span></span><span style="display:flex;"><span>    img <span style="color:#f92672">=</span> Draw<span style="color:#f92672">.</span>MolToImage(mol, size<span style="color:#f92672">=</span>(size, size))<span style="color:#f92672">.</span>convert(<span style="color:#e6db74">&#34;RGBA&#34;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Create a canvas with extra space at the bottom for the legend</span>
</span></span><span style="display:flex;"><span>    legend_height <span style="color:#f92672">=</span> int(size <span style="color:#f92672">*</span> <span style="color:#ae81ff">0.1</span>)
</span></span><span style="display:flex;"><span>    canvas <span style="color:#f92672">=</span> Image<span style="color:#f92672">.</span>new(<span style="color:#e6db74">&#34;RGBA&#34;</span>, (size, size <span style="color:#f92672">+</span> legend_height), <span style="color:#e6db74">&#34;white&#34;</span>)
</span></span><span style="display:flex;"><span>    canvas<span style="color:#f92672">.</span>paste(img, (<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">0</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    draw <span style="color:#f92672">=</span> ImageDraw<span style="color:#f92672">.</span>Draw(canvas)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Define dynamic font sizes</span>
</span></span><span style="display:flex;"><span>    font_reg <span style="color:#f92672">=</span> get_font(int(size <span style="color:#f92672">*</span> <span style="color:#ae81ff">0.03</span>))
</span></span><span style="display:flex;"><span>    font_sub <span style="color:#f92672">=</span> get_font(int(size <span style="color:#f92672">*</span> <span style="color:#ae81ff">0.02</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Draw the legend</span>
</span></span><span style="display:flex;"><span>    x <span style="color:#f92672">=</span> int(size <span style="color:#f92672">*</span> <span style="color:#ae81ff">0.02</span>)
</span></span><span style="display:flex;"><span>    y <span style="color:#f92672">=</span> size <span style="color:#f92672">+</span> int(size <span style="color:#f92672">*</span> <span style="color:#ae81ff">0.02</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Draw &#34;Formula: &#34; label</span>
</span></span><span style="display:flex;"><span>    draw<span style="color:#f92672">.</span>text((x, y), <span style="color:#e6db74">&#34;Formula: &#34;</span>, fill<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;black&#34;</span>, font<span style="color:#f92672">=</span>font_reg)
</span></span><span style="display:flex;"><span>    x <span style="color:#f92672">+=</span> draw<span style="color:#f92672">.</span>textlength(<span style="color:#e6db74">&#34;Formula: &#34;</span>, font<span style="color:#f92672">=</span>font_reg)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Draw formula with subscript handling for numbers</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> char <span style="color:#f92672">in</span> formula:
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Use smaller font and lower y-offset for numbers (subscripts)</span>
</span></span><span style="display:flex;"><span>        font <span style="color:#f92672">=</span> font_sub <span style="color:#66d9ef">if</span> char<span style="color:#f92672">.</span>isdigit() <span style="color:#66d9ef">else</span> font_reg
</span></span><span style="display:flex;"><span>        y_offset <span style="color:#f92672">=</span> int(size <span style="color:#f92672">*</span> <span style="color:#ae81ff">0.005</span>) <span style="color:#66d9ef">if</span> char<span style="color:#f92672">.</span>isdigit() <span style="color:#66d9ef">else</span> <span style="color:#ae81ff">0</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        draw<span style="color:#f92672">.</span>text((x, y <span style="color:#f92672">+</span> y_offset), char, fill<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;black&#34;</span>, font<span style="color:#f92672">=</span>font)
</span></span><span style="display:flex;"><span>        x <span style="color:#f92672">+=</span> draw<span style="color:#f92672">.</span>textlength(char, font<span style="color:#f92672">=</span>font)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Draw original string</span>
</span></span><span style="display:flex;"><span>    label <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;SELFIES&#34;</span> <span style="color:#66d9ef">if</span> is_selfies <span style="color:#66d9ef">else</span> <span style="color:#e6db74">&#34;SMILES&#34;</span>
</span></span><span style="display:flex;"><span>    draw<span style="color:#f92672">.</span>text((x, y), <span style="color:#e6db74">f</span><span style="color:#e6db74">&#34; | </span><span style="color:#e6db74">{</span>label<span style="color:#e6db74">}</span><span style="color:#e6db74">: </span><span style="color:#e6db74">{</span>mol_string<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>, fill<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;black&#34;</span>, font<span style="color:#f92672">=</span>font_reg)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    canvas<span style="color:#f92672">.</span>save(output_file)
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Saved: </span><span style="color:#e6db74">{</span>output_file<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>)
</span></span></code></pre></div><p>This function handles everything: SELFIES decoding, validation, coordinate generation, image creation, and legend drawing.</p>
<h3 id="font-handling">Font Handling</h3>
<p>We need a helper to handle fonts robustly across systems:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_font</span>(size: int, font_name: str <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;arial.ttf&#34;</span>):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Attempts to load a TTF font, falls back to default if unavailable.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">try</span>:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> ImageFont<span style="color:#f92672">.</span>truetype(font_name, size)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">except</span> <span style="color:#a6e22e">IOError</span>:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> ImageFont<span style="color:#f92672">.</span>load_default()
</span></span></code></pre></div><h2 id="examples-in-action">Examples in Action</h2>
<p>Let&rsquo;s see the tool in action with some common molecules, comparing the SMILES and SELFIES inputs.</p>
<h3 id="simple-molecules">Simple Molecules</h3>















<figure class="post-figure center ">
    <img src="/img/smiles2img/ethanol_demo.webp"
         alt="Ethanol molecular structure with formula C2H6O"
         title="Ethanol molecular structure with formula C2H6O"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>Ethanol</strong>: A simple alcohol. The SMILES is <code>CCO</code>, while the SELFIES is <code>[C][C][O]</code>.</figcaption>
    
</figure>

<h3 id="aromatic-compounds">Aromatic Compounds</h3>















<figure class="post-figure center ">
    <img src="/img/smiles2img/benzene_demo.webp"
         alt="Benzene molecular structure with formula C6H6"
         title="Benzene molecular structure with formula C6H6"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>Benzene</strong>: The classic aromatic ring. SMILES uses numbers for ring closures (<code>C1=CC=CC=C1</code>), while SELFIES uses explicit tokens (<code>[C][=C][C][=C][C][=C][Ring1][=Branch1]</code>).</figcaption>
    
</figure>

<h3 id="complex-pharmaceuticals">Complex Pharmaceuticals</h3>















<figure class="post-figure center ">
    <img src="/img/smiles2img/aspirin_demo.webp"
         alt="Aspirin molecular structure with formula C9H8O4"
         title="Aspirin molecular structure with formula C9H8O4"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>Aspirin</strong>: A more complex molecule showing how the tool handles branched structures and multiple functional groups.</figcaption>
    
</figure>

<h2 id="going-further-vector-graphics-svg">Going Further: Vector Graphics (SVG)</h2>
<p>Use vector graphics (SVG/PDF) for true publication-quality figures. Vector graphics scale infinitely without pixelation.</p>
<p>RDKit handles this natively with <code>rdMolDraw2D</code>:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> rdkit <span style="color:#f92672">import</span> Chem
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> rdkit.Chem.Draw <span style="color:#f92672">import</span> rdMolDraw2D
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">string_to_svg</span>(mol_string: str, output_file: str, size: int <span style="color:#f92672">=</span> <span style="color:#ae81ff">500</span>, is_selfies: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>) <span style="color:#f92672">-&gt;</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Generates a 2D molecule SVG image.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> is_selfies:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">try</span>:
</span></span><span style="display:flex;"><span>            mol_string <span style="color:#f92672">=</span> sf<span style="color:#f92672">.</span>decoder(mol_string)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">except</span> <span style="color:#a6e22e">Exception</span> <span style="color:#66d9ef">as</span> e:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">raise</span> <span style="color:#a6e22e">ValueError</span>(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Invalid SELFIES string: </span><span style="color:#e6db74">{</span>mol_string<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>) <span style="color:#f92672">from</span> e
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    mol <span style="color:#f92672">=</span> Chem<span style="color:#f92672">.</span>MolFromSmiles(mol_string)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> <span style="color:#f92672">not</span> mol:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">raise</span> <span style="color:#a6e22e">ValueError</span>(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Invalid string: </span><span style="color:#e6db74">{</span>mol_string<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    rdDepictor<span style="color:#f92672">.</span>Compute2DCoords(mol)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> rdMolDraw2D<span style="color:#f92672">.</span>MolDraw2DSVG(size, size)
</span></span><span style="display:flex;"><span>    d<span style="color:#f92672">.</span>DrawMolecule(mol)
</span></span><span style="display:flex;"><span>    d<span style="color:#f92672">.</span>FinishDrawing()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">with</span> open(output_file, <span style="color:#e6db74">&#34;w&#34;</span>) <span style="color:#66d9ef">as</span> f:
</span></span><span style="display:flex;"><span>        f<span style="color:#f92672">.</span>write(d<span style="color:#f92672">.</span>GetDrawingText())
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Saved: </span><span style="color:#e6db74">{</span>output_file<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>)
</span></span></code></pre></div><p>This provides a perfect vector image. Note that this method omits the custom PIL-based legend. Choose the right tool for the job: PNG for quick checks and slides, SVG for journal submissions.</p>
<h2 id="command-line-interface">Command-Line Interface</h2>
<p>The tool uses Python&rsquo;s standard <code>argparse</code> library for the command-line interface. It automatically detects if you want an SVG based on the file extension and includes a <code>--selfies</code> flag.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bash" data-lang="bash"><span style="display:flex;"><span><span style="color:#75715e"># Basic SMILES usage</span>
</span></span><span style="display:flex;"><span>python mol2img.py <span style="color:#e6db74">&#34;CCO&#34;</span> -o ethanol.png
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># SELFIES usage</span>
</span></span><span style="display:flex;"><span>python mol2img.py <span style="color:#e6db74">&#34;[C][C][O]&#34;</span> -o ethanol.png --selfies
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Generate SVG for publication</span>
</span></span><span style="display:flex;"><span>python mol2img.py <span style="color:#e6db74">&#34;CCO&#34;</span> -o ethanol.svg
</span></span></code></pre></div><h2 id="download-the-complete-script">Download the Complete Script</h2>
<p>You can copy the complete <code>mol2img.py</code> script directly from the code block below. For a fuller version with an SVG fallback, type hints, and batch (grid) rendering, see the <a href="/projects/molecular-string-renderer/">Molecular String Renderer project</a>.</p>
<h3 id="installation-and-setup">Installation and Setup</h3>
<p>Before using the script, install the required dependencies:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bash" data-lang="bash"><span style="display:flex;"><span>pip install rdkit pillow selfies
</span></span></code></pre></div><h3 id="complete-script">Complete Script</h3>
<details>
<summary>Click to expand the complete mol2img.py script</summary>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> argparse
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> sys
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> os
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> selfies <span style="color:#66d9ef">as</span> sf
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> rdkit <span style="color:#f92672">import</span> Chem
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> rdkit.Chem <span style="color:#f92672">import</span> Draw, rdDepictor, rdMolDescriptors
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> rdkit.Chem.Draw <span style="color:#f92672">import</span> rdMolDraw2D
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> PIL <span style="color:#f92672">import</span> Image, ImageDraw, ImageFont
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_font</span>(size: int, font_name: str <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;arial.ttf&#34;</span>):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Attempts to load a TTF font, falls back to default if unavailable.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">try</span>:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> ImageFont<span style="color:#f92672">.</span>truetype(font_name, size)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">except</span> <span style="color:#a6e22e">IOError</span>:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> ImageFont<span style="color:#f92672">.</span>load_default()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">string_to_svg</span>(mol_string: str, output_file: str, size: int <span style="color:#f92672">=</span> <span style="color:#ae81ff">500</span>, is_selfies: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>) <span style="color:#f92672">-&gt;</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Generates a 2D molecule SVG image.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> is_selfies:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">try</span>:
</span></span><span style="display:flex;"><span>            mol_string <span style="color:#f92672">=</span> sf<span style="color:#f92672">.</span>decoder(mol_string)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">except</span> <span style="color:#a6e22e">Exception</span> <span style="color:#66d9ef">as</span> e:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">raise</span> <span style="color:#a6e22e">ValueError</span>(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Invalid SELFIES string: </span><span style="color:#e6db74">{</span>mol_string<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>) <span style="color:#f92672">from</span> e
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    mol <span style="color:#f92672">=</span> Chem<span style="color:#f92672">.</span>MolFromSmiles(mol_string)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> <span style="color:#f92672">not</span> mol:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">raise</span> <span style="color:#a6e22e">ValueError</span>(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Invalid string: </span><span style="color:#e6db74">{</span>mol_string<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    rdDepictor<span style="color:#f92672">.</span>Compute2DCoords(mol)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> rdMolDraw2D<span style="color:#f92672">.</span>MolDraw2DSVG(size, size)
</span></span><span style="display:flex;"><span>    d<span style="color:#f92672">.</span>DrawMolecule(mol)
</span></span><span style="display:flex;"><span>    d<span style="color:#f92672">.</span>FinishDrawing()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">with</span> open(output_file, <span style="color:#e6db74">&#34;w&#34;</span>) <span style="color:#66d9ef">as</span> f:
</span></span><span style="display:flex;"><span>        f<span style="color:#f92672">.</span>write(d<span style="color:#f92672">.</span>GetDrawingText())
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Saved: </span><span style="color:#e6db74">{</span>output_file<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">string_to_png</span>(mol_string: str, output_file: str, size: int <span style="color:#f92672">=</span> <span style="color:#ae81ff">500</span>, is_selfies: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>) <span style="color:#f92672">-&gt;</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Generates a 2D molecule image with a chemical formula legend.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> is_selfies:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">try</span>:
</span></span><span style="display:flex;"><span>            smiles <span style="color:#f92672">=</span> sf<span style="color:#f92672">.</span>decoder(mol_string)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">except</span> <span style="color:#a6e22e">Exception</span> <span style="color:#66d9ef">as</span> e:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">raise</span> <span style="color:#a6e22e">ValueError</span>(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Invalid SELFIES string: </span><span style="color:#e6db74">{</span>mol_string<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>) <span style="color:#f92672">from</span> e
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>        smiles <span style="color:#f92672">=</span> mol_string
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    mol <span style="color:#f92672">=</span> Chem<span style="color:#f92672">.</span>MolFromSmiles(smiles)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> <span style="color:#f92672">not</span> mol:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">raise</span> <span style="color:#a6e22e">ValueError</span>(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Could not generate molecule from string: </span><span style="color:#e6db74">{</span>mol_string<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Generate 2D coordinates and formula</span>
</span></span><span style="display:flex;"><span>    rdDepictor<span style="color:#f92672">.</span>Compute2DCoords(mol)
</span></span><span style="display:flex;"><span>    formula <span style="color:#f92672">=</span> rdMolDescriptors<span style="color:#f92672">.</span>CalcMolFormula(mol)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Render the molecule</span>
</span></span><span style="display:flex;"><span>    img <span style="color:#f92672">=</span> Draw<span style="color:#f92672">.</span>MolToImage(mol, size<span style="color:#f92672">=</span>(size, size))<span style="color:#f92672">.</span>convert(<span style="color:#e6db74">&#34;RGBA&#34;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Create a canvas with extra space at the bottom for the legend</span>
</span></span><span style="display:flex;"><span>    legend_height <span style="color:#f92672">=</span> int(size <span style="color:#f92672">*</span> <span style="color:#ae81ff">0.1</span>)
</span></span><span style="display:flex;"><span>    canvas <span style="color:#f92672">=</span> Image<span style="color:#f92672">.</span>new(<span style="color:#e6db74">&#34;RGBA&#34;</span>, (size, size <span style="color:#f92672">+</span> legend_height), <span style="color:#e6db74">&#34;white&#34;</span>)
</span></span><span style="display:flex;"><span>    canvas<span style="color:#f92672">.</span>paste(img, (<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">0</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    draw <span style="color:#f92672">=</span> ImageDraw<span style="color:#f92672">.</span>Draw(canvas)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Define dynamic font sizes</span>
</span></span><span style="display:flex;"><span>    font_reg <span style="color:#f92672">=</span> get_font(int(size <span style="color:#f92672">*</span> <span style="color:#ae81ff">0.03</span>))
</span></span><span style="display:flex;"><span>    font_sub <span style="color:#f92672">=</span> get_font(int(size <span style="color:#f92672">*</span> <span style="color:#ae81ff">0.02</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Draw the legend</span>
</span></span><span style="display:flex;"><span>    x <span style="color:#f92672">=</span> int(size <span style="color:#f92672">*</span> <span style="color:#ae81ff">0.02</span>)
</span></span><span style="display:flex;"><span>    y <span style="color:#f92672">=</span> size <span style="color:#f92672">+</span> int(size <span style="color:#f92672">*</span> <span style="color:#ae81ff">0.02</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Draw &#34;Formula: &#34; label</span>
</span></span><span style="display:flex;"><span>    draw<span style="color:#f92672">.</span>text((x, y), <span style="color:#e6db74">&#34;Formula: &#34;</span>, fill<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;black&#34;</span>, font<span style="color:#f92672">=</span>font_reg)
</span></span><span style="display:flex;"><span>    x <span style="color:#f92672">+=</span> draw<span style="color:#f92672">.</span>textlength(<span style="color:#e6db74">&#34;Formula: &#34;</span>, font<span style="color:#f92672">=</span>font_reg)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Draw formula with subscript handling for numbers</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> char <span style="color:#f92672">in</span> formula:
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Use smaller font and lower y-offset for numbers (subscripts)</span>
</span></span><span style="display:flex;"><span>        font <span style="color:#f92672">=</span> font_sub <span style="color:#66d9ef">if</span> char<span style="color:#f92672">.</span>isdigit() <span style="color:#66d9ef">else</span> font_reg
</span></span><span style="display:flex;"><span>        y_offset <span style="color:#f92672">=</span> int(size <span style="color:#f92672">*</span> <span style="color:#ae81ff">0.005</span>) <span style="color:#66d9ef">if</span> char<span style="color:#f92672">.</span>isdigit() <span style="color:#66d9ef">else</span> <span style="color:#ae81ff">0</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        draw<span style="color:#f92672">.</span>text((x, y <span style="color:#f92672">+</span> y_offset), char, fill<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;black&#34;</span>, font<span style="color:#f92672">=</span>font)
</span></span><span style="display:flex;"><span>        x <span style="color:#f92672">+=</span> draw<span style="color:#f92672">.</span>textlength(char, font<span style="color:#f92672">=</span>font)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Draw original string</span>
</span></span><span style="display:flex;"><span>    label <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;SELFIES&#34;</span> <span style="color:#66d9ef">if</span> is_selfies <span style="color:#66d9ef">else</span> <span style="color:#e6db74">&#34;SMILES&#34;</span>
</span></span><span style="display:flex;"><span>    draw<span style="color:#f92672">.</span>text((x, y), <span style="color:#e6db74">f</span><span style="color:#e6db74">&#34; | </span><span style="color:#e6db74">{</span>label<span style="color:#e6db74">}</span><span style="color:#e6db74">: </span><span style="color:#e6db74">{</span>mol_string<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>, fill<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;black&#34;</span>, font<span style="color:#f92672">=</span>font_reg)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    canvas<span style="color:#f92672">.</span>save(output_file)
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Saved: </span><span style="color:#e6db74">{</span>output_file<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">if</span> __name__ <span style="color:#f92672">==</span> <span style="color:#e6db74">&#34;__main__&#34;</span>:
</span></span><span style="display:flex;"><span>    parser <span style="color:#f92672">=</span> argparse<span style="color:#f92672">.</span>ArgumentParser(description<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;Convert a SMILES or SELFIES string to a 2D molecular image.&#34;</span>)
</span></span><span style="display:flex;"><span>    parser<span style="color:#f92672">.</span>add_argument(<span style="color:#e6db74">&#34;string&#34;</span>, help<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;The molecular string to convert&#34;</span>)
</span></span><span style="display:flex;"><span>    parser<span style="color:#f92672">.</span>add_argument(<span style="color:#e6db74">&#34;-o&#34;</span>, <span style="color:#e6db74">&#34;--output&#34;</span>, default<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;molecule.png&#34;</span>, help<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;Output filename (default: molecule.png)&#34;</span>)
</span></span><span style="display:flex;"><span>    parser<span style="color:#f92672">.</span>add_argument(<span style="color:#e6db74">&#34;--size&#34;</span>, type<span style="color:#f92672">=</span>int, default<span style="color:#f92672">=</span><span style="color:#ae81ff">500</span>, help<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;Image width/height in pixels (default: 500)&#34;</span>)
</span></span><span style="display:flex;"><span>    parser<span style="color:#f92672">.</span>add_argument(<span style="color:#e6db74">&#34;--svg&#34;</span>, action<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;store_true&#34;</span>, help<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;Force SVG output (overrides filename extension)&#34;</span>)
</span></span><span style="display:flex;"><span>    parser<span style="color:#f92672">.</span>add_argument(<span style="color:#e6db74">&#34;--selfies&#34;</span>, action<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;store_true&#34;</span>, help<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;Treat the input string as SELFIES.&#34;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    args <span style="color:#f92672">=</span> parser<span style="color:#f92672">.</span>parse_args()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">try</span>:
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Determine format based on flag or file extension</span>
</span></span><span style="display:flex;"><span>        is_svg <span style="color:#f92672">=</span> args<span style="color:#f92672">.</span>svg <span style="color:#f92672">or</span> args<span style="color:#f92672">.</span>output<span style="color:#f92672">.</span>lower()<span style="color:#f92672">.</span>endswith(<span style="color:#e6db74">&#34;.svg&#34;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> is_svg:
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># Ensure extension is correct if not present</span>
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">if</span> <span style="color:#f92672">not</span> args<span style="color:#f92672">.</span>output<span style="color:#f92672">.</span>lower()<span style="color:#f92672">.</span>endswith(<span style="color:#e6db74">&#34;.svg&#34;</span>):
</span></span><span style="display:flex;"><span>                args<span style="color:#f92672">.</span>output <span style="color:#f92672">=</span> os<span style="color:#f92672">.</span>path<span style="color:#f92672">.</span>splitext(args<span style="color:#f92672">.</span>output)[<span style="color:#ae81ff">0</span>] <span style="color:#f92672">+</span> <span style="color:#e6db74">&#34;.svg&#34;</span>
</span></span><span style="display:flex;"><span>            string_to_svg(args<span style="color:#f92672">.</span>string, args<span style="color:#f92672">.</span>output, args<span style="color:#f92672">.</span>size, args<span style="color:#f92672">.</span>selfies)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            string_to_png(args<span style="color:#f92672">.</span>string, args<span style="color:#f92672">.</span>output, args<span style="color:#f92672">.</span>size, args<span style="color:#f92672">.</span>selfies)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">except</span> <span style="color:#a6e22e">Exception</span> <span style="color:#66d9ef">as</span> e:
</span></span><span style="display:flex;"><span>        print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Error: </span><span style="color:#e6db74">{</span>e<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>)
</span></span><span style="display:flex;"><span>        sys<span style="color:#f92672">.</span>exit(<span style="color:#ae81ff">1</span>)
</span></span></code></pre></div></details>
]]></content:encoded></item><item><title>Exponential Random Numbers: Two Classic Algorithms</title><link>https://hunterheidenreich.com/posts/random-number-tricks/</link><pubDate>Sun, 31 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/random-number-tricks/</guid><description>Compare inverse transform sampling and von Neumann's rejection method for exponential random numbers with Python implementations and performance.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>In the early days of computing, generating random numbers was a significant computational challenge. In a landmark 1951 paper, mathematician John von Neumann detailed various &ldquo;cooking recipes&rdquo; for producing and using random numbers on machines like the ENIAC. While much of the paper focuses on generating <em>uniform</em> random digits, he also described ingenious methods for generating numbers from more complex, non-uniform probability distributions.</p>
<p>One of the most fundamental needs in scientific simulation (from modeling radioactive decay to calculating particle free-paths in molecular dynamics) is sampling from an <strong>exponential distribution</strong> with probability density function:</p>
<p>$$f(x) = e^{-x} \quad \text{for } x \ge 0$$</p>
<p>Today&rsquo;s standard approach is elegant and direct, but it requires computing a natural logarithm (a computationally expensive operation on early hardware). To sidestep this limitation, von Neumann described a fascinating alternative that uses only basic comparisons, resembling what he called &ldquo;a well known game of chance Twenty-One, or Black Jack.&rdquo;</p>
<p>In this post, we&rsquo;ll explore both methods: the modern inverse transform approach and von Neumann&rsquo;s ingenious comparison-based algorithm. We&rsquo;ll implement them in Python, verify their correctness, and compare their performance, empirically testing the trade-offs von Neumann identified nearly 75 years ago.</p>
<hr>
<h2 id="method-1-the-standard-approach-inverse-transform-sampling">Method 1: The Standard Approach (Inverse Transform Sampling)</h2>
<p>The most common method for sampling from a given distribution is <strong>inverse transform sampling</strong>. This method relies on a fundamental principle: if you have a uniform random variable $U$ on the interval (0, 1), you can transform it into a random variable $X$ with any desired cumulative distribution function (CDF) $F(x)$ by applying:</p>
<p>$$X = F^{-1}(U)$$</p>
<p>For the exponential distribution, the CDF is $F(x) = 1 - e^{-x}$. To find the inverse, we set $U = 1 - e^{-X}$ and solve for $X$:</p>
<p>$$
\begin{align}
e^{-X} &amp;= 1 - U \
-X &amp;= \ln(1 - U) \
X &amp;= -\ln(1 - U)
\end{align}
$$</p>
<p>Here&rsquo;s a useful simplification: since $U$ is uniformly distributed on (0, 1), the quantity $(1 - U)$ is also uniformly distributed on (0, 1). Therefore, we can use the simpler formula:</p>
<p>$$X = -\ln(U)$$</p>
<p>This gives us an efficient method for generating exponentially distributed numbers, provided the logarithm function is computationally accessible.</p>
<h3 id="python-implementation">Python Implementation</h3>
<p>Here&rsquo;s a straightforward implementation using NumPy:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> numpy <span style="color:#66d9ef">as</span> np
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">exponential_inverse_transform</span>(n_samples<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Generate samples from an exponential distribution using inverse transform sampling.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        n_samples (int): Number of samples to generate.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Returns:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        np.ndarray: Array of exponentially distributed samples.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Generate uniform random numbers</span>
</span></span><span style="display:flex;"><span>    U <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>rand(n_samples)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Apply the inverse transform</span>
</span></span><span style="display:flex;"><span>    X <span style="color:#f92672">=</span> <span style="color:#f92672">-</span>np<span style="color:#f92672">.</span>log(U)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> X
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Generate 100,000 samples for testing</span>
</span></span><span style="display:flex;"><span>n_samples <span style="color:#f92672">=</span> <span style="color:#ae81ff">100000</span>
</span></span><span style="display:flex;"><span>inverse_samples <span style="color:#f92672">=</span> exponential_inverse_transform(n_samples)
</span></span></code></pre></div><hr>
<h2 id="method-2-von-neumanns-ingenious-trick-rejection-sampling">Method 2: Von Neumann&rsquo;s Ingenious Trick (Rejection Sampling)</h2>
<p>Von Neumann proposed a clever alternative that avoids transcendental functions entirely. His procedure, which he noted &ldquo;resembles a well known game of chance Twenty-One, or Black Jack,&rdquo; generates sequences of uniform random numbers and accepts or rejects them based on simple comparison rules.</p>
<p>The algorithm works as follows to generate a single exponential sample $X$:</p>
<ol>
<li>
<p><strong>Initialize</strong>: Start with an integer offset <code>k = 0</code>, which will form the integer part of the final result.</p>
</li>
<li>
<p><strong>Generate a trial sequence</strong>:</p>
<ul>
<li>Generate uniform random numbers $Y_1, Y_2, Y_3, \ldots$ from (0, 1)</li>
<li>Find the smallest integer <code>n</code> such that the sequence is no longer strictly decreasing</li>
<li>That is, find <code>n</code> where $Y_1 &gt; Y_2 &gt; \cdots &gt; Y_n$ but $Y_n \leq Y_{n+1}$</li>
</ul>
</li>
<li>
<p><strong>Accept or reject</strong>:</p>
<ul>
<li>If <code>n</code> is <strong>odd</strong>: Accept the trial. Return $X = Y_1 + k$ and terminate.</li>
<li>If <code>n</code> is <strong>even</strong>: Reject the trial. Increment <code>k</code> by 1 and start a new trial.</li>
</ul>
</li>
</ol>
<p>This process is guaranteed to terminate and produces samples that follow the exponential distribution exactly. As von Neumann elegantly put it, the machine has &ldquo;in effect computed a logarithm by performing only discriminations on the relative magnitude of numbers.&rdquo;</p>
<h3 id="python-implementation-1">Python Implementation</h3>
<p>This implementation requires more careful state management due to the nested trial structure:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> numpy <span style="color:#66d9ef">as</span> np
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">exponential_von_neumann</span>(n_samples<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Generate samples from an exponential distribution using von Neumann&#39;s
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    comparison-based rejection sampling method.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        n_samples (int): Number of samples to generate.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Returns:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        tuple[np.ndarray, float]: Array of samples and average uniform draws per sample.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    samples <span style="color:#f92672">=</span> []
</span></span><span style="display:flex;"><span>    total_uniform_draws <span style="color:#f92672">=</span> <span style="color:#ae81ff">0</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> _ <span style="color:#f92672">in</span> range(n_samples):
</span></span><span style="display:flex;"><span>        k <span style="color:#f92672">=</span> <span style="color:#ae81ff">0</span>  <span style="color:#75715e"># Integer offset</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">while</span> <span style="color:#66d9ef">True</span>:  <span style="color:#75715e"># Trial loop</span>
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># Generate decreasing sequence</span>
</span></span><span style="display:flex;"><span>            y_prev <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>rand()
</span></span><span style="display:flex;"><span>            total_uniform_draws <span style="color:#f92672">+=</span> <span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>            y1 <span style="color:#f92672">=</span> y_prev  <span style="color:#75715e"># Store first value</span>
</span></span><span style="display:flex;"><span>            n <span style="color:#f92672">=</span> <span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># Find length of decreasing sequence</span>
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">while</span> <span style="color:#66d9ef">True</span>:
</span></span><span style="display:flex;"><span>                y_curr <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>rand()
</span></span><span style="display:flex;"><span>                total_uniform_draws <span style="color:#f92672">+=</span> <span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>                <span style="color:#66d9ef">if</span> y_prev <span style="color:#f92672">&lt;=</span> y_curr:
</span></span><span style="display:flex;"><span>                    <span style="color:#66d9ef">break</span>  <span style="color:#75715e"># Sequence no longer decreasing</span>
</span></span><span style="display:flex;"><span>                y_prev <span style="color:#f92672">=</span> y_curr
</span></span><span style="display:flex;"><span>                n <span style="color:#f92672">+=</span> <span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># Accept if n is odd, reject if even</span>
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">if</span> n <span style="color:#f92672">%</span> <span style="color:#ae81ff">2</span> <span style="color:#f92672">==</span> <span style="color:#ae81ff">1</span>:  <span style="color:#75715e"># Accept</span>
</span></span><span style="display:flex;"><span>                samples<span style="color:#f92672">.</span>append(y1 <span style="color:#f92672">+</span> k)
</span></span><span style="display:flex;"><span>                <span style="color:#66d9ef">break</span>
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">else</span>:  <span style="color:#75715e"># Reject</span>
</span></span><span style="display:flex;"><span>                k <span style="color:#f92672">+=</span> <span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    avg_draws <span style="color:#f92672">=</span> total_uniform_draws <span style="color:#f92672">/</span> n_samples
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> np<span style="color:#f92672">.</span>array(samples), avg_draws
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Generate samples using von Neumann&#39;s method</span>
</span></span><span style="display:flex;"><span>von_neumann_samples, avg_draws <span style="color:#f92672">=</span> exponential_von_neumann(n_samples)
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Von Neumann method used </span><span style="color:#e6db74">{</span>avg_draws<span style="color:#e6db74">:</span><span style="color:#e6db74">.2f</span><span style="color:#e6db74">}</span><span style="color:#e6db74"> uniform draws per sample on average.&#34;</span>)
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-console" data-lang="console"><span style="display:flex;"><span>Von Neumann method used 4.30 uniform draws per sample on average.
</span></span></code></pre></div><p>The algorithm requires approximately <strong>4.3</strong> uniform draws per exponential sample, matching the theoretical value $e^2/(e-1) = 4.30$.</p>
<hr>
<h2 id="verification-and-comparison">Verification and Comparison</h2>
<p>The critical test: do both methods actually produce the same distribution? And how do their performance characteristics compare?</p>
<h3 id="visual-verification">Visual Verification</h3>
<p>Let&rsquo;s plot histograms of samples from both methods alongside the theoretical probability density function $f(x) = e^{-x}$:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> matplotlib.pyplot <span style="color:#66d9ef">as</span> plt
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> seaborn <span style="color:#66d9ef">as</span> sns
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Configure plot aesthetics</span>
</span></span><span style="display:flex;"><span>sns<span style="color:#f92672">.</span>set_style(<span style="color:#e6db74">&#34;whitegrid&#34;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>figure(figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">12</span>, <span style="color:#ae81ff">7</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Plot histograms for both methods</span>
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>hist(inverse_samples, bins<span style="color:#f92672">=</span><span style="color:#ae81ff">50</span>, density<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.7</span>,
</span></span><span style="display:flex;"><span>         label<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;Inverse Transform&#39;</span>, color<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;skyblue&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>hist(von_neumann_samples, bins<span style="color:#f92672">=</span><span style="color:#ae81ff">50</span>, density<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.7</span>,
</span></span><span style="display:flex;"><span>         label<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;Von Neumann&#39;s Method&#34;</span>, color<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;lightcoral&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Overlay theoretical PDF</span>
</span></span><span style="display:flex;"><span>x <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>linspace(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">8</span>, <span style="color:#ae81ff">400</span>)
</span></span><span style="display:flex;"><span>pdf <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>exp(<span style="color:#f92672">-</span>x)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>plot(x, pdf, <span style="color:#e6db74">&#39;r-&#39;</span>, linewidth<span style="color:#f92672">=</span><span style="color:#ae81ff">2</span>, label<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;Theoretical PDF ($e^{-x}$)&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>title(<span style="color:#e6db74">&#39;Exponential Sampling Methods vs. Theoretical Distribution&#39;</span>, fontsize<span style="color:#f92672">=</span><span style="color:#ae81ff">16</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>xlabel(<span style="color:#e6db74">&#39;x&#39;</span>, fontsize<span style="color:#f92672">=</span><span style="color:#ae81ff">12</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>ylabel(<span style="color:#e6db74">&#39;Density&#39;</span>, fontsize<span style="color:#f92672">=</span><span style="color:#ae81ff">12</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>legend()
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>xlim(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">8</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>tight_layout()
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>show()
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/exponential_random_gens.webp"
         alt="Comparison of exponential sampling methods showing histograms from both inverse transform and von Neumann methods overlaid with the theoretical exponential distribution"
         title="Comparison of exponential sampling methods showing histograms from both inverse transform and von Neumann methods overlaid with the theoretical exponential distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Both sampling methods reproduce the exponential distribution $f(x) = e^{-x}$</figcaption>
    
</figure>

<p>The visualization confirms that both methods accurately reproduce the target exponential distribution. The empirical histograms match the theoretical curve, confirming both algorithms sample the target distribution.</p>
<h3 id="performance-analysis">Performance Analysis</h3>
<p>Mathematical elegance often diverges from computational efficiency. Von Neumann himself observed that on the ENIAC, it was actually &ldquo;slightly quicker to use a truncated power series for log(1-T)&rdquo; than to perform all the comparisons his method required.</p>
<p>Let&rsquo;s benchmark both approaches in a modern Python environment:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> time
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Benchmark inverse transform method</span>
</span></span><span style="display:flex;"><span>start_time <span style="color:#f92672">=</span> time<span style="color:#f92672">.</span>time()
</span></span><span style="display:flex;"><span>_ <span style="color:#f92672">=</span> exponential_inverse_transform(n_samples)
</span></span><span style="display:flex;"><span>inverse_time <span style="color:#f92672">=</span> time<span style="color:#f92672">.</span>time() <span style="color:#f92672">-</span> start_time
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Benchmark von Neumann method</span>
</span></span><span style="display:flex;"><span>start_time <span style="color:#f92672">=</span> time<span style="color:#f92672">.</span>time()
</span></span><span style="display:flex;"><span>_ <span style="color:#f92672">=</span> exponential_von_neumann(n_samples)
</span></span><span style="display:flex;"><span>vn_time <span style="color:#f92672">=</span> time<span style="color:#f92672">.</span>time() <span style="color:#f92672">-</span> start_time
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Inverse Transform:  </span><span style="color:#e6db74">{</span>inverse_time<span style="color:#e6db74">:</span><span style="color:#e6db74">.4f</span><span style="color:#e6db74">}</span><span style="color:#e6db74"> seconds&#34;</span>)
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Von Neumann Method: </span><span style="color:#e6db74">{</span>vn_time<span style="color:#e6db74">:</span><span style="color:#e6db74">.4f</span><span style="color:#e6db74">}</span><span style="color:#e6db74"> seconds&#34;</span>)
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Speedup factor: </span><span style="color:#e6db74">{</span>vn_time <span style="color:#f92672">/</span> inverse_time<span style="color:#e6db74">:</span><span style="color:#e6db74">.1f</span><span style="color:#e6db74">}</span><span style="color:#e6db74">x&#34;</span>)
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-console" data-lang="console"><span style="display:flex;"><span>Inverse Transform:  0.0018 seconds
</span></span><span style="display:flex;"><span>Von Neumann Method: 0.1860 seconds
</span></span><span style="display:flex;"><span>Speedup factor: 103.3x
</span></span></code></pre></div><p>The gap is large. The vectorized NumPy implementation of inverse transform sampling, leveraging a highly optimized C-backed logarithm function, outperforms the Python-looped von Neumann implementation by more than two orders of magnitude. While a vectorized or JIT-compiled version of von Neumann&rsquo;s method would close this gap by removing Python interpreter overhead, the inverse transform remains the practical winner on modern hardware with fast floating-point units. This confirms von Neumann&rsquo;s prescient observation: the &ldquo;theoretically elegant&rdquo; method avoiding transcendental functions often yields to direct computation.</p>
<h2 id="conclusion">Conclusion</h2>
<p>This exploration offers a window into the ingenuity of early computational mathematics. Von Neumann&rsquo;s comparison-based algorithm demonstrates remarkable mathematical creativity (showing how to &ldquo;compute a logarithm&rdquo; using only basic machine operations). Our implementation reproduces the algorithm, producing samples whose histogram and moments match the exponential distribution.</p>
<p>The performance comparison validates von Neumann&rsquo;s own pragmatic assessment. His rejection sampling method is intellectually elegant and historically significant. The direct logarithmic approach proves far more efficient on both early and modern hardware. It serves as a timeless reminder in scientific computing: theoretical beauty often diverges from computational practicality.</p>
<p>The enduring value of von Neumann&rsquo;s work lies in the fundamental insight that creative mathematical thinking can circumvent apparent computational limitations. Understanding alternative methods deepens our appreciation for the rich landscape of algorithmic possibilities, even when the direct approach proves superior.</p>
]]></content:encoded></item><item><title>Implementing the Müller-Brown Potential in PyTorch</title><link>https://hunterheidenreich.com/posts/muller-brown-in-pytorch/</link><pubDate>Wed, 27 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/muller-brown-in-pytorch/</guid><description>Guide to implementing the Müller-Brown potential in PyTorch, comparing analytical vs automatic differentiation with performance analysis.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>The Müller-Brown potential reads, in hindsight, like an adversarial example for optimization algorithms.</p>
<p>Designed in 1979 to break naive path-finding methods, this deceptively simple 2D surface features deep minima, high barriers, and tricky saddle points. For nearly five decades, it has served as a ground-truth benchmark for computational chemistry.</p>
<p>Today, it finds new life as a testbed for machine learning. In the 1970s, chemists struggled to find transition states, saddle points where standard gradient descent fails catastrophically. Modern machine learning engineers face a strikingly similar challenge: escaping saddle points in high-dimensional loss landscapes. The Müller-Brown potential was the original stress test for these algorithms, and it remains a perfect, low-cost sandbox for benchmarking modern optimizers.</p>
<p>Whether you&rsquo;re training neural network potentials or benchmarking reinforcement learning agents for exploration, the Müller-Brown potential offers a fast, noise-free, and mathematically exact environment.</p>
<p>In this guide, we&rsquo;ll implement it in PyTorch, work through the engineering trade-off between <strong>analytical derivatives</strong> and <strong>Autograd</strong>, and measure the roughly <strong>4x</strong> force-evaluation speedup of the compiled analytical kernel over the Autograd reference.</p>
<h2 id="the-problem-finding-saddle-points-in-the-1970s">The Problem: Finding Saddle Points in the 1970s</h2>
<p>In the 1970s, finding energy minima was straightforward: follow the gradient downhill. But finding transition states proved much more challenging. These saddle points are maxima along the reaction coordinate but minima in all other directions, like standing at the top of a mountain pass.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-saddle.webp"
         alt="Diagram showing gradient descent getting stuck at a saddle point, where the surface curves up in one direction and down in another"
         title="Diagram showing gradient descent getting stuck at a saddle point, where the surface curves up in one direction and down in another"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The saddle point problem: standard gradient descent sees a minimum in one direction but a maximum in another. Modern optimizers like SGD and Adam face this same challenge in high-dimensional loss landscapes.</figcaption>
    
</figure>

<p>Standard first-order optimizers, whether 1970s simplex methods or modern SGD and Adam, are designed to minimize loss functions blindly. Point them at a saddle point, and they slide into the nearest valley. The gradients vanish or point in misleading directions. Specialized algorithms were needed to navigate this mixed landscape of ups and downs.</p>
<p>The computational reality made this worse. Early quantum chemistry programs like ATMOL and Gaussian made energy calculations possible, but each computation was expensive. Gradients required even more resources, and second derivatives were rarely computed.</p>
<p>This created a catch-22: sophisticated algorithms were needed to find saddle points, but researchers couldn&rsquo;t afford to test them on real molecular systems. Every calculation represented a major investment of time and computational resources.</p>
<h2 id="müller-and-browns-solution">Müller and Brown&rsquo;s Solution</h2>
<p>Müller and Brown&rsquo;s insight<sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup> was to create a simple analytical test function that captured the essential difficulties of real chemical systems without the computational cost. Their potential offered three key advantages:</p>
<ul>
<li><strong>Negligible computational cost</strong> - Evaluate millions of points instantly</li>
<li><strong>Analytical derivatives</strong> - Exact gradients and Hessians available immediately</li>
<li><strong>Realistic challenges</strong> - Multiple minima, saddle points, and curved pathways</li>
</ul>
<p>The clever part was the deliberate design to break naive approaches. Early methods often assumed linear paths between reactants and products. The Müller-Brown potential has a curved minimum energy path that punishes this assumption. Try to take shortcuts, and algorithms climb over high-energy barriers.</p>
<h2 id="the-mathematical-foundation">The Mathematical Foundation</h2>
<p>The Müller-Brown potential combines four two-dimensional Gaussian functions:</p>
<p>$$V(x,y) = \sum_{k=1}^{4} A_k \exp\left[a_k(x-x_k^0)^2 + b_k(x-x_k^0)(y-y_k^0) + c_k(y-y_k^0)^2\right]$$</p>
<p>Each Gaussian contributes a different &ldquo;bump&rdquo; or &ldquo;well&rdquo; to the landscape. The parameters control amplitude ($A_k$), width, orientation, and center position.</p>
<h3 id="the-standard-parameters">The Standard Parameters</h3>
<p>The specific parameter values that define the canonical Müller-Brown surface are:</p>
<table>
  <thead>
      <tr>
          <th>k</th>
          <th>$A_k$</th>
          <th>$a_k$</th>
          <th>$b_k$</th>
          <th>$c_k$</th>
          <th>$x_k^0$</th>
          <th>$y_k^0$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1</td>
          <td>-200</td>
          <td>-1</td>
          <td>0</td>
          <td>-10</td>
          <td>1</td>
          <td>0</td>
      </tr>
      <tr>
          <td>2</td>
          <td>-100</td>
          <td>-1</td>
          <td>0</td>
          <td>-10</td>
          <td>0</td>
          <td>0.5</td>
      </tr>
      <tr>
          <td>3</td>
          <td>-170</td>
          <td>-6.5</td>
          <td>11</td>
          <td>-6.5</td>
          <td>-0.5</td>
          <td>1.5</td>
      </tr>
      <tr>
          <td>4</td>
          <td>15</td>
          <td>0.7</td>
          <td>0.6</td>
          <td>0.7</td>
          <td>-1</td>
          <td>1</td>
      </tr>
  </tbody>
</table>
<p>Notice that the first three terms have negative amplitudes (creating energy wells), while the fourth has a positive amplitude (creating a barrier). The cross-term $b_k$ in the third Gaussian creates the tilted orientation that gives the surface its characteristic curved pathways.</p>
<p><a href="/muller-brown-optimized"><strong>View Interactive Müller-Brown Potential Energy Surface →</strong></a></p>
<h3 id="the-resulting-landscape">The Resulting Landscape</h3>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-potential-surface.webp"
         alt="Müller-Brown Potential Energy Surface showing the three minima (dark blue regions) and two saddle points"
         title="Müller-Brown Potential Energy Surface showing the three minima (dark blue regions) and two saddle points"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Müller-Brown potential energy surface showing the three minima (dark blue regions) and two saddle points.</figcaption>
    
</figure>

<p>This simple formula creates a surprisingly rich topography with exactly the features needed to challenge optimization algorithms:</p>
<table>
  <thead>
      <tr>
          <th><strong>Stationary Point</strong></th>
          <th><strong>Coordinates</strong></th>
          <th><strong>Energy</strong></th>
          <th><strong>Type</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MA (Reactant)</td>
          <td>(-0.558, 1.442)</td>
          <td>-146.70</td>
          <td>Deep minimum</td>
      </tr>
      <tr>
          <td>MC (Intermediate)</td>
          <td>(-0.050, 0.467)</td>
          <td>-80.77</td>
          <td>Shallow minimum</td>
      </tr>
      <tr>
          <td>MB (Product)</td>
          <td>(0.623, 0.028)</td>
          <td>-108.17</td>
          <td>Medium minimum</td>
      </tr>
      <tr>
          <td>S1</td>
          <td>(-0.822, 0.624)</td>
          <td>-40.66</td>
          <td>First saddle point</td>
      </tr>
      <tr>
          <td>S2</td>
          <td>(0.212, 0.293)</td>
          <td>-72.25</td>
          <td>Second saddle point</td>
      </tr>
  </tbody>
</table>
<h3 id="the-key-challenge-curved-pathways">The Key Challenge: Curved Pathways</h3>
<p>The path from the deep reactant minimum (MA) to the product minimum (MB) doesn&rsquo;t go directly over a single barrier. Instead, it follows a curved route:</p>
<ol>
<li><strong>MA → S1 → MC</strong>: First transition over the higher, rate-limiting barrier (S1) into an intermediate basin</li>
<li><strong>MC → S2 → MB</strong>: Second transition over a much lower barrier (S2) to the product</li>
</ol>
<p>This two-step pathway breaks linear interpolation methods. Algorithms that draw a straight line from reactant to product miss both the intermediate minimum and the correct transition states, climbing over much higher energy regions instead.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/naive-versus-minimum-path.webp"
         alt="Two-panel comparison showing naive linear interpolation versus minimum energy path. Left panel shows the contour map with both paths overlaid. Right panel shows the energy profile along each path, revealing the naive path hits a far higher barrier."
         title="Two-panel comparison showing naive linear interpolation versus minimum energy path. Left panel shows the contour map with both paths overlaid. Right panel shows the energy profile along each path, revealing the naive path hits a far higher barrier."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Why naive optimization fails: The left panel shows a straight-line path (red dashed) versus the true minimum energy path (green solid) on the potential surface. The right panel reveals the energetic cost. The naive path climbs a barrier roughly 53 reduced units higher that the curved path avoids entirely. This is the &lsquo;adversarial&rsquo; nature of the Müller-Brown surface.</figcaption>
    
</figure>

<p>The energy profile comparison makes the failure mode concrete. A naive optimizer following the red dashed path would encounter a barrier roughly <strong>53 reduced units higher</strong> than necessary (the naive summit sits at about +13 while the true path tops out near -41, the S1 saddle). The green minimum energy path navigates through the valleys, passing through the intermediate basin MC and crossing only the low-lying saddle points S1 and S2.</p>
<h2 id="why-it-works-as-a-benchmark">Why It Works as a Benchmark</h2>
<p>The Müller-Brown potential has served as a computational chemistry benchmark for over four decades because of four key characteristics:</p>
<p><strong>Low dimensionality</strong>: As a 2D surface, you can visualize the entire landscape and see exactly why algorithms succeed or fail.</p>
<p><strong>Analytical form</strong>: Energy and gradient calculations cost virtually nothing, enabling exhaustive testing impossible with quantum mechanical surfaces.</p>
<p><strong>Non-trivial topology</strong>: The curved minimum energy path and shallow intermediate minimum challenge sophisticated methods while remaining manageable.</p>
<p><strong>Known ground truth</strong>: All minima and saddle points are precisely known, providing unambiguous success metrics.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basins-of-attraction.webp"
         alt="Basins of attraction map showing which regions of the Müller-Brown surface lead to each minimum under gradient descent. Blue region flows to MA, green to MC, and yellow to MB."
         title="Basins of attraction map showing which regions of the Müller-Brown surface lead to each minimum under gradient descent. Blue region flows to MA, green to MC, and yellow to MB."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Basins of attraction: the &lsquo;optimization map&rsquo; of the Müller-Brown surface. Each color indicates which minimum a gradient descent optimizer will reach from that starting point. Saddle points (red x) sit precisely at the basin boundaries, the unstable equilibria that algorithms must navigate to find reaction paths.</figcaption>
    
</figure>

<p>This basin map reveals why the Müller-Brown potential is such an effective benchmark. Standard gradient descent from any point in the blue region inevitably falls into the deep MA minimum; from yellow, into MB. The saddle points S1 and S2 lie exactly on the boundaries between basins, infinitesimally perturbing an optimizer at these points sends it tumbling into different valleys. Finding these saddle points requires algorithms that identify and stabilize at these boundary regions.</p>
<p>What makes this particularly valuable is the contrast with other classic potentials. While the Lennard-Jones potential<sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup> serves as the benchmark for equilibrium properties with its single energy minimum, Müller-Brown explicitly models reactive landscapes. Its multiple minima and connecting barriers make it the testing ground for algorithms that find reaction paths, the methods that reveal how chemistry actually happens.</p>
<h3 id="applications-across-decades">Applications Across Decades</h3>
<p>The potential has evolved with the field&rsquo;s changing focus:</p>
<p><strong>1980s-1990s</strong>: Testing path-finding methods like Nudged Elastic Band (NEB)<sup id="fnref:3"><a href="#fn:3" class="footnote-ref" role="doc-noteref">3</a></sup>, which creates discrete representations of reaction pathways and optimizes them to find minimum energy paths.</p>
<p><strong>2000s-2010s</strong>: Validating Transition Path Sampling (TPS) methods<sup id="fnref:4"><a href="#fn:4" class="footnote-ref" role="doc-noteref">4</a></sup> that harvest statistical ensembles of reactive trajectories.</p>
<p><strong>2020s</strong>: Benchmarking machine learning models and generative approaches that learn to sample transition paths or approximate potential energy surfaces.</p>
<h2 id="modern-applications-in-machine-learning">Modern Applications in Machine Learning</h2>
<p>The rise of machine learning has given the Müller-Brown potential renewed purpose. Modern <strong>Machine Learning Interatomic Potentials (MLIPs)</strong><sup id="fnref:5"><a href="#fn:5" class="footnote-ref" role="doc-noteref">5</a></sup><sup id="fnref:6"><a href="#fn:6" class="footnote-ref" role="doc-noteref">6</a></sup> aim to bridge the gap between quantum mechanical accuracy and classical force field efficiency by training flexible models on expensive quantum chemistry data.</p>
<p>This creates a benchmarking challenge: with countless ML architectures available, how do you objectively compare them? The Müller-Brown potential provides an ideal solution, an exactly known potential energy surface that can generate unlimited, noise-free training data.</p>
<p>This enables researchers to ask fundamental questions:</p>
<ul>
<li>How well does a given architecture learn complex, curved surfaces?</li>
<li>How many training points are needed for acceptable accuracy?</li>
<li>How does the model behave when extrapolating beyond training data?</li>
<li>Can it correctly identify minima and saddle points?</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-ml-benchmark.webp"
         alt="Three-panel comparison showing the analytical Müller-Brown surface (left), a neural network at epoch 50 with high error (middle), and a converged neural network at epoch 1000 (right)"
         title="Three-panel comparison showing the analytical Müller-Brown surface (left), a neural network at epoch 50 with high error (middle), and a converged neural network at epoch 1000 (right)"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Visualizing the ML benchmark: A comparison of the analytical ground truth (left) versus a neural network potential at early (middle) and late (right) stages of training. Notice how the model in early training quickly grasps the deep minima regions but struggles significantly with the complex topography of the saddle points and energy barriers: the curved pathways are smoothed into simpler shapes. This illustrates why the Müller-Brown surface remains a challenging test case for modern architectures.</figcaption>
    
</figure>

<p>The potential has evolved from a simple model system into a <strong>reference benchmark</strong>: a fixed, exactly-known surface against which AI learning capacity is measured. Any prediction error is due to model limitations, not data quality.</p>
<p>Beyond static benchmarking, a PyTorch implementation enables <strong>differentiable simulation</strong>. Because the potential, forces, and integrator are all differentiable tensor operations, gradients can be backpropagated <em>through time</em> (via the trajectory) to optimize force field parameters or control policies directly. This capability connects classical molecular simulation to modern gradient-based machine learning in a single computational graph.</p>
<p>These benchmarking principles extend beyond abstract test cases. Real molecular dynamics simulations, such as those studying <a href="/posts/adatom-cu-diffusion/">adatom diffusion on metal surfaces</a>, face similar challenges in understanding energy landscapes and transition pathways. The Müller-Brown potential provides a controlled environment for developing methods that eventually tackle these complex realistic systems.</p>
<h2 id="extension-to-higher-dimensions">Extension to Higher Dimensions</h2>
<p>The canonical Müller-Brown potential can be extended beyond two dimensions to create more challenging test cases that better reflect real molecular systems. This extensibility demonstrates why it remains such an effective template for computational method development.</p>
<h3 id="why-higher-dimensions-matter">Why Higher Dimensions Matter</h3>
<p>Real molecules have dozens or hundreds of degrees of freedom. Understanding how algorithms scale with dimensionality is crucial for practical applications. Higher-dimensional extensions allow researchers to systematically test:</p>
<ul>
<li><strong>Algorithm scaling</strong> - Does performance degrade gracefully as dimensions increase?</li>
<li><strong>Model robustness</strong> - Do machine learning approaches maintain accuracy in high-dimensional spaces?</li>
<li><strong>Parallel efficiency</strong> - Can massively parallel methods exploit additional dimensions effectively?</li>
</ul>
<h3 id="extension-approaches">Extension Approaches</h3>
<p><strong>Harmonic constraints</strong>: Add quadratic wells in orthogonal dimensions while preserving the complex 2D landscape<sup id="fnref:7"><a href="#fn:7" class="footnote-ref" role="doc-noteref">7</a></sup>:</p>
<p>$$V_{5D}(x_1, x_2, x_3, x_4, x_5) = V(x_1, x_3) + \kappa(x_2^2 + x_4^2 + x_5^2)$$</p>
<p>The parameter $\kappa$ controls constraint strength: small values create nearly flat directions that test algorithmic efficiency.</p>
<p><strong>Collective variables</strong>: Define new coordinates that mix multiple dimensions<sup id="fnref:8"><a href="#fn:8" class="footnote-ref" role="doc-noteref">8</a></sup>:</p>
<p>$$\tilde{x} = \sqrt{x_1^2 + x_2^2 + \epsilon x_5^2},\quad \tilde{y} = \sqrt{x_3^2 + x_4^2}$$</p>
<p>where $\epsilon \ll 1$. The 5D potential becomes $V_{5D}(\tilde{x}, \tilde{y}) = V(\tilde{x}, \tilde{y})$, embedding the original surface in a higher-dimensional space.</p>
<h3 id="value-for-algorithm-development">Value for Algorithm Development</h3>
<p>This extensibility makes the Müller-Brown potential ideal for systematic testing:</p>
<ul>
<li><strong>Progressive complexity</strong>: Debug on 2D, then scale to higher dimensions</li>
<li><strong>Ground truth preservation</strong>: Known minima and saddle points remain in the active subspace</li>
<li><strong>Realistic challenges</strong>: Captures the &ldquo;needle in a haystack&rdquo; problem of transition state finding while maintaining analytical tractability</li>
</ul>
<p>These extensions transform a simple 2D benchmark into a scalable testbed for modern computational methods, probing specific challenges in high-dimensional optimization.</p>
<h2 id="implementation-in-pytorch">Implementation in PyTorch</h2>
<p>Now let&rsquo;s implement the Müller-Brown potential in PyTorch. A practical implementation needs to handle batch processing, support both analytical and automatic differentiation, and be optimized for performance.</p>
<h3 id="core-implementation">Core Implementation</h3>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> torch <span style="color:#f92672">import</span> Tensor
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">MuellerBrownPotential</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Müller-Brown potential with a torch.compile-accelerated force kernel.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        device: str <span style="color:#f92672">|</span> torch<span style="color:#f92672">.</span>device <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;cpu&#34;</span>,
</span></span><span style="display:flex;"><span>        dtype: torch<span style="color:#f92672">.</span>dtype <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>float64,
</span></span><span style="display:flex;"><span>        use_autograd: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>
</span></span><span style="display:flex;"><span>    ):
</span></span><span style="display:flex;"><span>        super()<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>use_autograd <span style="color:#f92672">=</span> use_autograd
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Standard Müller-Brown parameters</span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;A&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#f92672">-</span><span style="color:#ae81ff">200.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">100.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">170.0</span>, <span style="color:#ae81ff">15.0</span>],
</span></span><span style="display:flex;"><span>                             device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;a&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#f92672">-</span><span style="color:#ae81ff">1.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">1.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">6.5</span>, <span style="color:#ae81ff">0.7</span>],
</span></span><span style="display:flex;"><span>                             device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;b&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#ae81ff">0.0</span>, <span style="color:#ae81ff">0.0</span>, <span style="color:#ae81ff">11.0</span>, <span style="color:#ae81ff">0.6</span>],
</span></span><span style="display:flex;"><span>                             device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;c&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#f92672">-</span><span style="color:#ae81ff">10.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">10.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">6.5</span>, <span style="color:#ae81ff">0.7</span>],
</span></span><span style="display:flex;"><span>                             device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;x_centers&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#ae81ff">1.0</span>, <span style="color:#ae81ff">0.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">1.0</span>],
</span></span><span style="display:flex;"><span>                                    device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;y_centers&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#ae81ff">0.0</span>, <span style="color:#ae81ff">0.5</span>, <span style="color:#ae81ff">1.5</span>, <span style="color:#ae81ff">1.0</span>],
</span></span><span style="display:flex;"><span>                                    device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, coordinates: Tensor) <span style="color:#f92672">-&gt;</span> Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute potential energy.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> _calculate_potential(
</span></span><span style="display:flex;"><span>            coordinates, self<span style="color:#f92672">.</span>A, self<span style="color:#f92672">.</span>a, self<span style="color:#f92672">.</span>b, self<span style="color:#f92672">.</span>c,
</span></span><span style="display:flex;"><span>            self<span style="color:#f92672">.</span>x_centers, self<span style="color:#f92672">.</span>y_centers
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">force</span>(self, coordinates: Tensor) <span style="color:#f92672">-&gt;</span> Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute forces (negative gradient).&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>use_autograd:
</span></span><span style="display:flex;"><span>            coordinates <span style="color:#f92672">=</span> coordinates<span style="color:#f92672">.</span>requires_grad_(<span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>            potential <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>forward(coordinates)
</span></span><span style="display:flex;"><span>            grad <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>autograd<span style="color:#f92672">.</span>grad(potential<span style="color:#f92672">.</span>sum(), coordinates)[<span style="color:#ae81ff">0</span>]
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> <span style="color:#f92672">-</span>grad
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> _calculate_force(
</span></span><span style="display:flex;"><span>                coordinates, self<span style="color:#f92672">.</span>A, self<span style="color:#f92672">.</span>a, self<span style="color:#f92672">.</span>b, self<span style="color:#f92672">.</span>c,
</span></span><span style="display:flex;"><span>                self<span style="color:#f92672">.</span>x_centers, self<span style="color:#f92672">.</span>y_centers
</span></span><span style="display:flex;"><span>            )
</span></span></code></pre></div><p>The implementation uses <code>register_buffer</code> to store parameters, a subtle but important PyTorch best practice. This ensures that the potential parameters are automatically moved to the GPU along with the model when calling <code>.to(device)</code>, a common pitfall that leads to frustrating device mismatch errors. Beyond device placement, <code>register_buffer</code> also ensures these parameters are correctly handled during <strong>DistributedDataParallel (DDP)</strong> broadcasting, preventing silent failures when scaling training to multi-GPU clusters. The <code>use_autograd</code> flag switches between analytical and automatic differentiation.</p>
<h3 id="compiling-the-force-kernel">Compiling the Force Kernel</h3>
<p>Two decisions make the force path fast while keeping the rest flexible:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_calculate_potential</span>(coordinates: Tensor, A: Tensor, a: Tensor,
</span></span><span style="display:flex;"><span>                        b: Tensor, c: Tensor, x_centers: Tensor,
</span></span><span style="display:flex;"><span>                        y_centers: Tensor) <span style="color:#f92672">-&gt;</span> Tensor:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Energy. Left eager (uncompiled) so autograd second derivatives,
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    e.g. the Hessian, keep working: torch.compile does not support
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    double-backward, and energy is computed per save, not per step.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    coords <span style="color:#f92672">=</span> coordinates<span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)
</span></span><span style="display:flex;"><span>    x, y <span style="color:#f92672">=</span> coords[:, <span style="color:#ae81ff">0</span>], coords[:, <span style="color:#ae81ff">1</span>]
</span></span><span style="display:flex;"><span>    dx <span style="color:#f92672">=</span> x<span style="color:#f92672">.</span>unsqueeze(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> x_centers
</span></span><span style="display:flex;"><span>    dy <span style="color:#f92672">=</span> y<span style="color:#f92672">.</span>unsqueeze(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> y_centers
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    potential <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>sum(
</span></span><span style="display:flex;"><span>        A <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>exp(a <span style="color:#f92672">*</span> dx<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">+</span> b <span style="color:#f92672">*</span> dx <span style="color:#f92672">*</span> dy <span style="color:#f92672">+</span> c <span style="color:#f92672">*</span> dy<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>),
</span></span><span style="display:flex;"><span>        dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>    )
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> potential<span style="color:#f92672">.</span>view(coordinates<span style="color:#f92672">.</span>shape[:<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@torch.compile</span>(dynamic<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_calculate_force</span>(coordinates: Tensor, A: Tensor, a: Tensor,
</span></span><span style="display:flex;"><span>                    b: Tensor, c: Tensor, x_centers: Tensor,
</span></span><span style="display:flex;"><span>                    y_centers: Tensor) <span style="color:#f92672">-&gt;</span> Tensor:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Forces (negative gradient). Compiled: this is the hot path,
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    called once per simulation step.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    coords <span style="color:#f92672">=</span> coordinates<span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)
</span></span><span style="display:flex;"><span>    x, y <span style="color:#f92672">=</span> coords[:, <span style="color:#ae81ff">0</span>], coords[:, <span style="color:#ae81ff">1</span>]
</span></span><span style="display:flex;"><span>    dx <span style="color:#f92672">=</span> x<span style="color:#f92672">.</span>unsqueeze(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> x_centers
</span></span><span style="display:flex;"><span>    dy <span style="color:#f92672">=</span> y<span style="color:#f92672">.</span>unsqueeze(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> y_centers
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    exp_terms <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>exp(a <span style="color:#f92672">*</span> dx<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">+</span> b <span style="color:#f92672">*</span> dx <span style="color:#f92672">*</span> dy <span style="color:#f92672">+</span> c <span style="color:#f92672">*</span> dy<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>)
</span></span><span style="display:flex;"><span>    A_exp <span style="color:#f92672">=</span> A <span style="color:#f92672">*</span> exp_terms
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    grad_x <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>sum(A_exp <span style="color:#f92672">*</span> (<span style="color:#ae81ff">2</span> <span style="color:#f92672">*</span> a <span style="color:#f92672">*</span> dx <span style="color:#f92672">+</span> b <span style="color:#f92672">*</span> dy), dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    grad_y <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>sum(A_exp <span style="color:#f92672">*</span> (b <span style="color:#f92672">*</span> dx <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span> <span style="color:#f92672">*</span> c <span style="color:#f92672">*</span> dy), dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    forces <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>stack([<span style="color:#f92672">-</span>grad_x, <span style="color:#f92672">-</span>grad_y], dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    new_shape <span style="color:#f92672">=</span> list(coordinates<span style="color:#f92672">.</span>shape[:<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>]) <span style="color:#f92672">+</span> [<span style="color:#ae81ff">2</span>]
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> forces<span style="color:#f92672">.</span>view(new_shape)
</span></span></code></pre></div><p>The force kernel is decorated with <code>@torch.compile(dynamic=True)</code>, which traces it once and hands it to TorchInductor to fuse the pointwise operations (the exponentials and polynomial terms) and cut Python&rsquo;s dispatch overhead in the inner loop that runs millions of times per simulation. The <code>dynamic=True</code> flag keeps a single compiled trace valid across particle counts, so changing the batch size does not trigger a recompile. The energy is left eager on purpose: <code>torch.compile</code> does not support double-backward, so leaving <code>forward</code> uncompiled keeps autograd second derivatives (the Hessian) available, and since the energy is only evaluated when an observable is saved, it is not on the hot path anyway.</p>
<h3 id="performance-analytical-vs-automatic-differentiation">Performance: Analytical vs. Automatic Differentiation</h3>
<p>A key design decision is whether to use analytical derivatives or automatic differentiation. I benchmarked both on an Apple M1 Max (CPU, PyTorch 2.x). To get a stable measurement, each configuration runs 100 warm-up iterations, then the median wall-clock time over 5 runs of 1000 iterations, which filters out operating-system jitter.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-throughput-analysis.webp"
         alt="Throughput comparison showing the analytical force kernel outperforming autograd across batch sizes"
         title="Throughput comparison showing the analytical force kernel outperforming autograd across batch sizes"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Force-evaluation throughput, analytical vs autograd, across batch sizes. The analytical kernel is about 4x faster (3-7x depending on batch size).</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-time-per-particle.webp"
         alt="Per-particle computation time showing analytical derivatives maintain sub-microsecond performance for large systems"
         title="Per-particle computation time showing analytical derivatives maintain sub-microsecond performance for large systems"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Per-particle computation time. Analytical derivatives maintain sub-microsecond performance for large systems.</figcaption>
    
</figure>

<p>The speedup comes from bypassing the computational graph that PyTorch&rsquo;s Autograd engine builds. By deriving the analytical Jacobian, we skip that machinery entirely. Every <code>autograd.grad()</code> call must:</p>
<ol>
<li><strong>Build a tape</strong> of operations during the forward pass</li>
<li><strong>Traverse the graph</strong> backward to compute gradients</li>
<li><strong>Allocate intermediate tensors</strong> for each node</li>
</ol>
<p>For iterative workloads like molecular dynamics (millions of force evaluations per trajectory), this overhead elimination is critical. The analytical kernel computes forces directly in a single fused operation, no graph, no tape, no intermediate allocations.</p>
<p><strong>When to use each approach:</strong></p>
<ul>
<li><strong>Analytical</strong>: Best for production molecular dynamics where forces are computed millions of times. The speedup directly reduces wall-clock simulation time.</li>
<li><strong>Autograd</strong>: Better for prototyping, machine learning training loops, or when implementing new potentials where correctness verification is paramount. The convenience and guaranteed accuracy often outweigh performance costs during development.</li>
</ul>
<h3 id="molecular-dynamics-simulations">Molecular Dynamics Simulations</h3>
<p>To demonstrate the PyTorch implementation in action, I performed Langevin dynamics simulations in different energy basins. These simulations reveal how particles behave when confined to different regions of the potential energy surface.</p>
<h4 id="simulation-parameters">Simulation Parameters</h4>
<p>I ran 3600 time steps with a 0.01 time unit step size, using a friction coefficient of 1.0 and temperature of 25.0 in reduced units. The simulations started from equilibrium positions within each basin and show the characteristic thermal fluctuations around local minima.</p>
<h4 id="basin-ma-deep-reactant-minimum">Basin MA: Deep Reactant Minimum</h4>
<p>The deepest energy well (-146.70 in reduced units) shows highly constrained motion due to the steep energy barriers surrounding it.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-ma-position-distributions.webp"
         alt="Position distributions in Basin MA showing tight confinement"
         title="Position distributions in Basin MA showing tight confinement"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Position distributions in Basin MA. The particle remains tightly confined around (-0.558, 1.442) due to the deep potential well.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-ma-time-series.webp"
         alt="Time evolution of coordinates in Basin MA"
         title="Time evolution of coordinates in Basin MA"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Time series showing small-amplitude oscillations around the equilibrium position. The deep well severely restricts thermal motion.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-ma-trajectory.webp"
         alt="Trajectory overlaid on potential surface for Basin MA"
         title="Trajectory overlaid on potential surface for Basin MA"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Trajectory visualization showing the particle&rsquo;s motion confined to a small region around the minimum. High energy barriers prevent escape on this time scale.</figcaption>
    
</figure>

<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/woVM90qXUQs?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<h4 id="basin-mb-product-minimum">Basin MB: Product Minimum</h4>
<p>The product minimum (-108.17 in reduced units) shows intermediate behavior between the deep reactant well and shallow intermediate basin.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-mb-position-distributions.webp"
         alt="Position distributions in Basin MB showing moderate confinement"
         title="Position distributions in Basin MB showing moderate confinement"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Position distributions in Basin MB. The particle shows moderate thermal motion around (0.623, 0.028), with confinement between the deep MA basin and shallow MC basin.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-mb-time-series.webp"
         alt="Time evolution showing moderate amplitude fluctuations in Basin MB"
         title="Time evolution showing moderate amplitude fluctuations in Basin MB"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Time series demonstrating moderate amplitude fluctuations. The particle explores a region larger than MA but more constrained than the shallow MC basin.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-mb-trajectory.webp"
         alt="Basin MB trajectory showing balanced exploration"
         title="Basin MB trajectory showing balanced exploration"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The trajectory shows balanced thermal exploration within the product basin. The moderate well depth allows reasonable sampling while maintaining basin confinement.</figcaption>
    
</figure>

<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/gdAHme07bGs?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<h4 id="transition-example">Transition Example</h4>
<p>Running a longer simulation demonstrates transitions between basins:</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-transition-trajectory.webp"
         alt="Transition trajectory between basins"
         title="Transition trajectory between basins"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The trajectory illustrates the particle&rsquo;s movement between the different basins, highlighting the energy barriers and pathways involved.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-transition-time-series.webp"
         alt="Time evolution of positions during transition"
         title="Time evolution of positions during transition"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Time series showing the evolution of the particle&rsquo;s position during the transition between basins.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-transition-position-distributions.webp"
         alt="Transition trajectory on potential surface"
         title="Transition trajectory on potential surface"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The trajectory on the potential surface highlights the energy landscape the particle navigates during the transition.</figcaption>
    
</figure>

<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/dVFe_4KZbps?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<h3 id="integration-with-modern-workflows">Integration with Modern Workflows</h3>
<p>The PyTorch implementation integrates naturally with machine learning workflows. It can serve as a component in larger computational graphs, enable gradient-based optimization, and bridge classical molecular simulation with deep learning approaches.</p>
<p>For readers interested in broader molecular ML applications, this implementation pairs well with other molecular representation methods. The <a href="/posts/molecular-descriptor-coulomb-matrix/">Coulomb matrix approach</a> offers complementary perspectives on encoding molecular structure, while the <a href="/posts/kabsch-algorithm/#the-math">Kabsch algorithm</a> provides essential tools for structural alignment.</p>
<p>The complete implementation is available on <a href="https://github.com/hunter-heidenreich/Muller-Brown-Potential">GitHub</a><sup id="fnref:9"><a href="#fn:9" class="footnote-ref" role="doc-noteref">9</a></sup><sup id="fnref:10"><a href="#fn:10" class="footnote-ref" role="doc-noteref">10</a></sup><sup id="fnref:11"><a href="#fn:11" class="footnote-ref" role="doc-noteref">11</a></sup>, including benchmarking scripts, visualization tools, the test suite, and examples for optimization and molecular dynamics.</p>
<h2 id="conclusion">Conclusion</h2>
<p>The Müller-Brown potential exemplifies how a well-designed benchmark can evolve with a field. Born from 1970s computational constraints, it provided a simple way to test algorithms when quantum chemistry calculations were expensive. Its clever design, simple enough to compute instantly, complex enough to break naive approaches, made it invaluable for algorithm development.</p>
<p>Today, it serves new purposes in the machine learning era. This PyTorch implementation pairs a hand-derived analytical force kernel (compiled with <code>torch.compile</code>) with an autograd reference, a BAOAB Langevin sampler, and a test suite that checks the sampler against the canonical distribution. The analytical kernel&rsquo;s roughly 4x advantage matters for intensive simulations, while PyTorch&rsquo;s flexibility enables integration with neural network potentials and enhanced sampling methods.</p>
<p>The potential&rsquo;s evolution from practical necessity to pedagogical tool to machine learning benchmark demonstrates the value of foundational test cases. As computational chemistry continues evolving, reliable standards like the Müller-Brown potential become even more important for rigorous method development and comparison.</p>
<p>For the complete implementation with benchmarking scripts, the BAOAB Langevin simulator, and visualization tools, see the <a href="https://github.com/hunter-heidenreich/Muller-Brown-Potential">GitHub repository</a>. The full project with architecture details and performance results is at the <a href="/projects/muller-brown-pytorch/">Müller-Brown Potential: A PyTorch ML Testbed project page</a>.</p>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p>Müller, K., &amp; Brown, L. D. (1979). Location of saddle points and minimum energy paths by a constrained simplex optimization procedure. <em>Theoretica Chimica Acta</em>, 53, 75-93. <a href="https://link.springer.com/article/10.1007/BF00547608">https://link.springer.com/article/10.1007/BF00547608</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:2">
<p>Lennard-Jones, J. E. (1931). Cohesion. <em>Proceedings of the Physical Society</em>, 43(5), 461-482. <a href="https://doi.org/10.1088/0959-5309/43/5/301">https://doi.org/10.1088/0959-5309/43/5/301</a>&#160;<a href="#fnref:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:3">
<p>Henkelman, G., &amp; Jónsson, H. (2000). Improved tangent estimate in the nudged elastic band method for finding minimum energy paths and saddle points. <em>Journal of Chemical Physics</em>, 113(22), 9901-9904. <a href="https://doi.org/10.1063/1.1329672">https://doi.org/10.1063/1.1329672</a>&#160;<a href="#fnref:3" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:4">
<p>Dellago, C., Bolhuis, P. G., Csajka, F. S., &amp; Chandler, D. (1998). Transition path sampling and the calculation of rate constants. <em>Journal of Chemical Physics</em>, 108(5), 1964-1977. <a href="https://doi.org/10.1063/1.475562">https://doi.org/10.1063/1.475562</a>&#160;<a href="#fnref:4" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:5">
<p>Behler, J., &amp; Parrinello, M. (2007). Generalized neural-network representation of high-dimensional potential-energy surfaces. <em>Physical Review Letters</em>, 98(14), 146401. <a href="https://doi.org/10.1103/PhysRevLett.98.146401">https://doi.org/10.1103/PhysRevLett.98.146401</a>&#160;<a href="#fnref:5" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:6">
<p>Smith, J. S., Isayev, O., &amp; Roitberg, A. E. (2017). ANI-1: an extensible neural network potential with DFT accuracy at force field computational cost. <em>Chemical Science</em>, 8(4), 3192-3203. <a href="https://doi.org/10.1039/C6SC05720A">https://doi.org/10.1039/C6SC05720A</a>&#160;<a href="#fnref:6" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:7">
<p>Sipka, M., Dietschreit, J. C. B., Grajciar, L., &amp; Gómez-Bombarelli, R. (2023). Differentiable simulations for enhanced sampling of rare events. In <em>International Conference on Machine Learning</em> (pp. 31990-32007). PMLR.&#160;<a href="#fnref:7" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:8">
<p>Sun, L., Vandermause, J., Batzner, S., Xie, Y., Clark, D., Chen, W., &amp; Kozinsky, B. (2022). Multitask machine learning of collective variables for enhanced sampling of rare events. <em>Journal of Chemical Theory and Computation</em>, 18(4), 2341-2353.&#160;<a href="#fnref:8" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:9">
<p>LED-Molecular Repository - Original implementation of the Müller-Brown potential. <a href="https://github.com/cselab/LED-Molecular">https://github.com/cselab/LED-Molecular</a>&#160;<a href="#fnref:9" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:10">
<p>Vlachas, P. R., Zavadlav, J., Praprotnik, M., &amp; Koumoutsakos, P. (2022). Accelerated simulations of molecular systems through learning of effective dynamics. <em>Journal of Chemical Theory and Computation</em>, 18(1), 538-549. <a href="https://pubs.acs.org/doi/10.1021/acs.jctc.1c00809">https://pubs.acs.org/doi/10.1021/acs.jctc.1c00809</a>&#160;<a href="#fnref:10" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:11">
<p>Vlachas, P. R., Arampatzis, G., Uhler, C., &amp; Koumoutsakos, P. (2022). Multiscale simulations of complex systems by learning their effective dynamics. <em>Nature Machine Intelligence</em>. <a href="https://www.nature.com/articles/s42256-022-00464-w">https://www.nature.com/articles/s42256-022-00464-w</a>&#160;<a href="#fnref:11" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded></item><item><title>Modernizing Rahman's 1964 Argon Simulation</title><link>https://hunterheidenreich.com/posts/rahman-1964-lammps-liquid-argon/</link><pubDate>Sat, 23 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/rahman-1964-lammps-liquid-argon/</guid><description>How I used modern software engineering (caching, vectorization, and dependency locking) to reproduce a 60-year-old physics milestone.</description><content:encoded><![CDATA[<p>Some papers invent entire fields. Aneesur Rahman&rsquo;s 1964 paper, <strong>&ldquo;Correlations in the Motion of Atoms in Liquid Argon&rdquo;</strong>, is the &ldquo;Hello World&rdquo; of molecular dynamics (MD). Using a computer with less memory than a modern microwave, Rahman solved Newton&rsquo;s equations for 864 atoms and proved that liquids have distinct, quantifiable structure.</p>
<p>The physics of liquid argon is a solved problem. We know the answer.</p>
<p>So, why replicate it in 2025? <strong>To apply modern engineering standards to legacy science.</strong></p>
<p>This project served as an exercise in <strong>software archaeology</strong>: taking a vintage scientific workflow and rebuilding it with a modular Python analysis pipeline. I wanted to see if I could replace Rahman&rsquo;s &ldquo;write-once&rdquo; Fortran mentality with modern reproducibility, type safety, and intelligent caching.</p>
<p>The full source code is available on <a href="https://github.com/hunter-heidenreich/argon-simulation">GitHub</a>. The complete project overview, including analysis results and pipeline architecture, is on the <a href="/projects/rahman-1964-replication/">Rahman 1964 Replication project page</a>.</p>
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/KjFixUt6bnQ?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<hr>
<h2 id="engineering-the-pipeline">Engineering the Pipeline</h2>
<p>The most interesting part of this project isn&rsquo;t the simulation engine (LAMMPS handles that); it&rsquo;s the architecture of the analysis suite. MD analysis is computationally expensive ($O(N^2)$), and iterating on plots can be painfully slow if you re-compute trajectory data every time.</p>
<p>Why bother? Don&rsquo;t modern MD packages come with analysis tools?
Well, some say that writing is thinking.
Sometimes getting into the weeds of how an algorithm works or an analysis is performed, you gain insights and a deeper understanding that might be obscured by a plug-and-play tool.</p>
<h3 id="intelligent-caching">Intelligent Caching</h3>
<p>I built the <code>argon_sim</code> package with a decorator-based caching layer. The system hashes the source file&rsquo;s modification time and the function&rsquo;s arguments to avoid re-calculating the Radial Distribution Function (RDF) or Van Hove correlations on every script run.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#a6e22e">@cached_computation</span>(<span style="color:#e6db74">&#34;gr&#34;</span>)
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">compute_radial_distribution</span>(filename: str, dr: float <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.05</span>):
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># ... expensive O(N^2) distance calculations ...</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> r_values, g_r, density
</span></span></code></pre></div><p>If I tweak a plot axis, the script runs instantly, loading pre-computed arrays from disk instead of re-running the $O(N^2)$ computation. If I change the simulation trajectory, the cache invalidates automatically.</p>
<h3 id="vectorization--memory-management">Vectorization &amp; Memory Management</h3>
<p>Rahman likely relied on nested loops. Python is too slow for that. I utilized <strong>NumPy broadcasting</strong> to vectorize the calculation of atomic displacements.</p>
<p>However, calculating an $864 \times 864$ distance matrix for 5,000 frames consumes significant RAM. I implemented a <strong>chunked MSD (Mean Square Displacement) algorithm</strong> that processes the trajectory in blocks, balancing vectorization speed with memory constraints. The chunking trades some vectorization speed for a bounded memory footprint, so the analysis is not capped by holding the full distance matrix in RAM.</p>
<h3 id="reproducibility-as-a-feature">Reproducibility as a Feature</h3>
<p>Academic code is notorious for &ldquo;it works on my machine.&rdquo; To combat this, I used <strong><code>uv</code></strong> for dependency management, locking the exact environment state. The entire workflow (from simulation to final figure generation) is abstracted into a <code>Makefile</code>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bash" data-lang="bash"><span style="display:flex;"><span><span style="color:#75715e"># One command to run the physics, analyze data, and generate plots</span>
</span></span><span style="display:flex;"><span>make workflow
</span></span></code></pre></div><hr>
<h2 id="the-simulation-1964-vs-2025">The Simulation: 1964 vs. 2025</h2>
<p>I preserved Rahman&rsquo;s physical parameters exactly to ensure a fair comparison:</p>
<ul>
<li><strong>System</strong>: 864 Argon atoms</li>
<li><strong>Potential</strong>: Lennard-Jones ($\sigma = 3.4$ Å, $\epsilon/k_B = 120$ K)</li>
<li><strong>Target</strong>: 94.4 K, 1.374 g/cm³</li>
</ul>
<p>However, I modernized the <em>numerical</em> methods to ensure stability:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Feature</th>
          <th style="text-align: left">Rahman (1964)</th>
          <th style="text-align: left">This Work (2025)</th>
          <th style="text-align: left">Why it Matters</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Integration</strong></td>
          <td style="text-align: left">Predictor-Corrector</td>
          <td style="text-align: left">Velocity Verlet</td>
          <td style="text-align: left">Better energy conservation over long runs</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Timestep</strong></td>
          <td style="text-align: left">10 fs</td>
          <td style="text-align: left">2 fs</td>
          <td style="text-align: left">Rahman&rsquo;s step was aggressive; 2 fs ensures numerical stability</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Equilibration</strong></td>
          <td style="text-align: left">Velocity Scaling</td>
          <td style="text-align: left">1 ns NVT</td>
          <td style="text-align: left">Rahman couldn&rsquo;t afford long equilibrations; I melted the crystal properly to remove bias</td>
      </tr>
  </tbody>
</table>
<p>The production run lasted 10 ps in the NVE ensemble, generating 5,001 frames. Temperature remained within 1% of target with an RMS fluctuation of 0.0165.</p>
<hr>
<h2 id="validation-results">Validation Results</h2>
<p>The replication was quantitatively successful. The analysis pipeline faithfully reproduced every key signature of liquid argon.</p>
<h3 id="the-cage-effect">The Cage Effect</h3>
<p>This is the paper&rsquo;s crown jewel. In a gas, velocity correlations decay exponentially. In a liquid, Rahman discovered that atoms get trapped by their neighbors and bounce back, causing the correlation to go <em>negative</em>.</p>















<figure class="post-figure center ">
    <img src="/img/rahman-1964-argon-molecular-dynamics/rahman-argon-velocity-autocorrelation.webp"
         alt="Velocity Autocorrelation Function"
         title="Velocity Autocorrelation Function"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The VACF dips below zero at 0.3 ps. This &rsquo;negative correlation&rsquo; is the signature of the cage effect: atoms rattling against their neighbors.</figcaption>
    
</figure>

<p>My simulation captures this minimum at -0.083, matching Rahman&rsquo;s observation. The Fourier transform of this data (the frequency spectrum) reveals a peak at $\beta \approx 0.25$, physically representing the frequency of atomic collisions within the cage.</p>















<figure class="post-figure center ">
    <img src="/img/rahman-1964-argon-molecular-dynamics/rahman-argon-vacf-frequency-spectrum.webp"
         alt="Frequency spectrum of the VACF showing characteristic peak from atomic caging effects"
         title="Frequency spectrum of the VACF showing characteristic peak from atomic caging effects"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Frequency spectrum of the VACF showing characteristic peak from atomic caging effects</figcaption>
    
</figure>

<h3 id="structural-fingerprints">Structural Fingerprints</h3>
<p>The Radial Distribution Function $g(r)$ and its Fourier transform, the Structure Factor $S(k)$, are the &ldquo;fingerprints&rdquo; of a liquid&rsquo;s structure.</p>















<figure class="post-figure center ">
    <img src="/img/rahman-1964-argon-molecular-dynamics/rahman-argon-radial-distribution-function.webp"
         alt="Radial Distribution Function and Structure Factor"
         title="Radial Distribution Function and Structure Factor"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The sharp first peak (3.82 Å) shows defined nearest neighbors, while the decay shows the lack of long-range order. My calculated peaks match Rahman&rsquo;s within 3%.</figcaption>
    
</figure>

<p>The agreement here is striking. My first peak appeared at <strong>3.82 Å</strong> (Rahman: 3.7 Å). The slight discrepancy is likely due to my improved equilibration method, which allowed the system to relax into a more natural liquid state than Rahman&rsquo;s 1960s hardware allowed.</p>
<h3 id="diffusion-and-non-gaussian-behavior">Diffusion and Non-Gaussian Behavior</h3>
<p>By calculating the Mean Square Displacement (MSD), I derived a diffusion coefficient of <strong>$D = 2.47 \times 10^{-5}$ cm²/s</strong>, which deviates only <strong>2%</strong> from Rahman&rsquo;s reported $2.43 \times 10^{-5}$.</p>















<figure class="post-figure center ">
    <img src="/img/rahman-1964-argon-molecular-dynamics/rahman-argon-mean-square-displacement.webp"
         alt="Mean Square Displacement vs time showing ballistic to diffusive transition"
         title="Mean Square Displacement vs time showing ballistic to diffusive transition"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Mean Square Displacement vs. time showing ballistic to diffusive transition</figcaption>
    
</figure>

<p>More interestingly, I reproduced the &ldquo;Non-Gaussian&rdquo; parameters. Standard diffusion assumes a Gaussian distribution of displacements. Rahman found (and I confirmed) that liquid atoms deviate from this. They exhibit &ldquo;jump&rdquo; and &ldquo;wait&rdquo; dynamics, a behavior that standard Brownian motion models fail to capture.</p>















<figure class="post-figure center ">
    <img src="/img/rahman-1964-argon-molecular-dynamics/rahman-argon-non-gaussian-parameters.webp"
         alt="Non-Gaussian parameters showing deviation from simple diffusive behavior"
         title="Non-Gaussian parameters showing deviation from simple diffusive behavior"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Evidence that atoms do not follow a simple random walk. The non-zero alpha parameters indicate heterogeneous dynamics.</figcaption>
    
</figure>

<h3 id="advanced-analysis-van-hove-functions">Advanced Analysis: Van Hove Functions</h3>
<p>Rahman also explored advanced properties like the Van Hove correlation function $G(r,t)$, which describes how liquid structure evolves over time.</p>















<figure class="post-figure center ">
    <img src="/img/rahman-1964-argon-molecular-dynamics/rahman-argon-van-hove-correlation.webp"
         alt="Van Hove distinct correlation function G_d(r,t) at two time points"
         title="Van Hove distinct correlation function G_d(r,t) at two time points"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Van Hove distinct correlation function showing how neighbor coordination shells &lsquo;melt&rsquo; as time progresses</figcaption>
    
</figure>

<p>At 1.0 ps, the structure remains well-defined with clear shells. By 2.5 ps, it becomes increasingly diffuse. Rahman compared this evolution to theoretical predictions (the Vineyard approximation) and found that theory predicted overly rapid structural decay. My results confirm this finding.</p>















<figure class="post-figure center ">
    <img src="/img/rahman-1964-argon-molecular-dynamics/rahman-argon-delayed-convolution.webp"
         alt="Delayed convolution approximation testing Rahman&#39;s theoretical improvement"
         title="Delayed convolution approximation testing Rahman&#39;s theoretical improvement"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Testing Rahman&rsquo;s &lsquo;delayed convolution approximation&rsquo; (his proposed improvement over existing theory)</figcaption>
    
</figure>

<hr>
<h2 id="system-validation">System Validation</h2>
<p>Before analyzing physics, basic sanity checks confirmed proper thermal equilibrium.</p>















<figure class="post-figure center ">
    <img src="/img/rahman-1964-argon-molecular-dynamics/rahman-argon-temperature-stability.webp"
         alt="Temperature vs time plot showing excellent temperature control around 94.4 K target"
         title="Temperature vs time plot showing excellent temperature control around 94.4 K target"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Temperature vs. Time - 5001 frames showing excellent temperature control with mean 94.73 K</figcaption>
    
</figure>

<p>Mean temperature was 94.73 K (0.33 K off target) with a standard deviation of 1.56 K.</p>















<figure class="post-figure center ">
    <img src="/img/rahman-1964-argon-molecular-dynamics/rahman-argon-maxwell-boltzmann-velocity.webp"
         alt="Maxwell-Boltzmann velocity distribution"
         title="Maxwell-Boltzmann velocity distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Maxwell-Boltzmann velocity distribution from 12.9 million velocity components</figcaption>
    
</figure>

<p>The velocity distribution from 12.9 million velocity components produces a clean Maxwell-Boltzmann distribution, as expected for thermal equilibrium. The distribution widths at various heights closely match Rahman&rsquo;s results: 1.77, 2.48, and 3.56 compared to his 1.77, 2.52, and 3.52.</p>
<hr>
<h2 id="conclusion">Conclusion</h2>
<p>Replicating a 60-year-old paper might seem like a solved puzzle, but it teaches a valuable lesson in computational science. Rahman relied on brilliance and raw mathematical intuition because he lacked compute power. Today, pairing modern compute with disciplined software practices makes the same result reproducible and auditable.</p>
<p>Applying modern software engineering (<strong>modular architecture, caching, and automated workflows</strong>) to classical physics reproduces the past and builds a foundation that makes the <em>next</em> discovery easier, faster, and more reliable.</p>
<p>The quantitative agreement is striking: diffusion coefficients within 2%, structural peaks within 0.1 Å, velocity distributions matching to three significant figures. This level of reproducibility, achieved with completely different hardware and software, validates something fundamental: Rahman&rsquo;s physical model was remarkably sound, and his computational methodology was scientifically rigorous despite 1960s constraints.</p>
<p>The cage effect, velocity correlations, and structural evolution are fundamental characteristics of how matter behaves at the atomic scale, as relevant today as they were six decades ago.</p>
]]></content:encoded></item><item><title>Modern PyTorch VAEs: A Detailed Implementation Guide</title><link>https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/</link><pubDate>Sun, 03 Mar 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/</guid><description>Complete PyTorch VAE tutorial: Copy-paste code, ELBO derivation, KL annealing, and stable softplus parameterization.</description><content:encoded><![CDATA[<h2 id="what-is-a-variational-autoencoder">What is a Variational Autoencoder?</h2>
<p>A Variational Autoencoder (VAE) is a type of <strong>generative model</strong>, meaning its primary purpose is to learn the underlying structure of a dataset so it can generate new, similar data.</p>
<p>Whether the data is images, raw audio clips, or 2D graphs of drug-like molecules, a VAE aims to capture the essential features that define the data distribution. Once trained, it should be able to create entirely new samples that resemble the training data without simply copying specific examples.</p>
<p>Introduced by Kingma and Welling in 2013 (<a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">Auto-Encoding Variational Bayes</a>, <a href="https://arxiv.org/abs/1312.6114">Paper</a>), VAEs are used for:</p>
<ul>
<li><strong>Generation</strong>: Creating new data (images, music, text).</li>
<li><strong>Dimensionality Reduction</strong>: Compressing data into a much smaller, meaningful representation (a &ldquo;latent space&rdquo;).</li>
<li><strong>Imputation</strong>: Intelligently filling in missing data (e.g., denoising images).</li>
</ul>
<p>Importantly, they aim to provide a structured and continuous latent space, which allows for smooth interpolation between data points and meaningful manipulations of generated samples (think: optimization).</p>
<h2 id="tldr-the-complete-pytorch-implementation">TL;DR: The Complete PyTorch Implementation</h2>
<p>For those who just want the code, here is a complete, modern VAE implementation in PyTorch. It features <strong>softplus standard deviation parameterization</strong> for numerical stability and a <strong>custom training step</strong> that handles the ELBO loss correctly.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn.functional <span style="color:#66d9ef">as</span> F
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> dataclasses <span style="color:#f92672">import</span> dataclass
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEOutput</span>:
</span></span><span style="display:flex;"><span>    z: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    mu: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    x_recon: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss_recon: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss_kl: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAE</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, input_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">784</span>, hidden_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">512</span>, latent_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">16</span>):
</span></span><span style="display:flex;"><span>        super()<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>encoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(input_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh()
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>fc_mu <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Linear(hidden_dim, latent_dim)
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>fc_std <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Linear(hidden_dim, latent_dim)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>decoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(latent_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, input_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">encode</span>(self, x):
</span></span><span style="display:flex;"><span>        h <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encoder(x)
</span></span><span style="display:flex;"><span>        mu <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>fc_mu(h)
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Softplus + epsilon for stable std deviation</span>
</span></span><span style="display:flex;"><span>        std <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>softplus(self<span style="color:#f92672">.</span>fc_std(h)) <span style="color:#f92672">+</span> <span style="color:#ae81ff">1e-6</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu, std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">reparameterize</span>(self, mu, std):
</span></span><span style="display:flex;"><span>        eps <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>randn_like(std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu <span style="color:#f92672">+</span> eps <span style="color:#f92672">*</span> std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">decode</span>(self, z):
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>decoder(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, x, kl_weight<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>        mu, std <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encode(x)
</span></span><span style="display:flex;"><span>        z <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>reparameterize(mu, std)
</span></span><span style="display:flex;"><span>        x_recon <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>decode(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 1. Reconstruction Loss (Binary Cross Entropy for MNIST)</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Sum over features, mean over batch</span>
</span></span><span style="display:flex;"><span>        recon_loss <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>binary_cross_entropy_with_logits(x_recon, x, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>mean()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 2. KL Divergence</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Analytic KL for Normal distributions</span>
</span></span><span style="display:flex;"><span>        kl_loss <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(<span style="color:#ae81ff">1</span> <span style="color:#f92672">+</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> mu<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">-</span> std<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>mean()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 3. Total Loss (ELBO)</span>
</span></span><span style="display:flex;"><span>        loss <span style="color:#f92672">=</span> recon_loss <span style="color:#f92672">+</span> (kl_weight <span style="color:#f92672">*</span> kl_loss)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> VAEOutput(z, mu, std, x_recon, loss, recon_loss, kl_loss)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Training Loop Example ---</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">train_step</span>(model, batch, optimizer, kl_weight<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>    model<span style="color:#f92672">.</span>train()
</span></span><span style="display:flex;"><span>    optimizer<span style="color:#f92672">.</span>zero_grad()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Forward pass</span>
</span></span><span style="display:flex;"><span>    output <span style="color:#f92672">=</span> model(batch, kl_weight)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Backward pass</span>
</span></span><span style="display:flex;"><span>    output<span style="color:#f92672">.</span>loss<span style="color:#f92672">.</span>backward()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Gradient clipping (recommended)</span>
</span></span><span style="display:flex;"><span>    torch<span style="color:#f92672">.</span>nn<span style="color:#f92672">.</span>utils<span style="color:#f92672">.</span>clip_grad_norm_(model<span style="color:#f92672">.</span>parameters(), max_norm<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    optimizer<span style="color:#f92672">.</span>step()
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> output<span style="color:#f92672">.</span>loss<span style="color:#f92672">.</span>item()
</span></span></code></pre></div><h3 id="the-core-idea-learning-to-generate">The Core Idea: Learning to Generate</h3>
<p>The VAE is built on a key assumption: our complex, high-dimensional data (like a $28 \times 28$ pixel image, $\mathbf{x}$) is actually <em>generated</em> by some simpler, low-dimensional, unobserved variable (a &ldquo;latent&rdquo; variable, $\mathbf{z}$).</p>
<blockquote>
<p><strong>A Physical Metaphor: Water Molecules and Phase Diagrams</strong></p>
<p>Consider a glass of water. At the microscopic level, you have more than $10^{24}$ $\text{H}_2\text{O}$ molecules bouncing around in an incredibly high-dimensional space. Each molecule has position, velocity, and interactions with its neighbors, computationally intractable to track directly. Yet we can describe the <em>macroscopic behavior</em> of all these molecules using just two simple variables: <strong>temperature</strong> and <strong>pressure</strong>. These two dimensions create a &ldquo;phase diagram&rdquo; that tells us whether our water will be ice, liquid, or vapor. The temperature and pressure are &ldquo;latent variables&rdquo; that capture the essential physics governing this complex molecular dance.</p></blockquote>















<figure class="post-figure center ">
    <img src="/img/vae-tut/phase-diagram.webp"
         alt="Water phase diagram showing solid, liquid, and gas phases as functions of temperature and pressure"
         title="Water phase diagram showing solid, liquid, and gas phases as functions of temperature and pressure"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A water phase diagram: Complex molecular behavior reduced to two simple variables (temperature and pressure). This illustrates how high-dimensional systems can often be understood through low-dimensional latent representations.</figcaption>
    
</figure>

<p>A VAE makes the same assumption: complex data (like images) emerges from simpler underlying factors. A handwritten digit might be generated by latent factors like &ldquo;pen thickness,&rdquo; &ldquo;writing angle,&rdquo; &ldquo;digit style,&rdquo; and &ldquo;size,&rdquo; a much simpler description than tracking all 784 pixel values independently.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/hypothetical-mnist-factors.webp"
         alt="Hypothetical illustration of MNIST digits generated from latent factors like pen thickness, angle, style, and size"
         title="Hypothetical illustration of MNIST digits generated from latent factors like pen thickness, angle, style, and size"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A hypothetical illustration showing how MNIST digits could be generated from a few latent factors like pen thickness, writing angle, digit style, and size.</figcaption>
    
</figure>

<p>The VAE learns two functions: one that maps from complex data ($\mathbf{x}$) to these descriptive factors ($\mathbf{z}$), and another that maps from these factors back to the data. It accomplishes this with two main components, typically implemented as neural networks:</p>
<p><strong>1. The Encoder (Recognition Model)</strong></p>
<p>This network takes a complex data point $\mathbf{x}$ (an image) and determines the &ldquo;knob settings&rdquo; $\mathbf{z}$ that could explain or generate it. This allows us to <em>compress</em> or <em>understand</em> the data.</p>
<p>$$q_{\phi}(\mathbf{z} | \mathbf{x})$$</p>
<p>It&rsquo;s like examining a container of molecules and summarizing their complex arrangement into key parameters like temperature and pressure.</p>
<p>Crucially, the encoder outputs the <em>parameters</em> of a probability distribution (a simple Gaussian) that describes $\mathbf{z}$.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/encoding-diagram.webp"
         alt="Diagram mapping MNIST five to a Gaussian distribution in latent space with mean and standard deviation"
         title="Diagram mapping MNIST five to a Gaussian distribution in latent space with mean and standard deviation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Encoder maps an input image (e.g., an MNIST digit &lsquo;5&rsquo;) to a Gaussian distribution in latent space, characterized by a mean vector and a standard deviation vector.</figcaption>
    
</figure>

<p>For each input $\mathbf{x}$, the encoder network outputs:</p>
<ul>
<li>A vector of means, $\mathbf{\mu}$</li>
<li>A vector of standard deviations, $\mathbf{\sigma}$</li>
</ul>
<p>These parameters define our approximation $q_{\phi}(\mathbf{z} | \mathbf{x}) = \mathcal{N}(\mathbf{z} \mid \mathbf{\mu}, \mathbf{\sigma}^2\mathbf{I})$. We then <em>sample</em> from this distribution to get the $\mathbf{z}$ that we feed to the decoder. This probabilistic step is what forces the latent space to be continuous and structured. It forces similar inputs to map to nearby regions in latent space, enabling smooth interpolation and generation.</p>
<p><strong>2. The Decoder (Generative Model)</strong></p>
<p>This network learns the &ldquo;generative process.&rdquo; It takes a simple latent vector $\mathbf{z}$ and reconstructs the complex data $\mathbf{x}$. This allows us to <em>generate</em> new data by feeding it a random $\mathbf{z}$ and observing what image $\mathbf{x}$ it produces.</p>
<p>$$p_{\theta}(\mathbf{x} | \mathbf{z})$$</p>
<p>The decoder reverses the encoder: it takes the simple latent representation and &ldquo;paints&rdquo; the full, complex image from it. It&rsquo;s like taking temperature and pressure values and producing a detailed arrangement of water molecules consistent with those conditions. The goal is to reproduce the exact input as closely as possible.</p>
<p>After training, we have two networks that can be used for a variety of purposes:</p>
<ul>
<li><strong>Generation</strong>: If the latent space is well-structured, we can sample random $\mathbf{z}$ vectors from a simple distribution (like a standard normal) and feed them into the Decoder to generate new images. This is particularly useful for searching for data points with desired properties, like in drug discovery, where we might want to generate molecules with specific characteristics.</li>
<li><strong>Compression</strong>: The Encoder can compress complex data into a low-dimensional latent space, which can be useful for visualization or as a feature extractor for other tasks.</li>
</ul>
<h3 id="the-variational-problem">The &ldquo;Variational&rdquo; Problem</h3>
<p>Calculating the <em>true</em> distribution of latent variables $p_{\theta}(\mathbf{z}|\mathbf{x})$ (the posterior) is mathematically intractable.</p>
<p>This intractability arises from Bayes&rsquo; theorem:</p>
<p>$$p_{\theta}(\mathbf{z} | \mathbf{x}) = \frac{p_{\theta}(\mathbf{x} | \mathbf{z}) p_{\theta}(\mathbf{z})}{p_{\theta}(\mathbf{x})}$$</p>
<p>Breaking down each component:</p>
<ul>
<li>$p_{\theta}(\mathbf{x} | \mathbf{z})$ is our decoder, which is straightforward to compute given our likelihood model.</li>
<li>$p_{\theta}(\mathbf{z})$ is our prior over latent variables, typically a simple distribution like a standard normal, making it easy to compute.</li>
<li>$p_{\theta}(\mathbf{x})$ is the marginal likelihood of the data. And here lies the problem. It requires integrating over all possible latent variables that could have generated $\mathbf{x}$:
$$p_{\theta}(\mathbf{x}) = \int p_{\theta}(\mathbf{x} | \mathbf{z}) p_{\theta}(\mathbf{z}) d\mathbf{z}$$
It is the normalization factor that ensures the posterior is a valid probability distribution (i.e., sums to 1 over all $\mathbf{z}$).</li>
</ul>
<p>This integral is intractable because it involves integrating over a high-dimensional latent space with a complex likelihood function. No closed-form solution exists, and numerical integration is computationally prohibitive.</p>
<p>This is where the &ldquo;variational&rdquo; approach provides the solution. We approximate the true posterior by learning an encoder, $q_{\phi}(\mathbf{z} | \mathbf{x})$, that serves as a variational approximation to this intractable true distribution. The VAE&rsquo;s training process optimizes this approximation to be as accurate as possible, pushing this learned distribution closer to the true posterior.</p>
<h3 id="the-vae-objective-a-balancing-act">The VAE Objective: A Balancing Act</h3>
<p>To get these two networks (parameterized by $\theta$ and $\phi$) to work together, we train them jointly with a special loss function. This objective has two parts that balance two different goals:</p>
<h4 id="1-reconstruction-loss">1. Reconstruction Loss</h4>
<p>$$E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})]$$</p>
<p>This term asks: &ldquo;How well can we reconstruct our original image?&rdquo; It forces the VAE to be good at its job. The process goes:</p>
<ol>
<li>Take an input point $\mathbf{x}$.</li>
<li>Use the <strong>Encoder</strong> to get its latent representation $\mathbf{z} \sim q_{\phi}(\mathbf{z} | \mathbf{x})$.</li>
<li>Use the <strong>Decoder</strong> to generate a new image $\mathbf{x}&rsquo;$ from $\mathbf{z}$, $\mathbf{x}&rsquo; \sim p_{\theta}(\mathbf{x} | \mathbf{z})$.</li>
<li>Compare $\mathbf{x}$ and $\mathbf{x}&rsquo;$.</li>
</ol>
<p>The reconstruction loss measures the difference between the original and the reconstructed image.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/reconstruction-loss-graphic.webp"
         alt="Graphic illustrating the reconstruction loss between original and reconstructed images"
         title="Graphic illustrating the reconstruction loss between original and reconstructed images"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Reconstruction Loss measures how closely the Decoder&rsquo;s output matches the original input image.</figcaption>
    
</figure>

<ul>
<li><strong>For continuous inputs</strong> (like general images), this is often Mean Squared Error (MSE).</li>
<li><strong>For inputs in $[0, 1]$</strong> (like MNIST pixel intensities, which after <code>ToTensor()</code> are continuous values in $[0, 1]$), we use Binary Cross-Entropy (BCE). We treat each pixel as an independent Bernoulli variable whose target is its intensity. The decoder outputs the <em>logits</em> for each pixel, and the BCE-with-logits loss (e.g., <code>F.binary_cross_entropy_with_logits</code>) is the numerically stable way to compute the negative log-likelihood.</li>
<li><strong>More generally</strong>, you can output parameters of a desired output distribution. What if you wanted a mixture of Gaussians? The decoder could output the means, variances, and mixture weights, and you could compute the negative log-likelihood accordingly.</li>
</ul>
<p>This loss pushes the encoder to produce useful $\mathbf{z}$ vectors and pushes the decoder to learn how to interpret them accurately.</p>
<h4 id="2-the-kl-divergence-the-regularizer">2. The KL Divergence (The Regularizer)</h4>
<p>$$D_{KL}(q_{\phi}(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>On its own, the reconstruction loss might &ldquo;cheat.&rdquo; The encoder could learn to map every image to a different, specific point in the latent space, essentially &ldquo;memorizing&rdquo; the data. While this minimizes reconstruction error, it creates a meaningless latent space that fails at generation.</p>
<p>The KL divergence term fixes this. It&rsquo;s a regularizer that forces the latent space to be organized and smooth.</p>
<p>We force the encoder&rsquo;s output, $q_{\phi}(\mathbf{z} | \mathbf{x})$, to be close to a simple, predefined <em>prior distribution</em>, $p_{\theta}(\mathbf{z})$. This prior is almost always a standard normal distribution because it is mathematically convenient, easy to sample from, and encourages a well-behaved latent space.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/kl-loss-graphic.webp"
         alt="Graphic illustrating the KL divergence between the encoder&#39;s output distribution and the prior distribution"
         title="Graphic illustrating the KL divergence between the encoder&#39;s output distribution and the prior distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The KL Divergence measures how much the Encoder&rsquo;s output distribution diverges from the simple prior distribution.</figcaption>
    
</figure>

<p>This regularization term acts as a penalty, measuring how much the encoder&rsquo;s output distribution diverges from the simple prior. By minimizing this KL divergence, we encourage the model to:</p>
<ul>
<li><strong>Avoid overfitting</strong> by preventing the encoder from memorizing specific locations for each input</li>
<li><strong>Create meaningful clusters</strong> where similar inputs map to nearby regions in the latent space</li>
<li><strong>Maintain continuity</strong> so that points close together in latent space (like different variations of the digit &ldquo;7&rdquo;) decode into visually similar outputs</li>
</ul>
<p>This smooth, structured latent space is what enables generation: we can sample random points from our prior distribution and decode them into realistic new data.</p>
<p>Ultimately, the optimizer finds a balance between these two objectives: reconstructing the data well while keeping the latent space organized and regularized.</p>
<h3 id="the-reparameterization-trick-making-it-all-trainable">The Reparameterization Trick: Making it All Trainable</h3>
<p>We have a problem. The training process requires sampling:</p>
<ol>
<li>Encoder produces $\mathbf{\mu}$ and $\mathbf{\sigma}$.</li>
<li>We <strong>sample</strong> $\mathbf{z} \sim \mathcal{N}(\mathbf{\mu}, \mathbf{\sigma}^2\mathbf{I})$.</li>
<li>Decoder uses $\mathbf{z}$ to reconstruct $\mathbf{x}&rsquo;$.</li>
<li>We calculate the loss.</li>
</ol>
<p>The &ldquo;sampling&rdquo; step is a random, non-differentiable operation. We can&rsquo;t backpropagate the reconstruction loss from the decoder <em>through</em> this random node to update the encoder&rsquo;s weights.</p>
<p>The <strong>reparameterization trick</strong> makes the sampling process differentiable. We generate $\mathbf{z}$ deterministically by sampling a random noise vector and transforming it:</p>
<ol>
<li>Sample a random noise vector $\mathbf{\epsilon}$ from a simple, fixed distribution (e.g., the standard normal $\mathcal{N}(\mathbf{0}, \mathbf{I})$).</li>
<li>Compute $\mathbf{z}$ as: $\mathbf{z} = \mathbf{\mu} + \mathbf{\sigma} \odot \mathbf{\epsilon}$</li>
</ol>
<p>This simple change moves the randomness &ldquo;outside&rdquo; the network. The gradient can now flow deterministically from $\mathbf{z}$ back through the $\mathbf{\mu}$ and $\mathbf{\sigma}$ nodes to the encoder network. This is the key engineering insight that allows us to train the entire model end-to-end with standard backpropagation.</p>
<h3 id="where-does-this-objective-come-from-the-math">Where Does This Objective Come From? (The Math)</h3>
<p>This two-part loss function is derived directly from the goal of maximizing the marginal likelihood of the data, $\log p_{\theta}(\mathbf{x})$.</p>
<p>For a single data point $\mathbf{x}^{(i)}$, we can write:
$$\log p_{\theta}(\mathbf{x}^{(i)}) = D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z} | \mathbf{x}^{(i)})) + \mathcal{L}(\theta, \phi; \mathbf{x}^{(i)})$$</p>
<ul>
<li>The first term is the KL divergence between our encoder&rsquo;s approximation and the (intractable) true posterior. This is non-negative, and unfortunately we cannot compute it.</li>
<li>The second term, $\mathcal{L}$, is the Variational Lower Bound (also known as the Evidence Lower Bound, or ELBO). Since the KL term is $\ge 0$, we know that $\log p_{\theta}(\mathbf{x}^{(i)}) \ge \mathcal{L}$.</li>
</ul>
<p>By maximizing this lower bound $\mathcal{L}$, we push up the &ldquo;floor&rdquo; on the true likelihood of our data. This is a problem we can solve.</p>
<p>When we expand this $\mathcal{L}$ term, we get our famous two-part objective:</p>
<p>$$\mathcal{L}(\theta, \phi; \mathbf{x}^{(i)}) = E_{q_{\phi}(\mathbf{z} | \mathbf{x}^{(i)})}[\log p_{\theta}(\mathbf{x}^{(i)} | \mathbf{z})] - D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z}))$$</p>
<ul>
<li><strong>Term 1:</strong> The expected log-likelihood of reconstructing $\mathbf{x}^{(i)}$ from $\mathbf{z}$. Maximizing this is the same as minimizing the Reconstruction Loss.</li>
<li><strong>Term 2:</strong> The negative KL divergence between our encoder and the simple prior. Maximizing this is the same as minimizing the KL Divergence Loss.</li>
</ul>
<p>Thus, the VAE&rsquo;s objective balances these two critical goals: faithfully reconstructing the data while maintaining a simple, regularized latent structure that is useful for generation.</p>
<h3 id="from-elbo-to-practical-loss">From ELBO to Practical Loss</h3>
<p>Remember, our goal is to <strong>maximize</strong> the ELBO:</p>
<p>$$\mathcal{L}(\theta, \phi; \mathbf{x}) = E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})] - D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>Since deep learning libraries are built to <strong>minimize</strong> a loss function, we simply flip the sign and <strong>minimize the negative ELBO ($-\mathcal{L}$)</strong>.</p>
<p>This gives us our final, practical loss function:</p>
<p>$$\text{Loss} = -\mathcal{L} = -E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})] + D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>This is the function you actually implement. Minimizing this loss achieves both of our goals:</p>
<ol>
<li>It <strong>minimizes the Reconstruction Loss</strong> (which is the same as maximizing the log-likelihood).</li>
<li>It <strong>minimizes the KL Divergence</strong>, forcing the encoder to match the prior.</li>
</ol>
<h2 id="modern-pytorch-vae-implementation">Modern PyTorch VAE Implementation</h2>
<p>Now that we understand the VAE architecture and objective, let&rsquo;s implement a modern VAE in PyTorch. I&rsquo;ll focus primarily on the model and loss function here, though the full code is available <a href="https://github.com/hunter-heidenreich/vae">on GitHub</a>.</p>
<p>My VAE implementation uses an output <code>dataclass</code> and a VAE class extending <code>nn.Module</code>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#e6db74">&#34;&#34;&#34;Variational Autoencoder (VAE) model implementation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> dataclasses <span style="color:#f92672">import</span> dataclass
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn.functional <span style="color:#66d9ef">as</span> F
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_activation</span>(activation: str) <span style="color:#f92672">-&gt;</span> nn<span style="color:#f92672">.</span>Module:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Get activation function by name.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    activation_lower <span style="color:#f92672">=</span> activation<span style="color:#f92672">.</span>lower()
</span></span><span style="display:flex;"><span>    ACTIVATION_MAP <span style="color:#f92672">=</span> {
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;relu&#34;</span>: nn<span style="color:#f92672">.</span>ReLU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;tanh&#34;</span>: nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;sigmoid&#34;</span>: nn<span style="color:#f92672">.</span>Sigmoid(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;leaky_relu&#34;</span>: nn<span style="color:#f92672">.</span>LeakyReLU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;elu&#34;</span>: nn<span style="color:#f92672">.</span>ELU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;gelu&#34;</span>: nn<span style="color:#f92672">.</span>GELU(),
</span></span><span style="display:flex;"><span>    }
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> activation_lower <span style="color:#f92672">not</span> <span style="color:#f92672">in</span> ACTIVATION_MAP:
</span></span><span style="display:flex;"><span>        supported <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;, &#34;</span><span style="color:#f92672">.</span>join(ACTIVATION_MAP<span style="color:#f92672">.</span>keys())
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">raise</span> <span style="color:#a6e22e">ValueError</span>(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Unsupported activation &#39;</span><span style="color:#e6db74">{</span>activation<span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;. Supported: </span><span style="color:#e6db74">{</span>supported<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> ACTIVATION_MAP[activation_lower]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEConfig</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;VAE model configuration specifying architecture and behavior.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    hidden_dim: int
</span></span><span style="display:flex;"><span>    latent_dim: int
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    input_shape: tuple[int, int, int] <span style="color:#f92672">=</span> (<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">28</span>, <span style="color:#ae81ff">28</span>)  <span style="color:#75715e"># Default: MNIST</span>
</span></span><span style="display:flex;"><span>    activation: str <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;tanh&#34;</span>  <span style="color:#75715e"># Default: tanh, what was used in the original VAE paper</span>
</span></span><span style="display:flex;"><span>    use_softplus_std: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>  <span style="color:#75715e"># Whether to use softplus for std parameterization</span>
</span></span><span style="display:flex;"><span>    n_samples: int <span style="color:#f92672">=</span> <span style="color:#ae81ff">1</span>  <span style="color:#75715e"># Number of latent samples per input during training</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEOutput</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;VAE forward pass output containing all relevant tensors and optional losses.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    x_logits: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    z: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    mu: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    x_recon: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss_recon: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss_kl: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAE</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Variational Autoencoder with support for deterministic and probabilistic reconstruction.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    DEFAULT_EPS <span style="color:#f92672">=</span> <span style="color:#ae81ff">1e-8</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, config: VAEConfig) <span style="color:#f92672">-&gt;</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Initialize VAE with given configuration.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            config: VAE configuration specifying architecture and behavior
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        super()<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>config <span style="color:#f92672">=</span> config
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Build encoder: input -&gt; hidden -&gt; latent parameters (mu, sigma)</span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>encoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Flatten(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(
</span></span><span style="display:flex;"><span>                int(torch<span style="color:#f92672">.</span>prod(torch<span style="color:#f92672">.</span>tensor(config<span style="color:#f92672">.</span>input_shape))), config<span style="color:#f92672">.</span>hidden_dim
</span></span><span style="display:flex;"><span>            ),
</span></span><span style="display:flex;"><span>            get_activation(config<span style="color:#f92672">.</span>activation),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(config<span style="color:#f92672">.</span>hidden_dim, config<span style="color:#f92672">.</span>latent_dim <span style="color:#f92672">*</span> <span style="color:#ae81ff">2</span>),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Build decoder: latent -&gt; hidden -&gt; reconstructed input</span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>decoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(config<span style="color:#f92672">.</span>latent_dim, config<span style="color:#f92672">.</span>hidden_dim),
</span></span><span style="display:flex;"><span>            get_activation(config<span style="color:#f92672">.</span>activation),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(
</span></span><span style="display:flex;"><span>                config<span style="color:#f92672">.</span>hidden_dim, int(torch<span style="color:#f92672">.</span>prod(torch<span style="color:#f92672">.</span>tensor(config<span style="color:#f92672">.</span>input_shape)))
</span></span><span style="display:flex;"><span>            ),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Unflatten(<span style="color:#ae81ff">1</span>, config<span style="color:#f92672">.</span>input_shape),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">encode</span>(self, x: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Encode input to latent distribution parameters.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        encoder_output <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encoder(x)
</span></span><span style="display:flex;"><span>        mu, sigma <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>chunk(encoder_output, <span style="color:#ae81ff">2</span>, dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu, sigma
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">decode</span>(self, z: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Decode latent representation to reconstruction logits&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>decoder(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">reparameterize</span>(self, mu: torch<span style="color:#f92672">.</span>Tensor, std: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Apply reparameterization trick for differentiable sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        epsilon <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>randn_like(std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu <span style="color:#f92672">+</span> std <span style="color:#f92672">*</span> epsilon
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        x: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        compute_loss: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">True</span>,
</span></span><span style="display:flex;"><span>        reconstruct: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>,
</span></span><span style="display:flex;"><span>        eps: float <span style="color:#f92672">=</span> DEFAULT_EPS,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> VAEOutput:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Forward pass through the VAE.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            x: Input tensor of shape (batch_size, *input_shape)
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            compute_loss: Whether to compute VAE loss components
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            reconstruct: Whether to return reconstructions or distributions
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            eps: Small epsilon value for numerical stability
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Returns:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            VAEOutput containing all relevant tensors and optionally computed losses
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Prepare input for multiple sampling if needed</span>
</span></span><span style="display:flex;"><span>        x_expanded <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_expand_for_sampling(x) <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples <span style="color:#f92672">&gt;</span> <span style="color:#ae81ff">1</span> <span style="color:#66d9ef">else</span> x
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Encode and sample from latent space</span>
</span></span><span style="display:flex;"><span>        mu, sigma <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encode(x)
</span></span><span style="display:flex;"><span>        std <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_sigma_to_std(sigma, eps<span style="color:#f92672">=</span>eps)
</span></span><span style="display:flex;"><span>        mu_expanded, std_expanded <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_expand_latent_params(mu, std)
</span></span><span style="display:flex;"><span>        z <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>reparameterize(mu_expanded, std_expanded)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Decode latent samples</span>
</span></span><span style="display:flex;"><span>        x_logits <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>decode(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Create output object</span>
</span></span><span style="display:flex;"><span>        output <span style="color:#f92672">=</span> VAEOutput(
</span></span><span style="display:flex;"><span>            x_logits<span style="color:#f92672">=</span>x_logits,
</span></span><span style="display:flex;"><span>            z<span style="color:#f92672">=</span>z,
</span></span><span style="display:flex;"><span>            mu<span style="color:#f92672">=</span>mu,
</span></span><span style="display:flex;"><span>            std<span style="color:#f92672">=</span>std,
</span></span><span style="display:flex;"><span>            x_recon<span style="color:#f92672">=</span>torch<span style="color:#f92672">.</span>sigmoid(x_logits) <span style="color:#66d9ef">if</span> reconstruct <span style="color:#66d9ef">else</span> <span style="color:#66d9ef">None</span>,
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Compute losses if requested</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> compute_loss:
</span></span><span style="display:flex;"><span>            loss, loss_recon, loss_kl <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_loss(
</span></span><span style="display:flex;"><span>                x_expanded, x_logits, mu, sigma, std
</span></span><span style="display:flex;"><span>            )
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss <span style="color:#f92672">=</span> loss
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss_recon <span style="color:#f92672">=</span> loss_recon
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss_kl <span style="color:#f92672">=</span> loss_kl
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> output
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># ==================== Helper Methods ====================</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_sigma_to_std</span>(
</span></span><span style="display:flex;"><span>        self, sigma: torch<span style="color:#f92672">.</span>Tensor, eps: float <span style="color:#f92672">=</span> DEFAULT_EPS
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Convert sigma parameter to standard deviation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>softplus(sigma) <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>exp(<span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sigma)  <span style="color:#75715e"># sigma represents log-variance</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_expand_for_sampling</span>(self, x: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Expand input tensor for multiple sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        shape_dims <span style="color:#f92672">=</span> [<span style="color:#ae81ff">1</span>] <span style="color:#f92672">*</span> len(self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>input_shape)
</span></span><span style="display:flex;"><span>        x_expanded <span style="color:#f92672">=</span> x<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#f92672">*</span>shape_dims)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> x_expanded<span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, <span style="color:#f92672">*</span>self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>input_shape)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_expand_latent_params</span>(
</span></span><span style="display:flex;"><span>        self, mu: torch<span style="color:#f92672">.</span>Tensor, std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Expand latent parameters for multiple sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples <span style="color:#f92672">==</span> <span style="color:#ae81ff">1</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> mu, std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        mu_expanded <span style="color:#f92672">=</span> (
</span></span><span style="display:flex;"><span>            mu<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>latent_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        std_expanded <span style="color:#f92672">=</span> (
</span></span><span style="display:flex;"><span>            std<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>latent_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu_expanded, std_expanded
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># ==================== Loss Computation ====================</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_loss</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        x: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        x_logits: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        mu: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        sigma: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        std: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute VAE loss components for deterministic reconstruction.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        loss_recon <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_reconstruction_loss(x, x_logits)
</span></span><span style="display:flex;"><span>        loss_kl <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_kl_loss(mu, sigma, std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> loss_recon <span style="color:#f92672">+</span> loss_kl, loss_recon, loss_kl
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_reconstruction_loss</span>(
</span></span><span style="display:flex;"><span>        self, x: torch<span style="color:#f92672">.</span>Tensor, x_logits: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute reconstruction loss using binary cross-entropy.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>binary_cross_entropy_with_logits(
</span></span><span style="display:flex;"><span>            x_logits, x, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;sum&#34;</span>
</span></span><span style="display:flex;"><span>        ) <span style="color:#f92672">/</span> x<span style="color:#f92672">.</span>size(<span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_kl_loss</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        mu: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        sigma: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        std: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        eps: float <span style="color:#f92672">=</span> DEFAULT_EPS,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute KL divergence between latent distribution and standard normal prior.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Analytical KL: KL(N(μ,σ²) || N(0,1)) = 0.5 * Σ(μ² + σ² - 1 - log(σ²))</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># sigma is just the raw output, need to use std directly: σ</span>
</span></span><span style="display:flex;"><span>            kl_per_sample <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(
</span></span><span style="display:flex;"><span>                mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> eps), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>            )
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># sigma represents log-variance parameterization: log(σ²)</span>
</span></span><span style="display:flex;"><span>            kl_per_sample <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> sigma<span style="color:#f92672">.</span>exp() <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> sigma, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> kl_per_sample<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="loss-scaling">Loss Scaling</h3>
<p>Both components of the VAE loss should be summed over data dimensions and averaged over the batch size. A common mistake is using the <code>reduction=&quot;mean&quot;</code> option in PyTorch loss functions, which averages over all elements in the tensor.</p>
<ul>
<li>The <strong>KL Divergence</strong> (<code>loss_kl</code>) is a penalization term. Each dimension of the latent space has the potential to add complexity and deviate from the prior. As you increase the latent dimensionality, you typically see the KL loss increase in magnitude. That&rsquo;s the cost of having a more expressive latent space.</li>
<li>The <strong>Reconstruction Loss</strong> (<code>loss_recon</code>) measures how well the model reconstructs the input data, and it should scale with input dimensionality (this can bias the model toward better reconstruction for higher-dimensional data).</li>
</ul>
<p>In the case of MNIST, if we used <code>reduction=&quot;mean&quot;</code> for BCE, it would be averaged over all $784 \times \text{batch size}$ pixels, making it tiny compared to the KL loss. The KL term would dominate, and the model would learn to ignore the input, potentially leading to posterior collapse.</p>
<p>While modern optimizers can handle a variety of scenarios and you can still learn effective models with imperfect scaling, the original VAE paper used the scaling described above, and I recommend following that convention.</p>
<h3 id="mitigating-posterior-collapse-kl-annealingwarmup">Mitigating Posterior Collapse: KL Annealing/Warmup</h3>
<p>One common issue in training VAEs, especially with powerful decoders (like RNNs or deep CNNs), is <strong>posterior collapse</strong>. This happens when the KL term dominates the loss early in training. The model quickly learns to just output the prior distribution ($q(z|x) \approx p(z)$) to drive the KL loss to zero, effectively ignoring the latent code $z$. The decoder then becomes a powerful autoregressive model that ignores the latent input.</p>
<p>To prevent this, we often use <strong>KL Annealing</strong> (or Warmup). We introduce a weight $\beta$ for the KL term that starts at 0 and slowly increases to 1 over the first $N$ steps or epochs.</p>
<p>$$ \mathcal{L} = \mathcal{L}_{recon} + \beta \cdot D_{KL} $$</p>
<p>This allows the model to focus purely on reconstruction first (using the full latent capacity), and then slowly adds the regularization pressure.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Simple Linear Annealing Scheduler</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_kl_weight</span>(step, total_steps, max_val<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>    val <span style="color:#f92672">=</span> (step <span style="color:#f92672">/</span> total_steps) <span style="color:#f92672">*</span> max_val
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> min(max(val, <span style="color:#ae81ff">0.0</span>), max_val)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># In your training loop:</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> epoch <span style="color:#f92672">in</span> range(epochs):
</span></span><span style="display:flex;"><span>    beta <span style="color:#f92672">=</span> get_kl_weight(epoch, warmup_epochs)
</span></span><span style="display:flex;"><span>    loss <span style="color:#f92672">=</span> recon_loss <span style="color:#f92672">+</span> beta <span style="color:#f92672">*</span> kl_loss
</span></span></code></pre></div><h4 id="parameterizing-standard-deviation">Parameterizing Standard Deviation</h4>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_sigma_to_std</span>(self, sigma: torch<span style="color:#f92672">.</span>Tensor, eps: float <span style="color:#f92672">=</span> DEFAULT_EPS) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Convert sigma parameter to standard deviation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>bound_std <span style="color:#f92672">is</span> <span style="color:#f92672">not</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>sigmoid(sigma) <span style="color:#f92672">*</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>bound_std <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">elif</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>softplus(sigma) <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>exp(<span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sigma)  <span style="color:#75715e"># sigma represents log-variance</span>
</span></span></code></pre></div><p>Parameterizing the mean of the latent distribution is straightforward since $\mu \in \mathbb{R}$. However, the standard deviation $\sigma$ must be strictly positive (as must the variance $\sigma^2$). This type of constrained optimization is challenging for neural networks.</p>
<p><strong>Log-Variance</strong>
One common approach is to have the network output the <strong>log-variance</strong> ($\log \sigma^2$). This is what the original VAE paper did. The idea is to allow the network to output any real number and treat that value as the log-variance, $s = \log \sigma^2$. We can then compute the standard deviation as $\sigma = \exp(0.5 s)$, which is always positive.</p>
<p>The KL divergence formula simplifies nicely with this parameterization:</p>
<p>$$
\text{KL}( \mathcal{N}(\mu, \sigma^2) || \mathcal{N}(0, 1) ) = \frac{1}{2} \sum_{i=1}^d (\mu_i^2 + \sigma_i^2 - 1 - \log \sigma_i^2)
$$</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> s<span style="color:#f92672">.</span>exp() <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> s, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span></code></pre></div><p><strong>Softplus Standard Deviation</strong>
An alternative is to have the network output $\sigma$ directly. This must be handled with care to ensure positivity. Strictly positive activations like <code>softplus</code> are required. Activations like <code>ReLU</code> can output zero, leading to numerical instability (during training) and deterministic behavior (during sampling). Additionally, adding a small epsilon value ensures numerical stability by preventing $\sigma$ from being exactly zero.</p>
<p>The KL divergence formula becomes slightly more complex:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(
</span></span><span style="display:flex;"><span>    mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> eps), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><p><strong>Bounded Standard Deviation</strong>
Another option is to bound the standard deviation to a maximum value using a <code>sigmoid</code> transformation (or similar). This replaces mapping to $(0, \infty)$ with mapping to $(0, \text{bound})$. This helps prevent extremely high variance values that might destabilize training, while limiting the expressiveness of the latent distribution. Like with <code>softplus</code>, adding a small epsilon ensures numerical stability by preventing $\sigma$ from being exactly zero or approaching it too closely.</p>
<p><strong>Gradient Behavior</strong>
All parameterizations can work well in practice and have different gradient behaviors. Think of $g(s)$ as a transformation function from the network output to the proper domain of $\sigma$ (or $\sigma^2$); in the log-variance case, $g(s) = \exp(s)$, while in the softplus case, $g(s) = \text{softplus}(s) + \epsilon$.</p>
<p>The gradient of the loss with respect to these outputs can be written using the chain rule:</p>
<p>$$
\frac{\partial \mathcal{L}}{\partial s} = \frac{\partial \mathcal{L}}{\partial \sigma} \cdot \frac{\partial g(s)}{\partial s}
$$</p>
<p>where $\frac{\partial g(s)}{\partial s}$ is the derivative of the transformation function.</p>
<p>We need to guard against two pathological cases:</p>
<ul>
<li>$\frac{\partial g(s)}{\partial s} \rightarrow 0$: This leads to vanishing gradients, making it hard for the network to learn.</li>
<li>$\frac{\partial g(s)}{\partial s} \rightarrow \infty$: This leads to exploding gradients, causing instability during training and potentially divergence.</li>
</ul>
<p>The log-variance parameterization, with its exponential transformation that is its own derivative, exhibits both issues at extreme values. If $s \rightarrow -\infty$, then $\sigma \rightarrow 0$ and the gradient vanishes. If $s \rightarrow \infty$, then $\sigma \rightarrow \infty$ and the gradient explodes. Since the interval $(0, 1)$ is mapped to $(-\infty, 0)$ in log-space, it&rsquo;s much more difficult for the network to drive $\sigma$ to small values. In practice, exploding gradients at high values have been more problematic in my experience. Gradient clipping, learning rate scheduling, and clamping the log-variance output to a maximum value can help mitigate this.</p>
<p>What about softplus? The derivative of <code>softplus</code> is the <code>sigmoid</code> function, which smoothly maps $(-\infty, \infty)$ to $(0, 1)$. Gradients are always bounded by unity, preventing explosion (barring explosion from other parts of the network). However, as $s \rightarrow -\infty$, the gradient approaches zero, leading to vanishing gradients. Adding a small epsilon helps mitigate this, ensuring that $\sigma$ never gets too close to zero. Nonetheless, learning can still slow down.</p>
<p>For bounded standard deviation, the derivative of the <code>sigmoid</code> function is also bounded, preventing exploding gradients. (The gradient of <code>sigmoid</code> is defined in terms of itself: $\text{sig}&rsquo;(x) = \text{sig}(x)(1 - \text{sig}(x))$; its maximum value is $0.25$ at $x=0$.)</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/gradient_behaviors.webp"
         alt="Graph comparing gradient behaviors of log-variance, softplus, and bounded standard deviation parameterizations"
         title="Graph comparing gradient behaviors of log-variance, softplus, and bounded standard deviation parameterizations"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Gradient behaviors of different standard deviation parameterizations: Log-Variance (exponential), Softplus, and Bounded Standard Deviation (sigmoid). Each has unique characteristics affecting training stability.</figcaption>
    
</figure>

<h2 id="experiments">Experiments</h2>
<h3 id="2d-mnist-vae-with-different-std-dev-parameterizations">2D MNIST VAE with Different Std. Dev. Parameterizations</h3>
<p>First, let&rsquo;s run an experiment that is close to what was done in the original VAE paper. We&rsquo;ll use MNIST as our dataset, a simple feedforward architecture with <code>tanh</code> activations, and the log-variance parameterization for the latent distribution.</p>
<p>Some of the differences from the original paper include:</p>
<ul>
<li>Using a hidden size of 512 (the original used 500)</li>
<li>Using the AdamW optimizer (the original used vanilla Adagrad)</li>
<li>Applying similar weight decay, doing so quite differently due to the optimizer change</li>
<li>Focusing primarily on 2D latent spaces (for now)</li>
</ul>
<p>This results in a network with 807,700 parameters.
I train each model for 150 epochs at most and highlight the best based on the reconstruction loss on the test set.
Just for fun, I sweep across different standard deviation parameterizations and learning rate warmup strategies.</p>
<table>
  <thead>
      <tr>
          <th>Std. Dev. Param</th>
          <th>Warmup Steps</th>
          <th>Test Recon. Loss</th>
          <th>Test KL Loss</th>
          <th>Test Total Loss</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Log-Variance</td>
          <td>0</td>
          <td>140.88</td>
          <td>6.96</td>
          <td>147.84</td>
      </tr>
      <tr>
          <td>Log-Variance</td>
          <td>600</td>
          <td>141.41</td>
          <td>6.63</td>
          <td>148.04</td>
      </tr>
      <tr>
          <td>Softplus</td>
          <td>0</td>
          <td>141.51</td>
          <td>6.56</td>
          <td>148.07</td>
      </tr>
      <tr>
          <td>Softplus</td>
          <td>600</td>
          <td>140.37</td>
          <td>6.67</td>
          <td>147.04</td>
      </tr>
      <tr>
          <td>Bounded Std. Dev. (10)</td>
          <td>0</td>
          <td>140.96</td>
          <td>6.82</td>
          <td>147.79</td>
      </tr>
      <tr>
          <td>Bounded Std. Dev. (10)</td>
          <td>600</td>
          <td>141.78</td>
          <td>6.68</td>
          <td>148.45</td>
      </tr>
  </tbody>
</table>
<p>From this summary table, all three parameterizations work well. The differences in final loss values are quite small. This could be due to the simplicity of the dataset and model architecture, further amplified by forcing the network to compress images into a very low-dimensional latent space (2D).</p>
<p>Since the softplus parameterization with learning rate warmup achieved the best reconstruction loss, let&rsquo;s visualize some of its training dynamics and results more closely.</p>
<h4 id="loss-dynamics">Loss Dynamics</h4>
<p>To understand the VAE&rsquo;s behavior, we must look at the ELBO and its two components: the Reconstruction Loss and the KL Divergence.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-elbo_epochs.webp"
         alt="Plot showing training and testing ELBO across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing ELBO across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Total ELBO: Training and testing ELBO across 150 epochs.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-reconstruction_loss_epochs.webp"
         alt="Plot showing training and testing reconstruction loss across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing reconstruction loss across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Reconstruction Loss: Training and testing reconstruction loss across 150 epochs.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-kl_loss_epochs.webp"
         alt="Plot showing training and testing KL divergence loss across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing KL divergence loss across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">KL Divergence: Training and testing KL divergence loss across 150 epochs.</figcaption>
    
</figure>

<p>These plots reveal a clear narrative:</p>
<ol>
<li><strong>Rapid Initial Learning:</strong> Performance skyrockets in the first ~15 epochs.</li>
<li><strong>Overfitting:</strong> The <strong>Reconstruction Loss</strong> (middle) flatlines for the test set while continuing to improve for training, a classic sign of memorization.</li>
<li><strong>The Balancing Act:</strong> The <strong>KL Divergence</strong> (bottom) initially rises (&ldquo;The Cost of Learning&rdquo;) as the model stretches the latent space to encode digits, then saturates.</li>
<li><strong>Equilibrium:</strong> The total <strong>ELBO</strong> (top) improves slowly, driven by the model finding the optimal trade-off between reconstruction and regularization. Notice that Test and Train KL tracks closely: a sign of good regularization!</li>
</ol>
<p><strong>Visualizing the VAE Trade-Off: BCE vs. KL</strong></p>
<p>While the line plots visualize progress over time, they miss the evolving <em>relationship</em> between our two competing objectives.</p>
<p>A VAE is fundamentally a multi-objective optimization problem. We want to:</p>
<ol>
<li>Minimize Reconstruction Loss (BCE)</li>
<li>Minimize KL Divergence</li>
</ol>
<p>Combining them as the ELBO is common and effective, though it can mask some of the underlying dynamics.</p>
<p>These two goals are in direct conflict. To get perfect reconstruction (BCE = 0), the encoder would need to &ldquo;memorize&rdquo; each input, mapping it to a unique, precise point in latent space. This would cause the KL divergence to skyrocket, as these specific, &ldquo;pointy&rdquo; distributions are nothing like our smooth <code>N(0, 1)</code> prior.</p>
<p>Conversely, to get perfect KL divergence (KL = 0), the encoder must <em>always</em> output <code>N(0, 1)</code>, regardless of the input. This perfectly matches the prior. Since the latent code $\mathbf{z}$ now contains zero information about the input $\mathbf{x}$, the decoder can only learn to output the &ldquo;average&rdquo; image, resulting in terrible reconstruction.</p>
<p>The training process is a search for the best compromise.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-loss_scatter_epochs.webp"
         alt="Scatter plot of Test BCE vs KL Divergence, showing the training path from epoch 0 to 150"
         title="Scatter plot of Test BCE vs KL Divergence, showing the training path from epoch 0 to 150"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The training path on the Test set, plotting Reconstruction Loss (BCE) vs. KL Divergence. The model&rsquo;s journey clearly shows the trade-off between these two objectives.</figcaption>
    
</figure>

<p>This plot shows the test set&rsquo;s BCE (y-axis) vs. KL Divergence (x-axis) at every evaluation step. The color gradient from cool (blue) to warm (red) represents the training progress from Epoch 0 to 150.</p>
<p>Here&rsquo;s how to interpret this training path:</p>
<ol>
<li>
<p><strong>The Start (Green Diamond, ~Epoch 0):</strong> The model starts at the top-left.</p>
<ul>
<li><strong>High BCE (Reconstruction):</strong> The decoder is random and hasn&rsquo;t learned to reconstruct anything. Reconstruction is terrible.</li>
<li><strong>Low KL Divergence:</strong> The <em>encoder</em> is also random. Its output distributions $q_{\phi}(\mathbf{z} | \mathbf{x})$ are a random mess. On average, this &ldquo;mess&rdquo; is coincidentally close to the &ldquo;mess&rdquo; of the prior $p_{\theta}(\mathbf{z})$, so the KL penalty is low. The model isn&rsquo;t encoding any useful information yet, so it&rsquo;s not paying a high price for it.</li>
</ul>
</li>
<li>
<p><strong>Phase 1: The Initial Plunge (Blue Path):</strong> The path moves almost <em>straight down</em>.</p>
<ul>
<li><strong>BCE Plummets:</strong> The model&rsquo;s first and easiest task is to learn to reconstruct <em>something</em>. The optimizer finds massive, easy gains by making the decoder output &ldquo;blurry digits&rdquo; to replace the initial noise.</li>
<li><strong>KL Stays Low:</strong> The model achieves this huge reconstruction win without needing to learn a very complex latent space. It&rsquo;s the &ldquo;low-hanging fruit&rdquo; of training.</li>
</ul>
</li>
<li>
<p><strong>Phase 2: The Trade-Off (The &ldquo;Elbow&rdquo;):</strong> The path stops dropping vertically and starts moving to the <em>right and down</em>.</p>
<ul>
<li><strong>&ldquo;Spending&rdquo; KL to &ldquo;Buy&rdquo; Reconstruction:</strong> This is the true VAE trade-off in action. The easy wins are gone. To make the reconstructions sharper and more accurate (lowering BCE further), the model must now learn a more complex, informative latent representation.</li>
<li>It &ldquo;stretches&rdquo; the latent distributions $q_{\phi}(\mathbf{z} | \mathbf{x})$ to encode more details about each specific digit. This &ldquo;stretching&rdquo; moves it further from the simple <code>N(0, 1)</code> prior, and the KL divergence (the &ldquo;cost&rdquo;) goes up.</li>
</ul>
</li>
<li>
<p><strong>The End Game (Red Path &amp; Star):</strong> The path settles in the bottom-right corner.</p>
<ul>
<li><strong>Finding the &ldquo;Elbow&rdquo;:</strong> The model finds an equilibrium. It has pushed the KL divergence as high as it&rsquo;s &ldquo;worth&rdquo; for the reconstruction gains it gets. Trying to get even better reconstruction (moving further down) would cost an enormous, disproportionate amount in KL divergence (moving far to the right), and the total loss would increase.</li>
<li><strong>Best Recon (Orange Star):</strong> The best reconstruction model (Epoch 118) is found right at this &ldquo;elbow,&rdquo; representing the best-found balance point on the trade-off frontier.</li>
</ul>
</li>
</ol>
<p>This single plot visualizes the entire training dynamic as a journey along the <strong>Pareto frontier</strong>: the set of optimal solutions where you can&rsquo;t improve one objective (BCE) without worsening the other (KL).</p>
<h4 id="generative-performance">Generative Performance</h4>
<p>Let&rsquo;s take a look at how well this model can decode samples.</p>
<p><strong>Reconstruction Performance</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-reconstructions.webp"
         alt="Grid of original and reconstructed MNIST images from the test set using the trained VAE model"
         title="Grid of original and reconstructed MNIST images from the test set using the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Original (top row) vs. Reconstructed (bottom row) MNIST images from the test set using the trained VAE model.</figcaption>
    
</figure>

<p>Immediately, we see a couple of key points:</p>
<ul>
<li>Reconstructions are quite blurry compared to the originals. This is expected given the low capacity of the model and the extreme compression into a 2D latent space. General structure is typically preserved, while fine details are lost.</li>
<li>The network struggles with 4s and 9s, often mixing them up or producing ambiguous shapes. This is a common failure mode in MNIST models due to the similarity of these digits.</li>
</ul>
<p><strong>Sampling from the Prior</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-samples.webp"
         alt="Grid of MNIST-like images generated by sampling from the prior distribution using the trained VAE model"
         title="Grid of MNIST-like images generated by sampling from the prior distribution using the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MNIST-like images generated by sampling from the prior distribution using the trained VAE model.</figcaption>
    
</figure>

<p>If we sample from the prior <code>N(0, 1)</code> and decode those samples, we get a variety of digit-like images. From this, we get a pretty rich representation of digits. Almost all digits appear to be featured in this random sampling. Again, we see the standard blurriness.</p>
<p><strong>Sweeping the Latent Space</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-generation_interpolation.webp"
         alt="Grid of images generated by sweeping across the 2D latent space of the trained VAE model"
         title="Grid of images generated by sweeping across the 2D latent space of the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Images generated by sweeping across the 2D latent space of the trained VAE model.</figcaption>
    
</figure>

<p>We can select two points at random (here, two zeros), embed them into our latent space and then walk across that latent space to interpolate between two data points.
Here, we see a walk that takes us from a zero that is askew to one that is more upright.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-generation_latent_sweep.webp"
         alt="2D latent sweep, varying one dimension at a time while holding the other constant"
         title="2D latent sweep, varying one dimension at a time while holding the other constant"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent sweep, varying one dimension at a time while holding the other constant.</figcaption>
    
</figure>

<p>Finally, we can also sweep each latent dimension independently to see how they affect the generated images.</p>
<ol>
<li>Sweeping <code>z_1</code> (top row), we see a 5 become an 8 and then a 9. The slant shifts from left to right as we sweep the dimension.</li>
<li>Sweeping <code>z_2</code> (bottom row), we see a 4 become a 9 and then an 8. Then it becomes a 3, a 2, some nonsense, and a 6.</li>
</ol>
<p>So clearly each latent dimension is encoding some high-level features of the digits, and we can manipulate those features by moving in latent space.</p>
<h4 id="inspecting-the-latent-space">Inspecting the Latent Space</h4>
<p>What does the actual latent space look like?</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-latent_combined.webp"
         alt="2D latent space visualization with points colored by their true digit labels"
         title="2D latent space visualization with points colored by their true digit labels"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent space visualization with points colored by their true digit labels (left) and 2D heatmap of latent space density (right).</figcaption>
    
</figure>

<p>Even without class information, the network organizes the latent space to encode digit structure effectively. It also becomes immediately apparent why 4s and 9s are so confused by the model. That region is a dense mixture of the two.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-latent_marginals.webp"
         alt="1D histograms of each latent dimension compared to the standard normal distribution"
         title="1D histograms of each latent dimension compared to the standard normal distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">1D histograms of each latent dimension compared to the standard normal distribution.</figcaption>
    
</figure>

<p>We can also look at the marginal distributions of each latent dimension to see how well they match the prior <code>N(0, 1)</code>. Here, <code>z_1</code> is closer to the prior than <code>z_2</code>. <code>z_2</code> exhibits a bimodal marginal distribution, indicating that the encoder is using this dimension to separate two distinct clusters of data.</p>
<p>We also might want to understand how the log-variance of the latent distributions behaves.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-logvar_combined.webp"
         alt="2D latent space visualization with log-variance values and 1D histograms of log-variance for each latent dimension"
         title="2D latent space visualization with log-variance values and 1D histograms of log-variance for each latent dimension"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent space visualization with log-variance values with respect to digit class (left) and 2D heatmap of log-variance magnitude (right).</figcaption>
    
</figure>

<p>For the most part, we see similar concentration. Some digits are more concentrated than others, though in general the difference is slight.</p>
<h3 id="beyond-2d-higher-dimensional-latent-spaces">Beyond 2D: Higher-Dimensional Latent Spaces</h3>
<p>What happens as we increase the latent dimensionality? We must do dimensionality reduction to visualize latent spaces, giving us an approximate sense of how the latent space is organized.</p>
<table>
  <thead>
      <tr>
          <th>Latent Dimensionality</th>
          <th>Test Recon. Loss</th>
          <th>Test KL Loss</th>
          <th>Test Total Loss</th>
          <th>KL per Dim</th>
          <th>Active Dims (KL &gt; 0.1)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>2</td>
          <td>140.37</td>
          <td>6.67</td>
          <td>147.04</td>
          <td>3.34</td>
          <td>2</td>
      </tr>
      <tr>
          <td>4</td>
          <td>114.94</td>
          <td>10.61</td>
          <td>125.56</td>
          <td>2.65</td>
          <td>4</td>
      </tr>
      <tr>
          <td>8</td>
          <td>89.46</td>
          <td>16.84</td>
          <td>106.31</td>
          <td>2.11</td>
          <td>8</td>
      </tr>
      <tr>
          <td>16</td>
          <td>76.57</td>
          <td>23.65</td>
          <td>100.21</td>
          <td>1.48</td>
          <td>16</td>
      </tr>
      <tr>
          <td>32</td>
          <td>74.65</td>
          <td>25.59</td>
          <td>100.25</td>
          <td>0.80</td>
          <td>24</td>
      </tr>
  </tbody>
</table>
<p>As we double the dimensionality, we see a dominant trend at first:</p>
<ul>
<li>The reconstruction loss goes down</li>
<li>The KL loss goes up</li>
<li>The KL loss per dimension goes down</li>
</ul>
<p>Something odd happens when we jump from 16 to 32 latent dimensions: some of our latent dimensions become degenerate and stop encoding useful information.
This could be an indication we need to choose our hyperparameters a little more cautiously. Perhaps we need a different architecture. Or maybe there is an intrinsic limit to the dimensionality needed for this dataset past which it&rsquo;s not really helpful to keep scaling the latent dimension.</p>
<h4 id="training-dynamics">Training Dynamics</h4>















<figure class="post-figure center ">
    <img src="/img/vae-tut/loss_scatter_epochs.webp"
         alt="Scatter plot of Test BCE vs KL Divergence for different latent dimensionalities, showing training paths from epoch 0 to 150"
         title="Scatter plot of Test BCE vs KL Divergence for different latent dimensionalities, showing training paths from epoch 0 to 150"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The training paths on the Test set for different latent dimensionalities, plotting Reconstruction Loss (BCE) vs. KL Divergence. Each path shows the model&rsquo;s journey, clearly illustrating the trade-off between these two objectives.</figcaption>
    
</figure>

<p>The training dynamics show the battle between reconstruction and KL divergence for different latent dimensionalities. As we increase the latent dimensionality, the oscillation in the KL divergence becomes more pronounced. Particularly chaotic is the $D=16$ case, which struggles to find a stable equilibrium. By the time we expand to $D=32$, the KL penalty seems to overpower the ability to encode information in the latent space, leading to many inactive dimensions. The drop in KL complexity has staircase-like steps without clearly gaining reconstruction ability.</p>
<h4 id="reconstruction-and-generation">Reconstruction and Generation</h4>
<p>As we increase the latent dimensionality, the reconstruction quality improves significantly.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/reconstructions.webp"
         alt="Grid of original and reconstructed MNIST images from the test set using trained VAE models with different latent dimensionalities"
         title="Grid of original and reconstructed MNIST images from the test set using trained VAE models with different latent dimensionalities"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Original (top row) vs. Reconstructed (bottom row) MNIST images from the test set using trained VAE models with different latent dimensionalities.</figcaption>
    
</figure>

<p>As we increase the dimensionality, we see the increase in quality we&rsquo;d expect given the reduction in BCE reconstruction loss. In the jump to 4D, we&rsquo;re able to better resolve the differences between 4s and 9s. Images become much sharper by the time we hit 16 dimensions. The differences between 16 and 32 dimensions, however, are marginal.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/samples.webp"
         alt="Grid of MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities"
         title="Grid of MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities.</figcaption>
    
</figure>

<p>Sampling quality also improves with latent dimensionality. Images are sharper as we increase the dimensionality. However, the space seems to get sparser as we increase to the largest dimensionalities, which makes sense given the size and nature of our dataset.</p>
<h4 id="latent-space-visualizations">Latent Space Visualizations</h4>















<figure class="post-figure center ">
    <img src="/img/vae-tut/latent_combined.webp"
         alt="2D PCA projections of higher-dimensional latent spaces colored by their true digit labels"
         title="2D PCA projections of higher-dimensional latent spaces colored by their true digit labels"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D PCA projections of higher-dimensional latent spaces colored by their true digit labels.</figcaption>
    
</figure>

<p>The challenge with visualizing higher-dimensional latent spaces is that we must reduce their dimensionality to 2D. PCA struggles to capture the variance of higher dimensionalities. The 4D and 8D plots suggest increasingly better separation of the numeric classes. However, the 16D and 32D plots only show 10-20% of the variance and give a misleading image of overlap.</p>
<h2 id="conclusion">Conclusion</h2>
<p>In this tutorial, we&rsquo;ve journeyed from the core theory of Variational Autoencoders to a practical, modern PyTorch implementation and a series of experiments on the MNIST dataset. Our findings highlight several key takeaways for practitioners:</p>
<ol>
<li>
<p><strong>The VAE is a Balancing Act:</strong> The fundamental tension between reconstruction fidelity and latent space regularization is the core of the VAE. Our visualization of the BCE vs. KL loss trade-off clearly showed training as a search for an optimal point on this Pareto frontier, where improving one objective necessarily means sacrificing the other.</p>
</li>
<li>
<p><strong>Latent Dimensionality is a Critical Hyperparameter:</strong> Increasing the latent dimension consistently improved reconstruction quality with diminishing returns. As we saw in the jump from 16 to 32 dimensions, too much capacity can lead to &ldquo;inactive&rdquo; dimensions, where the KL penalty overpowers the model&rsquo;s ability to encode useful information. This demonstrates that choosing the right latent size is crucial for both performance and efficiency.</p>
</li>
<li>
<p><strong>VAEs Learn Meaningful Unsupervised Representations:</strong> Without any labels, our VAE successfully organized the latent space, clustering similar digits and enabling smooth interpolations. This underscores the power of VAEs for unsupervised learning, dimensionality reduction, and discovering the underlying structure in complex data.</p>
</li>
<li>
<p><strong>Implementation Details Matter:</strong> While different standard deviation parameterizations yielded similar results on this simple problem, understanding their gradient behaviors is key for tackling more complex datasets where training stability can be a major challenge. Proper loss scaling is similarly crucial to prevent one term from dominating the other and leading to issues like posterior collapse.</p>
</li>
</ol>
<p>While the classic VAE produces characteristically blurry reconstructions, it remains a foundational generative model. The principles we&rsquo;ve explored here (the ELBO, the reparameterization trick, and the trade-off between reconstruction and regularization) are central to many more advanced generative models used today.</p>
<p><strong>Questions or feedback?</strong> Feel free to reach out. I&rsquo;d love to hear about your experiences with VAE experiments!</p>
]]></content:encoded></item><item><title>Hearing Molecular Shape via Coulomb Matrix Eigenvalues</title><link>https://hunterheidenreich.com/posts/alkane-constitutional-isomer-classification/</link><pubDate>Sat, 24 Feb 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/alkane-constitutional-isomer-classification/</guid><description>Explore molecular shape recognition using Coulomb matrix eigenvalues. An analysis of alkane isomers, clustering limits, and supervised classification.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>Can you determine a molecule&rsquo;s shape from mathematical fingerprints alone? This question drives some of the most fundamental challenges in computational chemistry and machine learning. In the broader ML context, this is the classic search for the right <em>inductive bias</em> or <em>invariant representation</em>. Whether we are processing messy documents, natural language, or molecular dynamics, finding a representation that captures essential structure while ignoring irrelevant variations is critical.</p>
<p>I recently encountered a paper with an intriguing title: <a href="https://doi.org/10.1021/acs.jcim.0c00631">&ldquo;Can One Hear the Shape of a Molecule (from its Coulomb Matrix Eigenvalues)?&rdquo;</a> The title references Mark Kac&rsquo;s famous mathematical question <a href="https://www.math.ucdavis.edu/~hunter/m207b/kac.pdf">&ldquo;Can One Hear the Shape of a Drum?&rdquo;</a> exploring whether a drum&rsquo;s shape dictates its sound frequencies.</p>
<p>The molecular version asks: can we determine a molecule&rsquo;s structure from the eigenvalues of its <a href="/posts/molecular-descriptor-coulomb-matrix/">Coulomb matrix</a>?</p>
<p>Molecular representations are the foundation of machine learning in chemistry. If eigenvalues can capture structural information, they become powerful features for property prediction. Successfully separating simple structural differences is a prerequisite for handling more complex molecules.</p>
<p>The original authors tested this hypothesis using alkane constitutional isomers (molecules with identical formulas but different structural arrangements). I decided to replicate and extend their work to better understand both the methods and their limitations.</p>
<p>In this post, we will explore molecular representation through eigenvalue analysis, covering data generation, unsupervised clustering approaches, and supervised classification methods. I&rsquo;ll also explore log-transformed Coulomb matrices, which can reveal structural details that standard matrices miss.</p>
<h2 id="why-alkanes-make-ideal-test-cases">Why Alkanes Make Ideal Test Cases</h2>
<p><a href="https://en.wikipedia.org/wiki/Alkane">Alkanes</a> are the simplest organic molecules: carbon and hydrogen connected by single bonds with the general formula $C_{n}H_{2n+2}$.</p>
<p>What makes them perfect for testing molecular representations is their constitutional isomers: molecules with identical formulas but different structural arrangements. For small alkanes ($n \leq 3$), atoms can connect in only one way. Starting with butane ($n = 4$), multiple arrangements become possible:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/4-Butane-3D-balls.webp"
         alt="Butane as a ball-and-stick model."
         title="Butane as a ball-and-stick model."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Butane: a linear chain</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/4-Isobutane-3D-balls.webp"
         alt="Isobutane as a ball-and-stick model."
         title="Isobutane as a ball-and-stick model."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Isobutane: a branched structure</figcaption>
    
</figure>

<p>The number of isomers grows rapidly with molecular size. By undecane ($n = 11$), there are 159 different structural arrangements:</p>
<table>
  <thead>
      <tr>
          <th>Alkane</th>
          <th>n</th>
          <th>Isomers</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Butane</td>
          <td>4</td>
          <td>2</td>
      </tr>
      <tr>
          <td>Pentane</td>
          <td>5</td>
          <td>3</td>
      </tr>
      <tr>
          <td>Hexane</td>
          <td>6</td>
          <td>5</td>
      </tr>
      <tr>
          <td>Heptane</td>
          <td>7</td>
          <td>9</td>
      </tr>
      <tr>
          <td>Octane</td>
          <td>8</td>
          <td>18</td>
      </tr>
      <tr>
          <td>Nonane</td>
          <td>9</td>
          <td>35</td>
      </tr>
      <tr>
          <td>Decane</td>
          <td>10</td>
          <td>75</td>
      </tr>
      <tr>
          <td>Undecane</td>
          <td>11</td>
          <td>159</td>
      </tr>
  </tbody>
</table>
<p>This creates a natural classification challenge: can Coulomb matrix eigenvalues distinguish these structural differences? Successfully separating simple alkane isomers is a prerequisite for handling more complex molecules.</p>
<h2 id="computational-pipeline">Computational Pipeline</h2>
<p>The analysis requires three computational steps:</p>
<ol>
<li><strong>Generate constitutional isomers</strong> for each alkane formula</li>
<li><strong>Create multiple 3D conformations</strong> for each isomer</li>
<li><strong>Calculate Coulomb matrix eigenvalues</strong> for each conformation</li>
</ol>
<h3 id="generating-constitutional-isomers">Generating Constitutional Isomers</h3>
<p>Enumerating all possible carbon skeletons is a combinatorial problem. I used <a href="https://github.com/MehmetAzizYirik/MAYGEN">MAYGEN</a>, an open-source Java tool for generating molecular structures from chemical formulas.</p>
<p>For butane ($C_{4}H_{10}$):</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bash" data-lang="bash"><span style="display:flex;"><span>java -jar MAYGEN-1.8.jar -v -m -f C4H10 -smi -o butane_conformers.smi
</span></span></code></pre></div><p>This generates:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-plaintext" data-lang="plaintext"><span style="display:flex;"><span>CCCC
</span></span><span style="display:flex;"><span>CC(C)C
</span></span></code></pre></div><p>The first is n-butane (linear), the second is isobutane (branched). We can automate this across all alkanes:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>os<span style="color:#f92672">.</span>makedirs(<span style="color:#e6db74">&#39;isomers&#39;</span>, exist_ok<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    cmd <span style="color:#f92672">=</span> <span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;java -jar MAYGEN-1.8.jar -f C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74"> -smi -o isomers/C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.smi&#34;</span>
</span></span><span style="display:flex;"><span>    os<span style="color:#f92672">.</span>system(cmd)
</span></span></code></pre></div><h3 id="generating-3d-conformations">Generating 3D Conformations</h3>
<p>For machine learning applications, we need multiple 3D structures of each isomer to capture conformational flexibility. I used <a href="https://github.com/rdkit/rdkit">RDKit</a>&rsquo;s ETKDG method, which <a href="https://doi.org/10.1021/acs.jcim.7b00505">remains competitive</a> with commercial alternatives:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> rdkit.Chem <span style="color:#f92672">import</span> AllChem
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">smiles_str_to_rdkit_mol</span>(smiles_str: str) <span style="color:#f92672">-&gt;</span> rdkit<span style="color:#f92672">.</span>Chem<span style="color:#f92672">.</span>Mol:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Convert a SMILES string to an RDKit mol object.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    - smiles_str (str): A SMILES string representing a molecule.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Returns:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    - mol (rdkit.Chem.Mol): An RDKit mol object representing the molecule.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Convert SMILES string to RDKit mol object</span>
</span></span><span style="display:flex;"><span>    mol <span style="color:#f92672">=</span> rdkit<span style="color:#f92672">.</span>Chem<span style="color:#f92672">.</span>MolFromSmiles(smiles_str)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Add hydrogens to the molecule</span>
</span></span><span style="display:flex;"><span>    mol <span style="color:#f92672">=</span> rdkit<span style="color:#f92672">.</span>Chem<span style="color:#f92672">.</span>AddHs(mol)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Assign 3D coordinates to the molecule</span>
</span></span><span style="display:flex;"><span>    AllChem<span style="color:#f92672">.</span>EmbedMolecule(mol)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> mol
</span></span></code></pre></div><h3 id="computing-coulomb-matrix-eigenvalues">Computing Coulomb Matrix Eigenvalues</h3>
<p>The Coulomb matrix encodes 3D structure in a rotation and translation-invariant way. Its eigenvalues should capture structural information while remaining invariant to molecular orientation.</p>
<p>First, I wrote a helper function to convert RDKit molecules into <a href="https://ase-lib.org/">ASE</a> <code>Atoms</code> objects that <a href="https://singroup.github.io/dscribe/">DScribe</a> can process:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">rdkit_mol_to_ase_atoms</span>(rdkit_mol: rdkit<span style="color:#f92672">.</span>Chem<span style="color:#f92672">.</span>Mol) <span style="color:#f92672">-&gt;</span> ase<span style="color:#f92672">.</span>Atoms:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Convert an RDKit molecule to an ASE Atoms object.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        rdkit_mol: RDKit molecule object.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Returns:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        ASE Atoms object.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    ase_atoms <span style="color:#f92672">=</span> ase<span style="color:#f92672">.</span>Atoms(
</span></span><span style="display:flex;"><span>        numbers<span style="color:#f92672">=</span>[
</span></span><span style="display:flex;"><span>            atom<span style="color:#f92672">.</span>GetAtomicNum() <span style="color:#66d9ef">for</span> atom <span style="color:#f92672">in</span> rdkit_mol<span style="color:#f92672">.</span>GetAtoms()
</span></span><span style="display:flex;"><span>        ],
</span></span><span style="display:flex;"><span>        positions<span style="color:#f92672">=</span>rdkit_mol<span style="color:#f92672">.</span>GetConformer()<span style="color:#f92672">.</span>GetPositions()
</span></span><span style="display:flex;"><span>    )
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> ase_atoms
</span></span></code></pre></div><p>Then I computed Coulomb matrix eigenvalues using DScribe, with optional log transformation:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">ase_atoms_to_coloumb_matrix_eigenvalues</span>(
</span></span><span style="display:flex;"><span>    ase_atoms: ase<span style="color:#f92672">.</span>Atoms,
</span></span><span style="display:flex;"><span>    log: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>
</span></span><span style="display:flex;"><span>) <span style="color:#f92672">-&gt;</span> np<span style="color:#f92672">.</span>ndarray:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Convert an ASE Atoms object to a Coulomb matrix and calculate its eigenvalues.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        ase_atoms: ASE Atoms object.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        log: Whether to log transform the Coulomb matrix prior to calculating the eigenvalues.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Returns:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Eigenvalues of the Coulomb matrix.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Create a Coulomb matrix</span>
</span></span><span style="display:flex;"><span>    coulomb_matrix <span style="color:#f92672">=</span> dscribe<span style="color:#f92672">.</span>descriptors<span style="color:#f92672">.</span>CoulombMatrix(
</span></span><span style="display:flex;"><span>        n_atoms_max<span style="color:#f92672">=</span>ase_atoms<span style="color:#f92672">.</span>get_global_number_of_atoms(),
</span></span><span style="display:flex;"><span>    )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Calculate the Coulomb matrix</span>
</span></span><span style="display:flex;"><span>    coulomb_matrix <span style="color:#f92672">=</span> coulomb_matrix<span style="color:#f92672">.</span>create(ase_atoms)
</span></span><span style="display:flex;"><span>    coulomb_matrix <span style="color:#f92672">=</span> coulomb_matrix<span style="color:#f92672">.</span>reshape(
</span></span><span style="display:flex;"><span>        ase_atoms<span style="color:#f92672">.</span>get_global_number_of_atoms(),
</span></span><span style="display:flex;"><span>        ase_atoms<span style="color:#f92672">.</span>get_global_number_of_atoms())
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> log:
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Log transform the Coulomb matrix</span>
</span></span><span style="display:flex;"><span>        coulomb_matrix <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>log(coulomb_matrix)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Calculate the eigenvalues of the Coulomb matrix</span>
</span></span><span style="display:flex;"><span>    eigenvalues <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>eigvals(coulomb_matrix)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> eigenvalues
</span></span></code></pre></div><p>Combining these functions enables efficient data generation:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Generate 1000 conformations per isomer for each alkane</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># gen_n_spectra() combines the above functions to generate multiple conformations</span>
</span></span><span style="display:flex;"><span>os<span style="color:#f92672">.</span>makedirs(<span style="color:#e6db74">&#39;spectra&#39;</span>, exist_ok<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>n_confs <span style="color:#f92672">=</span> <span style="color:#ae81ff">1000</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;Generating spectra for C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">with</span> open(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;isomers/C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.smi&#39;</span>) <span style="color:#66d9ef">as</span> f:
</span></span><span style="display:flex;"><span>        smiles_list <span style="color:#f92672">=</span> [line<span style="color:#f92672">.</span>strip() <span style="color:#66d9ef">for</span> line <span style="color:#f92672">in</span> f]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i, smiles <span style="color:#f92672">in</span> enumerate(smiles_list):
</span></span><span style="display:flex;"><span>        spectra <span style="color:#f92672">=</span> gen_n_spectra(smiles, n_confs, log<span style="color:#f92672">=</span><span style="color:#66d9ef">False</span>)
</span></span><span style="display:flex;"><span>        np<span style="color:#f92672">.</span>save(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;spectra/C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">_</span><span style="color:#e6db74">{</span>i<span style="color:#e6db74">}</span><span style="color:#e6db74">.npy&#39;</span>, spectra)
</span></span></code></pre></div><h2 id="reproducing-the-original-results">Reproducing the Original Results</h2>
<p>To validate our computational pipeline, I replicated key figures from the original paper. This ensures our implementation correctly captures the phenomena they observed.</p>
<p>We generate data using 1000 conformations per isomer:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>os<span style="color:#f92672">.</span>makedirs(<span style="color:#e6db74">&#39;spectra&#39;</span>, exist_ok<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>n_confs <span style="color:#f92672">=</span> <span style="color:#ae81ff">1000</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;Generating spectra for C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">with</span> open(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;isomers/C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.smi&#39;</span>) <span style="color:#66d9ef">as</span> f:
</span></span><span style="display:flex;"><span>        lines <span style="color:#f92672">=</span> f<span style="color:#f92672">.</span>readlines()
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">for</span> i, line <span style="color:#f92672">in</span> enumerate(lines):
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">if</span> <span style="color:#f92672">not</span> line<span style="color:#f92672">.</span>strip():
</span></span><span style="display:flex;"><span>                <span style="color:#66d9ef">continue</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>            smiles <span style="color:#f92672">=</span> line<span style="color:#f92672">.</span>strip()
</span></span><span style="display:flex;"><span>            spectra <span style="color:#f92672">=</span> gen_n_spectra(n_confs, smiles, log<span style="color:#f92672">=</span><span style="color:#66d9ef">False</span>)
</span></span><span style="display:flex;"><span>            np<span style="color:#f92672">.</span>save(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;spectra/C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">_</span><span style="color:#e6db74">{</span>i<span style="color:#e6db74">:</span><span style="color:#e6db74">03d</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.npy&#39;</span>, spectra)
</span></span></code></pre></div><p>After generation, we can load the data into a structured format:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> re
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> glob <span style="color:#f92672">import</span> glob
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>spectra <span style="color:#f92672">=</span> {}
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    spectra[n] <span style="color:#f92672">=</span> {}
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> f <span style="color:#f92672">in</span> glob(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;spectra/C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">_*.npy&#39;</span>):
</span></span><span style="display:flex;"><span>        j <span style="color:#f92672">=</span> int(re<span style="color:#f92672">.</span>search(<span style="color:#e6db74">rf</span><span style="color:#e6db74">&#39;C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">_(\d+).npy&#39;</span>, f)<span style="color:#f92672">.</span>group(<span style="color:#ae81ff">1</span>))
</span></span><span style="display:flex;"><span>        spectra[n][j] <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>load(f)
</span></span></code></pre></div><h3 id="largest-eigenvalues-across-alkane-series">Largest Eigenvalues Across Alkane Series</h3>
<p>The first analysis examines how the largest Coulomb matrix eigenvalues vary across constitutional isomers for each alkane formula. This plot reveals whether single eigenvalues can distinguish between different molecular formulas.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>fig, ax <span style="color:#f92672">=</span> plt<span style="color:#f92672">.</span>subplots(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>, figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">10</span>, <span style="color:#ae81ff">5</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    eigenvalues <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>array([spectra[n][i][:, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>mean() <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> spectra[n]])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># add black dots to boxplot, but not dirextly to the center line</span>
</span></span><span style="display:flex;"><span>    jitter <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>normal(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">0.1</span>, size<span style="color:#f92672">=</span>len(eigenvalues))
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>scatter(np<span style="color:#f92672">.</span>full(len(eigenvalues), n) <span style="color:#f92672">+</span> jitter, eigenvalues, color<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;black&#39;</span>, s<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.3</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Plot median</span>
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>scatter([n], [np<span style="color:#f92672">.</span>median(eigenvalues)], color<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;red&#39;</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.5</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Plot range</span>
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>plot([n <span style="color:#f92672">-</span> <span style="color:#ae81ff">0.5</span>, n <span style="color:#f92672">+</span> <span style="color:#ae81ff">0.5</span>], [np<span style="color:#f92672">.</span>min(eigenvalues), np<span style="color:#f92672">.</span>min(eigenvalues)], <span style="color:#e6db74">&#39;k-&#39;</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.5</span>)
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>plot([n <span style="color:#f92672">-</span> <span style="color:#ae81ff">0.5</span>, n <span style="color:#f92672">+</span> <span style="color:#ae81ff">0.5</span>], [np<span style="color:#f92672">.</span>max(eigenvalues), np<span style="color:#f92672">.</span>max(eigenvalues)], <span style="color:#e6db74">&#39;k-&#39;</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.5</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Molecular formula&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Largest eigenvalue&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_xticks(range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>))
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_xticklabels([<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span> <span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>)])
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;Largest eigenvalues of the Coulomb matrix for alkane constitutional isomers&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>savefig(<span style="color:#e6db74">&#39;alkane_coulomb_matrix_largest_eigenvalues.webp&#39;</span>, bbox_inches<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;tight&#39;</span>)
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/alkane_coulomb_matrix_largest_eigenvalues.webp"
         alt="Largest eigenvalues of the Coulomb matrix for alkane constitutional isomers."
         title="Largest eigenvalues of the Coulomb matrix for alkane constitutional isomers."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Largest eigenvalues show sub-linear growth with molecular size and increasing overlap between isomers.</figcaption>
    
</figure>

<p>Our results match the original paper. We observe sub-linear growth in the largest eigenvalue with carbon number, and critically, increasing overlap between isomers as molecules grow larger. The largest eigenvalue alone cannot reliably distinguish constitutional isomers for larger alkanes.</p>
<h3 id="eigenvalue-distributions-for-heptane-isomers">Eigenvalue Distributions for Heptane Isomers</h3>
<p>Looking deeper at a specific case, I analyzed the probability density functions for heptane ($C_7H_{16}$) isomers. This molecule has nine constitutional isomers, providing a good test of discrimination power.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n_sel <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">8</span>):
</span></span><span style="display:flex;"><span>    fig, ax <span style="color:#f92672">=</span> plt<span style="color:#f92672">.</span>subplots(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>, figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">10</span>, <span style="color:#ae81ff">5</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    smiles <span style="color:#f92672">=</span> []
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">with</span> open(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;isomers/C</span><span style="color:#e6db74">{</span>n_sel<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n_sel <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.smi&#39;</span>) <span style="color:#66d9ef">as</span> f:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">for</span> line <span style="color:#f92672">in</span> f:
</span></span><span style="display:flex;"><span>            smiles<span style="color:#f92672">.</span>append(line<span style="color:#f92672">.</span>strip())
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> range(len(spectra[n_sel])):
</span></span><span style="display:flex;"><span>        eigenvalues <span style="color:#f92672">=</span> spectra[n_sel][i][:, <span style="color:#ae81ff">0</span>]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># kde plot with the following params</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># - Gaussian kernel</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># - bandwidth with Silverman&#39;s rule of thumb</span>
</span></span><span style="display:flex;"><span>        ax <span style="color:#f92672">=</span> sns<span style="color:#f92672">.</span>kdeplot(eigenvalues, bw_method<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;silverman&#39;</span>, label<span style="color:#f92672">=</span>get_iupac_name(smiles[i]))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Largest eigenvalue&#39;</span>)
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Density&#39;</span>)
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;PDF of the largest eigenvalue for $C_&#39;</span> <span style="color:#f92672">+</span> str(n_sel) <span style="color:#f92672">+</span> <span style="color:#e6db74">&#39;H_{&#39;</span> <span style="color:#f92672">+</span> str(<span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n_sel <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> <span style="color:#e6db74">&#39;}$&#39;</span>)
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>legend()
</span></span><span style="display:flex;"><span>    plt<span style="color:#f92672">.</span>savefig(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;pdf_largest_eigenvalue_C</span><span style="color:#e6db74">{</span>n_sel<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n_sel <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.webp&#39;</span>, bbox_inches<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;tight&#39;</span>)
</span></span><span style="display:flex;"><span>    plt<span style="color:#f92672">.</span>close()
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/pdf_largest_eigenvalue_C7H16.webp"
         alt="PDFs of the largest eigenvalue for heptane isomers."
         title="PDFs of the largest eigenvalue for heptane isomers."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Heptane isomers show distinct eigenvalue ranges: n-heptane (smallest), 2,2,3-trimethylbutane (largest), with others overlapping.</figcaption>
    
</figure>

<p>The pattern is clear:</p>
<ul>
<li><strong>n-heptane</strong> (linear chain) has the smallest eigenvalues</li>
<li><strong>2,2,3-trimethylbutane</strong> (highly branched) has the largest</li>
<li><strong>Seven other isomers</strong> fall in between with substantial overlap</li>
</ul>
<p>This demonstrates the fundamental limitation: while extreme structural differences (linear vs. highly branched) create separable eigenvalue distributions, intermediate structures become indistinguishable.</p>
<p>For smaller alkanes, the separation is more promising:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/pdf_largest_eigenvalue_C4H10.webp"
         alt="PDFs for butane isomers."
         title="PDFs for butane isomers."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Butane (n=4): Clean separation between linear and branched structures.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/pdf_largest_eigenvalue_C5H12.webp"
         alt="PDFs for pentane isomers."
         title="PDFs for pentane isomers."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Pentane (n=5): Good separation between most isomers.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/pdf_largest_eigenvalue_C6H14.webp"
         alt="PDFs for hexane isomers."
         title="PDFs for hexane isomers."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Hexane (n=6): Some isomers (2-methylpentane, 3-methylpentane) become difficult to distinguish.</figcaption>
    
</figure>

<p>The progression is clear: eigenvalue-based discrimination works well for small alkanes and degrades as molecular complexity increases.</p>
<h3 id="two-dimensional-eigenvalue-space">Two-Dimensional Eigenvalue Space</h3>
<p>Can we improve discrimination by using multiple eigenvalues? For butane, plotting the first two eigenvalues reveals interesting structure:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>fig, ax <span style="color:#f92672">=</span> plt<span style="color:#f92672">.</span>subplots(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>, figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">10</span>, <span style="color:#ae81ff">5</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>n_sel <span style="color:#f92672">=</span> <span style="color:#ae81ff">4</span>
</span></span><span style="display:flex;"><span>smiles <span style="color:#f92672">=</span> []
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">with</span> open(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;isomers/C</span><span style="color:#e6db74">{</span>n_sel<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n_sel <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.smi&#39;</span>) <span style="color:#66d9ef">as</span> f:
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> line <span style="color:#f92672">in</span> f:
</span></span><span style="display:flex;"><span>        smiles<span style="color:#f92672">.</span>append(line<span style="color:#f92672">.</span>strip())
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> range(len(spectra[n_sel])):
</span></span><span style="display:flex;"><span>    eigenvalues <span style="color:#f92672">=</span> spectra[n_sel][i][:, :<span style="color:#ae81ff">2</span>]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>scatter(eigenvalues[:, <span style="color:#ae81ff">0</span>], eigenvalues[:, <span style="color:#ae81ff">1</span>], label<span style="color:#f92672">=</span>get_iupac_name(smiles[i]))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Largest eigenvalue&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Second largest eigenvalue&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;2D plot of the first two eigenvalues for $C_4H_</span><span style="color:#e6db74">{10}</span><span style="color:#e6db74">$ conformers&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>legend()
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>savefig(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;2d_largest_eigenvalue_C</span><span style="color:#e6db74">{</span>n_sel<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n_sel <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.webp&#39;</span>, bbox_inches<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;tight&#39;</span>)
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/2d_largest_eigenvalue_C4H10.webp"
         alt="2D eigenvalue space for butane isomers."
         title="2D eigenvalue space for butane isomers."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Perfect linear separation of butane isomers using the first two eigenvalues.</figcaption>
    
</figure>

<p>The two isomers cluster distinctly, demonstrating that multi-dimensional eigenvalue features can achieve perfect separation for simple cases. The outlier point in the lower right likely results from conformational sampling noise.</p>
<h3 id="dimensionality-and-information-content">Dimensionality and Information Content</h3>
<p>How many eigenvalues do we actually need? Principal component analysis reveals the effective dimensionality of the eigenvalue representations:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.decomposition <span style="color:#f92672">import</span> PCA
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>fig, ax <span style="color:#f92672">=</span> plt<span style="color:#f92672">.</span>subplots(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>, figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">10</span>, <span style="color:#ae81ff">5</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>n_components <span style="color:#f92672">=</span> {}
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    n_components[n] <span style="color:#f92672">=</span> []
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> range(len(spectra[n])):
</span></span><span style="display:flex;"><span>        eigenvalues <span style="color:#f92672">=</span> spectra[n][i]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># PCA</span>
</span></span><span style="display:flex;"><span>        pca <span style="color:#f92672">=</span> PCA(n_components<span style="color:#f92672">=</span><span style="color:#ae81ff">0.99</span>, svd_solver<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;full&#39;</span>, whiten<span style="color:#f92672">=</span><span style="color:#66d9ef">False</span>, random_state<span style="color:#f92672">=</span><span style="color:#ae81ff">42</span>)
</span></span><span style="display:flex;"><span>        pca<span style="color:#f92672">.</span>fit(eigenvalues)
</span></span><span style="display:flex;"><span>        n_components[n]<span style="color:#f92672">.</span>append(pca<span style="color:#f92672">.</span>n_components_)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>scatter([n] <span style="color:#f92672">*</span> len(n_components[n]), n_components[n], alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.3</span>)
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>plot([n <span style="color:#f92672">-</span> <span style="color:#ae81ff">0.25</span>, n <span style="color:#f92672">+</span> <span style="color:#ae81ff">0.25</span>], [np<span style="color:#f92672">.</span>mean(n_components[n]), np<span style="color:#f92672">.</span>mean(n_components[n])], <span style="color:#e6db74">&#39;k-&#39;</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.5</span>)  <span style="color:#75715e"># Draw a line for the mean</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>plot([<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">11</span>], [<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">11</span>], <span style="color:#e6db74">&#39;k--&#39;</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.5</span>, label<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;y = num carbon&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>plot([<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">11</span>], [<span style="color:#ae81ff">5</span>, <span style="color:#ae81ff">35</span>], <span style="color:#e6db74">&#39;r--&#39;</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.5</span>, label<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;y = num atoms&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Number of principal components&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;99% variance explained by number of principal components&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>legend()
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>savefig(<span style="color:#e6db74">&#39;99_variance_explained.webp&#39;</span>, bbox_inches<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;tight&#39;</span>)
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/99_variance_explained.webp"
         alt="Principal components needed for 99% variance."
         title="Principal components needed for 99% variance."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The eigenvalue space compresses efficiently. Far fewer components than the $3n+2$ eigenvalue dimensions are needed.</figcaption>
    
</figure>

<p>Key observations:</p>
<ul>
<li><strong>High compressibility</strong>: Far fewer than the $3n+2$ eigenvalue dimensions (one per atom) are needed</li>
<li><strong>Linear scaling</strong>: Principal components grow roughly linearly with carbon number</li>
<li><strong>Efficient representation</strong>: The eigenvalue space has lower effective dimensionality than expected</li>
</ul>
<p>This suggests the representations are highly correlated, enabling significant dimensionality reduction without information loss.</p>
<h2 id="log-transformed-coulomb-matrices">Log-Transformed Coulomb Matrices</h2>
<p>As explored in our <a href="/posts/molecular-descriptor-coulomb-matrix/">previous post on Coulomb matrices</a>, log transformation can reveal different structural information. Standard Coulomb matrices emphasize heavy atom interactions. Log transformation expands the influence of hydrogen atoms by mapping magnitudes in $[0,1]$ to $[-\infty,0]$.</p>
<p>I generated equivalent datasets using log-transformed matrices to test how this affects discriminative power.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>os<span style="color:#f92672">.</span>makedirs(<span style="color:#e6db74">&#39;spectra&#39;</span>, exist_ok<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>n_confs <span style="color:#f92672">=</span> <span style="color:#ae81ff">1000</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;Generating spectra for C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">with</span> open(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;isomers/C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.smi&#39;</span>) <span style="color:#66d9ef">as</span> f:
</span></span><span style="display:flex;"><span>        lines <span style="color:#f92672">=</span> f<span style="color:#f92672">.</span>readlines()
</span></span><span style="display:flex;"><span>        print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;</span><span style="color:#ae81ff">\t</span><span style="color:#e6db74">Number of SMILES strings: </span><span style="color:#e6db74">{</span>len(lines)<span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">for</span> i, line <span style="color:#f92672">in</span> enumerate(tqdm(lines)):
</span></span><span style="display:flex;"><span>            print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;</span><span style="color:#ae81ff">\t\t</span><span style="color:#e6db74">{</span>i <span style="color:#f92672">+</span> <span style="color:#ae81ff">1</span><span style="color:#e6db74">}</span><span style="color:#e6db74">/</span><span style="color:#e6db74">{</span>len(lines)<span style="color:#e6db74">}</span><span style="color:#e6db74"> - </span><span style="color:#e6db74">{</span>line<span style="color:#f92672">.</span>strip()<span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">if</span> <span style="color:#f92672">not</span> line<span style="color:#f92672">.</span>strip():
</span></span><span style="display:flex;"><span>                <span style="color:#66d9ef">continue</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">if</span> os<span style="color:#f92672">.</span>path<span style="color:#f92672">.</span>exists(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;spectra/log-C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">_</span><span style="color:#e6db74">{</span>i<span style="color:#e6db74">}</span><span style="color:#e6db74">.npy&#39;</span>):
</span></span><span style="display:flex;"><span>                <span style="color:#66d9ef">continue</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>            smiles <span style="color:#f92672">=</span> line<span style="color:#f92672">.</span>strip()
</span></span><span style="display:flex;"><span>            spectra <span style="color:#f92672">=</span> gen_n_spectra(n<span style="color:#f92672">=</span>n_confs, smiles_str<span style="color:#f92672">=</span>smiles, log<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>            np<span style="color:#f92672">.</span>save(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;spectra/log-C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">_</span><span style="color:#e6db74">{</span>i<span style="color:#e6db74">:</span><span style="color:#e6db74">03d</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.npy&#39;</span>, spectra)
</span></span></code></pre></div><h3 id="log-transformed-eigenvalue-distributions">Log-Transformed Eigenvalue Distributions</h3>
<p>The log transformation substantially changes the eigenvalue landscape:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/alkane_log_coulomb_matrix_largest_eigenvalues.webp"
         alt="Log-transformed eigenvalues across alkane series."
         title="Log-transformed eigenvalues across alkane series."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Log transformation emphasizes hydrogen interactions, creating larger eigenvalue ranges and more negative values.</figcaption>
    
</figure>

<p>Log-transformed versions exhibit distinct characteristics:</p>
<ul>
<li><strong>Span large negative values</strong> due to hydrogen atom emphasis</li>
<li><strong>Show increasing variance</strong> between isomers as molecular size grows</li>
<li><strong>Demonstrate greater discrimination potential</strong> for some isomers</li>
</ul>
<p>This comes with trade-offs. The distributions can become significantly broader, as seen in the heptane analysis:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/pdf_log_largest_eigenvalue_C7H16.webp"
         alt="Log-transformed eigenvalue PDFs for heptane."
         title="Log-transformed eigenvalue PDFs for heptane."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Log transformation creates wider, more overlapping distributions that may reduce discrimination power.</figcaption>
    
</figure>

<p>The log scale on the y-axis is necessary because unbranched isomers become nearly invisible due to the highly concentrated distributions of branched isomers.</p>
<h3 id="two-dimensional-log-transformed-space">Two-Dimensional Log-Transformed Space</h3>
<p>The 2D eigenvalue plot for log-transformed butane shows similar clustering behavior:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/2d_log_largest_eigenvalue_C4H10.webp"
         alt="2D log-transformed eigenvalue space for butane."
         title="2D log-transformed eigenvalue space for butane."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Log transformation brings isomers closer together while maintaining separability for simple cases.</figcaption>
    
</figure>

<p>The transformation reduces the separation distance, yet linear discriminability remains intact for this simple case.</p>
<h3 id="dimensionality-of-log-transformed-features">Dimensionality of Log-Transformed Features</h3>
<p>Principal component analysis of log-transformed eigenvalues reveals similar compression properties:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/99_variance_explained_log.webp"
         alt="Principal components for log-transformed eigenvalues."
         title="Principal components for log-transformed eigenvalues."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Log transformation requires slightly more principal components while maintaining efficient compression.</figcaption>
    
</figure>

<p>The log-transformed features show comparable dimensionality reduction with marginally higher component requirements.</p>
<h2 id="testing-eigenvalue-separability">Testing Eigenvalue Separability</h2>
<p>Our exploratory analysis revealed concerning patterns that hint at fundamental limitations: high correlation between eigenvalue dimensions, rapid dimensionality compression via PCA, and overlapping distributions for larger molecules ($n \geq 6$).</p>
<p>These findings leave the question open. We now test eigenvalues directly to see if they can actually separate constitutional isomers without supervision.</p>
<p>We&rsquo;ll use two complementary clustering metrics to measure how well eigenvalues separate constitutional isomers. This is a fair test. We only compare isomers with identical molecular formulas, keeping eigenvalue dimensions constant.</p>
<h3 id="dunn-index-global-cluster-quality">Dunn Index: Global Cluster Quality</h3>
<p>The <a href="https://en.wikipedia.org/wiki/Dunn_index">Dunn Index</a> provides a single metric capturing cluster quality. It asks: &ldquo;Are the closest different clusters still farther apart than the most spread-out individual cluster?&rdquo;</p>
<p>$$
\text{Dunn Index} = \frac{\text{smallest distance between different clusters}}{\text{largest diameter within any cluster}}
$$</p>
<p>Higher values indicate better separation. When it approaches zero, clusters become indistinguishable, exactly what we suspected from the overlapping eigenvalue distributions observed earlier.</p>
<p>Computing the Dunn Index for each alkane series:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> time
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>dunn_scores <span style="color:#f92672">=</span> {}
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    tik <span style="color:#f92672">=</span> time<span style="color:#f92672">.</span>time()
</span></span><span style="display:flex;"><span>    dunn_scores[n] <span style="color:#f92672">=</span> dunn_index([spectra[n][i] <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> spectra[n]])
</span></span><span style="display:flex;"><span>    tok <span style="color:#f92672">=</span> time<span style="color:#f92672">.</span>time()
</span></span><span style="display:flex;"><span>    dunn_scores[n][<span style="color:#e6db74">&#39;time&#39;</span>] <span style="color:#f92672">=</span> tok <span style="color:#f92672">-</span> tik
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">:&#39;</span>, dunn_scores[n])
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-plaintext" data-lang="plaintext"><span style="display:flex;"><span>C4H10: {&#39;diameter&#39;: 21.43072917950398, &#39;distance&#39;: 8.316362440688767, &#39;dunn_index&#39;: 0.3880578383978837, &#39;time&#39;: 0.06010293960571289}
</span></span><span style="display:flex;"><span>C5H12: {&#39;diameter&#39;: 23.449286379564892, &#39;distance&#39;: 2.4693042873545856, &#39;dunn_index&#39;: 0.10530402705587172, &#39;time&#39;: 0.10832405090332031}
</span></span><span style="display:flex;"><span>C6H14: {&#39;diameter&#39;: 19.602363375467938, &#39;distance&#39;: 1.4477574259511048, &#39;dunn_index&#39;: 0.07385626917634591, &#39;time&#39;: 0.28030991554260254}
</span></span><span style="display:flex;"><span>C7H16: {&#39;diameter&#39;: 20.065014927470955, &#39;distance&#39;: 0.4050094394280803, &#39;dunn_index&#39;: 0.02018485612355977, &#39;time&#39;: 1.0307331085205078}
</span></span><span style="display:flex;"><span>C8H18: {&#39;diameter&#39;: 24.794154667613665, &#39;distance&#39;: 0.5013450168168625, &#39;dunn_index&#39;: 0.020220290771668196, &#39;time&#39;: 4.199508905410767}
</span></span><span style="display:flex;"><span>C9H20: {&#39;diameter&#39;: 21.811025941686033, &#39;distance&#39;: 0.34381162248560415, &#39;dunn_index&#39;: 0.01576320267578513, &#39;time&#39;: 17.400264978408813}
</span></span><span style="display:flex;"><span>C10H22: {&#39;diameter&#39;: 27.180773716656066, &#39;distance&#39;: 0.4986608768730121, &#39;dunn_index&#39;: 0.0183460883811206, &#39;time&#39;: 86.00787401199341}
</span></span><span style="display:flex;"><span>C11H24: {&#39;diameter&#39;: 25.58731511020692, &#39;distance&#39;: 0.5490373275460223, &#39;dunn_index&#39;: 0.021457402825629343, &#39;time&#39;: 424.4431610107422}
</span></span></code></pre></div><p>The computation time grows dramatically, over 7 minutes for $C_{11}H_{24}$, due to quadratic scaling with the number of isomers (159 isomers requiring ~12,000 pairwise comparisons).</p>
<p>The results reveal a clear trend:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>fig, axs <span style="color:#f92672">=</span> plt<span style="color:#f92672">.</span>subplots(<span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">2</span>, figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">15</span>, <span style="color:#ae81ff">10</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># in axs[0, 0] - diameter vs number of carbon atoms</span>
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>plot(
</span></span><span style="display:flex;"><span>    list(range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)),
</span></span><span style="display:flex;"><span>    [dunn_scores[n][<span style="color:#e6db74">&#39;diameter&#39;</span>] <span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)],
</span></span><span style="display:flex;"><span>    marker<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;o&#39;</span>
</span></span><span style="display:flex;"><span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Diameter&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;Diameter vs number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># in axs[0, 1] - distance vs number of carbon atoms</span>
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>plot(
</span></span><span style="display:flex;"><span>    list(range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)),
</span></span><span style="display:flex;"><span>    [dunn_scores[n][<span style="color:#e6db74">&#39;distance&#39;</span>] <span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)],
</span></span><span style="display:flex;"><span>    marker<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;o&#39;</span>
</span></span><span style="display:flex;"><span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Distance&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;Distance vs number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># in axs[1, 0] - dunn index vs number of carbon atoms</span>
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>plot(
</span></span><span style="display:flex;"><span>    list(range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)),
</span></span><span style="display:flex;"><span>    [dunn_scores[n][<span style="color:#e6db74">&#39;dunn_index&#39;</span>] <span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)],
</span></span><span style="display:flex;"><span>    marker<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;o&#39;</span>
</span></span><span style="display:flex;"><span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Dunn index&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;Dunn index vs number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># in axs[1, 1] - time vs number of carbon atoms</span>
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>plot(
</span></span><span style="display:flex;"><span>    list(range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)),
</span></span><span style="display:flex;"><span>    [dunn_scores[n][<span style="color:#e6db74">&#39;time&#39;</span>] <span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)],
</span></span><span style="display:flex;"><span>    marker<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;o&#39;</span>
</span></span><span style="display:flex;"><span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Time (s)&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;Time vs number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>tight_layout()
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>savefig(<span style="color:#e6db74">&#39;dunn_index_vs_num_carbon_atoms.webp&#39;</span>, bbox_inches<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;tight&#39;</span>)
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/dunn_index_vs_num_carbon_atoms.webp"
         alt="Dunn Index analysis showing separability metrics, distances, and computation time versus molecular size"
         title="Dunn Index analysis showing separability metrics, distances, and computation time versus molecular size"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Dunn Index analysis reveals deteriorating separability as molecular complexity increases.</figcaption>
    
</figure>

<p>The trend confirms our earlier concerns:</p>
<ul>
<li><strong>$C_{4}H_{10}$</strong>: Excellent separation (Dunn Index = 0.39) between butane and isobutane</li>
<li><strong>$C_{5}H_{12}$ to $C_{6}H_{14}$</strong>: Rapid decline in separability</li>
<li><strong>$C_{7}H_{16}$ and beyond</strong>: Poor separation (Dunn Index $\approx$ 0.02)</li>
</ul>
<p>This validates our computational pipeline and matches the original paper&rsquo;s findings. For larger molecules, eigenvalue clusters become nearly indistinguishable, confirming the overlapping distributions we observed earlier.</p>
<h3 id="silhouette-analysis-individual-conformation-assessment">Silhouette Analysis: Individual Conformation Assessment</h3>
<p>The Dunn Index provides the global view. We must also consider individual molecules. The <a href="https://en.wikipedia.org/wiki/Silhouette_(clustering)">silhouette score</a> evaluates each conformation separately, asking: &ldquo;Is this molecule closer to its own isomer family or to a different one?&rdquo;</p>
<p>For each molecular conformation $i$:</p>
<p>$$
s(i) = \frac{b(i) - a(i)}{\max(a(i), b(i))}
$$</p>
<p>where:</p>
<ul>
<li>$a(i)$ = average distance to other conformations of the <strong>same</strong> isomer</li>
<li>$b(i)$ = average distance to conformations of the <strong>nearest different</strong> isomer</li>
</ul>
<p><strong>Interpretation:</strong></p>
<ul>
<li><strong>Score near +1</strong>: Conformation clusters correctly (good clustering)</li>
<li><strong>Score near -1</strong>: Conformation closer to different isomer (misclassification)</li>
</ul>
<p>This enables two critical measurements:</p>
<ol>
<li>How many isomers have <strong>any</strong> misclassified conformations?</li>
<li>What fraction of <strong>individual conformations</strong> get misclassified?</li>
</ol>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.metrics <span style="color:#f92672">import</span> silhouette_samples
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> tqdm <span style="color:#f92672">import</span> tqdm
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>s_scores <span style="color:#f92672">=</span> {}
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> tqdm(range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)):
</span></span><span style="display:flex;"><span>    X <span style="color:#f92672">=</span> []
</span></span><span style="display:flex;"><span>    y <span style="color:#f92672">=</span> []
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> spectra[n]:
</span></span><span style="display:flex;"><span>        X<span style="color:#f92672">.</span>append(spectra[n][i])
</span></span><span style="display:flex;"><span>        y<span style="color:#f92672">.</span>extend(np<span style="color:#f92672">.</span>full(spectra[n][i]<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>], i))
</span></span><span style="display:flex;"><span>    X <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>concatenate(X)
</span></span><span style="display:flex;"><span>    y <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>array(y)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    s_scores[n] <span style="color:#f92672">=</span> silhouette_samples(X, y)
</span></span></code></pre></div><p>Computing both clustering quality metrics:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Metric 1: Fraction of isomers with ANY negative scores</span>
</span></span><span style="display:flex;"><span>neg_iso <span style="color:#f92672">=</span> {}
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    n_iso <span style="color:#f92672">=</span> s_scores[n]<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>] <span style="color:#f92672">//</span> <span style="color:#ae81ff">1000</span>
</span></span><span style="display:flex;"><span>    n_has_neg <span style="color:#f92672">=</span> <span style="color:#ae81ff">0</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> range(n_iso):
</span></span><span style="display:flex;"><span>        chunk <span style="color:#f92672">=</span> s_scores[n][i <span style="color:#f92672">*</span> <span style="color:#ae81ff">1000</span>:(i <span style="color:#f92672">+</span> <span style="color:#ae81ff">1</span>) <span style="color:#f92672">*</span> <span style="color:#ae81ff">1000</span>]
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> np<span style="color:#f92672">.</span>any(chunk <span style="color:#f92672">&lt;</span> <span style="color:#ae81ff">0</span>):
</span></span><span style="display:flex;"><span>            n_has_neg <span style="color:#f92672">+=</span> <span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>    neg_iso[n] <span style="color:#f92672">=</span> n_has_neg <span style="color:#f92672">/</span> n_iso
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Metric 2: Individual conformation misclassification rates</span>
</span></span><span style="display:flex;"><span>neg_confs <span style="color:#f92672">=</span> {}
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    n_iso <span style="color:#f92672">=</span> s_scores[n]<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>] <span style="color:#f92672">//</span> <span style="color:#ae81ff">1000</span>
</span></span><span style="display:flex;"><span>    neg_confs[n] <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>zeros(n_iso)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> range(n_iso):
</span></span><span style="display:flex;"><span>        isomer_scores <span style="color:#f92672">=</span> s_scores[n][i <span style="color:#f92672">*</span> <span style="color:#ae81ff">1000</span>:(i <span style="color:#f92672">+</span> <span style="color:#ae81ff">1</span>) <span style="color:#f92672">*</span> <span style="color:#ae81ff">1000</span>]
</span></span><span style="display:flex;"><span>        neg_confs[n][i] <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>sum(isomer_scores <span style="color:#f92672">&lt;</span> <span style="color:#ae81ff">0</span>) <span style="color:#f92672">/</span> isomer_scores<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>]
</span></span></code></pre></div><h4 id="isomer-level-analysis">Isomer-Level Analysis</h4>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/fraction_of_negative_silhouette_scores_vs_num_carbon_atoms.webp"
         alt="Chart showing fraction of isomers with at least one misclassified conformation"
         title="Chart showing fraction of isomers with at least one misclassified conformation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Fraction of isomers with at least one misclassified conformation (a stringent test of cluster purity).</figcaption>
    
</figure>

<p>The trend is concerning: by $C_{11}H_{24}$, 97% of isomers have at least one conformation that would be misclassified. This metric is deliberately strict. Even a single misplaced conformation marks the entire isomer as problematic.</p>
<h4 id="conformation-level-analysis">Conformation-Level Analysis</h4>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/fraction_of_negative_silhouette_scores_vs_num_carbon_atoms_individual.webp"
         alt="Chart showing individual misclassification rates per isomer with horizontal lines showing range for each molecular size"
         title="Chart showing individual misclassification rates per isomer with horizontal lines showing range for each molecular size"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Individual misclassification rates per isomer. Each point represents one isomer; horizontal lines show the range for each molecular size.</figcaption>
    
</figure>

<p>The individual analysis reveals dramatic variation:</p>
<ul>
<li><strong>$C_{4}H_{10}$</strong>: Perfect clustering (0% misclassification), confirming our earlier 2D separation plots</li>
<li><strong>$C_{5}H_{12}$ to $C_{6}H_{14}$</strong>: Modest problems (1-8% misclassification rates)</li>
<li><strong>$C_{11}H_{24}$</strong>: Average 35% conformations misclassified per isomer</li>
</ul>
<p>Some isomers experience up to 99.5% conformation misclassification (they become essentially unrecognizable in eigenvalue space). This directly connects to our earlier observation: mathematical representations that appear elegant may lack the structural nuances needed for practical discrimination.</p>
<h2 id="supervised-learning-finding-hidden-structure">Supervised Learning: Finding Hidden Structure</h2>
<p>Both clustering metrics deliver the same conclusion: Coulomb matrix eigenvalues alone struggle to reliably distinguish constitutional isomers for larger alkanes. The mathematical elegance of eigenvalues encounters practical limitations as molecular complexity increases.</p>
<p>Supervised learning offers an alternative approach. Providing labels allows models to extract hidden patterns that elude clustering algorithms. The mathematical structure often requires explicit guidance for discovery.</p>
<p>I&rsquo;ll focus on two baseline approaches: k-nearest neighbors and logistic regression. These represent fundamentally different learning paradigms (one memorizes patterns, the other learns linear boundaries) giving us insight into what types of structure might exist in eigenvalue space.</p>
<h2 id="k-nearest-neighbors-pattern-recognition-through-memory">k-Nearest Neighbors: Pattern Recognition Through Memory</h2>
<p>k-NN represents the simplest supervised learning approach: it stores all training examples and classifies new samples based on their closest neighbors. If eigenvalue patterns truly distinguish isomers, nearby points in eigenvalue space should belong to the same class.</p>
<p>This directly tests the local structure. Local neighborhoods often preserve meaningful distinctions even when global structure appears diffuse.</p>
<h3 id="testing-different-feature-representations">Testing Different Feature Representations</h3>
<p>We compare three approaches: full eigenvalue vectors, top 10 eigenvalues only, and PCA-reduced representations.</p>
<p>Testing 1-nearest neighbor with full dimensionality:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.model_selection <span style="color:#f92672">import</span> cross_val_score, StratifiedKFold
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.neighbors <span style="color:#f92672">import</span> KNeighborsClassifier
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>df_1nn <span style="color:#f92672">=</span> []
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Prepare the data for CnH2n+2</span>
</span></span><span style="display:flex;"><span>    X, y <span style="color:#f92672">=</span> prep_data(n<span style="color:#f92672">=</span>n)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Create knn classifier</span>
</span></span><span style="display:flex;"><span>    knn <span style="color:#f92672">=</span> KNeighborsClassifier(n_neighbors<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Set up stratified 5-fold cross-validation</span>
</span></span><span style="display:flex;"><span>    cv <span style="color:#f92672">=</span> StratifiedKFold(n_splits<span style="color:#f92672">=</span><span style="color:#ae81ff">5</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Perform cross-validation. Since &#39;cross_val_score&#39; computes accuracy, we compute misclassification rate by subtracting accuracy from 1.</span>
</span></span><span style="display:flex;"><span>    acc_scores <span style="color:#f92672">=</span> cross_val_score(knn, X, y, cv<span style="color:#f92672">=</span>cv, scoring<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;accuracy&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Convert accuracy scores to misclassification error rates</span>
</span></span><span style="display:flex;"><span>    misclassification_error_rates <span style="color:#f92672">=</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> acc_scores
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Calculate the average and standard deviation of the misclassification error rates</span>
</span></span><span style="display:flex;"><span>    avg_misclassification_error <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>mean(misclassification_error_rates)
</span></span><span style="display:flex;"><span>    std_misclassification_error <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>std(misclassification_error_rates)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">: </span><span style="color:#e6db74">{</span>avg_misclassification_error<span style="color:#e6db74">:</span><span style="color:#e6db74">.2%</span><span style="color:#e6db74">}</span><span style="color:#e6db74"> ± </span><span style="color:#e6db74">{</span>std_misclassification_error<span style="color:#e6db74">:</span><span style="color:#e6db74">.2%</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    df_1nn<span style="color:#f92672">.</span>append({
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;molecule&#39;</span>: <span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;avg_misclassification_error&#39;</span>: avg_misclassification_error,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;std_misclassification_error&#39;</span>: std_misclassification_error,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;n&#39;</span>: n,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;representation&#39;</span>: <span style="color:#e6db74">&#39;full&#39;</span>,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;model&#39;</span>: <span style="color:#e6db74">&#39;1nn&#39;</span>,
</span></span><span style="display:flex;"><span>    })
</span></span></code></pre></div><p>The results are remarkable compared to unsupervised clustering:</p>
<pre><code>C4H10: 0.00% ± 0.00%
C5H12: 0.00% ± 0.00%
C6H14: 0.00% ± 0.00%
C7H16: 0.07% ± 0.05%
C8H18: 0.11% ± 0.05%
C9H20: 0.51% ± 0.09%
C10H22: 1.31% ± 0.09%
C11H24: 3.24% ± 0.09%
</code></pre>
<p><strong>Perfect classification</strong> for molecules up to $C_{6}H_{14}$, with low error rates even for $C_{11}H_{24}$ (3.24%). This is a large improvement over clustering, where 97% of $C_{11}H_{24}$ isomers had misclassified conformations.</p>
<p><strong>Note on feature scaling:</strong> Standardizing features significantly degraded performance; eigenvalue magnitudes carry crucial structural information.</p>
<p>Comparing performance across different feature representations:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/alkane-classification-1nn.webp"
         alt="1-NN performance across different representations"
         title="1-NN performance across different representations"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">1-NN classification performance across different eigenvalue representations shows similar results, with slight advantages for full representations on larger molecules.</figcaption>
    
</figure>

<p><strong>Key insights:</strong></p>
<ul>
<li><strong>Representation choice matters little</strong> for 1-NN. Full, top-10, and PCA representations perform nearly identically</li>
<li><strong>PCA slightly outperforms</strong> top-10 eigenvalues for larger molecules, capturing more structural variance</li>
<li><strong>Perfect classification</strong> persists through $C_{6}H_{14}$ regardless of representation</li>
</ul>
<p>This confirms that discriminative information concentrates in the largest eigenvalues, validating our earlier PCA findings.</p>
<h3 id="the-neighbor-count-effect">The Neighbor Count Effect</h3>
<p>Testing k-NN with different neighbor counts (k=1, 3, 5) reveals a counterintuitive pattern:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/alkane-classification-knn.webp"
         alt="k-NN performance for different k values"
         title="k-NN performance for different k values"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">k-NN classification performance decreases as k increases. More neighbors actually hurt accuracy.</figcaption>
    
</figure>

<p><strong>Why does performance degrade with more neighbors?</strong> This connects directly to our earlier clustering analysis. The eigenvalue space lacks meaningful local structure. When k-NN examines beyond the immediate nearest neighbor, it increasingly finds examples from different classes.</p>
<p>This validates our unsupervised findings: in the absence of clear cluster boundaries, examining more neighbors introduces noise.</p>
<h2 id="logistic-regression-learning-linear-decision-boundaries">Logistic Regression: Learning Linear Decision Boundaries</h2>
<p>Logistic regression represents a fundamentally different approach. Logistic regression learns linear decision boundaries in eigenvalue space. If eigenvalues encode structural information linearly, this should work well.</p>
<p>We&rsquo;ll focus on PCA-reduced representations to keep computation manageable, using insights from the k-NN analysis.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.linear_model <span style="color:#f92672">import</span> LogisticRegression
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.pipeline <span style="color:#f92672">import</span> Pipeline
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.decomposition <span style="color:#f92672">import</span> PCA
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>df_lr <span style="color:#f92672">=</span> []
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Prepare the data for CnH2n+2</span>
</span></span><span style="display:flex;"><span>    X, y <span style="color:#f92672">=</span> prep_data(n<span style="color:#f92672">=</span>n)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Create logistic regression classifier with PCA</span>
</span></span><span style="display:flex;"><span>    lr <span style="color:#f92672">=</span> Pipeline([
</span></span><span style="display:flex;"><span>        (<span style="color:#e6db74">&#39;pca&#39;</span>, PCA(n_components<span style="color:#f92672">=</span><span style="color:#ae81ff">10</span>)),
</span></span><span style="display:flex;"><span>        (<span style="color:#e6db74">&#39;lr&#39;</span>, LogisticRegression(
</span></span><span style="display:flex;"><span>            max_iter<span style="color:#f92672">=</span><span style="color:#ae81ff">10_000</span>,
</span></span><span style="display:flex;"><span>            penalty<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;l2&#39;</span>,
</span></span><span style="display:flex;"><span>            solver<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;lbfgs&#39;</span>,
</span></span><span style="display:flex;"><span>            C<span style="color:#f92672">=</span><span style="color:#ae81ff">10.0</span>,  <span style="color:#75715e"># Reduced regularization</span>
</span></span><span style="display:flex;"><span>            random_state<span style="color:#f92672">=</span><span style="color:#ae81ff">42</span>,
</span></span><span style="display:flex;"><span>            n_jobs<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>,
</span></span><span style="display:flex;"><span>        ))
</span></span><span style="display:flex;"><span>    ])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 5-fold stratified cross-validation</span>
</span></span><span style="display:flex;"><span>    cv <span style="color:#f92672">=</span> StratifiedKFold(n_splits<span style="color:#f92672">=</span><span style="color:#ae81ff">5</span>)
</span></span><span style="display:flex;"><span>    acc_scores <span style="color:#f92672">=</span> cross_val_score(lr, X, y, cv<span style="color:#f92672">=</span>cv, scoring<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;accuracy&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Convert to misclassification rates</span>
</span></span><span style="display:flex;"><span>    avg_error <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>mean(<span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> acc_scores)
</span></span><span style="display:flex;"><span>    std_error <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>std(<span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> acc_scores)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">: </span><span style="color:#e6db74">{</span>avg_error<span style="color:#e6db74">:</span><span style="color:#e6db74">.2%</span><span style="color:#e6db74">}</span><span style="color:#e6db74"> ± </span><span style="color:#e6db74">{</span>std_error<span style="color:#e6db74">:</span><span style="color:#e6db74">.2%</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span></code></pre></div><p>Comparing k-NN versus logistic regression performance:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/alkane-classification-1nn-lr.webp"
         alt="Comparison of 1-NN and Logistic Regression performance"
         title="Comparison of 1-NN and Logistic Regression performance"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">k-NN significantly outperforms logistic regression, especially for larger molecules. The performance gap widens as molecular complexity increases.</figcaption>
    
</figure>

<p><strong>Key observations:</strong></p>
<ul>
<li><strong>k-NN dominates</strong> across all molecular sizes</li>
<li><strong>Linear boundaries fail</strong> for larger molecules. This suggests nonlinear eigenvalue relationships.</li>
<li><strong>Performance gap grows</strong> with molecular complexity, indicating increasingly nonlinear structural patterns</li>
</ul>
<p>Logistic regression&rsquo;s performance indicates that discriminative patterns in eigenvalue space are fundamentally nonlinear. Capturing these complex relationships requires memory-based or non-linear approaches.</p>
<h2 id="implications-for-molecular-representation">Implications for Molecular Representation</h2>
<p>Our supervised learning experiments reveal a nuanced picture of Coulomb matrix eigenvalues as molecular descriptors. Eigenvalues preserve sufficient local structure for nearest-neighbor classification to work remarkably well, despite lacking clean global clusters.</p>
<p>This analysis reveals important lessons about molecular representations:</p>
<ol>
<li><strong>Empirical performance and mathematical elegance are separate axes</strong>: an elegant descriptor can still fail in practice as the space gets more complex.</li>
<li><strong>Context matters</strong>: Representations exhibit distinct performance characteristics under supervised versus unsupervised conditions.</li>
<li><strong>Molecular complexity is challenging</strong>: Even simple alkanes test our best descriptors.</li>
<li><strong>Local vs. global structure</strong>: Local neighborhood structures often contain highly discriminative information.</li>
</ol>
<p>For practitioners working with molecular representations, it is crucial to test multiple learning paradigms. Supervised and unsupervised approaches often yield different insights. Furthermore, logistic regression&rsquo;s poor performance indicates that discriminative patterns in eigenvalue space are fundamentally nonlinear. Capturing these complex relationships requires memory-based or non-linear approaches.</p>
<p><strong>Why this matters beyond the alkane case:</strong> molecular representations are the input layer for property prediction, and understanding where a simple descriptor like Coulomb-matrix eigenvalues fails (overlapping clusters for larger molecules, nonlinear class structure that defeats logistic regression) is what motivates moving to graph- or coordinate-aware models. The failure modes here are the argument for richer representations.</p>
<p>The data pipeline that generated the datasets used in this analysis is available at the <a href="/projects/isomer-dataset-generation/">Synthetic Isomer Data Generation Pipeline project page</a>.</p>
]]></content:encoded></item><item><title>Coulomb Matrices for Molecular Machine Learning</title><link>https://hunterheidenreich.com/posts/molecular-descriptor-coulomb-matrix/</link><pubDate>Sat, 10 Feb 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/molecular-descriptor-coulomb-matrix/</guid><description>Learn how Coulomb matrices encode 3D molecular structure for machine learning from basic theory to Python implementation and practical limitations.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>When working with machine learning in chemistry, one of the first challenges you encounter is how to represent molecules in a way that algorithms can understand. You can&rsquo;t just feed raw atomic coordinates into a model. The representation needs to be invariant to rotation, translation, and atom ordering, since these operations don&rsquo;t change the molecule&rsquo;s fundamental properties.</p>
<p>The Coulomb matrix, introduced by Rupp et al. in 2012 <a href="#ref-1">[1]</a>, provides a straightforward solution to this problem. While newer methods have largely superseded it for practical applications, the Coulomb matrix remains an excellent starting point for understanding how molecular descriptors work.</p>
<p>The key insight is simple: we encode pairwise relationships between atoms in a way that captures the essential physics while maintaining the required invariances.</p>
<h2 id="the-coulomb-matrix-theory-and-intuition">The Coulomb Matrix: Theory and Intuition</h2>
<p>The Coulomb matrix encodes molecular structure in a symmetric $N \times N$ matrix, where $N$ is the number of atoms. Each element $C_{ij}$ is defined as:</p>
<p>$$
C_{ij} = \begin{cases} 0.5 Z_i^{2.4} &amp; \text{if } i = j, \\ \frac{Z_i Z_j}{|\mathbf{R}_i - \mathbf{R}_j|} &amp; \text{if } i \neq j, \end{cases}
$$</p>
<p>Here, $Z_i$ is the atomic number of atom $i$, and $\mathbf{R}_i$ is its position in 3D space. The diagonal elements ($0.5 Z_i^{2.4}$) represent atomic self-energies, derived from fitting atomic numbers to experimental data. The off-diagonal elements mimic Coulombic interactions between atoms. They&rsquo;re inversely proportional to distance, just like electrostatic potential energy <a href="#ref-3">[3]</a>.</p>
<p>This construction gives us several useful properties:</p>
<ul>
<li><strong>Rotation and translation invariant</strong>: Only relative distances matter</li>
<li><strong>Symmetric</strong>: $C_{ij} = C_{ji}$, which is physically sensible</li>
<li><strong>Size-extensive</strong>: Larger molecules have larger matrix elements</li>
<li><strong>Captures 3D structure</strong>: Nearby atoms have larger interaction terms</li>
</ul>
<p>While more sophisticated methods exist today <a href="#ref-2">[2]</a>, the Coulomb matrix&rsquo;s simplicity makes it ideal for understanding the fundamentals of molecular representation.</p>
<h3 id="hands-on-example-bicyclobutane">Hands-on Example: Bicyclobutane</h3>
<p>Let&rsquo;s calculate the Coulomb matrix for <a href="https://en.wikipedia.org/wiki/Bicyclobutane">bicyclobutane</a>, a strained but stable bicyclic system (bicyclo[1.1.0]butane, C4H6, two cis-fused cyclopropane rings). This example will show you exactly how the theory translates to practice.</p>















<figure class="post-figure center ">
    <img src="https://upload.wikimedia.org/wikipedia/commons/b/b4/Bicyclobutane-2.svg"
         alt="Bicyclobutane"
         title="Bicyclobutane"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Bicyclobutane structure (Smokefoot, Public domain, via Wikimedia Commons)</figcaption>
    
</figure>

<p>I&rsquo;ll use Python with the Atomic Simulation Environment (<code>ase</code>) for molecular structure <a href="#ref-4">[4]</a> and <code>dscribe</code> for the Coulomb matrix calculation <a href="#ref-2">[2]</a>:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> ase.build <span style="color:#f92672">import</span> molecule
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> ase.visualize <span style="color:#f92672">import</span> view
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Load the bicyclobutane structure</span>
</span></span><span style="display:flex;"><span>bicyclobutane <span style="color:#f92672">=</span> molecule(<span style="color:#e6db74">&#39;bicyclobutane&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Optional: visualize the structure</span>
</span></span><span style="display:flex;"><span>view(bicyclobutane, viewer<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;x3d&#39;</span>)
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/bicyclobutane_ase_1.webp"
         alt="Bicyclobutane 3D structure"
         title="Bicyclobutane 3D structure"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">3D structure of bicyclobutane</figcaption>
    
</figure>

<p>Now we calculate the Coulomb matrix using DScribe:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> dscribe.descriptors <span style="color:#f92672">import</span> CoulombMatrix
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Set up the descriptor</span>
</span></span><span style="display:flex;"><span>cm <span style="color:#f92672">=</span> CoulombMatrix(n_atoms_max<span style="color:#f92672">=</span>len(bicyclobutane))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Calculate and reshape into matrix form</span>
</span></span><span style="display:flex;"><span>cm_bicyclobutane <span style="color:#f92672">=</span> cm<span style="color:#f92672">.</span>create(bicyclobutane)
</span></span><span style="display:flex;"><span>cm_bicyclobutane <span style="color:#f92672">=</span> cm_bicyclobutane<span style="color:#f92672">.</span>reshape(len(bicyclobutane), len(bicyclobutane))
</span></span></code></pre></div><h3 id="visualizing-the-results">Visualizing the Results</h3>
<p>The Coulomb matrix can be visualized as a heatmap. Let&rsquo;s look at both the raw matrix and its logarithm:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> matplotlib.pyplot <span style="color:#66d9ef">as</span> plt
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> numpy <span style="color:#66d9ef">as</span> np
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Raw Coulomb matrix</span>
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>figure(figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">8</span>, <span style="color:#ae81ff">8</span>), dpi<span style="color:#f92672">=</span><span style="color:#ae81ff">150</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>imshow(cm_bicyclobutane, cmap<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;coolwarm&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>colorbar(label<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;Magnitude&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>title(<span style="color:#e6db74">&#39;Coulomb Matrix for Bicyclobutane&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>show()
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/cm_bicyclobutane.webp"
         alt="Coulomb matrix of bicyclobutane"
         title="Coulomb matrix of bicyclobutane"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Coulomb matrix for bicyclobutane</figcaption>
    
</figure>

<p>The raw matrix shows clear patterns:</p>
<ul>
<li><strong>Large diagonal elements</strong>: Carbon atoms (Z=6) dominate due to their higher atomic numbers</li>
<li><strong>Smaller off-diagonal elements</strong>: Represent pairwise interactions</li>
<li><strong>Minimal hydrogen contribution</strong>: Hydrogen atoms (Z=1) have much smaller values</li>
</ul>
<p>For better visualization of the structure, the logarithm reveals more detail:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>figure(figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">8</span>, <span style="color:#ae81ff">8</span>), dpi<span style="color:#f92672">=</span><span style="color:#ae81ff">150</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>imshow(np<span style="color:#f92672">.</span>log(cm_bicyclobutane), cmap<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;coolwarm&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>colorbar(label<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;log(Magnitude)&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>title(<span style="color:#e6db74">&#39;Log Coulomb Matrix for Bicyclobutane&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>show()
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/cm_bicyclobutane_log.webp"
         alt="Log Coulomb matrix of bicyclobutane"
         title="Log Coulomb matrix of bicyclobutane"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Log-scale reveals more structural detail</figcaption>
    
</figure>

<h3 id="eigenvalue-analysis">Eigenvalue Analysis</h3>
<p>The eigenvalues of the Coulomb matrix provide another perspective on molecular structure:</p>















<figure class="post-figure center ">
    <img src="/img/cm_bicyclobutane_eigenvalues.webp"
         alt="Eigenvalues of Coulomb matrix"
         title="Eigenvalues of Coulomb matrix"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Eigenvalues of the Coulomb matrix</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/cm_bicyclobutane_log_eigenvalues.webp"
         alt="Eigenvalues of log Coulomb matrix"
         title="Eigenvalues of log Coulomb matrix"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Eigenvalues on logarithmic scale</figcaption>
    
</figure>

<p>These eigenvalues are often used as features themselves, providing a more compact representation than the full matrix.</p>
<h2 id="practical-limitations">Practical Limitations</h2>
<p>The Coulomb matrix has significant limitations that explain why it&rsquo;s been largely superseded by modern methods. Understanding these constraints is crucial for knowing when and how to use this descriptor.</p>
<h3 id="the-size-problem">The Size Problem</h3>
<p>Every molecule must be represented by the same size matrix, which creates several issues:</p>
<ul>
<li><strong>Padding overhead</strong>: Small molecules get padded with zeros up to the maximum size</li>
<li><strong>Quadratic scaling</strong>: An $N$-atom molecule requires $N^2$ features</li>
<li><strong>Fixed maximum size</strong>: You can&rsquo;t represent molecules larger than your preset limit</li>
<li><strong>Inefficient storage</strong>: Most elements are zero for small molecules in large matrices</li>
</ul>
<p>For a dataset ranging from 5-atom to 50-atom molecules, every molecule needs a 50x50 matrix. That&rsquo;s 2,500 features, most of which are zero for smaller molecules.</p>
<h3 id="permutation-sensitivity">Permutation Sensitivity</h3>
<p>Despite being called &ldquo;invariant,&rdquo; the Coulomb matrix can actually change if you reorder the atoms in your input file. The standard solution is to sort atoms by the L2 norm of their matrix rows, but this introduces its own problems:</p>
<ul>
<li><strong>Symmetry breaking</strong>: Equivalent atoms might be ordered differently</li>
<li><strong>Numerical instability</strong>: Small coordinate changes can flip the ordering</li>
<li><strong>Loss of chemical intuition</strong>: The sorted order doesn&rsquo;t reflect meaningful chemistry</li>
</ul>
<p>Interestingly, some studies suggest that adding controlled noise to create multiple permutations can actually improve machine learning performance <a href="#ref-5">[5]</a>.</p>
<h3 id="limited-scope">Limited Scope</h3>
<p>The Coulomb matrix works well only for specific types of systems:</p>
<ul>
<li><strong>Small molecules</strong>: Performance degrades for large systems due to size scaling</li>
<li><strong>Gas-phase</strong>: Not suitable for periodic systems like crystals or surfaces</li>
<li><strong>Single conformations</strong>: Each 3D structure gets its own matrix</li>
<li><strong>Non-reactive</strong>: Doesn&rsquo;t capture bond-breaking or formation</li>
</ul>
<p>For periodic systems, you&rsquo;d need specialized variants like the Ewald sum matrix <a href="#ref-6">[6]</a>.</p>
<h2 id="why-learn-it-anyway">Why Learn It Anyway?</h2>
<p>Given these limitations, why spend time understanding the Coulomb matrix? Several reasons:</p>
<p><strong>Educational value</strong>: It&rsquo;s conceptually straightforward and provides excellent intuition for how molecular descriptors work. The mathematical formulation is simple enough to implement from scratch.</p>
<p><strong>Historical importance</strong>: Many subsequent methods build on ideas first explored with Coulomb matrices. Understanding this foundation helps you appreciate why newer methods were developed.</p>
<p><strong>Benchmarking</strong>: It remains useful as a baseline method for comparing new descriptors on small molecular datasets.</p>
<p><strong>Proof of concept</strong>: For exploratory work on small, well-defined datasets, the Coulomb matrix can still provide quick insights.</p>
<p>If you&rsquo;re working on practical problems with larger datasets or diverse molecular sizes, consider modern alternatives like graph neural networks, descriptors from DScribe&rsquo;s extended library, or learned representations from transformer models.</p>
<h2 id="putting-it-in-context">Putting It in Context</h2>
<p>To see the Coulomb matrix applied to real problems, I&rsquo;ve written a detailed guide using it for molecular classification:</p>
<ul>
<li><a href="/posts/alkane-constitutional-isomer-classification/">Coulomb Matrix Eigenvalues: Can You Hear the Shape of a Molecule?</a>: A comprehensive analysis of alkane isomers, from unsupervised clustering limits to supervised classification successes.</li>
</ul>
<p>For comparison with modern approaches, check out my post on <a href="/posts/geom-conformer-generation-dataset/">3D conformer generation with the GEOM dataset</a>, which showcases more sophisticated molecular representations. For technical specifications and benchmarks, see the <a href="/notes/chemistry/datasets/geom/">GEOM dataset card</a>.</p>
<p>The Coulomb matrix may be dated, but it remains an excellent entry point into the world of molecular machine learning. Once you understand its strengths and limitations, you&rsquo;ll be better equipped to appreciate why the field has moved toward more sophisticated approaches.</p>
<hr>
<p><em>Have questions about molecular descriptors or want to discuss other approaches to molecular machine learning? I&rsquo;d be happy to explore these topics further.</em></p>
<h2 id="references">References</h2>
<ul>
<li><a id="ref-1"></a>[1]: M. Rupp, A. Tkatchenko, K.-R. Müller, and O. A. von Lilienfeld, &ldquo;Fast and Accurate Modeling of Molecular Atomization Energies with Machine Learning,&rdquo; Physical Review Letters, 108(5), 058301 (2012). <a href="https://doi.org/10.1103/PhysRevLett.108.058301">https://doi.org/10.1103/PhysRevLett.108.058301</a> <a href="https://arxiv.org/abs/1109.2618">arXiv:1109.2618</a></li>
<li><a id="ref-2"></a>[2] L. Himanen, M. O. J. Jäger, E. V. Morooka, F. F. Canova, Y. S. Ranawat, D. Z. Gao, P. Rinke, and A. S. Foster, &ldquo;DScribe: Library of descriptors for machine learning in materials science,&rdquo; Computer Physics Communications, 247, 106949 (2020). <a href="https://doi.org/10.1016/j.cpc.2019.106949">https://doi.org/10.1016/j.cpc.2019.106949</a> <a href="https://arxiv.org/abs/1904.08875">arXiv:1904.08875</a></li>
<li><a id="ref-3"></a>[3] J. Schrier, &ldquo;Can one hear the shape of a molecule (from its Coulomb matrix eigenvalues)?,&rdquo; Journal of Chemical Information and Modeling, 60(8), 3804-3811 (2020). <a href="https://doi.org/10.1021/acs.jcim.0c00631">https://doi.org/10.1021/acs.jcim.0c00631</a></li>
<li><a id="ref-4"></a>[4] A. H. Larsen, J. J. Mortensen, J. Blomqvist, I. E. Castelli, R. Christensen, M. Dułak, J. Friis, M. N. Groves, B. Hammer, C. Hargus, E. D. Hermes, P. C. Jennings, P. B. Jensen, J. Kermode, J. R. Kitchin, E. L. Kolsbjerg, J. Kubal, K. Kaasbjerg, S. Lysgaard, J. B. Maronsson, T. Maxson, T. Olsen, L. Pastewka, A. Peterson, C. Rostgaard, J. Schiøtz, O. Schütt, M. Strange, K. S. Thygesen, T. Vegge, L. Vilhelmsen, M. Walter, Z. Zeng, and K. W. Jacobsen, &ldquo;The Atomic Simulation Environment - A Python library for working with atoms,&rdquo; J. Phys.: Condens. Matter, 29, 273002 (2017). <a href="https://doi.org/10.1088/1361-648X/aa680e">https://doi.org/10.1088/1361-648X/aa680e</a> <a href="https://ase-lib.org/index.html">documentation</a></li>
<li><a id="ref-5"></a>[5] G. Montavon, K. Hansen, S. Fazli, M. Rupp, F. Biegler, A. Ziehe, A. Tkatchenko, A. Lilienfeld, and K.-R. Müller, &ldquo;Learning invariant representations of molecules for atomization energy prediction,&rdquo; Advances in Neural Information Processing Systems, 25 (2012). Available online: <a href="https://proceedings.neurips.cc/paper_files/paper/2012/file/115f89503138416a242f40fb7d7f338e-Paper.pdf">https://proceedings.neurips.cc/paper_files/paper/2012/file/115f89503138416a242f40fb7d7f338e-Paper.pdf</a></li>
<li><a id="ref-6"></a>[6] F. Faber, A. Lindmaa, O. A. von Lilienfeld, and R. Armiento, &ldquo;Crystal structure representations for machine learning models of formation energies,&rdquo; International Journal of Quantum Chemistry, 115(16), 1094-1101 (2015). <a href="https://doi.org/10.1002/qua.24917">https://doi.org/10.1002/qua.24917</a></li>
</ul>
]]></content:encoded></item><item><title>Kabsch Algorithm: NumPy, PyTorch, TensorFlow, and JAX</title><link>https://hunterheidenreich.com/posts/kabsch-algorithm/</link><pubDate>Tue, 03 Oct 2023 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/kabsch-algorithm/</guid><description>Learn about the Kabsch algorithm for optimal point alignment with implementations in NumPy, PyTorch, TensorFlow, and JAX for ML applications.</description><content:encoded><![CDATA[<h2 id="what-is-the-kabsch-algorithm">What is the Kabsch Algorithm?</h2>
<p>In computer vision or scientific computing, a common problem frequently arises: given two sets of points, what is the optimal rigid body transformation for their alignment? The Kabsch algorithm provides a nice solution.</p>















<figure class="post-figure center ">
    <img src="/img/scientific-computing/kabsch-alignment-before-and-after.webp"
         alt="Visualization of two point sets before and after Kabsch alignment"
         title="Visualization of two point sets before and after Kabsch alignment"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Kabsch algorithm optimally rotates and translates the blue points to align with the red points.</figcaption>
    
</figure>

<p>What are some concrete situations where this crops up?</p>
<ul>
<li><strong>Molecular Dynamics</strong>: Your points are a set of atoms (with physically relevant types), and you want to compare two molecular conformations. Are they the same structure with minor noise or rotation? Or are they different conformations, like a different folding of a protein? This is especially helpful when applying generative models to chemical structures. For example, if you are building a <a href="/notes/chemistry/molecular-simulation/ml-potentials/denoise-vae/">3D Molecular VAE</a> in PyTorch or working with <a href="/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/">Flow Matching models</a>, Kabsch alignment ensures your generative loss function remains rotationally invariant.</li>
<li><strong>Computer Vision</strong>: You have two point clouds from 3D scans of an object taken from different angles. You want to align them to reconstruct the full shape. Or perhaps you&rsquo;re generating 3D shapes from 2D images and need to compare the generated shape to a ground truth scan. Anytime a 3D system is represented as a point cloud, the Kabsch algorithm can help with alignment.</li>
</ul>
<p>Of course, existing libraries implement this algorithm. However, often I find it beneficial to implement algorithms from scratch to build intuition. Furthermore, modern machine learning applications require automatic differentiation, so we will implement the algorithm in PyTorch, TensorFlow, and JAX.</p>
<p>Below, we&rsquo;ll cover the math behind the Kabsch algorithm (and its scaling variant, the <strong>Kabsch-Umeyama</strong> algorithm) and provide complete, differentiable implementations in <strong>NumPy</strong>, <strong>PyTorch</strong>, <strong>TensorFlow</strong>, and <strong>JAX</strong>, demonstrating both single-pair and batched computations for ML applications.</p>
<h2 id="the-math">The Math</h2>















<figure class="post-figure center ">
    <img src="/img/scientific-computing/kabsch-algorithm-basic-animation.webp"
         alt="Animation showing the iterative steps of centroid alignment and rotation"
         title="Animation showing the iterative steps of centroid alignment and rotation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Visualizing the alignment process: first centering the datasets, then finding the optimal rotation.</figcaption>
    
</figure>

<p>Let&rsquo;s say we have two sets of paired points,
$P={\mathbf{p}_i} \in \mathbb{R}^{N \times D}$ and $Q={\mathbf{q}_i} \in \mathbb{R}^{N \times D}$, for $i = 1, \dots, N$
(where $D$ is the dimensionality and $N$ is the number of points).
We want to find a translation vector $\mathbf{t}$ and rotation matrix $R$ to transform $P$ to align with $Q$.</p>
<p>The optimization problem is:</p>
<p>$$
\min_{\mathbf{t}, \ R} \mathcal{L}(\mathbf{t}, R) = \frac{1}{2} \sum_{i=1}^N | \mathbf{q}_i - (R\mathbf{p}_i + \mathbf{t}) |^2
$$</p>
<p>where $\mathbf{t}^\ast \in \mathbb{R}^D$ and $R^\ast \in \mathbb{R}^{D \times D}$ are the optimal translation and rotation.</p>
<p>Often we use a weighted version with weights $w_i$ (e.g., atomic masses in molecular dynamics):</p>
<p>$$
\min_{\mathbf{t}, \ R} \mathcal{L}(\mathbf{t}, R) = \frac{1}{2} \sum_{i=1}^N w_i | \mathbf{q}_i - (R\mathbf{p}_i + \mathbf{t}) |^2
$$</p>
<h3 id="the-translation">The Translation</h3>
<p>The translation and rotation are coupled, but they separate cleanly once we work in centroid-centered coordinates. Compute the centroids (averages) of both point sets:</p>
<p>$$
\bar{\mathbf{p}} = \frac{1}{N} \sum_{i=1}^N \mathbf{p}_i \quad \text{and} \quad \bar{\mathbf{q}} = \frac{1}{N} \sum_{i=1}^N \mathbf{q}_i
$$</p>
<p>For any fixed rotation $R$, the translation that minimizes $\mathcal{L}$ is found by setting $\partial \mathcal{L} / \partial \mathbf{t} = 0$. It maps the rotated source centroid onto the target centroid:</p>
<p>$$
\mathbf{t} = \bar{\mathbf{q}} - R\bar{\mathbf{p}}
$$</p>
<p>A tempting shortcut is to write $\mathbf{t} = \bar{\mathbf{q}} - \bar{\mathbf{p}}$, but that is only correct when $R = I$. In general the translation depends on the rotation, so we compute it <em>after</em> solving for $R$. Substituting this optimal $\mathbf{t}$ back into the objective cancels the centroids and leaves a rotation-only problem in the centered coordinates $\mathbf{p}_i^\prime = \mathbf{p}_i - \bar{\mathbf{p}}$ and $\mathbf{q}_i^\prime = \mathbf{q}_i - \bar{\mathbf{q}}$:</p>
<p>$$
\mathcal{L}(R) = \frac{1}{2} \sum_{i=1}^N | \mathbf{q}_i^\prime - R\mathbf{p}_i^\prime |^2
$$</p>
<p>which is what the next section solves.</p>
<h3 id="the-rotation-matrix">The Rotation Matrix</h3>
<p>We now minimize $\mathcal{L}(R)$ over rotations, using the centered points $\mathbf{p}_i^\prime$ and $\mathbf{q}_i^\prime$ from above. Compute the cross-covariance matrix between the centered sets:</p>
<p>$$
C = P^{\prime T} Q^\prime = \sum_{i=1}^N \mathbf{p}_i^{\prime T} \mathbf{q}_i^{\prime} \in \mathbb{R}^{D \times D}
$$</p>
<p>This is a fairly lightweight operation since $D$ is typically small (e.g., 3 for 3D points), even if $N$ is large.</p>
<p>With $C$ in hand, we want to compute its Singular Value Decomposition (SVD):</p>
<p>$$
C = U \Sigma V^T
$$</p>
<p>This operation is computationally expensive. It scales cubically with $D$ (i.e., $O(D^3)$).
However, since we&rsquo;re often interested in cases where $D$ is small (e.g., 2D or 3D points), this is manageable.</p>
<p>Next, we check for improper rotations (i.e., reflections) and correct for them where necessary:</p>
<p>$$
d = \text{sign}(\det(V U^T))
$$</p>
<p>If $d = -1$, we need to flip the last column of $V$ in the final rotation matrix.</p>
<p>Let $B = \text{diag}(1, 1, d)$.
The optimal rotation matrix comes out:</p>
<p>$$
R^\ast = V B U^T
$$</p>
<h3 id="summary">Summary</h3>
<p>In a nutshell, the Kabsch algorithm boils down to:</p>
<ol>
<li>Compute centroids of $P$ and $Q$ ($\bar{\mathbf{p}}$ and $\bar{\mathbf{q}}$)</li>
<li>Center both point sets by subtracting centroids: $P^\prime$ and $Q^\prime$</li>
<li>Compute cross-covariance matrix $C = P^{\prime T} Q^\prime$</li>
<li>Compute SVD: $C = U \Sigma V^T$ (<em>expensive step</em>)</li>
<li>Compute $d = \text{sign}(\det(V U^T))$ and $B = \text{diag}(1, 1, d)$</li>
<li>Optimal rotation: $R^\ast = V B U^T$</li>
<li>Optimal translation (using the rotation from step 6): $\mathbf{t}^\ast = \bar{\mathbf{q}} - R^\ast\bar{\mathbf{p}}$</li>
</ol>
<p>The resulting root-mean-square deviation (RMSD) between aligned point sets is</p>
<p>$$
\text{RMSD} = \sqrt{\frac{1}{N} \sum_{i=1}^N | \mathbf{q}_i - (R^\ast\mathbf{p}_i + \mathbf{t}^\ast) |^2}
$$</p>















<figure class="post-figure center ">
    <img src="/img/scientific-computing/kabsch-algorithm-visualized-rmsd.webp"
         alt="Diagram illustrating Root Mean Square Deviation (RMSD) distances"
         title="Diagram illustrating Root Mean Square Deviation (RMSD) distances"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">RMSD measures the average distance between the aligned points.</figcaption>
    
</figure>

<p>which is frequently used as a measure of similarity between molecular structures or as a metric in loss functions for ML applications.</p>
<h3 id="the-kabsch-umeyama-algorithm-scaling">The Kabsch-Umeyama Algorithm (Scaling)</h3>
<p>While the standard Kabsch algorithm solves for optimal rotation and translation, the <strong>Kabsch-Umeyama algorithm</strong> extends this by also finding an optimal <strong>scaling factor</strong> $c$. This is essential when aligning structures of different scales, such as a 3D scan versus a ground truth model.</p>
<p><em>(Note: This is sometimes searched for as the &ldquo;Absch-Umeyama algorithm&rdquo; due to typos, but the correct attribution is to Shinji Umeyama based on Wolfgang Kabsch&rsquo;s work.)</em></p>
<p>The method estimates the transformation $\mathbf{q}_i \approx c R \mathbf{p}_i + \mathbf{t}$. The optimal scale is the trace of the (reflection-corrected) singular values of the cross-covariance divided by the variance of the source points about their centroid. See the <a href="/notes/biology/computational-biology/umeyama-similarity-transformation/">Umeyama paper notes</a> for the full derivation.</p>
<p><strong>A Note on SVD and Automatic Differentiation</strong></p>
<p>While modern frameworks allow us to backpropagate through the Singular Value Decomposition (SVD), it comes with a known stability issue: if the cross-covariance matrix has identical (degenerate) singular values (which can occur if the point clouds are perfectly aligned or have certain symmetries), the gradient of the SVD approaches infinity, causing <code>NaN</code> values during backpropagation. If you plan to use this algorithm as a loss function for a neural network, it is often necessary to add a tiny epsilon to the matrix before computing the SVD, or to utilize an SVD gradient patch. The <a href="/projects/kabsch-horn-cookbook/">Kabsch-Horn Cookbook</a> library provides a SafeSVD primitive that floors the singular-value-gap denominator at machine epsilon in the backward pass, producing finite gradients at degenerate inputs across PyTorch, JAX, TensorFlow, and MLX.</p>
<h2 id="implementation">Implementation</h2>
<p>Let&rsquo;s implement the algorithm in different frameworks. Note that for simplicity, the following implementations cover the <strong>unweighted</strong> Kabsch algorithm. If your application (like molecular dynamics) requires weights (e.g., atomic masses), the <a href="/projects/kabsch-horn-cookbook/">Kabsch-Horn Cookbook</a> library provides per-point weighted alignment out of the box.</p>
<h3 id="numpy">NumPy</h3>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> numpy <span style="color:#66d9ef">as</span> np
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_numpy</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>mean(P, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>mean(Q, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>dot(p<span style="color:#f92672">.</span>T, q)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    U, S, Vt <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Validate right-handed coordinate system</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> np<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>det(np<span style="color:#f92672">.</span>dot(Vt<span style="color:#f92672">.</span>T, U<span style="color:#f92672">.</span>T)) <span style="color:#f92672">&lt;</span> <span style="color:#ae81ff">0.0</span>:
</span></span><span style="display:flex;"><span>        Vt[<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, :] <span style="color:#f92672">*=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">1.0</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal rotation</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>dot(Vt<span style="color:#f92672">.</span>T, U<span style="color:#f92672">.</span>T)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> centroid_Q <span style="color:#f92672">-</span> np<span style="color:#f92672">.</span>dot(R, centroid_P)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>sqrt(np<span style="color:#f92672">.</span>sum(np<span style="color:#f92672">.</span>square(np<span style="color:#f92672">.</span>dot(p, R<span style="color:#f92672">.</span>T) <span style="color:#f92672">-</span> q)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div><p>Here&rsquo;s a quick test to verify correctness:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">test_numpy</span>():
</span></span><span style="display:flex;"><span>    np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>seed(<span style="color:#ae81ff">12345</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    P <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>randn(<span style="color:#ae81ff">100</span>, <span style="color:#ae81ff">3</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    alpha <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>rand() <span style="color:#f92672">*</span> <span style="color:#ae81ff">2</span> <span style="color:#f92672">*</span> np<span style="color:#f92672">.</span>pi
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>array([[np<span style="color:#f92672">.</span>cos(alpha), <span style="color:#f92672">-</span>np<span style="color:#f92672">.</span>sin(alpha), <span style="color:#ae81ff">0</span>],
</span></span><span style="display:flex;"><span>                    [np<span style="color:#f92672">.</span>sin(alpha), np<span style="color:#f92672">.</span>cos(alpha), <span style="color:#ae81ff">0</span>],
</span></span><span style="display:flex;"><span>                    [<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>]])
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>randn(<span style="color:#ae81ff">3</span>) <span style="color:#f92672">*</span> <span style="color:#ae81ff">10</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    Q <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>dot(P, R<span style="color:#f92672">.</span>T) <span style="color:#f92672">+</span> t
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    R_opt, t_opt, rmsd <span style="color:#f92672">=</span> kabsch_numpy(P, Q)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">&#39;RMSD: </span><span style="color:#e6db74">{}</span><span style="color:#e6db74">&#39;</span><span style="color:#f92672">.</span>format(rmsd))
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">&#39;R:</span><span style="color:#ae81ff">\n</span><span style="color:#e6db74">{}</span><span style="color:#e6db74">&#39;</span><span style="color:#f92672">.</span>format(R))
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">&#39;R_opt:</span><span style="color:#ae81ff">\n</span><span style="color:#e6db74">{}</span><span style="color:#e6db74">&#39;</span><span style="color:#f92672">.</span>format(R_opt))
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">&#39;t:</span><span style="color:#ae81ff">\n</span><span style="color:#e6db74">{}</span><span style="color:#e6db74">&#39;</span><span style="color:#f92672">.</span>format(t))
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">&#39;t_opt:</span><span style="color:#ae81ff">\n</span><span style="color:#e6db74">{}</span><span style="color:#e6db74">&#39;</span><span style="color:#f92672">.</span>format(t_opt))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    l2_t <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>norm(t <span style="color:#f92672">-</span> t_opt)
</span></span><span style="display:flex;"><span>    l2_R <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>norm(R <span style="color:#f92672">-</span> R_opt)
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">&#39;l2_t: </span><span style="color:#e6db74">{}</span><span style="color:#e6db74">&#39;</span><span style="color:#f92672">.</span>format(l2_t))
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">&#39;l2_R: </span><span style="color:#e6db74">{}</span><span style="color:#e6db74">&#39;</span><span style="color:#f92672">.</span>format(l2_R))
</span></span></code></pre></div><p>Running this test shows the algorithm correctly recovers the rotation and translation:</p>
<pre><code>RMSD: 3.2111501877699246e-15
R:
[[-0.8475392 -0.5307328  0.       ]
 [ 0.5307328 -0.8475392  0.       ]
 [ 0.         0.         1.       ]]
R_opt:
[[-8.47539198e-01 -5.30732803e-01 -2.95434260e-16]
 [ 5.30732803e-01 -8.47539198e-01  2.92859649e-16]
 [ 0.00000000e+00 -2.77555756e-16  1.00000000e+00]]
t:
[ 5.99726796  1.50078468 -3.34633977]
t_opt:
[ 5.99726796  1.50078468 -3.34633977]
l2_t: 2.7012892057857038e-15
l2_R: 8.028174304721057e-16
</code></pre>
<p>Both the rotation and the translation are recovered to within floating-point precision (the residuals <code>l2_t</code> and <code>l2_R</code> are on the order of <code>1e-15</code>).</p>
<p>For batch processing:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_numpy_batched</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A BxNx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A BxNx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>mean(P, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)  <span style="color:#75715e"># Bx1x3</span>
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>mean(Q, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)  <span style="color:#75715e"># Bx1x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P  <span style="color:#75715e"># BxNx3</span>
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q  <span style="color:#75715e"># BxNx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>matmul(p<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>), q)  <span style="color:#75715e"># Bx3x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    U, S, Vt <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)  <span style="color:#75715e"># Bx3x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Validate right-handed coordinate system</span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>det(np<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>), U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>)))
</span></span><span style="display:flex;"><span>    flip <span style="color:#f92672">=</span> d <span style="color:#f92672">&lt;</span> <span style="color:#ae81ff">0.0</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> flip<span style="color:#f92672">.</span>any():
</span></span><span style="display:flex;"><span>        Vt[flip, <span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, :] <span style="color:#f92672">*=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">1.0</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal rotation</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>), U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>))  <span style="color:#75715e"># Bx3x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> centroid_Q<span style="color:#f92672">.</span>squeeze(<span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> np<span style="color:#f92672">.</span>matmul(centroid_P, R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>))<span style="color:#f92672">.</span>squeeze(<span style="color:#ae81ff">1</span>)  <span style="color:#75715e"># Bx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>sqrt(np<span style="color:#f92672">.</span>sum(np<span style="color:#f92672">.</span>square(np<span style="color:#f92672">.</span>matmul(p, R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>)) <span style="color:#f92672">-</span> q), axis<span style="color:#f92672">=</span>(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">1</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div><h3 id="pytorch">PyTorch</h3>


<p><details >
  <summary markdown="span">📝 Important Update (February 15, 2026)</summary>
  <strong>Bug Fix Notice:</strong> The PyTorch implementation has been updated to use the &ldquo;B-matrix&rdquo; broadcasting approach. This eliminates in-place tensor modification (which breaks <code>autograd</code>) and data-dependent control flow (which breaks <code>torch.compile</code> and <code>torch.vmap</code>).
</details></p>

<p>The PyTorch implementation now uses broadcasting to ensure differentiability:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_torch</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>mean(P, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>mean(Q, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>matmul(p<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>), q)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    U, S, Vt <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 1. Calculate determinant</span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>det(torch<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>), U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>)))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 2. Build diagonal B tensor without in-place mutation</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># We use stack to preserve gradients and graph connections</span>
</span></span><span style="display:flex;"><span>    B_diag <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>stack([torch<span style="color:#f92672">.</span>tensor(<span style="color:#ae81ff">1.0</span>, device<span style="color:#f92672">=</span>d<span style="color:#f92672">.</span>device, dtype<span style="color:#f92672">=</span>d<span style="color:#f92672">.</span>dtype),
</span></span><span style="display:flex;"><span>                          torch<span style="color:#f92672">.</span>tensor(<span style="color:#ae81ff">1.0</span>, device<span style="color:#f92672">=</span>d<span style="color:#f92672">.</span>device, dtype<span style="color:#f92672">=</span>d<span style="color:#f92672">.</span>dtype),
</span></span><span style="display:flex;"><span>                          torch<span style="color:#f92672">.</span>sign(d)])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 3. Scale columns of Vt.T via broadcasting, then multiply by U^T</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Vt.T: (3, 3). B_diag: (3) -&gt; B_diag[None, :]: (1, 3)</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>) <span style="color:#f92672">*</span> B_diag[<span style="color:#66d9ef">None</span>, :], U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> centroid_Q <span style="color:#f92672">-</span> centroid_P <span style="color:#f92672">@</span> R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>sqrt(torch<span style="color:#f92672">.</span>sum(torch<span style="color:#f92672">.</span>square(torch<span style="color:#f92672">.</span>matmul(p, R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>)) <span style="color:#f92672">-</span> q)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div><p>And our batched version:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_torch_batched</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD, in a batched manner.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A BxNx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A BxNx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>mean(P, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)  <span style="color:#75715e"># Bx1x3</span>
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>mean(Q, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)  <span style="color:#75715e"># Bx1x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P  <span style="color:#75715e"># BxNx3</span>
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q  <span style="color:#75715e"># BxNx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>matmul(p<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>), q)  <span style="color:#75715e"># Bx3x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    U, S, Vt <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)  <span style="color:#75715e"># Bx3x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 1. Calculate batched determinant</span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>det(torch<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>), U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)))  <span style="color:#75715e"># B</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 2. Build batched B_diag without in-place mutation or control flow</span>
</span></span><span style="display:flex;"><span>    ones <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>ones_like(d)
</span></span><span style="display:flex;"><span>    B_diag <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>stack([ones, ones, torch<span style="color:#f92672">.</span>sign(d)], dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>) <span style="color:#75715e"># Bx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 3. Scale columns of Vt.T and multiply</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Vt.T: (B, 3, 3). B_diag: (B, 3). B_diag[:, None, :]: (B, 1, 3).</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>) <span style="color:#f92672">*</span> B_diag[:, <span style="color:#66d9ef">None</span>, :], U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> centroid_Q<span style="color:#f92672">.</span>squeeze(<span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> torch<span style="color:#f92672">.</span>matmul(centroid_P, R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>))<span style="color:#f92672">.</span>squeeze(<span style="color:#ae81ff">1</span>)  <span style="color:#75715e"># Bx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>sqrt(torch<span style="color:#f92672">.</span>sum(torch<span style="color:#f92672">.</span>square(torch<span style="color:#f92672">.</span>matmul(p, R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)) <span style="color:#f92672">-</span> q), dim<span style="color:#f92672">=</span>(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">1</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div><h3 id="tensorflow">TensorFlow</h3>
<p>The TensorFlow implementation returns <code>S</code>, <code>U</code>, and <code>V</code> directly. To handle immutability and potential compilation (e.g., via <code>@tf.function</code>), we avoid explicit conditional branching by constructing a correction matrix $B$ and broadcasting it.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> tensorflow <span style="color:#66d9ef">as</span> tf
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_tensorflow</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    P <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>convert_to_tensor(P, dtype<span style="color:#f92672">=</span>tf<span style="color:#f92672">.</span>float32)
</span></span><span style="display:flex;"><span>    Q <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>convert_to_tensor(Q, dtype<span style="color:#f92672">=</span>tf<span style="color:#f92672">.</span>float32)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>reduce_mean(P, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>reduce_mean(Q, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>matmul(tf<span style="color:#f92672">.</span>transpose(p), q)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    S, U, V <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 1. Calculate determinant</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Note: V in TF SVD is V, not V^T.</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># R = V * U^T. Det(R) = Det(V * U^T)</span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>det(tf<span style="color:#f92672">.</span>matmul(V, tf<span style="color:#f92672">.</span>transpose(U)))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 2. Build diagonal B tensor: [1.0, 1.0, sign(d)]</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Use static shape 3 if possible, or infer from D. Assuming D=3 here.</span>
</span></span><span style="display:flex;"><span>    B_diag <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>stack([<span style="color:#ae81ff">1.0</span>, <span style="color:#ae81ff">1.0</span>, tf<span style="color:#f92672">.</span>sign(d)])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 3. Scale columns of V via broadcasting (V * B_diag), then multiply by U^T</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># V is DxD, B_diag is D. V * B_diag[None, :] multiplies each column j by B_diag[j]</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>matmul(V <span style="color:#f92672">*</span> B_diag[<span style="color:#66d9ef">None</span>, :], tf<span style="color:#f92672">.</span>transpose(U))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> centroid_Q <span style="color:#f92672">-</span> tf<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>matvec(R, centroid_P)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>sqrt(tf<span style="color:#f92672">.</span>reduce_sum(tf<span style="color:#f92672">.</span>square(tf<span style="color:#f92672">.</span>matmul(p, tf<span style="color:#f92672">.</span>transpose(R)) <span style="color:#f92672">-</span> q)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div><p>and a batched version:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_tensorflow_batched</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    P <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>convert_to_tensor(P, dtype<span style="color:#f92672">=</span>tf<span style="color:#f92672">.</span>float32)
</span></span><span style="display:flex;"><span>    Q <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>convert_to_tensor(Q, dtype<span style="color:#f92672">=</span>tf<span style="color:#f92672">.</span>float32)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>reduce_mean(P, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>reduce_mean(Q, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>matmul(tf<span style="color:#f92672">.</span>transpose(p, perm<span style="color:#f92672">=</span>[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>]), q)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    S, U, V <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 1. Calculate batched determinant</span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>det(tf<span style="color:#f92672">.</span>matmul(V, tf<span style="color:#f92672">.</span>transpose(U, perm<span style="color:#f92672">=</span>[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>])))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 2. Build batched B_diag: shape (B, 3)</span>
</span></span><span style="display:flex;"><span>    ones <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>ones_like(d)
</span></span><span style="display:flex;"><span>    B_diag <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>stack([ones, ones, tf<span style="color:#f92672">.</span>sign(d)], axis<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 3. Scale columns of V (Broadcasting adds the middle dimension)</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># V: (B, 3, 3), B_diag: (B, 3) -&gt; B_diag[:, None, :]: (B, 1, 3)</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>matmul(V <span style="color:#f92672">*</span> B_diag[:, <span style="color:#66d9ef">None</span>, :], tf<span style="color:#f92672">.</span>transpose(U, perm<span style="color:#f92672">=</span>[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>]))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>squeeze(centroid_Q, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> tf<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>matvec(R, tf<span style="color:#f92672">.</span>squeeze(centroid_P, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>))  <span style="color:#75715e"># Bx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>sqrt(tf<span style="color:#f92672">.</span>reduce_sum(tf<span style="color:#f92672">.</span>square(tf<span style="color:#f92672">.</span>matmul(p, tf<span style="color:#f92672">.</span>transpose(R, perm<span style="color:#f92672">=</span>[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>])) <span style="color:#f92672">-</span> q), axis<span style="color:#f92672">=</span>(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">1</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div><h3 id="jax">JAX</h3>
<p>The JAX implementation closely mirrors NumPy, replacing <code>np</code> with <code>jnp</code>. However, we again avoid <code>if</code> statements and in-place assignment (which JAX disallows) by using the broadcasting B-matrix approach.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> jax.numpy <span style="color:#66d9ef">as</span> jnp
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_jax</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    P <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>array(P)
</span></span><span style="display:flex;"><span>    Q <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>array(Q)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>mean(P, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>mean(Q, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>dot(p<span style="color:#f92672">.</span>T, q)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    U, S, Vt <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 1. Calculate determinant</span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>det(jnp<span style="color:#f92672">.</span>dot(Vt<span style="color:#f92672">.</span>T, U<span style="color:#f92672">.</span>T))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 2. Build diagonal B array</span>
</span></span><span style="display:flex;"><span>    B_diag <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>array([<span style="color:#ae81ff">1.0</span>, <span style="color:#ae81ff">1.0</span>, jnp<span style="color:#f92672">.</span>sign(d)])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 3. Scale columns of Vt.T and multiply by U.T</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Vt.T is V.</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>dot(Vt<span style="color:#f92672">.</span>T <span style="color:#f92672">*</span> B_diag[<span style="color:#66d9ef">None</span>, :], U<span style="color:#f92672">.</span>T)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> centroid_Q <span style="color:#f92672">-</span> jnp<span style="color:#f92672">.</span>dot(R, centroid_P)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>sqrt(jnp<span style="color:#f92672">.</span>sum(jnp<span style="color:#f92672">.</span>square(jnp<span style="color:#f92672">.</span>dot(p, R<span style="color:#f92672">.</span>T) <span style="color:#f92672">-</span> q)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div><p>and batched:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_jax_batched</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A BxNx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A BxNx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    P <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>array(P)
</span></span><span style="display:flex;"><span>    Q <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>array(Q)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>mean(P, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)  <span style="color:#75715e"># Bx1x3</span>
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>mean(Q, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)  <span style="color:#75715e"># Bx1x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P  <span style="color:#75715e"># BxNx3</span>
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q  <span style="color:#75715e"># BxNx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>matmul(p<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>), q)  <span style="color:#75715e"># Bx3x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    U, S, Vt <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)  <span style="color:#75715e"># Bx3x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 1. Calculate batched determinant</span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>det(jnp<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>), U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>)))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 2. Build batched B_diag</span>
</span></span><span style="display:flex;"><span>    ones <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>ones_like(d)
</span></span><span style="display:flex;"><span>    B_diag <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>stack([ones, ones, jnp<span style="color:#f92672">.</span>sign(d)], axis<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 3. Scale columns of Vt.T and multiply by U.T</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Vt.T: (B, 3, 3). B_diag: (B, 3).</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>) <span style="color:#f92672">*</span> B_diag[:, <span style="color:#66d9ef">None</span>, :], U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> centroid_Q<span style="color:#f92672">.</span>squeeze(<span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> jnp<span style="color:#f92672">.</span>matmul(centroid_P, R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>))<span style="color:#f92672">.</span>squeeze(<span style="color:#ae81ff">1</span>)  <span style="color:#75715e"># Bx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>sqrt(jnp<span style="color:#f92672">.</span>sum(jnp<span style="color:#f92672">.</span>square(jnp<span style="color:#f92672">.</span>matmul(p, R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>)) <span style="color:#f92672">-</span> q), axis<span style="color:#f92672">=</span>(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">1</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/scientific-computing/kabsch-animated-protein-conformational-alignment-analysis.webp"
         alt="Animation of a protein structure being aligned using the Kabsch algorithm"
         title="Animation of a protein structure being aligned using the Kabsch algorithm"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Real-world application: Aligning protein conformations to analyze structural changes.</figcaption>
    
</figure>

<h2 id="extensions">Extensions</h2>
<p>The Kabsch algorithm has several important extensions that go beyond the formulation dealt with here:</p>
<ul>
<li><strong>Quaternion Form</strong>: The algorithm can be reformulated using quaternions for better numerical stability, particularly useful in applications requiring high precision.</li>
<li><strong>Iterative Versions</strong>: More robust variants that handle noise better and have improved scaling properties for large point sets. This also can be advantageous for setups with limited computational resources.</li>
<li><strong>Weighted Kabsch</strong>: Extensions that incorporate point weights (e.g., atomic masses in molecular dynamics). While SciPy provides a <a href="https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.align_vectors.html#scipy.spatial.transform.Rotation.align_vectors">weighted version</a>, it lacks batch processing capabilities.</li>
<li><strong>The Umeyama Algorithm</strong>: If your point sets are rotated, translated, and scaled differently, the Umeyama algorithm is the direct extension of Kabsch. It solves the same optimization problem but introduces a scaling factor $c$, finding the optimal alignment for $Q \approx c R P + t$.</li>
</ul>
<p>Several of these extensions are implemented in the <a href="/projects/kabsch-horn-cookbook/">Kabsch-Horn Cookbook</a> library, which provides differentiable Kabsch, Horn, and Umeyama alignment across NumPy, PyTorch, JAX, TensorFlow, and MLX.</p>
<h2 id="further-reading">Further Reading</h2>
<ul>
<li><a href="https://en.wikipedia.org/wiki/Kabsch_algorithm">Wikipedia, Kabsch Algorithm</a></li>
<li><a href="https://zalo.github.io/blog/kabsch/">Zalo on Kabsch</a>: An interactive shape matching demo.</li>
</ul>
<h3 id="original-papers">Original Papers</h3>
<ul>
<li><strong>[Kabsch 1976]</strong> Kabsch, W. (1976). &ldquo;A solution for the best rotation to relate two sets of vectors.&rdquo; <em>Acta Crystallographica Section A</em>, 32(5), 922-923. <a href="https://doi.org/10.1107/S0567739476001873">DOI: 10.1107/S0567739476001873</a>
<em>The original paper: a closed-form, non-iterative optimal-rotation solution derived via Lagrange multipliers and eigendecomposition of $\tilde{R}R$ (the SVD reformulation came later; see Arun et al. 1987).</em> See also: <a href="/notes/biology/computational-biology/kabsch-algorithm/">paper notes</a>.</li>
<li><strong>[Kabsch 1978]</strong> Kabsch, W. (1978). &ldquo;A discussion of the solution for the best rotation to relate two sets of vectors.&rdquo; <em>Acta Crystallographica Section A</em>, 34(5), 827-828. <a href="https://doi.org/10.1107/S0567739478001680">DOI: 10.1107/S0567739478001680</a>
<em>The follow-up paper correcting for improper rotations (reflections).</em></li>
<li><strong>[Arun et al. 1987]</strong> Arun, K. S., Huang, T. S., &amp; Blostein, S. D. (1987). &ldquo;Least-Squares Fitting of Two 3-D Point Sets.&rdquo; <em>IEEE Transactions on Pattern Analysis and Machine Intelligence</em>, PAMI-9(5), 698-700. <a href="https://doi.org/10.1109/TPAMI.1987.4767965">DOI: 10.1109/TPAMI.1987.4767965</a>
<em>The first SVD-based formulation for 3D point set alignment.</em> See also: <a href="/notes/biology/computational-biology/arun-svd-point-fitting/">paper notes</a>.</li>
<li><strong>[Horn et al. 1988]</strong> Horn, B. K. P., Hilden, H. M., &amp; Negahdaripour, S. (1988). &ldquo;Closed-form solution of absolute orientation using orthonormal matrices.&rdquo; <em>Journal of the Optical Society of America A</em>, 5(7), 1127-1135. <a href="https://doi.org/10.1364/JOSAA.5.001127">DOI: 10.1364/JOSAA.5.001127</a>
<em>The matrix square root (polar decomposition) approach to the same problem.</em> See also: <a href="/notes/biology/computational-biology/horn-orthonormal-matrices/">paper notes</a>.</li>
<li><strong>[Horn 1987]</strong> Horn, B. K. P. (1987). &ldquo;Closed-form solution of absolute orientation using unit quaternions.&rdquo; <em>Journal of the Optical Society of America A</em>, 4(4), 629-642. <a href="https://doi.org/10.1364/JOSAA.4.000629">DOI: 10.1364/JOSAA.4.000629</a>
<em>An alternative quaternion-based closed-form solution that also handles scale.</em> See also: <a href="/notes/biology/computational-biology/horn-absolute-orientation/">paper notes</a>.</li>
<li><strong>[Umeyama 1991]</strong> Umeyama, S. (1991). &ldquo;Least-squares estimation of transformation parameters between two point patterns.&rdquo; <em>IEEE Transactions on Pattern Analysis and Machine Intelligence</em>, 13(4), 376-380. <a href="https://doi.org/10.1109/34.88573">DOI: 10.1109/34.88573</a>
<em>The extension of the algorithm to include optimal scaling in addition to rotation and translation.</em> See also: <a href="/notes/biology/computational-biology/umeyama-similarity-transformation/">paper notes</a>.</li>
</ul>
]]></content:encoded></item><item><title>LAMMPS Tutorial: Copper and Platinum Adatom Diffusion</title><link>https://hunterheidenreich.com/posts/adatom-cu-diffusion/</link><pubDate>Wed, 27 Sep 2023 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/adatom-cu-diffusion/</guid><description>LAMMPS tutorial for copper and platinum surface diffusion simulation and ML training data generation. Includes setup, analysis, and Ovito visualization.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>Understanding how individual atoms move on crystal surfaces is fundamental to materials science, catalysis, and nanotechnology. This atomic-scale motion, called adatom diffusion, drives processes like thin film growth and surface chemical reactions.</p>
<p>While learning molecular dynamics simulations for my graduate work, I discovered these simulations generate valuable training data for machine learning models. This tutorial walks through simulating copper adatom diffusion on a Cu(100) surface using LAMMPS, building on Eric N. Hahn&rsquo;s excellent <a href="https://www.ericnhahn.com/tutorials/lammps-tutorials/adatom">adatom tutorial</a>.</p>
<p><strong>What you&rsquo;ll learn:</strong></p>
<ul>
<li>Setting up LAMMPS for surface diffusion simulations</li>
<li>Understanding simulation parameters and their impact</li>
<li>Visualizing results with Ovito</li>
<li>Analyzing trajectory data for ML applications</li>
<li>Connecting simulation data to machine learning workflows</li>
</ul>
<p>In this tutorial, we will explore both Copper (Cu) and Platinum (Pt) to show how atomic properties affect diffusion behavior, generating data for training element-aware ML models.</p>
<h2 id="prerequisites">Prerequisites</h2>
<p>Before starting this tutorial, you&rsquo;ll need:</p>
<ul>
<li><strong>LAMMPS</strong> with EAM potential support (version 2020 or later recommended)</li>
<li><strong>Python 3.x</strong> with matplotlib for analysis scripts</li>
<li><strong>Ovito</strong> (free version) for trajectory visualization</li>
<li><strong>Cu01.eam.alloy</strong> potential file from the <a href="https://www.ctcms.nist.gov/potentials/">NIST repository</a></li>
<li>Basic familiarity with molecular dynamics concepts (atoms, forces, timesteps)</li>
</ul>
<h2 id="understanding-adatoms-and-surface-diffusion">Understanding Adatoms and Surface Diffusion</h2>
<h3 id="what-is-an-adatom">What is an Adatom?</h3>
<p>An <strong>adatom</strong> (adsorbed atom) sits on a crystal surface but isn&rsquo;t incorporated into the bulk structure. Adatoms have fewer bonds than fully coordinated bulk atoms, making them highly mobile and reactive.</p>















<figure class="post-figure center ">
    <img src="/img/posts/crystal-surface.webp"
         alt="Ball model representation of a real (atomically rough) crystal surface with steps, kinks, adatoms, and vacancies in a closely-packed crystalline material. Adsorbed molecules, substitutional and interstitial atoms are also illustrated."
         title="Ball model representation of a real (atomically rough) crystal surface with steps, kinks, adatoms, and vacancies in a closely-packed crystalline material. Adsorbed molecules, substitutional and interstitial atoms are also illustrated."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Ball model representation of a real (atomically rough) crystal surface with steps, kinks, adatoms, and vacancies in a closely-packed crystalline material. Adsorbed molecules, substitutional and interstitial atoms are also illustrated. (<a href="https://creativecommons.org/licenses/by-sa/4.0/deed.en">CC-BY-SA-4.0: ShutterWaves</a>)</figcaption>
    
</figure>

<h3 id="why-study-adatom-diffusion">Why Study Adatom Diffusion?</h3>
<p>Adatom diffusion is important for several technological processes:</p>
<ul>
<li><strong>Thin film growth</strong>: Adatoms are the building blocks of deposited films</li>
<li><strong>Catalysis</strong>: Many reactions happen at these mobile surface atoms</li>
<li><strong>Corrosion</strong>: How surface atoms move affects material degradation</li>
<li><strong>Self-assembly</strong>: Adatom movement enables formation of ordered structures</li>
</ul>
<p>From a <strong>machine learning perspective</strong>, adatom diffusion is an ideal test case because:</p>
<ul>
<li>Well-understood physics provides ground truth for validation</li>
<li>Small system size enables extensive simulation</li>
<li>Behavior varies significantly with temperature and atomic species</li>
<li>Systematic data generation across different conditions</li>
</ul>
<h3 id="why-cu100">Why Cu(100)?</h3>
<p>Cu(100) surfaces are well-studied in literature, making them excellent benchmarks. The face-centered cubic (fcc) structure creates clear diffusion pathways, and copper&rsquo;s moderate binding energy lets us observe diffusion at reasonable temperatures without extreme computational demands.</p>
<h2 id="simulation-overview">Simulation Overview</h2>
<p>Before diving into the code details, let&rsquo;s understand the simulation design:</p>
<h3 id="key-simulation-parameters">Key Simulation Parameters</h3>
<table>
  <thead>
      <tr>
          <th>Parameter</th>
          <th>Value</th>
          <th>Why this choice</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>System size</strong></td>
          <td>$8 \x8 \x6$ unit cells</td>
          <td>Large enough to avoid edge effects while keeping simulation time reasonable</td>
      </tr>
      <tr>
          <td><strong>Ensemble</strong></td>
          <td>NVT (constant volume, temperature)</td>
          <td>Appropriate for surface studies where pressure isn&rsquo;t the focus</td>
      </tr>
      <tr>
          <td><strong>Potential</strong></td>
          <td>EAM (Embedded Atom Method)</td>
          <td>Captures metallic bonding better than simple pair potentials</td>
      </tr>
      <tr>
          <td><strong>Time step</strong></td>
          <td>5 fs</td>
          <td>Small enough for numerical stability while allowing reasonable run times</td>
      </tr>
      <tr>
          <td><strong>Duration</strong></td>
          <td>500 ps</td>
          <td>Long enough to see multiple diffusion events</td>
      </tr>
      <tr>
          <td><strong>Temperature</strong></td>
          <td>600 K initial seed; 850 K thermostat on the bottom reservoir layer</td>
          <td>Drives thermal energy up from the substrate into the free surface where the adatom diffuses</td>
      </tr>
  </tbody>
</table>
<h3 id="simulation-strategy">Simulation Strategy</h3>
<p>The approach uses a <strong>thermal gradient setup</strong>:</p>
<ul>
<li>Bottom layers: Fixed to represent bulk crystal</li>
<li>Middle layers: Heated to 850 K for thermal energy</li>
<li>Top layers and adatom: Equilibrate to $\sim 600$ K for diffusion</li>
<li>This lets thermal energy propagate up from the heated reservoir to the free surface where the adatom diffuses</li>
</ul>
<p>The complete LAMMPS script implementing this approach:</p>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">### Original Created by Eric N. Hahn  ###
### ericnhahn@gmail.com ###

### Modifications by Hunter Heidenreich, CSE lab (Harvard, 2023)
### hheidenreich@g.harvard.edu
### 2023-09-01

### Simulating adatoms ###
### Version 0.2 ###


units metal
dimension 3
boundary p p s
atom_style atomic

lattice fcc 3.614
variable cubel equal 4
variable fixer1 equal &#34;v_cubel+2&#34;
variable fixer2 equal &#34;v_cubel+1.49&#34;
region  box block -${cubel} ${cubel} -${cubel} ${cubel} -${fixer1} 1 units lattice
region cbox block -${cubel} ${cubel} -${cubel} ${cubel} -${fixer1} 0 units lattice
create_box 1 box
create_atoms 1 region cbox
create_atoms 1 single -0.5 0 0.5 units lattice
region hold block INF INF INF INF -${fixer1} -${fixer2} units lattice
region temp block INF INF INF INF -${fixer2} -${cubel} units lattice
group hold region hold
group temp region temp

pair_style eam/alloy
pair_coeff * * Cu01.eam.alloy Cu

timestep        0.005
compute         new all temp
velocity        temp create 600 12345
fix heater temp temp/rescale 1 850 850 5 1
fix nve all nve
fix freeze hold setforce 0 0 0

variable e     equal pe
variable k     equal ke
variable t     equal etotal
variable T     equal temp
fix energy all ave/time 1 50 50 v_k v_e v_t v_T file energy_avg.txt

minimize 1.0e-4 1.0e-6 1000 10000

dump eve all custom 5 dump.lammpstrj id type xu yu zu   # fx fy fz  # uncomment for forces
dump_modify eve sort id

thermo 50
run 100000  # 100_000 * 5 fs = 500 ps
</code></pre><h2 id="line-by-line-breakdown">Line-by-Line Breakdown</h2>
<p>Let&rsquo;s examine each part of the LAMMPS script:</p>
<h3 id="simulation-setup">Simulation Setup</h3>
<h4 id="units">Units</h4>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">units metal
</code></pre><p>Sets simulation units to &ldquo;metal&rdquo; units (a standard choice for metallic systems). Key conversions: length in $\text{\AA}$, energy in eV, time in ps. Full details in the <a href="https://docs.lammps.org/units.html">LAMMPS documentation</a>.</p>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">dimension 3
</code></pre><p>Sets 3D simulation.</p>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">boundary p p s
</code></pre><p>Boundary conditions: periodic in x,y (infinite surface) and shrink-wrapped in z (finite surface height). This allows the adatom to potentially leave the surface if needed.</p>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">atom_style atomic
</code></pre><p>Uses &ldquo;atomic&rdquo; style, atoms as point masses without internal structure. Standard for metallic systems.</p>
<h4 id="lattice">Lattice</h4>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">lattice fcc 3.614
</code></pre><p>Defines face-centered cubic lattice with experimental Cu lattice constant ($3.614 \text{ \AA}$).</p>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">variable cubel equal 4
variable fixer1 equal &#34;v_cubel+2&#34;
variable fixer2 equal &#34;v_cubel+1.49&#34;
</code></pre><p>Define variables for simulation box dimensions. <code>cubel=4</code> sets system size, while <code>fixer1</code> and <code>fixer2</code> define the frozen and heated regions.</p>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">region  box block -${cubel} ${cubel} -${cubel} ${cubel} -${fixer1} 1 units lattice
region cbox block -${cubel} ${cubel} -${cubel} ${cubel} -${fixer1} 0 units lattice
</code></pre><p>Define regions: <code>box</code> for the entire simulation volume and <code>cbox</code> for crystal creation (excludes the surface layer where we&rsquo;ll place the adatom).</p>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">create_box 1 box
create_atoms 1 region cbox
create_atoms 1 single -0.5 0 0.5 units lattice
</code></pre><p>Create simulation box, populate with Cu atoms, then add single adatom at specified position.</p>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">region hold block INF INF INF INF -${fixer1} -${fixer2} units lattice
region temp block INF INF INF INF -${fixer2} -${cubel} units lattice
group hold region hold
group temp region temp
</code></pre><p>Define atom groups: <code>hold</code> (frozen bottom layers) and <code>temp</code> (heated middle layers for thermal energy).</p>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">pair_style eam/alloy
pair_coeff * * Cu01.eam.alloy Cu
</code></pre><p>Use <a href="/notes/chemistry/molecular-simulation/classical-methods/embedded-atom-method/">Embedded Atom Method (EAM)</a> potential for metallic bonding. The Cu01.eam.alloy potential from <a href="https://doi.org/10.1103/PhysRevB.63.224106">Mishin et al.</a> is available from the <a href="https://www.ctcms.nist.gov/potentials/testing/entry/2001--Mishin-Y-Mehl-M-J-Papaconstantopoulos-D-A-et-al--Cu-1/">NIST repository</a>.</p>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">timestep        0.005
</code></pre><p>5 femtosecond timestep (small enough for numerical stability).</p>
<h4 id="initial-conditions">Initial Conditions</h4>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">velocity        temp create 600 12345
</code></pre><p>Initialize velocities for 600 K temperature using random seed 12345.</p>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">fix heater temp temp/rescale 1 850 850 5 1
fix nve all nve
fix freeze hold setforce 0 0 0
</code></pre><p>Three fixes control dynamics:</p>
<ul>
<li><code>heater</code>: Maintains 850 K in middle layers</li>
<li><code>nve</code>: Velocity Verlet integration for all atoms</li>
<li><code>freeze</code>: Sets forces to zero for bottom atoms</li>
</ul>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">variable e     equal pe
variable k     equal ke
variable t     equal etotal
variable T     equal temp
fix energy all ave/time 1 50 50 v_k v_e v_t v_T file energy_avg.txt
</code></pre><p>Track energies and temperature, averaging every 50 timesteps and writing to file.</p>
<h3 id="execution">Execution</h3>
<h4 id="minimization">Minimization</h4>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">minimize 1.0e-4 1.0e-6 1000 10000
</code></pre><p>Relax initial structure. Should converge quickly, indicating the system is already well-optimized.</p>
<h4 id="output-setup">Output Setup</h4>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">dump eve all custom 5 dump.lammpstrj id type xu yu zu   # fx fy fz  # uncomment for forces
dump_modify eve sort id
</code></pre><p>Write atomic positions every 5 timesteps, sorted by atom ID. Uncomment force components if needed for analysis.</p>
<h4 id="production-run">Production Run</h4>
<pre tabindex="0"><code class="language-lammps" data-lang="lammps">thermo 50
run 100000  # 100_000 * 5 fs = 500 ps
</code></pre><p>Run simulation for 500 ps with thermo output every 50 steps.</p>
<h2 id="visualization-and-analysis">Visualization and Analysis</h2>
<p>Visualize results using <a href="https://www.ovito.org/">Ovito</a>, a free atomistic visualization tool:</p>
<ol>
<li>Open the trajectory file in Ovito</li>
<li>Color atoms by z-coordinate</li>
<li>Restrict height range to $0\text{-}2 \text{ \AA}$ for surface focus</li>
<li>Animate to observe diffusion events</li>
</ol>
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/nIdbNqEEPys?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<h2 id="analysis-results">Analysis Results</h2>
<p>The simulation generates rich data for machine learning applications:</p>
<h3 id="energy-analysis">Energy Analysis</h3>
<p>Energy fluctuations reveal thermal motion patterns:</p>















<figure class="post-figure center ">
    <img src="/img/adatom_cu_energy_avg.webp"
         alt="Average kinetic energy, potential energy, total energy, and temperature over time."
         title="Average kinetic energy, potential energy, total energy, and temperature over time."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Energy and temperature evolution over 500 ps simulation.</figcaption>
    
</figure>

<p>Skipping the first 30 logged data points (each averaged over 50 timesteps, so the first ~1500 timesteps / 7.5 ps of equilibration), these fluctuations enable:</p>
<ul>
<li><strong>Anomaly detection</strong>: Identifying unusual diffusion events</li>
<li><strong>Temperature prediction</strong>: Estimating local temperature from atomic motion</li>
<li><strong>Stability analysis</strong>: Detecting equilibrium states</li>
</ul>
<h3 id="trajectory-analysis">Trajectory Analysis</h3>
<p>Adatom motion reveals diffusion mechanisms:</p>















<figure class="post-figure center ">
    <img src="/img/adatom_cu_xy.webp"
         alt="x and y coordinates of the adatom over time."
         title="x and y coordinates of the adatom over time."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Adatom surface trajectory showing random walk behavior.</figcaption>
    
</figure>

<p>This data enables:</p>
<ul>
<li><strong>Path prediction</strong>: Training models for future position forecasting</li>
<li><strong>Diffusion coefficient estimation</strong>: Learning temperature-mobility relationships</li>
<li><strong>Transition state identification</strong>: Detecting hops between stable sites</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/adatom_cu_z.webp"
         alt="z coordinate of the adatom over time."
         title="z coordinate of the adatom over time."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Height fluctuations revealing exchange events with surface atoms.</figcaption>
    
</figure>

<p>Z-coordinate data shows <strong>exchange events</strong> where the adatom swaps with surface atoms (crucial for surface chemistry understanding). This enables:</p>
<ul>
<li><strong>Event classification</strong>: Distinguishing diffusion vs. exchange mechanisms</li>
<li><strong>Activation barrier estimation</strong>: Learning energy landscapes from fluctuations</li>
<li><strong>Surface coordination analysis</strong>: Correlating height with local environment</li>
</ul>
<h3 id="machine-learning-applications">Machine Learning Applications</h3>
<p>This simulation produces multiple data types for ML training:</p>
<ol>
<li><strong>Coordinate trajectories</strong>: Neural network potential inputs or graph neural network features</li>
<li><strong>Energy time series</strong>: Regression model features for system property prediction</li>
<li><strong>Event annotations</strong>: Supervised learning labels for diffusion mechanism classification</li>
<li><strong>Environmental descriptors</strong>: Local atomic arrangement features</li>
</ol>
<p>Systematic MD simulations generate large, labeled datasets across varied conditions.</p>
<h2 id="extending-to-platinum-mass-and-bonding-effects">Extending to Platinum: Mass and Bonding Effects</h2>
<p>To understand how different elements behave, we can extend this framework to platinum (Pt). Platinum&rsquo;s higher atomic mass and stronger metallic bonding create notably different diffusion behavior, providing comparative data for machine learning.</p>
<h3 id="key-differences-from-copper">Key Differences from Copper</h3>
<table>
  <thead>
      <tr>
          <th>Parameter</th>
          <th>Copper (Cu)</th>
          <th>Platinum (Pt)</th>
          <th>Impact</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Atomic mass</strong></td>
          <td>63.5 u</td>
          <td>195.1 u</td>
          <td>Slower diffusion, longer correlation times</td>
      </tr>
      <tr>
          <td><strong>Lattice const.</strong></td>
          <td>3.614 Å</td>
          <td>3.96 Å</td>
          <td>Larger diffusion barriers, different pathways</td>
      </tr>
      <tr>
          <td><strong>Potential</strong></td>
          <td>Mishin et al.</td>
          <td>Zhou et al.</td>
          <td>Different interaction strengths</td>
      </tr>
      <tr>
          <td><strong>Melting point</strong></td>
          <td>1358 K</td>
          <td>2041 K</td>
          <td>Stronger surface binding</td>
      </tr>
  </tbody>
</table>
<h3 id="modifying-the-lammps-script">Modifying the LAMMPS Script</h3>
<p>The platinum simulation uses the exact same framework as the copper case, with three simple element-specific modifications:</p>
<ol>
<li><strong>Lattice constant</strong>: Change <code>lattice fcc 3.614</code> to <code>lattice fcc 3.96</code></li>
<li><strong>Potential file</strong>: Change <code>Cu01.eam.alloy</code> to <code>Pt_Zhou04.eam.alloy</code> (available from the <a href="https://www.ctcms.nist.gov/potentials/testing/entry/2004--Zhou-X-W-Johnson-R-A-Wadley-H-N-G--Pt/">NIST repository</a>)</li>
<li><strong>Element specification</strong>: Change <code>Cu</code> to <code>Pt</code> in the <code>pair_coeff</code> line</li>
</ol>
<p>These simple changes capture the essential physics differences between elements while maintaining the same simulation protocol, which is ideal for generating comparative datasets for ML training.</p>
<h3 id="expected-behavior-vs-copper">Expected Behavior vs. Copper</h3>
<p>When you run the analysis scripts on the platinum trajectory, you will observe:</p>
<ul>
<li><strong>Slower motion</strong>: Heavier atoms move more slowly at the same temperature. Platinum&rsquo;s ~3x greater mass reduces diffusion rates.</li>
<li><strong>Higher energy barriers</strong>: Stronger metallic bonding creates deeper potential wells, requiring more thermal energy for diffusion hops.</li>
<li><strong>Different pathways</strong>: The larger lattice constant changes the energy landscape, potentially favoring different diffusion mechanisms.</li>
</ul>
<p>Comparing Cu and Pt trajectories enables training element-aware models that account for atomic mass effects, binding strengths, and temperature scaling across different metals.</p>
<h2 id="code-and-data">Code and Data</h2>
<p>The complete simulation scripts and analysis tools are available for reproducibility:</p>
<h3 id="energy-analysis-script">Energy Analysis Script</h3>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Hunter Heidenreich, 2023</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Plots the energy of a simulation over time.</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> matplotlib.pyplot <span style="color:#66d9ef">as</span> plt
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> argparse <span style="color:#f92672">import</span> ArgumentParser
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">if</span> __name__ <span style="color:#f92672">==</span> <span style="color:#e6db74">&#39;__main__&#39;</span>:
</span></span><span style="display:flex;"><span>    parser <span style="color:#f92672">=</span> ArgumentParser()
</span></span><span style="display:flex;"><span>    parser<span style="color:#f92672">.</span>add_argument(<span style="color:#e6db74">&#39;--input&#39;</span>, type<span style="color:#f92672">=</span>str, required<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>    parser<span style="color:#f92672">.</span>add_argument(<span style="color:#e6db74">&#39;--output&#39;</span>, type<span style="color:#f92672">=</span>str, required<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>    parser<span style="color:#f92672">.</span>add_argument(<span style="color:#e6db74">&#39;--skip&#39;</span>, type<span style="color:#f92672">=</span>int, default<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    args <span style="color:#f92672">=</span> parser<span style="color:#f92672">.</span>parse_args()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Parse energy data</span>
</span></span><span style="display:flex;"><span>    data <span style="color:#f92672">=</span> {<span style="color:#e6db74">&#39;ts&#39;</span>: [], <span style="color:#e6db74">&#39;kes&#39;</span>: [], <span style="color:#e6db74">&#39;pes&#39;</span>: [], <span style="color:#e6db74">&#39;tes&#39;</span>: [], <span style="color:#e6db74">&#39;Ts&#39;</span>: []}
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">with</span> open(args<span style="color:#f92672">.</span>input, <span style="color:#e6db74">&#39;r&#39;</span>) <span style="color:#66d9ef">as</span> f:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">for</span> line <span style="color:#f92672">in</span> f:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">if</span> line<span style="color:#f92672">.</span>startswith(<span style="color:#e6db74">&#39;#&#39;</span>) <span style="color:#f92672">or</span> <span style="color:#f92672">not</span> line<span style="color:#f92672">.</span>strip():
</span></span><span style="display:flex;"><span>                <span style="color:#66d9ef">continue</span>
</span></span><span style="display:flex;"><span>            t, v_k, v_e, v_t, v_T <span style="color:#f92672">=</span> map(float, line<span style="color:#f92672">.</span>split())
</span></span><span style="display:flex;"><span>            data[<span style="color:#e6db74">&#39;ts&#39;</span>]<span style="color:#f92672">.</span>append(t)
</span></span><span style="display:flex;"><span>            data[<span style="color:#e6db74">&#39;kes&#39;</span>]<span style="color:#f92672">.</span>append(v_k)
</span></span><span style="display:flex;"><span>            data[<span style="color:#e6db74">&#39;pes&#39;</span>]<span style="color:#f92672">.</span>append(v_e)
</span></span><span style="display:flex;"><span>            data[<span style="color:#e6db74">&#39;tes&#39;</span>]<span style="color:#f92672">.</span>append(v_t)
</span></span><span style="display:flex;"><span>            data[<span style="color:#e6db74">&#39;Ts&#39;</span>]<span style="color:#f92672">.</span>append(v_T)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Skip initial equilibration</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> key <span style="color:#f92672">in</span> data:
</span></span><span style="display:flex;"><span>        data[key] <span style="color:#f92672">=</span> data[key][args<span style="color:#f92672">.</span>skip:]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Create subplots</span>
</span></span><span style="display:flex;"><span>    fig, axs <span style="color:#f92672">=</span> plt<span style="color:#f92672">.</span>subplots(<span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">2</span>, figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">16</span>, <span style="color:#ae81ff">12</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    plots <span style="color:#f92672">=</span> [(<span style="color:#e6db74">&#39;Kinetic Energy&#39;</span>, <span style="color:#e6db74">&#39;kes&#39;</span>), (<span style="color:#e6db74">&#39;Potential Energy&#39;</span>, <span style="color:#e6db74">&#39;pes&#39;</span>),
</span></span><span style="display:flex;"><span>             (<span style="color:#e6db74">&#39;Total Energy&#39;</span>, <span style="color:#e6db74">&#39;tes&#39;</span>), (<span style="color:#e6db74">&#39;Temperature&#39;</span>, <span style="color:#e6db74">&#39;Ts&#39;</span>)]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> ax, (title, key) <span style="color:#f92672">in</span> zip(axs<span style="color:#f92672">.</span>flat, plots):
</span></span><span style="display:flex;"><span>        ax<span style="color:#f92672">.</span>plot(data[<span style="color:#e6db74">&#39;ts&#39;</span>], data[key])
</span></span><span style="display:flex;"><span>        ax<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;TimeStep&#39;</span>)
</span></span><span style="display:flex;"><span>        ax<span style="color:#f92672">.</span>set_ylabel(title)
</span></span><span style="display:flex;"><span>        ax<span style="color:#f92672">.</span>set_title(title)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    plt<span style="color:#f92672">.</span>tight_layout()
</span></span><span style="display:flex;"><span>    plt<span style="color:#f92672">.</span>savefig(args<span style="color:#f92672">.</span>output, dpi<span style="color:#f92672">=</span><span style="color:#ae81ff">300</span>, bbox_inches<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;tight&#39;</span>)
</span></span></code></pre></div><h3 id="trajectory-analysis-script">Trajectory Analysis Script</h3>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Hunter Heidenreich, 2023</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Plots the coordinates of the adatom.</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> matplotlib.pyplot <span style="color:#66d9ef">as</span> plt
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> argparse <span style="color:#f92672">import</span> ArgumentParser
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">if</span> __name__ <span style="color:#f92672">==</span> <span style="color:#e6db74">&#39;__main__&#39;</span>:
</span></span><span style="display:flex;"><span>    parser <span style="color:#f92672">=</span> ArgumentParser()
</span></span><span style="display:flex;"><span>    parser<span style="color:#f92672">.</span>add_argument(<span style="color:#e6db74">&#39;--input&#39;</span>, type<span style="color:#f92672">=</span>str, required<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>    parser<span style="color:#f92672">.</span>add_argument(<span style="color:#e6db74">&#39;--output&#39;</span>, type<span style="color:#f92672">=</span>str, required<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>    parser<span style="color:#f92672">.</span>add_argument(<span style="color:#e6db74">&#39;--id&#39;</span>, type<span style="color:#f92672">=</span>int, default<span style="color:#f92672">=</span><span style="color:#ae81ff">1665</span>,
</span></span><span style="display:flex;"><span>                       help<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;Atom ID to track (the adatom is the last created atom)&#39;</span>)
</span></span><span style="display:flex;"><span>    parser<span style="color:#f92672">.</span>add_argument(<span style="color:#e6db74">&#39;--do_z&#39;</span>, action<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;store_true&#39;</span>,
</span></span><span style="display:flex;"><span>                       help<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;Plot z-coordinate instead of xy scatter&#39;</span>)
</span></span><span style="display:flex;"><span>    args <span style="color:#f92672">=</span> parser<span style="color:#f92672">.</span>parse_args()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    coords <span style="color:#f92672">=</span> {<span style="color:#e6db74">&#39;x&#39;</span>: [], <span style="color:#e6db74">&#39;y&#39;</span>: [], <span style="color:#e6db74">&#39;z&#39;</span>: []}
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">with</span> open(args<span style="color:#f92672">.</span>input, <span style="color:#e6db74">&#39;r&#39;</span>) <span style="color:#66d9ef">as</span> f:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">for</span> line <span style="color:#f92672">in</span> f:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">if</span> line<span style="color:#f92672">.</span>startswith(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;</span><span style="color:#e6db74">{</span>args<span style="color:#f92672">.</span>id<span style="color:#e6db74">}</span><span style="color:#e6db74"> &#39;</span>):
</span></span><span style="display:flex;"><span>                x, y, z <span style="color:#f92672">=</span> map(float, line<span style="color:#f92672">.</span>split()[<span style="color:#ae81ff">2</span>:<span style="color:#ae81ff">5</span>])
</span></span><span style="display:flex;"><span>                coords[<span style="color:#e6db74">&#39;x&#39;</span>]<span style="color:#f92672">.</span>append(x)
</span></span><span style="display:flex;"><span>                coords[<span style="color:#e6db74">&#39;y&#39;</span>]<span style="color:#f92672">.</span>append(y)
</span></span><span style="display:flex;"><span>                coords[<span style="color:#e6db74">&#39;z&#39;</span>]<span style="color:#f92672">.</span>append(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    plt<span style="color:#f92672">.</span>figure(figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">10</span>, <span style="color:#ae81ff">8</span>))
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> args<span style="color:#f92672">.</span>do_z:
</span></span><span style="display:flex;"><span>        plt<span style="color:#f92672">.</span>plot(range(len(coords[<span style="color:#e6db74">&#39;z&#39;</span>])), coords[<span style="color:#e6db74">&#39;z&#39;</span>], <span style="color:#e6db74">&#39;b-&#39;</span>, linewidth<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>        plt<span style="color:#f92672">.</span>xlabel(<span style="color:#e6db74">&#39;Simulation Step&#39;</span>)
</span></span><span style="display:flex;"><span>        plt<span style="color:#f92672">.</span>ylabel(<span style="color:#e6db74">&#39;Z Coordinate (Å)&#39;</span>)
</span></span><span style="display:flex;"><span>        plt<span style="color:#f92672">.</span>title(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;Height vs. Time for Adatom </span><span style="color:#e6db74">{</span>args<span style="color:#f92672">.</span>id<span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span><span style="display:flex;"><span>        plt<span style="color:#f92672">.</span>grid(<span style="color:#66d9ef">True</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.3</span>)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>        plt<span style="color:#f92672">.</span>scatter(coords[<span style="color:#e6db74">&#39;x&#39;</span>], coords[<span style="color:#e6db74">&#39;y&#39;</span>], s<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.7</span>, c<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;red&#39;</span>)
</span></span><span style="display:flex;"><span>        plt<span style="color:#f92672">.</span>xlabel(<span style="color:#e6db74">&#39;X Coordinate (Å)&#39;</span>)
</span></span><span style="display:flex;"><span>        plt<span style="color:#f92672">.</span>ylabel(<span style="color:#e6db74">&#39;Y Coordinate (Å)&#39;</span>)
</span></span><span style="display:flex;"><span>        plt<span style="color:#f92672">.</span>title(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;XY Trajectory for Adatom </span><span style="color:#e6db74">{</span>args<span style="color:#f92672">.</span>id<span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span><span style="display:flex;"><span>        plt<span style="color:#f92672">.</span>axis(<span style="color:#e6db74">&#39;equal&#39;</span>)
</span></span><span style="display:flex;"><span>        plt<span style="color:#f92672">.</span>grid(<span style="color:#66d9ef">True</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.3</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    plt<span style="color:#f92672">.</span>savefig(args<span style="color:#f92672">.</span>output, dpi<span style="color:#f92672">=</span><span style="color:#ae81ff">300</span>, bbox_inches<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;tight&#39;</span>)
</span></span></code></pre></div><h2 id="summary-and-next-steps">Summary and Next Steps</h2>
<p>This tutorial demonstrates how molecular dynamics generates valuable ML training data for materials science. Adatom diffusion provides an ideal starting point because it:</p>
<ul>
<li><strong>Has interpretable physics</strong>: Well-understood mechanisms enable ML validation</li>
<li><strong>Shows diverse behaviors</strong>: Temperature-dependent dynamics create rich datasets</li>
<li><strong>Scales efficiently</strong>: Small systems allow extensive parameter exploration</li>
<li><strong>Connects to applications</strong>: Direct relevance to catalysis and surface engineering</li>
</ul>
<h3 id="whats-next">What&rsquo;s Next</h3>
<p>Future posts will extend this framework:</p>
<ol>
<li><strong>Mixed-metal surfaces</strong>: Alloy effects on diffusion pathways</li>
<li><strong>Stepped surfaces</strong>: How defects alter atomic mobility</li>
<li><strong>ML implementation</strong>: Training neural networks on simulation data</li>
</ol>
<h3 id="broader-applications">Broader Applications</h3>
<p>These simulation techniques enable various ML applications:</p>
<ul>
<li><strong>Neural network potentials</strong>: Replacing expensive quantum calculations with trained models</li>
<li><strong>Rare event sampling</strong>: ML-enhanced diffusion pathway identification</li>
<li><strong>Catalyst design</strong>: Predicting surface modification effects on reactivity</li>
<li><strong>Materials discovery</strong>: Screening alloy compositions for desired properties</li>
</ul>
<h3 id="getting-started">Getting Started</h3>
<p>To reproduce these simulations:</p>
<ol>
<li>Install LAMMPS with EAM potential support</li>
<li>Download Cu01.eam.alloy from the <a href="https://www.ctcms.nist.gov/potentials/entry/2001--Mishin-Y-Mehl-M-J-Papaconstantopoulos-D-A-et-al--Cu-1/">NIST repository</a> and place in your working directory</li>
<li>Save the LAMMPS script as <code>adatom_cu.lammps</code> and run:
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bash" data-lang="bash"><span style="display:flex;"><span>lammps -in adatom_cu.lammps
</span></span></code></pre></div></li>
<li>Analyze the results with the Python scripts:
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bash" data-lang="bash"><span style="display:flex;"><span>python plot_energy.py --input energy_avg.txt --output energy.png --skip <span style="color:#ae81ff">30</span>
</span></span><span style="display:flex;"><span>python plot_trajectory.py --input dump.lammpstrj --output trajectory_xy.png
</span></span><span style="display:flex;"><span>python plot_trajectory.py --input dump.lammpstrj --output trajectory_z.png --do_z
</span></span></code></pre></div></li>
<li>Visualize in Ovito by opening <code>dump.lammpstrj</code></li>
<li>Experiment with different temperatures, orientations, or elements</li>
</ol>
<hr>
<p>The full project, including the simulation architecture and automated analysis pipeline, is documented on the <a href="/projects/lammps-adatom-diffusion/">Automated Adatom Diffusion Workflow project page</a>.</p>
<p><em>Questions about the simulation setup or interested in applying these techniques to your research? Feel free to reach out. I&rsquo;m always happy to discuss molecular dynamics and machine learning applications.</em></p>
<h2 id="references">References</h2>
<ul>
<li><a href="https://www.lammps.org/">LAMMPS</a></li>
<li><a href="https://www.ovito.org/">Ovito</a></li>
<li><a href="https://www.ctcms.nist.gov/potentials/">NIST Interatomic Potentials Repository</a></li>
<li><a href="https://doi.org/10.1103/PhysRevB.63.224106">Mishin et al.</a></li>
</ul>
]]></content:encoded></item><item><title>Generating Mini-Protein Trajectories with GROMACS</title><link>https://hunterheidenreich.com/posts/mini-proteins/</link><pubDate>Thu, 21 Sep 2023 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/mini-proteins/</guid><description>Systematic GROMACS workflows for simulating mini-proteins across multiple amino acids to generate diverse MD trajectories for ML applications.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>When developing machine learning models for protein dynamics, I needed training data, lots of it. Most researchers start with alanine dipeptide, a tiny two-amino-acid system that&rsquo;s become the &ldquo;hello world&rdquo; of protein simulation. It&rsquo;s small enough to simulate quickly but complex enough to show interesting folding behavior.</p>
<p>I wanted more diversity in my training data. Different amino acid side chains behave differently, and I was curious how this would affect model performance. So I extended the typical alanine dipeptide approach to include eight other amino acids, creating a small collection of &ldquo;mini-proteins&rdquo; for ML studies.</p>
<p>These dipeptides give a controlled testbed for studying how different chemical properties (aromatic rings, flexibility, branching) affect molecular dynamics, and for generating training data that varies those properties systematically.</p>
<h2 id="what-are-mini-proteins">What Are Mini-Proteins?</h2>
<p>In this context, &ldquo;mini-proteins&rdquo; are single amino acid residues capped with acetyl and N-methyl groups (Ace-X-Nme, where X is the amino acid). These systems act as the simplest possible models that still capture essential protein-like behavior.</p>
<p>These systems are popular in computational studies because they:</p>
<ul>
<li>Simulate quickly (seconds to minutes instead of hours)</li>
<li>Have well-characterized behavior for validation</li>
<li>Show enough complexity to be interesting</li>
<li>Can be systematically varied to study different chemical effects</li>
</ul>
<h2 id="getting-started">Getting Started</h2>
<p>The complete workflow and scripts are available on GitHub: <a href="https://github.com/hunter-heidenreich/mini-proteins/">mini-proteins</a>. The full project overview is on the <a href="/projects/mini-protein-trajectories/">Mini-Protein Trajectory Generation project page</a>.</p>
<h3 id="requirements">Requirements</h3>
<ul>
<li>Linux system with GROMACS installed</li>
<li>Python 3 with numpy and matplotlib</li>
<li>Basic familiarity with molecular dynamics concepts</li>
</ul>
<h3 id="quick-start">Quick Start</h3>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bash" data-lang="bash"><span style="display:flex;"><span>git clone https://github.com/hunter-heidenreich/mini-proteins.git
</span></span><span style="display:flex;"><span>cd mini-proteins
</span></span><span style="display:flex;"><span>ID<span style="color:#f92672">=</span>ala sh scripts/run.sh
</span></span></code></pre></div><p>This runs the complete pipeline: energy minimization, solvation, equilibration, and production simulation. The default settings generate 1 ns of trajectory data saved every 100 fs. I chose high temporal resolution for my ML models, but you can adjust this in <code>config/md_langevin.mdp</code>.</p>
<p>For longer production runs (recommended for most applications), increase the simulation time to ~100 ns and reduce the save frequency to manage file sizes.</p>
<h2 id="the-collection">The Collection</h2>
<p>I&rsquo;ve included nine different amino acid dipeptides, each with distinct chemical properties:</p>
<p><strong>Flexible systems</strong>: Glycine (smallest side chain), Alanine (methyl group)</p>
<p><strong>Branched systems</strong>: Valine, Isoleucine, Leucine (different branching patterns)</p>
<p><strong>Aromatic systems</strong>: Phenylalanine, Tryptophan (different ring structures)</p>
<p><strong>Special cases</strong>: Proline (ring constraint), Methionine (sulfur chemistry)</p>
<p>This systematic set allows studying how different chemical features affect dynamics:</p>
<ul>
<li>Does the flexibility of glycine lead to more diverse conformational sampling?</li>
<li>How do aromatic rings in tryptophan affect folding pathways?</li>
<li>Does the ring constraint in proline create different energy landscapes?</li>
</ul>
<p>These fundamental questions provide systematic data to test ML models against known chemical intuition, building confidence in the approach.</p>
<p>Ideally, a neural network trained on this dataset should learn physical <em>invariances</em>. By training on both aliphatic (Val, Leu, Ile) and aromatic (Phe, Trp) systems, the model learns to focus entirely on how electron density (π-systems vs. σ-bonds) influences local potential energy surfaces.</p>
<h3 id="generating-ml-ready-trajectory-data">Generating ML-Ready Trajectory Data</h3>
<p>Generating raw coordinates is easy; generating <strong>ML-ready data</strong> requires specific configurations. Standard MD simulations compress trajectory files to save space, discarding high-frequency velocity and force data. To train Neural Network Potentials (NNPs), I configured the GROMACS pipeline differently.</p>
<p>The fastest way to generate trajectory data is using the <code>run.sh</code> script:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bash" data-lang="bash"><span style="display:flex;"><span>ID<span style="color:#f92672">=</span>ala sh scripts/run.sh
</span></span></code></pre></div><p>where <code>ID</code> is the three-letter amino acid code (here, <code>ala</code> for alanine).</p>
<p>This script performs energy minimization, solvation, neutralization, NVT equilibration, NPT equilibration, and production simulation. The resulting trajectory saves to the <code>out/ID/data</code> directory.</p>
<h4 id="why-this-pipeline-differs-from-standard-tutorials">Why This Pipeline Differs from Standard Tutorials</h4>
<p>A key deviation from standard tutorials is the use of <strong>Stochastic Dynamics (Langevin)</strong> as the integrator. This adds friction and noise terms to the equations of motion, ensuring correct thermodynamic sampling:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-ini" data-lang="ini"><span style="display:flex;"><span><span style="color:#75715e">; config/md_langevin.mdp</span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">integrator</span>  <span style="color:#f92672">=</span> <span style="color:#e6db74">sd        ; Stochastic dynamics (Langevin)</span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">dt</span>          <span style="color:#f92672">=</span> <span style="color:#e6db74">0.001     ; 1 fs timestep</span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">nstxout</span>     <span style="color:#f92672">=</span> <span style="color:#e6db74">100       ; Save coordinates every 100 steps</span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">nstvout</span>     <span style="color:#f92672">=</span> <span style="color:#e6db74">100       ; Save velocities every 100 steps</span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">nstfout</span>     <span style="color:#f92672">=</span> <span style="color:#e6db74">100       ; Save forces every 100 steps</span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">tc-grps</span>     <span style="color:#f92672">=</span> <span style="color:#e6db74">Protein Non-Protein</span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">tau_t</span>       <span style="color:#f92672">=</span> <span style="color:#e6db74">0.1  0.1  ; Friction constant (ps)</span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">ref_t</span>       <span style="color:#f92672">=</span> <span style="color:#e6db74">298  298  ; Reference temperature (K)</span>
</span></span></code></pre></div><p>The critical settings for ML applications:</p>
<ol>
<li><strong>Langevin Dynamics (<code>sd</code>)</strong>: Ensures proper canonical (NVT) sampling, providing a robust alternative to the velocity-rescaling thermostat often used in tutorials</li>
<li><strong>Uncompressed Force Output (<code>nstfout = 100</code>)</strong>: Writing to <code>.trr</code> format captures the precise atomic forces acting on every atom, essential for force-matching in NNP training</li>
<li><strong>High-Frequency Sampling (0.1 ps)</strong>: Saving frames every 100 fs captures fast bond vibrations often missed in standard 10 ps snapshots</li>
</ol>
<p><strong>Note</strong>: A production simulation currently runs for 1 nanosecond, saved every 0.1 picoseconds (100 fs). For most applications, increase this to 100 nanoseconds and adjust the save frequency to avoid large data files. I targeted 100 fs because I needed correlated time data for ML models; other applications may require a lower frequency.</p>
<p>You can also run each step individually (see <code>scripts/run.sh</code> for examples).</p>
<h2 id="the-systems">The Systems</h2>
<p>Here are the nine amino acid dipeptides I&rsquo;ve included, each chosen for different chemical properties:</p>
<h3 id="alanine-dipeptide-the-standard">Alanine Dipeptide: The Standard</h3>















<figure class="post-figure center ">
    <img src="/img/alanine-dipeptide-molecular-dynamics.webp"
         alt="Alanine dipeptide molecular dynamics simulation animation"
         title="Alanine dipeptide molecular dynamics simulation animation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Alanine Dipeptide</figcaption>
    
</figure>

<p>The classic starting point for protein folding studies. The small methyl side chain provides a simple yet challenging system.</p>
<h3 id="glycine-dipeptide-maximum-flexibility">Glycine Dipeptide: Maximum Flexibility</h3>















<figure class="post-figure center ">
    <img src="/img/glycine-dipeptide-molecular-dynamics.webp"
         alt="Glycine dipeptide molecular dynamics simulation animation"
         title="Glycine dipeptide molecular dynamics simulation animation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Glycine Dipeptide</figcaption>
    
</figure>

<p>No side chain means maximum backbone flexibility. Great for studying how constraints affect conformational sampling.</p>
<h3 id="proline-dipeptide-built-in-rigidity">Proline Dipeptide: Built-in Rigidity</h3>















<figure class="post-figure center ">
    <img src="/img/proline-dipeptide-molecular-dynamics.webp"
         alt="Proline dipeptide molecular dynamics simulation animation"
         title="Proline dipeptide molecular dynamics simulation animation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Proline Dipeptide</figcaption>
    
</figure>

<p>The ring structure creates backbone constraints. Interesting comparison to glycine&rsquo;s flexibility.</p>
<h3 id="aromatic-systems">Aromatic Systems</h3>















<figure class="post-figure center ">
    <img src="/img/phenylalanine-dipeptide-molecular-dynamics.webp"
         alt="Phenylalanine dipeptide molecular dynamics simulation animation"
         title="Phenylalanine dipeptide molecular dynamics simulation animation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Phenylalanine Dipeptide</figcaption>
    
</figure>

<p><strong>Phenylalanine</strong>: Simple benzene ring for studying aromatic interactions.</p>















<figure class="post-figure center ">
    <img src="/img/tryptophan-dipeptide-molecular-dynamics.webp"
         alt="Tryptophan dipeptide molecular dynamics simulation animation"
         title="Tryptophan dipeptide molecular dynamics simulation animation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Tryptophan Dipeptide</figcaption>
    
</figure>

<p><strong>Tryptophan</strong>: Larger indole ring system with more complex aromatic chemistry.</p>
<h3 id="branched-aliphatic-systems">Branched Aliphatic Systems</h3>















<figure class="post-figure center ">
    <img src="/img/valine-dipeptide-molecular-dynamics.webp"
         alt="Valine dipeptide molecular dynamics simulation animation"
         title="Valine dipeptide molecular dynamics simulation animation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Valine Dipeptide</figcaption>
    
</figure>

<p><strong>Valine</strong>: β-branched, creates steric constraints near the backbone.</p>















<figure class="post-figure center ">
    <img src="/img/isoleucine-dipeptide-molecular-dynamics.webp"
         alt="Isoleucine dipeptide molecular dynamics simulation animation"
         title="Isoleucine dipeptide molecular dynamics simulation animation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Isoleucine Dipeptide</figcaption>
    
</figure>

<p><strong>Isoleucine</strong>: γ-branched, different steric profile than valine.</p>















<figure class="post-figure center ">
    <img src="/img/leucine-dipeptide-molecular-dynamics.webp"
         alt="Leucine dipeptide molecular dynamics simulation animation"
         title="Leucine dipeptide molecular dynamics simulation animation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Leucine Dipeptide</figcaption>
    
</figure>

<p><strong>Leucine</strong>: Longer branched chain with more conformational freedom.</p>
<h3 id="special-chemistry">Special Chemistry</h3>















<figure class="post-figure center ">
    <img src="/img/methionine-dipeptide-molecular-dynamics.webp"
         alt="Methionine dipeptide molecular dynamics simulation animation"
         title="Methionine dipeptide molecular dynamics simulation animation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Methionine Dipeptide</figcaption>
    
</figure>

<p><strong>Methionine</strong>: Sulfur chemistry, different from the others and interesting for studying heteroatom effects.</p>
<h2 id="whats-next">What&rsquo;s Next?</h2>
<p>These mini-protein simulations have been useful for my ML work, providing systematic training data with controlled chemical variation. These simple systems have helped me understand how different amino acid properties affect molecular behavior, knowledge that&rsquo;s valuable when working with larger, more complex proteins.</p>
<p>The primary value of this pipeline lies in the <strong>force extraction</strong> workflow. Having atomic forces alongside coordinates enables training NNPs via force matching; force information is a richer training signal than energies alone. Tools like <a href="https://github.com/torchmd/torchmd-net">TorchMD-Net</a>, <a href="https://github.com/mir-group/nequip">NequIP</a>, and <a href="https://github.com/ACEsuit/mace">MACE</a> can directly consume this data format.</p>
<p>The scripts are designed to be easily modified for different amino acids or simulation conditions. I&rsquo;ve tried to make the workflow straightforward while keeping it flexible.</p>
<p>This work complements my other molecular dynamics projects:</p>
<ul>
<li><a href="/posts/adatom-cu-diffusion/">LAMMPS Tutorial: Copper and Platinum Adatom Diffusion</a>: Learning LAMMPS for surface simulations and extending to different elements</li>
</ul>
<p>Together, these projects have given me a solid foundation in MD simulations for generating ML training data across different molecular systems.</p>
<hr>
<p><em>Find the complete code and documentation on <a href="https://github.com/hunter-heidenreich/mini-proteins">GitHub</a>. Questions or suggestions? I&rsquo;d love to hear from you, especially if you&rsquo;ve found interesting ways to extend or improve the approach.</em></p>
<h2 id="acknowledgements">Acknowledgements</h2>
<p>The scripts build on the <a href="https://cbp-unitn.gitlab.io/qcb22-23/QCB/tutorial2_gromacs">GROMACS tutorial</a> by Luca Tubiana at the University of Trento.</p>
]]></content:encoded></item><item><title>5 Axes of Multi-Arm Bandit Problems: A Practical Guide</title><link>https://hunterheidenreich.com/posts/a-roadmap-to-multi-arm-bandit-algorithms/</link><pubDate>Tue, 10 Nov 2020 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/a-roadmap-to-multi-arm-bandit-algorithms/</guid><description>Explore 5 key dimensions of multi-arm bandit problems to help practitioners better navigate the exploration-exploitation tradeoff in ML applications.</description><content:encoded><![CDATA[<h2 id="what-is-a-multi-arm-bandit-problem">What is a Multi-Arm Bandit Problem?</h2>
<p>Multi-arm bandit problems are a fundamental class of sequential decision-making problems in machine learning. They&rsquo;re less complex than a full reinforcement learning problem, but they capture a lot of the essential challenges of learning from interaction.</p>
<p>The name comes from the analogy of a gambler facing multiple slot machines (or &ldquo;one-armed bandits&rdquo;) and trying to figure out which one pays out the most over time. Do you keep playing the machine that&rsquo;s paid out well so far, or try others to see if they&rsquo;re better? This is the exploration-exploitation dilemma at the heart of bandit algorithms.</p>
<p>Bandit algorithms solve this by learning the reward distributions for each arm over time, balancing the need to explore new options with exploiting what they&rsquo;ve learned.</p>















<figure class="post-figure center ">
    <img src="/img/multi-arm-bandits/multi-arm-bandit-conceptual-graphic.webp"
         alt="Illustration of a vintage slot machine with multiple arms, representing the multi-arm bandit problem in machine learning where algorithms must balance exploration and exploitation"
         title="Illustration of a vintage slot machine with multiple arms, representing the multi-arm bandit problem in machine learning where algorithms must balance exploration and exploitation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Multi-arm bandit algorithms balance exploration and exploitation</figcaption>
    
</figure>

<p>What I&rsquo;ve found helpful is thinking about any bandit problem along five key dimensions. Asking these five questions helps quickly identify which approaches work best for a given problem:</p>
<div style="display: flex; flex-direction: column; gap: 16px; max-width: 700px; margin: 30px auto; font-family: system-ui, -apple-system, sans-serif;">
  <div style="background-color: #5A9BD5; color: white; padding: 20px 24px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.05);">
    <h3 style="margin: 0 0 8px 0; font-size: 1.25rem; font-weight: 700; color: white;">1. Action Space</h3>
    <p style="margin: 0; font-size: 1.05rem; line-height: 1.5;"><em>What does the problem action space look like?</em> <br>Consider whether your options are finite vs. infinite, or single vs. combinatorial.</p>
  </div>
  <div style="background-color: #55C3BA; color: white; padding: 20px 24px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.05);">
    <h3 style="margin: 0 0 8px 0; font-size: 1.25rem; font-weight: 700; color: white;">2. Problem Structure</h3>
    <p style="margin: 0; font-size: 1.05rem; line-height: 1.5;"><em>Is there any structure to the problem?</em> <br>Determine whether choosing certain actions provides information about the expected rewards of other actions.</p>
  </div>
  <div style="background-color: #5CC382; color: white; padding: 20px 24px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.05);">
    <h3 style="margin: 0 0 8px 0; font-size: 1.25rem; font-weight: 700; color: white;">3. External Information</h3>
    <p style="margin: 0; font-size: 1.05rem; line-height: 1.5;"><em>Is there external information that my learner has access to?</em> <br>Assess the availability of contextual information (like user data or environment state) before an action is chosen.</p>
  </div>
  <div style="background-color: #56B14E; color: white; padding: 20px 24px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.05);">
    <h3 style="margin: 0 0 8px 0; font-size: 1.25rem; font-weight: 700; color: white;">4. Reward Mechanism</h3>
    <p style="margin: 0; font-size: 1.05rem; line-height: 1.5;"><em>How are rewards generated?</em> <br>Identify if the environment's payouts are stochastic (random but consistent), non-stationary (changing over time), or adversarial.</p>
  </div>
  <div style="background-color: #81B653; color: white; padding: 20px 24px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.05);">
    <h3 style="margin: 0 0 8px 0; font-size: 1.25rem; font-weight: 700; color: white;">5. Learner Feedback</h3>
    <p style="margin: 0; font-size: 1.05rem; line-height: 1.5;"><em>What kind of feedback does my learner receive each round?</em> <br>Clarify if the feedback loop is strict bandit (only seeing the reward for the chosen arm), full, partial, or semi-bandit feedback.</p>
  </div>
</div>
<p>Jump to <strong><a href="#real-examples">Real Examples</a></strong> to see these dimensions in practice.</p>
<hr>
<h2 id="1-what-can-you-do-action-space">1. What Can You Do? (Action Space)</h2>
<p>The first question is simple: what options do you have? Action spaces are generally categorized along two dimensions: size and complexity.</p>
<p><strong>Finite vs. Infinite (Size)</strong></p>
<ul>
<li><strong>Finite (Discretized):</strong> Most people think of this scenario where the agent chooses between a small number of discrete options. Examples include selecting an arm in a standard 3-arm bandit or testing 5 different website layouts.</li>
<li><strong>Infinite (Continuous):</strong> Sometimes your action has continuous parameters. For example, selecting a bid price anywhere in the range of (0, 1), or adjusting a recommendation algorithm&rsquo;s temperature parameter. This requires different mathematical approaches.</li>
</ul>
<p><strong>Single vs. Combinatorial (Complexity)</strong></p>
<ul>
<li><strong>Single Action:</strong> Selection of exactly 1 action per round (e.g., showing a user <em>one</em> specific ad).</li>
<li><strong>Combinatorial Actions:</strong> Selection of a vector (multiple) of actions simultaneously. For example, selecting a subset of edges in a graph to form a path from node $t$ to node $s$. If you are picking 10 movies to populate a Netflix homepage at once, those selections might interact with one another.</li>
</ul>
<h2 id="2-do-your-choices-tell-you-about-other-choices-problem-structure">2. Do Your Choices Tell You About Other Choices? (Problem Structure)</h2>
<p>This is often overlooked but crucial: does trying option A teach you anything about option B?</p>
<ul>
<li><strong>Independent (Unstructured):</strong> Information gained from one action provides <em>zero insight</em> into the expected reward of other actions. For example, knowing the open rate of an email campaign tells you nothing about the click-through rate of a separate social media post.</li>
<li><strong>Correlated (Structured):</strong> Information gained from one action provides <em>valuable hints</em> about similar actions. If you test a 10% discount and see high conversions, it strongly implies a 15% discount will also perform well, helping you map the underlying demand curve.</li>
</ul>
<h2 id="3-what-extra-information-do-you-have-context">3. What Extra Information Do You Have? (Context)</h2>
<p>Real-world problems rarely happen in isolation. The question is: what additional predictive information might help you make better decisions <em>before</em> you pull the lever?</p>
<p>When you incorporate these state variables, you move from a Standard Bandit to a <strong>Contextual Bandit</strong>. This data usually falls into two buckets:</p>
<ul>
<li><strong>User Context:</strong> State information tied directly to the individual, such as age, location, or past purchase history.</li>
<li><strong>Environmental Context:</strong> State information tied to the surrounding conditions at the exact moment of decision, such as time of day, seasonality, or device type (mobile vs. desktop).</li>
</ul>
<h2 id="4-how-do-rewards-work-reward-mechanism">4. How Do Rewards Work? (Reward Mechanism)</h2>
<p>This dimension is about understanding the nature of the feedback you&rsquo;re getting. Are the rewards predictable, changing over time, or actively working against you?</p>
<blockquote>
<p><strong>Stochastic (Stable but Random)</strong>
Each action corresponds to an IID (Independent and Identically Distributed) reward. The underlying mean rewards do not shift significantly over time.
<em>Example: Clinical trials. A drug has a true underlying effectiveness rate, but individual patients respond with random variation around that mean.</em></p></blockquote>
<blockquote>
<p><strong>Non-Stationary (Changing Over Time)</strong>
Reward distributions <em>do</em> shift over time, usually following some underlying rule. This is a realistic relaxation of the stochastic model, but comes at a learning cost.
<em>Example: Stock trading or ad performance. What worked last month might not work today.</em></p></blockquote>
<blockquote>
<p><strong>Adversarial (Actively Working Against You)</strong>
An adversary selects the worst-case rewards for your options, often with full knowledge of your learner&rsquo;s policy. Randomization is key to remaining unpredictable.
<em>Example: Cybersecurity defenses. Attackers actively adapt their strategies to exploit your algorithm.</em></p></blockquote>
<h2 id="5-how-much-do-you-learn-from-each-decision-feedback">5. How Much Do You Learn From Each Decision? (Feedback)</h2>
<p>The last dimension is about information flow. How much do you learn each time you make a choice?</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Feedback Type</th>
          <th style="text-align: left">What You Learn</th>
          <th style="text-align: left">Real-World Example</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Bandit</strong></td>
          <td style="text-align: left">You only observe the reward for the specific action selected. You have no knowledge of what could have been gained from other options.</td>
          <td style="text-align: left">A literal slot machine payout.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Semi-Bandit</strong></td>
          <td style="text-align: left">Common in combinatorial settings. You see the individual rewards associated with each <em>sub-action</em> you took.</td>
          <td style="text-align: left">Learning exactly which specific edges of a graph &ldquo;dropped&rdquo; or succeeded in a path-finding algorithm.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Full</strong></td>
          <td style="text-align: left">You see all reward signals for <em>every</em> action, including the ones you didn&rsquo;t take.</td>
          <td style="text-align: left">Analyzing historical stock market data where you can see all alternative outcomes.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Partial Monitoring</strong></td>
          <td style="text-align: left">Feedback is <em>not</em> received every round. You are occasionally flying blind.</td>
          <td style="text-align: left"><em>(Note: Because the core feedback loop is broken, this is generally not considered a true bandit problem!)</em></td>
      </tr>
  </tbody>
</table>
<h2 id="real-examples">Real Examples</h2>
<p>If you are trying to figure out which bandit algorithm to use for your project, it helps to map out your problem first. Here is how these five dimensions play out in three common real-world scenarios:</p>
<h3 id="1-e-commerce-recommendations">1. E-commerce Recommendations</h3>















<figure class="post-figure center ">
    <img src="/img/multi-arm-bandits/multi-arm-bandit-for-ecommerce.webp"
         alt="Illustration of a multi-arm bandit algorithm applied to e-commerce recommendations, showing a user interacting with a carousel of product recommendations on a website"
         title="Illustration of a multi-arm bandit algorithm applied to e-commerce recommendations, showing a user interacting with a carousel of product recommendations on a website"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The algorithm must populate an entire &lsquo;Recommended for You&rsquo; carousel with multiple items simultaneously.</figcaption>
    
</figure>

<ul>
<li><strong>Action Space:</strong> Combinatorial <em>(Pick multiple products to show at once)</em></li>
<li><strong>Problem Structure:</strong> Structured <em>(Similar products perform similarly)</em></li>
<li><strong>Context:</strong> Available <em>(User history, time of day, device type)</em></li>
<li><strong>Reward Mechanism:</strong> Stochastic <em>(Relatively stable underlying user preferences)</em></li>
<li><strong>Feedback:</strong> Bandit <em>(You only see clicks on the specific items you showed)</em></li>
</ul>
<h3 id="2-online-ad-bidding">2. Online Ad Bidding</h3>















<figure class="post-figure center ">
    <img src="/img/multi-arm-bandits/multi-arm-bandit-for-ad-bidding.webp"
         alt="Illustration of a multi-arm bandit algorithm applied to online ad bidding, showing a marketer adjusting bid prices in real-time auctions for ad placements"
         title="Illustration of a multi-arm bandit algorithm applied to online ad bidding, showing a marketer adjusting bid prices in real-time auctions for ad placements"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The algorithm must learn the optimal price to bid in real-time auctions.</figcaption>
    
</figure>

<ul>
<li><strong>Action Space:</strong> Infinite / Continuous <em>(Bid any amount from 0.01 to 10.00)</em></li>
<li><strong>Problem Structure:</strong> Structured <em>(Similar bid prices yield similar win rates)</em></li>
<li><strong>Context:</strong> Available <em>(User demographics, search terms, ad relevance)</em></li>
<li><strong>Reward Mechanism:</strong> Non-stationary <em>(Market conditions and competitor budgets change constantly)</em></li>
<li><strong>Feedback:</strong> Bandit <em>(You only see the results from your winning bids)</em></li>
</ul>
<h3 id="3-content-personalization">3. Content Personalization</h3>
<p><em>A media site dynamically selects articles or videos based on trending topics and user habits to create a personalized homepage.</em></p>
<ul>
<li><strong>Action Space:</strong> Combinatorial <em>(Select a layout of multiple articles/videos)</em></li>
<li><strong>Problem Structure:</strong> Structured <em>(Content categories and tags have predictable patterns)</em></li>
<li><strong>Context:</strong> Available <em>(User profile, current geographic trends, browsing history)</em></li>
<li><strong>Reward Mechanism:</strong> Non-stationary <em>(User interests and news cycles evolve over time)</em></li>
<li><strong>Feedback:</strong> Partial / Semi-Bandit <em>(You see the performance of the individual content pieces within the selected layout)</em></li>
</ul>
<h2 id="the-algorithm-selection-cheat-sheet">The Algorithm Selection Cheat Sheet</h2>
<p>If you prefer a text breakdown, here is how those steps translate into algorithm choices:</p>
<p><strong>Step 1: What are my options? (Action Space)</strong></p>
<ul>
<li><strong>Few discrete choices?</strong> $\rightarrow$ Start with standard <strong>UCB</strong> (Upper Confidence Bound) or <strong>Thompson Sampling</strong>.</li>
<li><strong>Continuous parameters?</strong> $\rightarrow$ Look into <strong>Gaussian Process</strong> methods.</li>
</ul>
<p><strong>Step 2: Do my choices relate to each other? (Structure)</strong></p>
<ul>
<li><strong>Yes?</strong> $\rightarrow$ Use <strong>Linear Bandits</strong> or kernel methods to share information across arms.</li>
<li><strong>No?</strong> $\rightarrow$ Treat each option completely independently.</li>
</ul>
<p><strong>Step 3: What extra information do I have? (Context)</strong></p>
<ul>
<li><strong>Rich context available?</strong> $\rightarrow$ Use contextual bandits (LinUCB is the standard choice here).</li>
<li><strong>No context?</strong> $\rightarrow$ Stick with standard, context-free approaches.</li>
</ul>
<p><strong>Step 4: How do rewards behave? (Mechanism)</strong></p>
<ul>
<li><strong>Stable?</strong> $\rightarrow$ <strong>UCB</strong> and <strong>Thompson Sampling</strong> are your go-to choices.</li>
<li><strong>Changing over time?</strong> $\rightarrow$ Add forgetting factors (discounting) or use sliding-window variants of standard algorithms.</li>
<li><strong>Actively working against me?</strong> $\rightarrow$ You need adversarial approaches, most notably the <strong>EXP3</strong> algorithm.</li>
</ul>
<p><strong>Step 5: How much do I learn each time? (Feedback)</strong></p>
<ul>
<li><strong>Just my choice?</strong> $\rightarrow$ Standard bandit algorithms.</li>
<li><strong>Everything?</strong> $\rightarrow$ Online gradient descent or multiplicative weights.</li>
<li><strong>Something in between?</strong> $\rightarrow$ Look for specialized algorithms that exploit your specific combinatorial structure.</li>
</ul>
<p>This framework keeps the focus exactly where it needs to be: on what actually matters for the problem at hand.</p>
<h2 id="summary">Summary</h2>
<p>These five questions have helped me navigate bandit problems more systematically:</p>
<ol>
<li><strong>What can you do?</strong> (Action space: few vs many options, single vs multiple choices)</li>
<li><strong>Do choices relate?</strong> (Whether trying one option teaches you about others)</li>
<li><strong>What extra info do you have?</strong> (Context that might improve decisions)</li>
<li><strong>How do rewards work?</strong> (Stable, changing, or adversarial)</li>
<li><strong>How much do you learn?</strong> (Feedback from just your choice vs everything)</li>
</ol>
<p>This framework helps avoid getting overwhelmed by the academic literature and focuses attention on what matters for real problems. Structured problems let you learn faster by sharing information between options, while adversarial settings require completely different approaches.</p>
<h2 id="if-you-want-to-learn-more">If You Want to Learn More</h2>
<ul>
<li><strong><a href="https://tor-lattimore.com/downloads/book/book.pdf">Bandit Algorithms by Lattimore &amp; Szepesvári</a></strong> - The comprehensive textbook (free PDF)</li>
<li><strong><a href="https://arxiv.org/abs/1904.07272">Introduction to Multi-Armed Bandits</a></strong> - Aleksandrs Slivkins&rsquo; survey paper is more accessible</li>
</ul>
<h2 id="get-your-hands-dirty">Get Your Hands Dirty</h2>
<p>Want to see these dimensions in code? I&rsquo;ve implemented the foundational algorithms (Explore-Then-Commit and Follow-The-Leader) in Python. Check out the <strong><a href="https://github.com/hunter-heidenreich/Bandit-Algorithms">Bandit Algorithms Repository</a></strong> to run the simulations yourself.</p>
]]></content:encoded></item><item><title>Breaking Down Machine Learning for the Average Person</title><link>https://hunterheidenreich.com/posts/breaking-down-ml-for-the-average-person/</link><pubDate>Tue, 04 Dec 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/breaking-down-ml-for-the-average-person/</guid><description>Discover how machine learning actually works through three fundamental approaches, explained with everyday examples you already know and use.</description><content:encoded><![CDATA[<h2 id="machine-learning">Machine Learning</h2>
<p>Machine learning is about teaching computer programs to improve at tasks through experience. We show algorithms examples and let them discover patterns in data.</p>
<p>There are three main approaches to machine learning: supervised learning, unsupervised learning, and reinforcement learning. Each works differently and suits different types of problems.</p>















<figure class="post-figure center ">
    <img src="/img/types-of-machine-learning.webp"
         alt="Diagram showing the three main types of machine learning: supervised, unsupervised, and reinforcement learning"
         title="Diagram showing the three main types of machine learning: supervised, unsupervised, and reinforcement learning"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Three fundamental approaches to machine learning, each suited to different types of problems and data</figcaption>
    
</figure>

<p>Each type addresses different kinds of problems and works with different data requirements.</p>
<h3 id="supervised-learning">Supervised Learning</h3>
<p>Supervised learning works like teaching with examples and answers. You show the algorithm many input-output pairs, and it learns to predict outputs for new inputs it hasn&rsquo;t seen before.</p>
<p>The algorithm learns by comparing its predictions to the correct answers. Over time, it gets better at finding patterns that connect inputs to outputs. Once trained, it can make predictions on new data.</p>
<p>Common examples you encounter:</p>
<ul>
<li><strong>Email Spam Filtering</strong>: Email systems learn to identify spam by training on thousands of emails labeled as spam or legitimate.</li>
<li><strong>Advertisement Targeting</strong>: Algorithms predict which ads you might click based on your browsing history and demographics.</li>
<li><strong>Face Recognition</strong>: Social media platforms use tagged photos to learn who appears in new images.</li>
</ul>
<h3 id="unsupervised-learning">Unsupervised Learning</h3>
<p>Unsupervised learning works without correct answers. Instead, algorithms analyze data to find patterns, group similar items, or discover structure that wasn&rsquo;t obvious before.</p>
<p>This approach is useful because most real-world data doesn&rsquo;t come with labels. Unsupervised algorithms can process large amounts of data to find patterns that might not be obvious to humans.</p>
<p>Examples include:</p>
<ul>
<li><strong>Recommendation Systems</strong>: Netflix and YouTube analyze viewing patterns to suggest content, even without explicit ratings.</li>
<li><strong>Customer Segmentation</strong>: Companies group customers by purchasing behavior for targeted marketing.</li>
<li><strong>Problem Identification</strong>: Tech companies automatically group similar bug reports to identify common issues.</li>
</ul>
<h3 id="reinforcement-learning">Reinforcement Learning</h3>
<p>Reinforcement learning works through trial and error. An algorithm tries different actions in an environment and learns from the consequences, getting rewards for good choices and penalties for poor ones.</p>
<p>This mirrors how many animals learn: through consequences. Good behavior gets rewards, bad behavior gets correction.</p>
<p>Consider an algorithm learning to play Mario:</p>
<ul>
<li><strong>Agent</strong>: The learning algorithm</li>
<li><strong>Environment</strong>: The game world</li>
<li><strong>Actions</strong>: Controller inputs (jump, run, etc.)</li>
<li><strong>State</strong>: Current game screen</li>
<li><strong>Reward</strong>: Points gained or lost</li>
</ul>
<p>The algorithm tries different button combinations, sees what happens, and gradually learns strategies that lead to higher scores.</p>
<p>Real applications include:</p>
<ul>
<li><strong>Game AI</strong>: AlphaGo and similar systems learned to play complex games through self-play.</li>
<li><strong>Robotics</strong>: Factory robots learn assembly processes through trial and error in simulated environments.</li>
<li><strong>Resource Management</strong>: Google uses reinforcement learning to manage data center cooling, reducing energy costs.</li>
</ul>
<h3 id="putting-it-together">Putting It Together</h3>
<p>In practice, these approaches often work together. Many real systems combine different learning methods depending on the problem and available data.</p>
<p>For example:</p>
<ul>
<li>A game AI might use supervised learning to recognize objects and reinforcement learning for strategy</li>
<li>A language model might learn word relationships without supervision, then improve with supervised training</li>
<li>A recommendation system could group users without labels, then use supervised learning to predict preferences</li>
</ul>
<p>These three approaches cover most machine learning applications. Understanding them helps explain how the AI systems we use daily actually work, whether it&rsquo;s email filters, recommendation engines, or game-playing algorithms.</p>
<p>Machine learning is pattern recognition and learning from data, not magic. The more we understand these basics, the better we can work with and build these systems.</p>
]]></content:encoded></item><item><title>Understanding GANs: From Fundamentals to Objective Functions</title><link>https://hunterheidenreich.com/posts/what-is-a-gan/</link><pubDate>Sat, 18 Aug 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/what-is-a-gan/</guid><description>A complete guide to Generative Adversarial Networks (GANs), covering intuitive explanations, mathematical foundations, and objective functions.</description><content:encoded><![CDATA[<h2 id="understanding-generative-models">Understanding Generative Models</h2>
<p>Modern generative AI is dominated by diffusion models and autoregressive transformers. The adversarial training dynamics and objective functions introduced by <a href="https://arxiv.org/abs/1406.2661">Generative Adversarial Networks</a> (GANs) still inform how the field thinks about loss-function design and training stability today. Before diving into GANs, let&rsquo;s establish what we&rsquo;re trying to accomplish with generative models.</p>
<p><strong>The core goal</strong>: Create a system that can generate new, realistic data that appears to come from the same distribution as our training data.</p>
<p>Think of having a model that can create images, text, or audio that are difficult to distinguish from human-created content. This is what generative modeling aims to achieve.</p>
<h3 id="the-mathematical-foundation">The Mathematical Foundation</h3>
<p>Generative models aim to estimate the probability distribution of real data. If we have parameters $\theta$, we want to find the optimal $\theta^*$ that maximizes the likelihood of observing our real samples:</p>
<p>$$
\theta^* = \arg\max_\theta \prod_{i=1}^{n} p_\theta(x_i)
$$</p>
<p>This is equivalent to minimizing the distance between our estimated distribution and the true data distribution. A common distance measure is the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">Kullback-Leibler Divergence</a>. Maximizing log-likelihood equals minimizing KL divergence.</p>
<h3 id="two-approaches-to-generative-modeling">Two Approaches to Generative Modeling</h3>
<h4 id="explicit-distribution-models">Explicit Distribution Models</h4>
<p>These models define an explicit probability distribution and refine it through training.</p>
<p><strong>Example</strong>: <a href="https://arxiv.org/abs/1606.05908">Variational Auto-Encoders</a> (VAEs) require:</p>
<ul>
<li>An explicitly assumed prior distribution</li>
<li>A likelihood distribution</li>
<li>A &ldquo;variational approximation&rdquo; to evaluate performance</li>
</ul>
<h4 id="implicit-distribution-models">Implicit Distribution Models</h4>
<p>These models learn to generate data by indirectly sampling from a learned distribution. GANs exemplify this implicit approach, learning distributions through adversarial competition.</p>















<figure class="post-figure center ">
    <img src="/img/gen_ai_types.webp"
         alt="Types of deep generative models showing taxonomy"
         title="Types of deep generative models showing taxonomy"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>Taxonomy of Deep Generative Models</strong>: GANs fall into the implicit density category, learning distributions through adversarial training. <em>Source: NeurIPS 2016 tutorial on Generative Adversarial Networks</em></figcaption>
    
</figure>

<h2 id="the-gan-architecture-a-game-of-deception">The GAN Architecture: A Game of Deception</h2>
<p>Generative Adversarial Networks get their name from three key components:</p>
<ul>
<li><strong>Generative</strong>: They create new data</li>
<li><strong>Adversarial</strong>: Two networks compete against each other</li>
<li><strong>Networks</strong>: Built using neural networks</li>
</ul>
<p>The core innovation is the adversarial setup: two neural networks compete against each other, driving mutual improvement.</p>















<figure class="post-figure center ">
    <img src="/img/GAN-70.webp"
         alt="Diagram showing data flow through a GAN architecture"
         title="Diagram showing data flow through a GAN architecture"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>GAN Data Flow</strong>: The generator creates fake samples from random noise, while the discriminator tries to distinguish real from fake data. This adversarial competition drives both networks to improve.</figcaption>
    
</figure>

<h3 id="the-generator-the-forger">The Generator: The Forger</h3>
<p><strong>Role</strong>: Create convincing fake data from random noise</p>
<p>The generator network $G$ learns a mapping function:
$$z \rightarrow G(z) \approx x_{\text{real}}$$</p>
<p>Where:</p>
<ul>
<li>$z$ is a random latent vector (the &ldquo;noise&rdquo;)</li>
<li>$G(z)$ is the generated sample</li>
<li>The goal is making $G(z)$ indistinguishable from real data</li>
</ul>
<p><strong>Key insight</strong>: The latent space $z$ is continuous, meaning small changes in $z$ produce smooth, meaningful changes in the generated output.</p>
<h3 id="the-discriminator-the-detective">The Discriminator: The Detective</h3>
<p><strong>Role</strong>: Distinguish between real and generated samples</p>
<p>The discriminator network $D$ outputs a probability:
$$D(x) = P(\text{x is real})$$</p>
<ul>
<li>$D(x) \approx 1$ for real samples</li>
<li>$D(x) \approx 0$ for fake samples</li>
</ul>
<p>It functions as an &ldquo;authenticity detector&rdquo; that progressively improves.</p>
<h3 id="the-adversarial-competition">The Adversarial Competition</h3>
<p>This adversarial dynamic drives the training process. The generator and discriminator have <strong>directly opposing objectives</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Generator Goal</th>
          <th>Discriminator Goal</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Fool the discriminator</td>
          <td>Correctly classify all samples</td>
      </tr>
      <tr>
          <td>Minimize $D(G(z))$</td>
          <td>Maximize $D(x_{\text{real}})$ and minimize $D(G(z))$</td>
      </tr>
      <tr>
          <td>&ldquo;Create convincing fakes&rdquo;</td>
          <td>&ldquo;Never be fooled&rdquo;</td>
      </tr>
  </tbody>
</table>
<p>This creates a dynamic where both networks continuously improve:</p>
<ul>
<li>Generator creates better fakes to fool the discriminator</li>
<li>Discriminator becomes better at detecting fakes</li>
<li>The cycle continues until equilibrium</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/GAN-SUMMARY-50.webp"
         alt="Illustration of GAN training process showing adversarial competition"
         title="Illustration of GAN training process showing adversarial competition"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>The Adversarial Training Process</strong>: Through competition, both networks improve. The generator learns to create increasingly realistic samples while the discriminator becomes more discerning.</figcaption>
    
</figure>

<h2 id="learning-through-metaphors">Learning Through Metaphors</h2>
<p>Relatable analogies often clarify complex concepts. Here are two metaphors that capture different aspects of how GANs work.</p>
<h3 id="the-art-forger-vs-critic">The Art Forger vs. Critic</h3>
<p><strong>Generator = Art Forger</strong><br>
<strong>Discriminator = Art Critic</strong></p>
<p>A criminal forger tries to create fake masterpieces, while an art critic must identify authentic works. Each interaction teaches both parties:</p>
<ul>
<li>The forger learns what makes art look authentic</li>
<li>The critic develops a keener eye for detecting fakes</li>
<li>Eventually, the forger becomes so skilled that even experts can&rsquo;t tell the difference</li>
</ul>
<p><em>This captures the adversarial nature and continuous improvement aspect of GANs.</em></p>
<h3 id="the-counterfeiter-vs-bank-teller">The Counterfeiter vs. Bank Teller</h3>
<p><strong>Generator = Counterfeiter</strong><br>
<strong>Discriminator = Bank Teller</strong></p>
<p>Day 1: Criminal brings a crayon drawing of a dollar bill. Even a new teller spots this fake.</p>
<p>Day 100: The counterfeiter has learned better techniques. The teller has developed expertise in security features.</p>
<p>Day 1000: The fake money is so convincing that detecting it requires advanced equipment.</p>
<p><em>This illustrates the progressive improvement and escalating sophistication in both networks.</em></p>
<h2 id="the-mathematical-foundation-1">The Mathematical Foundation</h2>
<p>Now let&rsquo;s examine the mathematical framework that makes GANs work. The core of GAN training is solving a <strong>minimax optimization problem</strong>.</p>
<h3 id="the-minimax-objective">The Minimax Objective</h3>
<p>$$
\min_{G} \max_{D} V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]
$$</p>
<p><strong>Breaking this down:</strong></p>
<ul>
<li>$\mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)]$: The expected log-probability for real data.
<ul>
<li><strong>Discriminator&rsquo;s Goal</strong>: Maximize this term to correctly classify real samples.</li>
</ul>
</li>
<li>$\mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]$: The expected log-probability for fake data being correctly identified as fake.
<ul>
<li><strong>Discriminator&rsquo;s Goal</strong>: Maximize this term.</li>
<li><strong>Generator&rsquo;s Goal</strong>: Minimize this term to fool the discriminator.</li>
</ul>
</li>
</ul>
<h3 id="why-minimax">Why &ldquo;Minimax&rdquo;?</h3>
<ul>
<li><strong>Discriminator ($D$)</strong>: Tries to <strong>maximize</strong> the objective → Better at distinguishing real from fake.</li>
<li><strong>Generator ($G$)</strong>: Tries to <strong>minimize</strong> the objective → Better at fooling the discriminator.</li>
</ul>
<h3 id="a-practical-challenge-vanishing-gradients">A Practical Challenge: Vanishing Gradients</h3>
<p>The minimax objective presents a practical problem early in training. When the generator is poor, the discriminator can easily distinguish real from fake samples with high confidence ($D(G(z)) \approx 0$). This causes $\log(1 - D(G(z)))$ to saturate and results in vanishing gradients for the generator, which effectively stalls learning.</p>
<p><strong>The Solution</strong>: Practitioners typically train the generator to <strong>maximize</strong> $\log(D(G(z)))$ to provide stronger gradients early in training. This non-saturating heuristic prevents the learning process from stalling.</p>
<h3 id="the-training-process">The Training Process</h3>
<p>The beauty of GANs lies in their alternating optimization:</p>
<ol>
<li><strong>Fix $G$, train $D$</strong>: Make the discriminator optimal for the current generator</li>
<li><strong>Fix $D$, train $G$</strong>: Improve the generator against the current discriminator</li>
<li><strong>Repeat</strong>: Continue until reaching Nash equilibrium</li>
</ol>
<h3 id="theoretical-goal-nash-equilibrium">Theoretical Goal: Nash Equilibrium</h3>
<p>At convergence, the discriminator outputs $D(x) = 0.5$ for all samples, meaning it can&rsquo;t distinguish between real and fake data. This indicates that $p_{\text{generator}} = p_{\text{data}}$. Our generator has learned the true data distribution.</p>
<h2 id="the-evolution-of-objective-functions">The Evolution of Objective Functions</h2>
<p>The objective function is the mathematical heart of any GAN. It defines how we measure the &ldquo;distance&rdquo; between our generated distribution and the real data distribution. This choice profoundly impacts:</p>
<ul>
<li><strong>Training stability</strong>: Some objectives lead to more stable convergence</li>
<li><strong>Sample quality</strong>: Different losses emphasize different aspects of realism</li>
<li><strong>Mode collapse</strong>: The tendency to generate limited variety</li>
<li><strong>Computational efficiency</strong>: Some objectives are faster to compute</li>
</ul>
<p>The original GAN uses Jensen-Shannon Divergence (JSD), but researchers have discovered many alternatives that address specific limitations. Let&rsquo;s explore this evolution.</p>
<h3 id="the-original-gan-jensen-shannon-divergence">The Original GAN: Jensen-Shannon Divergence</h3>
<p>The foundational GAN minimizes the Jensen-Shannon Divergence:</p>
<p>$$
\text{JSD}(P, Q) = \frac{1}{2} \text{KL}(P | M) + \frac{1}{2} \text{KL}(Q | M)
$$</p>
<p>Where $M = \frac{1}{2}(P + Q)$ is the average distribution, and $\text{KL}$ is the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">Kullback-Leibler Divergence</a>.</p>
<p><strong>Strengths</strong>: Solid theoretical foundation, introduced adversarial training<br>
<strong>Limitations</strong>: Can suffer from vanishing gradients and mode collapse</p>
<h3 id="wasserstein-gan-wgan">Wasserstein GAN (WGAN)</h3>
<p>The <a href="https://arxiv.org/abs/1701.07875">Wasserstein GAN</a> replaced Jensen-Shannon divergence with the Earth-Mover (Wasserstein) distance, which gives meaningful gradients even when the real and generated distributions do not overlap.</p>
<h4 id="understanding-earth-mover-distance">Understanding Earth-Mover Distance</h4>
<p>The Wasserstein distance, also known as Earth-Mover distance, has an intuitive interpretation:</p>
<blockquote>
<p><strong>Imagine two probability distributions as piles of dirt.</strong> The Earth-Mover distance measures the minimum cost to transform one pile into the other, where cost = mass x distance moved.</p></blockquote>
<p>Mathematically:</p>
<p>$$
W_p(\mu, \nu) = \left( \inf_{\gamma \in \Gamma(\mu, \nu)} \int_{M xM} d(x, y)^p , d\gamma(x, y) \right)^{1/p}
$$</p>
<h4 id="why-earth-mover-distance-matters">Why Earth-Mover Distance Matters</h4>
<table>
  <thead>
      <tr>
          <th>Jensen-Shannon Divergence</th>
          <th>Earth-Mover Distance</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Can be discontinuous</td>
          <td><strong>Always continuous</strong></td>
      </tr>
      <tr>
          <td>May have vanishing gradients</td>
          <td><strong>Meaningful gradients everywhere</strong></td>
      </tr>
      <tr>
          <td>Limited convergence guarantees</td>
          <td><strong>Broader convergence properties</strong></td>
      </tr>
  </tbody>
</table>
<h4 id="wgan-implementation">WGAN Implementation</h4>
<p>Since we can&rsquo;t compute Wasserstein distance directly, WGAN uses the <strong>Kantorovich-Rubinstein duality</strong>:</p>
<ol>
<li><strong>Train a critic function</strong> $f$ to approximate the Wasserstein distance</li>
<li><strong>Constrain the critic</strong> to be 1-Lipschitz (using weight clipping)</li>
<li><strong>Optimize the generator</strong> to minimize this distance</li>
</ol>















<figure class="post-figure center ">
    <img src="/img/wasserstein.webp"
         alt="WGAN training results showing stable convergence"
         title="WGAN training results showing stable convergence"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>WGAN Results</strong>: Demonstrating improved training stability and meaningful loss curves. <em>Source: Wasserstein GAN paper</em></figcaption>
    
</figure>

<h4 id="key-wgan-benefits">Key WGAN Benefits</h4>
<p><strong>Meaningful loss function</strong>: Loss correlates with sample quality<br>
<strong>Improved stability</strong>: Less prone to mode collapse<br>
<strong>Theoretical guarantees</strong>: Solid mathematical foundation<br>
<strong>Better convergence</strong>: Works even when distributions don&rsquo;t overlap</p>
<h3 id="improved-wgan-solving-the-weight-clipping-problem">Improved WGAN: Solving the Weight Clipping Problem</h3>
<p><a href="https://arxiv.org/abs/1704.00028">Improved WGAN</a> (WGAN-GP) addresses a critical flaw in the original WGAN: <strong>weight clipping</strong>.</p>
<h4 id="the-problem-with-weight-clipping">The Problem with Weight Clipping</h4>
<p>Original WGAN clips weights to maintain the 1-Lipschitz constraint:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Problematic approach</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> param <span style="color:#f92672">in</span> critic<span style="color:#f92672">.</span>parameters():
</span></span><span style="display:flex;"><span>    param<span style="color:#f92672">.</span>data<span style="color:#f92672">.</span>clamp_(<span style="color:#f92672">-</span><span style="color:#ae81ff">0.01</span>, <span style="color:#ae81ff">0.01</span>)
</span></span></code></pre></div><p><strong>Issues with clipping</strong>:</p>
<ul>
<li>Forces critic to use extremely simple functions</li>
<li>Pushes weights toward extreme values ($\pm c$)</li>
<li>Can lead to poor gradient flow</li>
<li>Capacity limitations hurt performance</li>
</ul>
<h4 id="the-gradient-penalty-solution">The Gradient Penalty Solution</h4>
<p>WGAN-GP introduces a <strong>gradient penalty term</strong> to constrain the critic:</p>
<p>$$
L = E_{\tilde{x} \sim P_g}[D(\tilde{x})] - E_{x \sim P_r}[D(x)] + \lambda E_{\hat{x}}[(||\nabla_{\hat{x}} D(\hat{x})||_2 - 1)^2]
$$</p>
<p>Where $\hat{x}$ are points sampled uniformly along straight lines between real and generated data points.</p>
<p><strong>Advantages</strong>:</p>
<ul>
<li>No capacity limitations</li>
<li>Better gradient flow</li>
<li>More stable training</li>
<li>Works across different architectures</li>
</ul>
<h3 id="lsgan-the-power-of-least-squares">LSGAN: The Power of Least Squares</h3>
<p><a href="https://arxiv.org/abs/1611.04076">Least Squares GAN</a> takes a different approach. It replaces the logarithmic loss with <strong>L2 (least squares) loss</strong>.</p>
<h4 id="motivation-beyond-binary-classification">Motivation: Beyond Binary Classification</h4>
<p>Traditional GANs use log loss, which focuses primarily on correct classification:</p>
<ul>
<li>Real sample correctly classified → minimal penalty</li>
<li>Fake sample correctly classified → minimal penalty</li>
<li>Distance from decision boundary ignored</li>
</ul>
<h4 id="l2-loss-distance-matters">L2 Loss: Distance Matters</h4>
<p>LSGAN uses L2 loss, which <strong>penalizes proportionally to distance</strong>:</p>
<p>$$
\min_D V_{LSGAN}(D) = \frac{1}{2}E_{x \sim p_{data}(x)}[(D(x) - b)^2] + \frac{1}{2}E_{z \sim p_z(z)}[(D(G(z)) - a)^2]
$$</p>
<p>$$
\min_G V_{LSGAN}(G) = \frac{1}{2}E_{z \sim p_z(z)}[(D(G(z)) - c)^2]
$$</p>
<p>Where typically: $a = 0$ (fake label), $b = c = 1$ (real label)</p>
<h4 id="benefits-of-l2-loss">Benefits of L2 Loss</h4>
<table>
  <thead>
      <tr>
          <th>Log Loss</th>
          <th>L2 Loss</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Binary focus</td>
          <td><strong>Distance-aware</strong></td>
      </tr>
      <tr>
          <td>Can saturate</td>
          <td><strong>Informative gradients</strong></td>
      </tr>
      <tr>
          <td>Sharp decision boundary</td>
          <td><strong>Smooth decision regions</strong></td>
      </tr>
  </tbody>
</table>















<figure class="post-figure center ">
    <img src="/img/lsgan-result.webp"
         alt="LSGAN generated samples showing improved quality"
         title="LSGAN generated samples showing improved quality"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>LSGAN Results</strong>: Demonstrating improved sample quality through distance-aware loss functions. <em>Source: LSGAN paper</em></figcaption>
    
</figure>

<p><strong>Key insight</strong>: LSGAN minimizes the Pearson χ² divergence, providing smoother optimization landscape than JSD.</p>
<h3 id="relaxed-wasserstein-gan-rwgan">Relaxed Wasserstein GAN (RWGAN)</h3>
<p><a href="https://arxiv.org/abs/1705.07164">Relaxed WGAN</a> bridges the gap between WGAN and WGAN-GP, proposing a <strong>general framework</strong> for designing GAN objectives.</p>
<h4 id="key-innovations">Key Innovations</h4>
<p><strong>Asymmetric weight clamping</strong>: RWGAN introduces an asymmetric approach that provides better balance.</p>
<p><strong>Relaxed Wasserstein divergences</strong>: A generalized framework that extends the Wasserstein distance, enabling systematic design of new GAN variants while maintaining theoretical guarantees.</p>
<h4 id="benefits">Benefits</h4>
<ul>
<li>Better convergence properties than standard WGAN</li>
<li>Framework for designing new loss functions and GAN architectures</li>
<li>Competitive performance with other Wasserstein-based methods</li>
</ul>
<p><strong>Key insight</strong>: RWGAN parameterized with KL divergence shows excellent performance while maintaining the theoretical foundations that make Wasserstein GANs attractive.</p>
<h3 id="statistical-distance-approaches">Statistical Distance Approaches</h3>
<p>Several GAN variants focus on minimizing specific statistical distances between distributions.</p>
<h4 id="mcgan-mean-and-covariance-matching">McGAN: Mean and Covariance Matching</h4>
<p><a href="https://arxiv.org/abs/1702.08398">McGAN</a> belongs to the Integral Probability Metric (IPM) family, using <strong>statistical moments</strong> as the distance measure.</p>
<p><strong>Approach</strong>: Match first and second-order statistics:</p>
<ul>
<li><strong>Mean matching</strong>: Align distribution centers</li>
<li><strong>Covariance matching</strong>: Align distribution shapes</li>
</ul>
<p>Moment-matching objectives like this are conceptually related to settings where aligning statistical moments matters, such as matching a generated distribution to a target physical distribution (e.g., molecular conformations). McGAN itself, however, was introduced and demonstrated as an IPM method for image generation.</p>
<p><strong>Limitation</strong>: Relies on weight clipping like original WGAN.</p>
<h4 id="gmmn-maximum-mean-discrepancy">GMMN: Maximum Mean Discrepancy</h4>
<p><a href="https://arxiv.org/abs/1502.02761">Generative Moment Matching Networks</a> eliminates the discriminator entirely, directly minimizing <strong>Maximum Mean Discrepancy (MMD)</strong>.</p>
<p><strong>MMD Intuition</strong>: Compare distributions by their means in a high-dimensional feature space:</p>
<p>$$
\text{MMD}^2(X, Y) = ||E[\phi(x)] - E[\phi(y)]||^2
$$</p>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Simple, discriminator-free training</li>
<li>Theoretical guarantees</li>
<li>Can incorporate autoencoders for better MMD estimation</li>
</ul>
<p><strong>Drawbacks</strong>:</p>
<ul>
<li>Computationally expensive</li>
<li>Often weaker empirical results</li>
</ul>
<h4 id="mmd-gan-learning-better-kernels">MMD GAN: Learning Better Kernels</h4>
<p><a href="https://arxiv.org/abs/1705.08584">MMD GAN</a> improves GMMN by <strong>learning optimal kernels</strong> adversarially to improve upon fixed Gaussian kernels.</p>
<p><strong>Innovation</strong>: Combine GAN adversarial training with MMD objective for the best of both worlds.</p>
<h3 id="different-distance-metrics">Different Distance Metrics</h3>
<h4 id="cramer-gan-addressing-sample-bias">Cramer GAN: Addressing Sample Bias</h4>
<p><a href="https://arxiv.org/abs/1705.10743">Cramer GAN</a> identifies a critical issue with WGAN: <strong>biased sample gradients</strong>.</p>
<p><strong>The Problem</strong>: WGAN&rsquo;s Wasserstein distance lacks three important properties:</p>
<ol>
<li><strong>Sum invariance</strong> (satisfied)</li>
<li><strong>Scale sensitivity</strong> (satisfied)</li>
<li><strong>Unbiased sample gradients</strong> (not satisfied)</li>
</ol>
<p><strong>The Solution</strong>: Use the <strong>Cramer distance</strong>, which satisfies all three properties:</p>
<p>$$
d_C^2(\mu, \nu) = \int ||E_{X \sim \mu}[X - x] - E_{Y \sim \nu}[Y - x]||^2 d\pi(x)
$$</p>
<p><strong>Benefit</strong>: More reliable gradients lead to better training dynamics.</p>
<h4 id="fisher-gan-chi-square-distance">Fisher GAN: Chi-Square Distance</h4>
<p><a href="https://arxiv.org/abs/1705.09675">Fisher GAN</a> uses a <strong>data-dependent constraint</strong> on the critic&rsquo;s second-order moments (variance).</p>
<p><strong>Key Innovation</strong>: The constraint naturally bounds the critic without manual techniques:</p>
<ul>
<li>No weight clipping needed</li>
<li>No gradient penalties required</li>
<li>Constraint emerges from the objective itself</li>
</ul>
<p><strong>Distance</strong>: Approximates the <strong>Chi-square distance</strong> as critic capacity increases:</p>
<p>$$
\chi^2(P, Q) = \int \frac{(P(x) - Q(x))^2}{Q(x)} dx
$$</p>
<p>The Fisher GAN essentially measures the Mahalanobis distance, which accounts for correlated variables relative to the distribution&rsquo;s centroid. This ensures the generator and critic remain bounded, and as the critic&rsquo;s capacity increases, it estimates the Chi-square distance.</p>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Efficient computation</li>
<li>Training stability</li>
<li>Unconstrained critic capacity</li>
</ul>
<h3 id="beyond-traditional-gans-alternative-approaches">Beyond Traditional GANs: Alternative Approaches</h3>
<p>The following variants explore fundamentally different architectures and training paradigms.</p>
<h4 id="ebgan-energy-based-discrimination">EBGAN: Energy-Based Discrimination</h4>
<p><a href="https://arxiv.org/abs/1609.03126">Energy-Based GAN</a> replaces the discriminator with an <strong>autoencoder</strong>.</p>
<p><strong>Key insight</strong>: Use reconstruction error as the discrimination signal:</p>
<ul>
<li>Good data → Low reconstruction error</li>
<li>Poor data → High reconstruction error</li>
</ul>
<p><strong>Architecture</strong>:</p>
<ol>
<li>Train autoencoder on real data</li>
<li>Generator creates samples</li>
<li>Poor generated samples have high reconstruction loss</li>
<li>This loss drives generator improvement</li>
</ol>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Fast and stable training</li>
<li>Robust to hyperparameter changes</li>
<li>No need to balance discriminator/generator</li>
</ul>
<h4 id="began-boundary-equilibrium">BEGAN: Boundary Equilibrium</h4>
<p><a href="https://arxiv.org/abs/1703.10717">BEGAN</a> combines EBGAN&rsquo;s autoencoder approach with WGAN-style loss functions.</p>
<p><strong>Innovation</strong>: Dynamic equilibrium parameter $k_t$ that balances:</p>
<ul>
<li>Real data reconstruction quality</li>
<li>Generated data reconstruction quality</li>
</ul>
<p><strong>Equilibrium equation</strong>:</p>
<p>$$
L_D = L(x) - k_t L(G(z))
$$</p>
<p>$$
k_{t+1} = k_t + \lambda(\gamma L(x) - L(G(z)))
$$</p>
<h4 id="magan-adaptive-margins">MAGAN: Adaptive Margins</h4>
<p><a href="https://arxiv.org/abs/1704.03817">MAGAN</a> improves EBGAN by making the margin in the hinge loss <strong>adaptive over time</strong>.</p>
<p><strong>Concept</strong>: Start with a large margin, gradually reduce it as training progresses:</p>
<ul>
<li>Early training: Focus on major differences</li>
<li>Later training: Fine-tune subtle details</li>
</ul>
<p><strong>Result</strong>: Better sample quality and training stability.</p>
<h2 id="summary-the-evolution-of-gan-objectives">Summary: The Evolution of GAN Objectives</h2>
<p>The evolution of GAN objective functions reflects the field&rsquo;s progression toward more stable and theoretically grounded training procedures. Each variant addresses specific limitations in earlier approaches.</p>
<h3 id="complete-reference-table">Complete Reference Table</h3>
<table>
  <thead>
      <tr>
          <th><strong>GAN Variant</strong></th>
          <th><strong>Key Innovation</strong></th>
          <th><strong>Main Benefit</strong></th>
          <th><strong>Limitation</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Original GAN</strong></td>
          <td>Jensen-Shannon divergence</td>
          <td>Foundation of adversarial training</td>
          <td>Vanishing gradients, mode collapse</td>
      </tr>
      <tr>
          <td><strong>WGAN</strong></td>
          <td>Earth-Mover distance</td>
          <td>Meaningful loss, better stability</td>
          <td>Weight clipping issues</td>
      </tr>
      <tr>
          <td><strong>WGAN-GP</strong></td>
          <td>Gradient penalty</td>
          <td>Solves weight clipping problems</td>
          <td>Additional hyperparameter tuning</td>
      </tr>
      <tr>
          <td><strong>LSGAN</strong></td>
          <td>Least squares loss</td>
          <td>Better gradients, less saturation</td>
          <td>May converge to non-optimal points</td>
      </tr>
      <tr>
          <td><strong>RWGAN</strong></td>
          <td>Relaxed Wasserstein framework</td>
          <td>General framework for new designs</td>
          <td>Complex theoretical setup</td>
      </tr>
      <tr>
          <td><strong>McGAN</strong></td>
          <td>Mean/covariance matching</td>
          <td>Simple statistical alignment</td>
          <td>Limited by weight clipping</td>
      </tr>
      <tr>
          <td><strong>GMMN</strong></td>
          <td>Maximum mean discrepancy</td>
          <td>No discriminator needed</td>
          <td>Computationally expensive</td>
      </tr>
      <tr>
          <td><strong>MMD GAN</strong></td>
          <td>Adversarial kernels for MMD</td>
          <td>Improved GMMN performance</td>
          <td>Still computationally heavy</td>
      </tr>
      <tr>
          <td><strong>Cramer GAN</strong></td>
          <td>Cramer distance</td>
          <td>Unbiased sample gradients</td>
          <td>Complex implementation</td>
      </tr>
      <tr>
          <td><strong>Fisher GAN</strong></td>
          <td>Chi-square distance</td>
          <td>Self-constraining critic</td>
          <td>Limited empirical validation</td>
      </tr>
      <tr>
          <td><strong>EBGAN</strong></td>
          <td>Autoencoder discriminator</td>
          <td>Fast, stable training</td>
          <td>Requires careful regularization</td>
      </tr>
      <tr>
          <td><strong>BEGAN</strong></td>
          <td>Boundary equilibrium</td>
          <td>Dynamic training balance</td>
          <td>Additional equilibrium parameter</td>
      </tr>
      <tr>
          <td><strong>MAGAN</strong></td>
          <td>Adaptive margin</td>
          <td>Progressive refinement</td>
          <td>Margin scheduling complexity</td>
      </tr>
  </tbody>
</table>
<h3 id="practical-recommendations">Practical Recommendations</h3>
<p>For practitioners, the choice depends on specific requirements and engineering tradeoffs:</p>
<ul>
<li><strong>WGAN-GP</strong>: Best balance of stability and performance for most applications. However, tuning the gradient penalty $\lambda$ can be sensitive in practice.</li>
<li><strong>LSGAN</strong>: Simpler implementation with good empirical results.</li>
<li><strong>EBGAN</strong>: Fast experimentation and prototyping.</li>
<li><strong>Original GAN</strong>: Educational purposes and understanding fundamentals.</li>
</ul>
<p><strong>Real-World Impact:</strong> In my work training VLMs on terabyte-scale multimodal data and forecasting chaotic physical systems, these foundational dynamics still matter. Most generation today runs on diffusion models or autoregressive transformers, but the loss-design and training-stability lessons that came out of GAN research carry over. The choice of objective function shapes generation quality, training stability, and compute cost.</p>
<hr>
<p><strong>Acknowledgments</strong>: This post was inspired by the excellent survey &ldquo;<a href="https://arxiv.org/abs/1711.05914">How Generative Adversarial Networks and Their Variants Work: An Overview of GAN</a>&rdquo;.</p>
]]></content:encoded></item><item><title>Word Embeddings in NLP: An Introduction</title><link>https://hunterheidenreich.com/posts/intro-to-word-embeddings/</link><pubDate>Sun, 05 Aug 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/intro-to-word-embeddings/</guid><description>Learn about word embeddings in NLP: from basic one-hot encoding to contextual models like ELMo. Guide with examples.</description><content:encoded><![CDATA[<h2 id="understanding-word-embeddings">Understanding Word Embeddings</h2>
<p>A word embedding maps words to real-valued vectors:</p>
<p>$$
\text{word} \rightarrow \mathbb{R}^n
$$</p>
<p>where $n$ represents the dimensionality of the embedding space.</p>
<p>The goal is simple: position semantically similar words close together in vector space. This dense representation typically uses hundreds of dimensions, a massive reduction from the millions required by one-hot encoding.</p>
<p>Word embeddings are grounded in <a href="https://en.wikipedia.org/wiki/Distributional_semantics">Zellig Harris&rsquo; distributional hypothesis</a>: words appearing in similar contexts tend to have similar meanings. This forms the foundation of distributional semantics.</p>















<figure class="post-figure center ">
    <img src="/img/distributional_semantics-50.webp"
         alt="Distributional semantics visualization"
         title="Distributional semantics visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Words embedded in three-dimensional space, organized by semantic similarity</figcaption>
    
</figure>

<p>Different embedding algorithms capture various aspects of this distributional principle. This post explores the main methods for creating word embeddings and their applications in natural language processing.</p>
<p>While modern foundation models and large Vision-Language Models rely on subword tokenizers (like BPE) and Transformer embedding layers, the goal is the same: mapping discrete text to a continuous vector space where math can capture meaning. These foundational techniques build the intuition for the embedding layers in today&rsquo;s models.</p>
<h2 id="why-word-embeddings-matter-in-nlp">Why Word Embeddings Matter in NLP</h2>
<p>Computers require numerical representations to apply machine learning algorithms to text. Word embeddings bridge this gap by converting text into dense vectors that preserve semantic and syntactic relationships.</p>
<p><strong>Key advantages:</strong></p>
<ol>
<li><strong>Dense representation</strong>: Hundreds of dimensions provide a compact alternative to vocabulary-sized sparse vectors.</li>
<li><strong>Semantic preservation</strong>: Similar words cluster together in vector space.</li>
<li><strong>Mathematical operations</strong>: Enable analogical reasoning ($\text{king} - \text{man} + \text{woman} \approx \text{queen}$).</li>
<li><strong>Transfer learning</strong>: Pre-trained embeddings work across multiple tasks and domains.</li>
</ol>
<p>Modern deep learning architectures leverage these properties extensively. The development of universal, pre-trained embeddings was a significant step forward. We can use versatile embeddings that generalize across applications, eliminating the need to train task-specific representations from scratch.</p>
<h2 id="word-embedding-approaches">Word Embedding Approaches</h2>
<h3 id="one-hot-encoding-and-count-vectorization">One-Hot Encoding and Count Vectorization</h3>
<p>One-hot encoding represents the simplest approach to word vectorization. Each word gets a unique dimension in a vocabulary-sized vector, marked with 1 for presence and 0 elsewhere. Count vectorization extends this by counting the occurrences of each word in a document.</p>















<figure class="post-figure center ">
    <img src="/img/word_vector_onehot-50.webp"
         alt="One-hot encoding visualization"
         title="One-hot encoding visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">One-hot encoding creates sparse vectors with single active dimensions</figcaption>
    
</figure>

<p><strong>Characteristics:</strong></p>
<ul>
<li><strong>High dimensionality</strong>: Vector length equals vocabulary size.</li>
<li><strong>Extreme sparsity</strong>: Most dimensions contain zeros.</li>
<li><strong>No relationships</strong>: Treats all words as equally distant.</li>
<li><strong>Computational efficiency</strong>: Simple to implement and understand.</li>
</ul>
<p>While lacking semantic information, count vectorization serves as a foundation for more complex methods. Let&rsquo;s look at a practical implementation using scikit-learn&rsquo;s <code>CountVectorizer</code>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.feature_extraction.text <span style="color:#f92672">import</span> CountVectorizer
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Initialize the vectorizer</span>
</span></span><span style="display:flex;"><span>vectorizer <span style="color:#f92672">=</span> CountVectorizer()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Sample text for demonstration</span>
</span></span><span style="display:flex;"><span>sample_text <span style="color:#f92672">=</span> [<span style="color:#e6db74">&#34;One of the most basic ways we can numerically represent words &#34;</span>
</span></span><span style="display:flex;"><span>               <span style="color:#e6db74">&#34;is through the one-hot encoding method (also sometimes called &#34;</span>
</span></span><span style="display:flex;"><span>               <span style="color:#e6db74">&#34;count vectorizing).&#34;</span>]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Fit the vectorizer to our text data</span>
</span></span><span style="display:flex;"><span>vectorizer<span style="color:#f92672">.</span>fit(sample_text)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Examine the vocabulary and word indices</span>
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">&#39;Vocabulary:&#39;</span>)
</span></span><span style="display:flex;"><span>print(vectorizer<span style="color:#f92672">.</span>vocabulary_)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Transform text to vectors</span>
</span></span><span style="display:flex;"><span>vector <span style="color:#f92672">=</span> vectorizer<span style="color:#f92672">.</span>transform(sample_text)
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">&#39;Full vector:&#39;</span>)
</span></span><span style="display:flex;"><span>print(vector<span style="color:#f92672">.</span>toarray())
</span></span></code></pre></div><p>At scale, count vectorization introduces engineering challenges. With millions of documents, the vocabulary grows large, and the sparse matrices become expensive to store and compute on. In these scaling scenarios, practitioners often turn to the <strong>Hashing Trick</strong> (via <code>HashingVectorizer</code>) to bound the dimensionality, or they move entirely to the dense embeddings discussed later in this post.</p>
<p>We can see count vectorization in action with a real dataset, building a simple text classifier for the <a href="https://www.kaggle.com/datasets/crawford/20-newsgroups">20 Newsgroups dataset</a>:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.datasets <span style="color:#f92672">import</span> fetch_20newsgroups
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.feature_extraction.text <span style="color:#f92672">import</span> CountVectorizer
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.naive_bayes <span style="color:#f92672">import</span> MultinomialNB
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn <span style="color:#f92672">import</span> metrics
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Load train and test splits, removing metadata for a cleaner signal</span>
</span></span><span style="display:flex;"><span>newsgroups_train <span style="color:#f92672">=</span> fetch_20newsgroups(subset<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;train&#39;</span>,
</span></span><span style="display:flex;"><span>                                      remove<span style="color:#f92672">=</span>(<span style="color:#e6db74">&#39;headers&#39;</span>, <span style="color:#e6db74">&#39;footers&#39;</span>, <span style="color:#e6db74">&#39;quotes&#39;</span>))
</span></span><span style="display:flex;"><span>newsgroups_test <span style="color:#f92672">=</span> fetch_20newsgroups(subset<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;test&#39;</span>,
</span></span><span style="display:flex;"><span>                                     remove<span style="color:#f92672">=</span>(<span style="color:#e6db74">&#39;headers&#39;</span>, <span style="color:#e6db74">&#39;footers&#39;</span>, <span style="color:#e6db74">&#39;quotes&#39;</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Initialize and fit vectorizer on training data</span>
</span></span><span style="display:flex;"><span>vectorizer <span style="color:#f92672">=</span> CountVectorizer()
</span></span><span style="display:flex;"><span>X_train <span style="color:#f92672">=</span> vectorizer<span style="color:#f92672">.</span>fit_transform(newsgroups_train<span style="color:#f92672">.</span>data)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Build and train classifier</span>
</span></span><span style="display:flex;"><span>classifier <span style="color:#f92672">=</span> MultinomialNB(alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.01</span>)
</span></span><span style="display:flex;"><span>classifier<span style="color:#f92672">.</span>fit(X_train, newsgroups_train<span style="color:#f92672">.</span>target)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Transform test data and make predictions</span>
</span></span><span style="display:flex;"><span>X_test <span style="color:#f92672">=</span> vectorizer<span style="color:#f92672">.</span>transform(newsgroups_test<span style="color:#f92672">.</span>data)
</span></span><span style="display:flex;"><span>y_pred <span style="color:#f92672">=</span> classifier<span style="color:#f92672">.</span>predict(X_test)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Evaluate performance</span>
</span></span><span style="display:flex;"><span>accuracy <span style="color:#f92672">=</span> metrics<span style="color:#f92672">.</span>accuracy_score(newsgroups_test<span style="color:#f92672">.</span>target, y_pred)
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;Accuracy: </span><span style="color:#e6db74">{</span>accuracy<span style="color:#e6db74">:</span><span style="color:#e6db74">.3f</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span></code></pre></div><p>This provides a solid baseline. To capture actual semantic meaning and reduce dimensionality, we must move beyond simple counting.</p>
<h3 id="tf-idf-term-frequency-inverse-document-frequency">TF-IDF (Term Frequency-Inverse Document Frequency)</h3>
<p><a href="https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html">TF-IDF</a> extends one-hot encoding by weighting terms based on their importance across a document collection. TF-IDF combines:</p>
<ul>
<li><strong>Term Frequency (TF)</strong>: How often a word appears in a document</li>
<li><strong>Inverse Document Frequency (IDF)</strong>: How rare a word is across all documents</li>
</ul>
<p>This weighting scheme reduces the impact of common words (like &ldquo;the&rdquo; or &ldquo;and&rdquo;) while emphasizing distinctive terms that appear frequently in specific documents but rarely elsewhere.</p>
<p><strong>Advantages:</strong></p>
<ul>
<li>Captures document-level importance</li>
<li>Reduces impact of stop words</li>
<li>Effective for information retrieval tasks</li>
</ul>
<p><strong>Limitations:</strong></p>
<ul>
<li>Still high-dimensional and sparse</li>
<li>No semantic relationships between terms</li>
<li>Context-independent representation</li>
</ul>
<h3 id="co-occurrence-matrices">Co-Occurrence Matrices</h3>
<p>Co-occurrence matrices capture word relationships by recording which terms appear together within defined contexts (sentences, paragraphs, or fixed windows). The resulting matrix has dimensions equal to vocabulary size squared, with entries showing co-occurrence frequency.</p>















<figure class="post-figure center ">
    <img src="/img/Word_co-occurrence_network_%28range_3_words%29_-_ENG-50.webp"
         alt="Co-occurrence network visualization"
         title="Co-occurrence network visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Co-occurrence relationships within a three-word window</figcaption>
    
</figure>

<p><strong>Key properties:</strong></p>
<ul>
<li><strong>Global statistics</strong>: Captures corpus-wide word relationships</li>
<li><strong>Symmetric relationships</strong>: Mutual co-occurrence patterns</li>
<li><strong>Extreme dimensionality</strong>: Vocabulary size squared creates storage challenges</li>
<li><strong>Sparse representation</strong>: Most word pairs never co-occur</li>
</ul>
<p>While computationally expensive to store and process, co-occurrence matrices form the foundation for advanced methods like GloVe that compress this information into dense representations.</p>
<h2 id="neural-network-based-embeddings">Neural Network-Based Embeddings</h2>
<h3 id="neural-probabilistic-language-models">Neural Probabilistic Language Models</h3>
<p><a href="https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf">Neural probabilistic models</a> pioneered the use of neural networks for learning word embeddings. These models learn dense representations as a byproduct of language modeling, predicting the next word in a sequence.</p>















<figure class="post-figure center ">
    <img src="/img/bengio-npm-50.webp"
         alt="Neural probabilistic model diagram"
         title="Neural probabilistic model diagram"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Architecture of neural probabilistic language models</figcaption>
    
</figure>

<p><strong>Training process:</strong></p>
<ol>
<li>Initialize random dense embeddings for each vocabulary word</li>
<li>Use embeddings as inputs to predict language modeling objectives</li>
<li>Update embeddings through backpropagation based on prediction errors</li>
<li>Resulting embeddings capture patterns useful for the training task</li>
</ol>
<p>This approach demonstrated that task-specific embeddings could be learned jointly with model objectives, establishing the foundation for modern embedding methods.</p>
<h3 id="word2vec">Word2Vec</h3>
<p><a href="https://code.google.com/archive/p/word2vec/">Word2Vec</a> made word embeddings practical at scale by introducing efficient training algorithms for massive corpora. It popularized compelling vector arithmetic properties, enabling analogical reasoning like the famous &ldquo;$\text{king} - \text{man} + \text{woman} \approx \text{queen}$&rdquo; example (a vector-offset regularity first reported by Mikolov, Yih &amp; Zweig (2013) on recurrent-network language-model embeddings).</p>















<figure class="post-figure center ">
    <img src="/img/Word_vector_illustration.webp"
         alt="Word2Vec vector arithmetic visualization"
         title="Word2Vec vector arithmetic visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Word2Vec demonstrates analogical relationships through vector arithmetic</figcaption>
    
</figure>

<p><strong>Two training architectures:</strong></p>
<h4 id="continuous-bag-of-words-cbow">Continuous Bag-of-Words (CBOW)</h4>
<p>Predicts target words from surrounding context words. Given a window of context words, the model learns to predict the central word.</p>
<h4 id="skip-gram">Skip-Gram</h4>
<p>Predicts context words from target words. Given a central word, the model learns to predict surrounding words within a defined window.</p>
<p><strong>Key advantages:</strong></p>
<ul>
<li><strong>Computational efficiency</strong>: Much faster than neural probabilistic models</li>
<li><strong>Scalable training</strong>: Can process billion-word corpora effectively</li>
<li><strong>Quality embeddings</strong>: Captures semantic and syntactic relationships</li>
<li><strong>Flexible context</strong>: Window size controls topical vs. functional similarity</li>
</ul>
<p>The choice of window size significantly impacts learned relationships. Larger windows capture topical associations, while smaller windows focus on syntactic and functional similarities.</p>
<h3 id="glove-global-vectors">GloVe (Global Vectors)</h3>
<p><a href="https://nlp.stanford.edu/projects/glove/">GloVe</a> combines the best aspects of matrix factorization methods (which capture global corpus statistics) and local context window approaches like Word2Vec. Matrix factorization methods excel at global patterns but struggle with analogical reasoning, while Word2Vec captures local relationships but may miss global structure.</p>
<p><strong>Key innovation:</strong>
GloVe trains on a global word-context co-occurrence matrix, incorporating corpus-wide statistical information while maintaining the analogical reasoning capabilities that made Word2Vec successful.</p>
<p><strong>Advantages over Word2Vec:</strong></p>
<ul>
<li><strong>Global optimization</strong>: Leverages entire corpus statistics</li>
<li><strong>Better performance</strong>: Often outperforms Word2Vec on word similarity and analogy tasks</li>
<li><strong>Stable training</strong>: More consistent convergence due to global objective function</li>
</ul>
<p>The result is embeddings that capture both local syntactic patterns and global semantic relationships more effectively.</p>
<h2 id="contextual-embedding-methods">Contextual Embedding Methods</h2>
<h3 id="fasttext">FastText</h3>
<p><a href="https://github.com/facebookresearch/fastText">FastText</a> addresses a critical limitation of previous methods: handling out-of-vocabulary (OOV) words. By incorporating subword information, FastText can generate meaningful representations for previously unseen words.</p>
<p><strong>Subword approach:</strong></p>
<ul>
<li>Decomposes words into character n-grams (typically 3-6 characters)</li>
<li>Represents words as sums of their component n-grams</li>
<li>Trains using skip-gram objective with negative sampling</li>
</ul>
<p><strong>Key advantages:</strong></p>
<ul>
<li><strong>OOV handling</strong>: Can embed unseen words using known subword components</li>
<li><strong>Morphological awareness</strong>: Captures relationships between related word forms</li>
<li><strong>Multilingual support</strong>: Facebook released pre-trained embeddings for 294 languages</li>
<li><strong>Robust performance</strong>: Particularly effective for morphologically rich languages</li>
</ul>
<p>For example, if the model knows &ldquo;navigate,&rdquo; it can provide meaningful representation for &ldquo;circumnavigate&rdquo; by leveraging shared subword components, even if &ldquo;circumnavigate&rdquo; wasn&rsquo;t in the training data.</p>
<h3 id="poincaré-embeddings">Poincaré Embeddings</h3>
<p><a href="https://radimrehurek.com/gensim/models/poincare.html">Poincaré embeddings</a> introduce a novel approach by learning representations in hyperbolic space. This geometric innovation specifically targets hierarchical relationships in data.</p>
<p><strong>Hyperbolic geometry advantages:</strong></p>
<ul>
<li><strong>Natural hierarchy encoding</strong>: Distance represents similarity, while norm encodes hierarchical level</li>
<li><strong>Efficient representation</strong>: Requires fewer dimensions for hierarchical data</li>
<li><strong>Mathematical elegance</strong>: Leverages properties of hyperbolic space for embedding optimization</li>
</ul>
<p><strong>Applications:</strong>
Particularly effective for data with inherent hierarchical structure, such as:</p>
<ul>
<li>WordNet taxonomies</li>
<li>Organizational charts</li>
<li>Computer network topologies</li>
<li>Knowledge graphs</li>
</ul>
<p>The <a href="https://arxiv.org/abs/1705.08039">original paper</a> demonstrates good efficiency in reproducing WordNet relationships with significantly lower dimensionality compared to traditional embedding methods.</p>
<h2 id="contextual-embeddings">Contextual Embeddings</h2>
<h3 id="elmo-embeddings-from-language-models">ELMo (Embeddings from Language Models)</h3>
<p><a href="https://github.com/allenai/allennlp-models">ELMo</a> represents a paradigm shift toward contextual word representations. ELMo generates dynamic representations based on sentence context, adapting to word usage patterns.</p>
<p><strong>Architecture:</strong></p>
<ul>
<li><strong>Bidirectional LSTM</strong>: Processes text in both forward and backward directions</li>
<li><strong>Character-level input</strong>: Handles OOV words and captures morphological patterns</li>
<li><strong>Multi-layer representations</strong>: Combines different abstraction levels</li>
</ul>
<p><strong>Layer specialization:</strong></p>
<ul>
<li><strong>Lower layers</strong>: Excel at syntactic tasks (POS tagging, parsing)</li>
<li><strong>Higher layers</strong>: Capture semantic relationships (word sense disambiguation)</li>
<li><strong>Combined layers</strong>: Weighted combination achieves good performance</li>
</ul>
<p><strong>Key innovation:</strong>
ELMo embeddings vary by context. The word &ldquo;bank&rdquo; receives different representations in &ldquo;river bank&rdquo; versus &ldquo;financial bank,&rdquo; addressing polysemy directly through contextual awareness.</p>
<p>This approach achieved strong performance across numerous NLP tasks by providing context-sensitive representations that adapt to word usage patterns.</p>
<h3 id="probabilistic-fasttext">Probabilistic FastText</h3>
<p><a href="https://github.com/benathi/multisense-prob-fasttext">Probabilistic FastText</a> addresses polysemy (words with multiple meanings) through probabilistic modeling. Traditional embeddings conflate different word senses into single representations, limiting their precision.</p>
<p><strong>The polysemy problem:</strong>
Consider &ldquo;rock&rdquo; which can mean:</p>
<ul>
<li>Rock music (genre)</li>
<li>A stone (geological object)</li>
<li>Rocking motion (verb)</li>
</ul>
<p>Standard embeddings average these meanings, producing representations that may not capture any sense precisely.</p>
<p><strong>Probabilistic approach:</strong>
Probabilistic FastText represents words as Gaussian mixture models: probability distributions that can capture multiple distinct meanings as separate components.</p>
<p><strong>Advantages:</strong></p>
<ul>
<li><strong>Multi-sense representation</strong>: Each word sense gets its own distribution</li>
<li><strong>Context sensitivity</strong>: Can select appropriate sense based on usage context</li>
<li><strong>Uncertainty quantification</strong>: Probabilistic framework captures embedding confidence</li>
</ul>
<p>This approach provides a more nuanced treatment of lexical ambiguity, particularly valuable for words with distinct, context-dependent meanings.</p>
<h2 id="summary-and-future-directions">Summary and Future Directions</h2>
<p>Word embeddings have evolved from simple one-hot encodings to contextual representations that capture nuanced linguistic relationships. Each approach offers distinct advantages:</p>
<p><strong>Static embeddings</strong> (Word2Vec, GloVe, FastText) provide:</p>
<ul>
<li>Computational efficiency for large-scale applications</li>
<li>Pre-trained models available for numerous languages</li>
<li>Clear analogical reasoning capabilities</li>
<li>Good performance on many downstream tasks</li>
</ul>
<p><strong>Contextual embeddings</strong> (ELMo, BERT, GPT) offer:</p>
<ul>
<li>Dynamic representations based on sentence context</li>
<li>Better handling of polysemy and word sense disambiguation</li>
<li>Strong performance on complex NLP tasks</li>
<li>Ability to capture subtle contextual nuances</li>
</ul>
<p><strong>Choosing the right approach</strong> depends on:</p>
<ul>
<li><strong>Task requirements</strong>: Static embeddings for efficiency, contextual for accuracy</li>
<li><strong>Data availability</strong>: Pre-trained models vs. domain-specific training</li>
<li><strong>Computational constraints</strong>: Static embeddings require less processing power</li>
<li><strong>Language coverage</strong>: Consider availability of pre-trained models for target languages</li>
</ul>
<p>The field continues advancing toward more efficient contextual models, better multilingual representations, and embeddings that capture increasingly complex linguistic phenomena.</p>
<p>For a from-scratch Word2Vec implementation in PyTorch (Skip-gram and CBOW, with hierarchical softmax and negative sampling) that takes these concepts further, see the <a href="/projects/modern-word2vec/">PyTorch Word2Vec project</a>.</p>
]]></content:encoded></item></channel></rss>