<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:content="http://purl.org/rss/1.0/modules/content/"><channel><title>Optimization on Hunter Heidenreich | Senior AI Research Scientist</title><link>https://hunterheidenreich.com/tags/optimization/</link><description>Recent content in Optimization on Hunter Heidenreich | Senior AI Research Scientist</description><image><title>Hunter Heidenreich | Senior AI Research Scientist</title><url>https://hunterheidenreich.com/img/avatar.webp</url><link>https://hunterheidenreich.com/img/avatar.webp</link></image><generator>Hugo -- 0.147.7</generator><language>en-US</language><copyright>2026 Hunter Heidenreich</copyright><lastBuildDate>Sun, 31 May 2026 00:00:00 +0000</lastBuildDate><atom:link href="https://hunterheidenreich.com/tags/optimization/index.xml" rel="self" type="application/rss+xml"/><item><title>Graph Grammar and ILP for Carbon Fixation Pathways</title><link>https://hunterheidenreich.com/notes/biology/computational-biology/carbon-fixation-pathway-design/</link><pubDate>Sun, 12 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/biology/computational-biology/carbon-fixation-pathway-design/</guid><description>Graph-based chemical space expansion with ILP flow queries discovers novel autocatalytic carbon fixation pathways competitive with CETCH and rTCA.</description><content:encoded><![CDATA[<h2 id="a-graph-grammar-and-ilp-framework-for-pathway-discovery">A Graph-Grammar and ILP Framework for Pathway Discovery</h2>
<p>Abel et al. present a Method paper that couples generative chemical space expansion with <a href="https://en.wikipedia.org/wiki/Integer_programming">integer linear programming</a> (ILP) pathway queries to systematically propose artificial carbon fixation pathways. The workflow uses the cheminformatics package MØD to iteratively expand a reaction hypergraph from a seed set of metabolites and rule-based enzyme reactions, then queries the resulting network for autocatalytic flows producing a chosen target molecule. Post-hoc annotation with eQuilibrator Gibbs energies and cofactor accounting ranks candidates by thermodynamic feasibility. Applied to the Acetyl-CoA-Succinyl-CoA pathway family plus selected synthetic and theoretical pathways, the framework recovers the natural pathways and proposes two new theoretical autocatalytic cycles (an 11-step Acetyl-CoA cycle and a 12-step Malate cycle) whose efficiency, measured in ATP and redox cofactors per fixed carbon, is comparable to the synthetic CETCH cycle and the natural <a href="https://en.wikipedia.org/wiki/Reverse_Krebs_cycle">rTCA</a>.</p>
<h2 id="why-computational-pathway-design-for-carbon-fixation">Why Computational Pathway Design for Carbon Fixation</h2>
<p>Fixing atmospheric CO$_2$ or bicarbonate into value-added chemicals is a thermodynamically unfavorable process that nature solves through enzymatic cascades coupled to cofactor-driven reactions. Seven natural carbon fixation pathways are known, along with several artificial proposals, and the Acetyl-CoA-Succinyl-CoA family is particularly appealing as a design template because each member overlaps structurally with at least one other and each exhibits <a href="https://en.wikipedia.org/wiki/Autocatalysis">autocatalysis</a>. Prior approaches to artificial pathway design (e.g., Erb Lab CETCH, HOPAC) rely heavily on manual heuristics, database searches, and extensive in-vitro optimization including <a href="https://en.wikipedia.org/wiki/Directed_evolution">directed evolution</a>. Earlier computational work (Löwe and Kremling, 2021) uses <a href="https://en.wikipedia.org/wiki/Flux_balance_analysis">flux balance analysis</a> and expert curation that requires complete kinetic parameterization, making generative exploration infeasible. Abel et al. target the design stage directly: a computational approach that can quickly enumerate many topologically distinct pathway candidates without requiring a priori kinetic parameters.</p>
<h2 id="generative-chemical-space-expansion-with-graph-grammar-rules">Generative Chemical Space Expansion with Graph-Grammar Rules</h2>
<p>The core innovation is treating the chemical reaction network (CRN) as a <a href="https://en.wikipedia.org/wiki/Hypergraph">directed multi-hypergraph</a> $H = (V, E)$ where vertices in $V$ are molecules and each hyperedge $e \in E$ is a directed pair $(e_{tail}, e_{head})$ of multisets representing reactants and products. This hyperedge formalization directly captures the many-to-many nature of biochemical reactions.</p>
<p>Reactions are specified as graph transformation rules written in the Graph Modeling Language (GML). A rule defines the bond rewiring at a reaction center plus a tunable molecular context around that center. A rule with no context is fully promiscuous (every oxidoreductase class reaction, say); a rule with rich context mimics a specific enzyme. This rule-based formalism lets one rule represent an entire reaction class, so the CRN can be expanded without enumerating every possible enzyme-substrate pair in advance. Expansion proceeds iteratively: the rules act on the current molecule pool, producing new molecules and new hyperedges, until a user-defined step count is reached. Two biochemical sanity constraints bound the combinatorial explosion: molecules are restricted to at most 6 carbon atoms in the backbone (excluding the CoA moiety), and at most one CoA group per molecule.</p>
<p>Pathway discovery is then an ILP flow query over the CRN. A pathway is a hyperflow: an assignment of integer flow values to hyperedges such that internal molecules balance between production and consumption, leaving only designated source and sink molecules with net flow. The main optimization objective minimizes the number of reactions used and, as a tiebreaker, the magnitude of flow on those reactions:</p>
<p>$$
\min \left(\sum_{e \in E} z_e \cdot w + x_e\right)
$$</p>
<p>where $z_e$ is a boolean indicator that hyperedge $e$ carries flow, $x_e$ is the integer flow on $e$, and the weight $w = 1000$ prioritizes minimizing the edge count over the total flow magnitude. Autocatalysis is encoded as a constraint on the autocatalyst molecule $a$: its inflow and outflow are both positive, with outflow strictly exceeding inflow so the cycle nets at least one additional molecule of the autocatalyst.</p>
<p>$$
0 &lt; x_a^{in} &lt; x_a^{out}
$$</p>
<p>Only the autocatalyst itself, cofactors, and CO$_2$/HCO$_3^-$ are permitted as sources and sinks, so any valid flow represents a net reaction that fixes carbon and regenerates the autocatalyst. Unlike classical flux balance analysis, which optimizes continuous flux distributions at steady state, the integer-valued ILP formulation emphasizes pathway structure (which reactions are active) rather than flux magnitude.</p>
<p>Solutions are post-annotated with two feasibility measures. The first is cofactor accounting, split into ATP/ADP as an energy proxy and reduced redox cofactors (NAD(P)H, ubiquinone, Ferredoxin) as an electron proxy. The second is the standard Gibbs free energy of the net reaction computed via the eQuilibrator 3.0 component-contribution method at pH 7 and ionic strength 0.1 M using the eQuilibrator API 0.6.0:</p>
<p>$$
\Delta_r G&rsquo;^{\circ} = \sum \Delta_f G&rsquo;^{\circ}_{\text{products}} - \sum \Delta_f G&rsquo;^{\circ}_{\text{reactants}}
$$</p>
<h2 id="experimental-setup-queries-and-comparison-to-literature">Experimental Setup, Queries, and Comparison to Literature</h2>
<p>The seed pool for expansion contains 49 intermediates drawn from the Acetyl-CoA-Succinyl-CoA family (rTCA, DC/4-HB, 3-HP/4-HB, 3-HP bicycle), the synthetic CETCH cycle, and theoretical pathways proposed by Bar-Even et al., plus 20 helper molecules (cofactors, water, CO$_2$). Rule contexts were derived from <a href="https://en.wikipedia.org/wiki/KEGG">KEGG</a> enzyme entries. The <a href="https://en.wikipedia.org/wiki/Calvin_cycle">Calvin-Benson-Basham cycle</a> and the non-autocatalytic <a href="https://en.wikipedia.org/wiki/Wood%E2%80%93Ljungdahl_pathway">Wood-Ljungdahl</a> and reductive glycine pathways were excluded.</p>
<p>Expansion statistics (Table 4 in the paper):</p>
<table>
  <thead>
      <tr>
          <th>Expansion steps</th>
          <th>Molecules (vertices)</th>
          <th>Reactions (hyperedges)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1</td>
          <td>165</td>
          <td>220</td>
      </tr>
      <tr>
          <td>2</td>
          <td>318</td>
          <td>942</td>
      </tr>
      <tr>
          <td>5</td>
          <td>996</td>
          <td>29,266</td>
      </tr>
  </tbody>
</table>
<p>At one expansion step, flow queries recover only the input pathways with no recombinations. Two expansion steps produce sufficient novelty for recombined pathways while keeping ILP runtimes tractable. Five steps makes flow queries computationally prohibitive without adding biological insight. All reported analyses use the two-step CRN.</p>
<p>Three benchmark flow queries target autocatalytic pathways producing Acetyl-CoA, Malate, and Propionyl-CoA. Each query is run to return 1000 topologically distinct optimal solutions (under the ILP objective, solutions with equal length are equally optimal). All flow queries were solved with Gurobi 11.0.3 under an academic license on a consumer laptop (AMD Ryzen 7 5700U, 16 GB RAM, Windows 11). The full 1000-solution search took just under 18 hours.</p>
<h2 id="two-novel-autocatalytic-cycles-competitive-with-synthetic-pathways">Two Novel Autocatalytic Cycles Competitive with Synthetic Pathways</h2>
<p>The shortest-pathway queries yield two novel theoretical autocatalytic cycles: an 11-step Acetyl-CoA cycle and a 12-step Malate cycle. Comparison to natural, theoretical, and synthetic pathways on the four standard measures (steps, ATP units, cofactors, carbon units fixed per cycle):</p>
<table>
  <thead>
      <tr>
          <th>Pathway</th>
          <th>Status</th>
          <th>Steps</th>
          <th>ATP</th>
          <th>Cofactors</th>
          <th>C fixed</th>
          <th>ATP/C</th>
          <th>Cof/C</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Shortest Acetyl-CoA (this work)</td>
          <td>Theoretical</td>
          <td>11</td>
          <td>2</td>
          <td>5</td>
          <td>2</td>
          <td>1</td>
          <td>2.5</td>
      </tr>
      <tr>
          <td>Shortest Malate (this work)</td>
          <td>Theoretical</td>
          <td>12</td>
          <td>3</td>
          <td>8</td>
          <td>4</td>
          <td>0.75</td>
          <td>2</td>
      </tr>
      <tr>
          <td>CETCH</td>
          <td>Synthetic</td>
          <td>11</td>
          <td>1</td>
          <td>4</td>
          <td>2</td>
          <td>0.5</td>
          <td>2</td>
      </tr>
      <tr>
          <td>rGPS-MCG</td>
          <td>Synthetic</td>
          <td>18</td>
          <td>4</td>
          <td>6</td>
          <td>3</td>
          <td>1.33</td>
          <td>2</td>
      </tr>
      <tr>
          <td>C4-glyoxylate / alanine</td>
          <td>Theoretical</td>
          <td>9</td>
          <td>2</td>
          <td>2</td>
          <td>2</td>
          <td>1</td>
          <td>1</td>
      </tr>
      <tr>
          <td>rTCA</td>
          <td>Natural</td>
          <td>12</td>
          <td>4</td>
          <td>7</td>
          <td>4</td>
          <td>1</td>
          <td>1.75</td>
      </tr>
      <tr>
          <td>3HP/4HB</td>
          <td>Natural</td>
          <td>16</td>
          <td>4</td>
          <td>6</td>
          <td>2</td>
          <td>2</td>
          <td>3</td>
      </tr>
      <tr>
          <td>DC/4HB</td>
          <td>Natural</td>
          <td>14</td>
          <td>4</td>
          <td>7</td>
          <td>2</td>
          <td>2</td>
          <td>3.5</td>
      </tr>
      <tr>
          <td>3HP-bicycle</td>
          <td>Natural</td>
          <td>19</td>
          <td>3</td>
          <td>4</td>
          <td>2</td>
          <td>1.5</td>
          <td>2</td>
      </tr>
  </tbody>
</table>
<p>The 11-step Acetyl-CoA cycle matches CETCH in length and carbon units fixed while using one more ATP and one more redox cofactor. The Malate cycle is the same length as rTCA (12 steps) but uses one fewer ATP and one fewer cofactor while fixing the same four carbons.</p>
<p>Across the 1000-solution benchmarks (Table 2 of the paper), the Acetyl-CoA cycle is the most cofactor-efficient per step (0.69 cofactors/step; average 7.6 total), while Propionyl-CoA and Malate average 0.89 and 0.88 cofactors/step. Gibbs energies average $\Delta_r G&rsquo;^{\circ} = -150.66$ kJ/mol for Acetyl-CoA, $-165.82$ for Propionyl-CoA, and $-196.98$ for Malate, making the Malate query the most thermodynamically driven even after accounting for its higher cofactor count. Three specific Acetyl-CoA solutions inspected in detail share a common rTCA-like core with a glyoxylate shunt and differ mainly along the oxaloacetate-to-malyl-CoA branch; their totals range from $\Delta_r G&rsquo;^{\circ}_{total} = -80$ kJ/mol (the one-ATP solution) to $-168$ kJ/mol.</p>
<p>All solutions rely on <a href="https://en.wikipedia.org/wiki/Ferredoxin">Ferredoxin</a>-dependent carboxylating enzymes (pyruvate:ferredoxin oxidoreductase and 2-ketoglutarate:ferredoxin oxidoreductase), which have higher reduction potentials than NAD(P) but are oxygen-sensitive and would restrict wet-lab implementation to anaerobic conditions or engineered anaerobic strains.</p>
<h2 id="findings-limitations-and-future-directions">Findings, Limitations, and Future Directions</h2>
<p>The workflow produces pathway candidates whose efficiency approaches the best synthetic designs while running on a consumer laptop, and it generalizes to any chemical space that can be formalized by graph-transformation rules. Because the ILP returns many equally optimal solutions, a downstream filtering step can select candidates matching user criteria (oxygen sensitivity, specific cofactor preference, enzyme availability).</p>
<p>Acknowledged limitations include: the topology-only search ignores enzyme kinetics, so candidates that look thermodynamically favorable might be bottlenecked in practice; the carbon-count and CoA restrictions are necessary to bound combinatorial blow-up but also constrain the discoverable space; reliance on Ferredoxin complicates implementation; and enzyme availability varies across organisms, which matters for recombination-based designs. The authors point to kinetic modeling, cofactor-recycling system inclusion, and incorporation of metabolic reactions outside the canonical carbon fixation space as future directions.</p>
<p>The paper positions itself as a design-stage tool rather than an end-to-end in-vitro pipeline. The authors frame the contribution as idea generation that complements, not replaces, the subsequent experimental optimization (enzyme engineering, directed evolution) that has carried prior synthetic pathway work from theory to in-vitro success.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Seed molecules</td>
          <td>Curated Acetyl-CoA-Succinyl-CoA family + CETCH + Bar-Even theoretical</td>
          <td>49 metabolites + 20 cofactors</td>
          <td>Tables S1-S2</td>
      </tr>
      <tr>
          <td>Reaction rules</td>
          <td>KEGG enzyme entries, GML-encoded</td>
          <td>Rules listed in Figure S1</td>
          <td>Conservative context</td>
      </tr>
      <tr>
          <td>CRN (2-step expansion)</td>
          <td>Generated by MØD</td>
          <td>318 molecules, 942 reactions</td>
          <td>Primary analysis space</td>
      </tr>
      <tr>
          <td>Thermodynamic data</td>
          <td>eQuilibrator 3.0 component-contribution</td>
          <td>All molecules in space</td>
          <td>pH 7, ionic strength 0.1 M</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Graph-grammar rule expansion via MØD 1.0.0 with a 6-carbon backbone cap and at most one CoA moiety per molecule. ILP flow queries formulated with the edge-minimization objective in Equation (1) and the autocatalysis constraint in Equation (2). Natural pathway presence first verified via set operations on the CRN, then reconfirmed by constraining the ILP to pass through core intermediates. The pathway solution enumeration is structural: 1000 topologically distinct solutions per query at the optimal objective value.</p>
<h3 id="models">Models</h3>
<p>No machine-learning models. The pipeline is symbolic: graph transformations, hypergraph flow constraints, and component-contribution free energy estimates.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Acetyl-CoA</th>
          <th>Propionyl-CoA</th>
          <th>Malate</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Avg steps</td>
          <td>11</td>
          <td>15</td>
          <td>12</td>
      </tr>
      <tr>
          <td>Avg cofactors</td>
          <td>7.6</td>
          <td>13.3</td>
          <td>10.6</td>
      </tr>
      <tr>
          <td>Cofactors/step</td>
          <td>0.69</td>
          <td>0.89</td>
          <td>0.88</td>
      </tr>
      <tr>
          <td>Avg $\Delta_r G&rsquo;^{\circ}$ (kJ/mol)</td>
          <td>$-150.66$</td>
          <td>$-165.82$</td>
          <td>$-196.98$</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Gurobi 11.0.3 (academic license) on a consumer laptop: AMD Ryzen 7 5700U, 16 GB RAM, Windows 11. Full 1000-solution runs for the three benchmark queries completed in just under 18 hours total.</p>
<h3 id="artifacts-and-licensing">Artifacts and licensing</h3>
<ul>
<li>Code and output pathways: <a href="https://github.com/anne-susann/C_fixation_pathway_design">github.com/anne-susann/C_fixation_pathway_design</a> (MIT License)</li>
<li>MØD cheminformatics package (version 1.0.0)</li>
<li>eQuilibrator API version 0.6.0</li>
<li>Gurobi 11.0.3</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Abel, A.-S., Lauber, N., Andersen, J. L., Fagerberg, R., Merkle, D. E., &amp; Flamm, C. (2026). Computational approaches in chemical space exploration for carbon fixation pathways. <em>npj Systems Biology and Applications</em>, 12(1), 17. <a href="https://doi.org/10.1038/s41540-025-00641-8">https://doi.org/10.1038/s41540-025-00641-8</a></p>
<p><strong>Publication</strong>: npj Systems Biology and Applications, 2026</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{abel2026computational,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Computational approaches in chemical space exploration for carbon fixation pathways}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Abel, Anne-Susann and Lauber, Nino and Andersen, Jakob Lykke and Fagerberg, Rolf and Merkle, Daniel Elmar and Flamm, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{npj Systems Biology and Applications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{17--17}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2026}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Portfolio}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41540-025-00641-8}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ACSESS: Diverse Optimal Molecules in the SMU</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/chemical-space/acsess-diverse-optimal-molecules/</link><pubDate>Sat, 11 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/chemical-space/acsess-diverse-optimal-molecules/</guid><description>Rupakheti et al. extend ACSESS to find diverse molecules with favorable properties without exhaustive enumeration of chemical space.</description><content:encoded><![CDATA[<h2 id="diversity-biased-search-of-the-small-molecule-universe">Diversity-Biased Search of the Small Molecule Universe</h2>
<p>The small molecule universe (SMU), estimated at over $10^{60}$ synthetically feasible organic molecules under ~500 Da, is far too large for exhaustive enumeration and evaluation. This paper extends the ACSESS (Algorithm for Chemical Space Exploration with Stochastic Search) framework to simultaneously optimize molecular diversity and a targeted physical property. The key insight is that enforcing diversity at each iteration prevents the search from collapsing into local optima, a failure mode common in standard <a href="/notes/chemistry/molecular-design/generation/search-based/genetic-algorithms-molecule-generation-baselines/">genetic algorithms</a>.</p>
<h2 id="motivation-diversity-vs-fitness">Motivation: Diversity vs. Fitness</h2>
<p>Standard genetic algorithms optimize fitness effectively but sacrifice diversity: they converge to a few high-fitness regions while ignoring equally good solutions elsewhere. Exhaustive enumeration guarantees completeness but is computationally infeasible beyond ~20 heavy atoms. ACSESS bridges this gap by maintaining a maximally diverse library throughout the optimization process, ensuring coverage of multiple fitness peaks without needing to evaluate every candidate.</p>
<h2 id="the-property-optimizing-acsess-algorithm">The Property-Optimizing ACSESS Algorithm</h2>
<p>The method has four iterative steps:</p>
<ol>
<li><strong>Initialize</strong> a library (from a single molecule or a seed collection)</li>
<li><strong>Breed</strong> new compounds via mutations and crossovers</li>
<li><strong>Filter</strong> by property threshold, removing compounds below a cutoff</li>
<li><strong>Select</strong> a maximally diverse subset of qualifying structures</li>
</ol>
<p>The property threshold increases linearly with each iteration, starting low (to prevent population collapse) and gradually rising until the desired fitness level is reached. Diversity is enforced via either a maximin algorithm (maximizing nearest-neighbor distance) or cell-based partitioning (linear scaling for large libraries).</p>
<p>Molecules are represented in a 40-dimensional chemical space using Moreau-Broto autocorrelation descriptors. The descriptor encodes correlations of atomic properties as a function of topological distance (bond distance) $d$:</p>
<p>$$
AC(d, p) = \sum_{i \leq j} p_{i} , p_{j} , \delta(d_{ij} - d)
$$</p>
<p>where $p_{i}$ is an atomic property of atom $i$ and $d_{ij}$ is the shortest bond path between atoms $i$ and $j$. Five atomic properties are used: atomic number, Gasteiger-Marsili partial charge, atomic polarizability, topological steric index, and unity ($p_{i} = 1$ for all $i$, effectively counting atom pairs at each distance). Topological distance $d$ ranges from 0 to 7, yielding $5 \times 8 = 40$ descriptor components. Descriptors are mean-centered and normalized to unit variance before computing distances.</p>
<p>Chemical space distance is the Euclidean distance between descriptor vectors:</p>
<p>$$
D_{ij} = \sqrt{\sum_{k=1}^{N} (d_{ik} - d_{jk})^2}
$$</p>
<p>Library diversity is measured as the average nearest-neighbor distance:</p>
<p>$$
D_{\min} = \frac{1}{M} \sqrt{\sum_{i=1}^{M} \min_{i \neq j} (D_{ij}^2)}
$$</p>
<h2 id="validation-on-nkp-fitness-landscapes">Validation on NKp Fitness Landscapes</h2>
<p>The <a href="https://en.wikipedia.org/wiki/NK_model">NKp model</a> maps binary strings of length $N$ to fitness values in $[0, 1]$. The fitness of a string $g$ is:</p>
<p>$$
\Phi(g) = \frac{1}{N} \sum_{i=1}^{N} \varphi_{i}(g)
$$</p>
<p>where each $\varphi_{i} \in [0, 1]$ is a randomly drawn fitness contribution. Ruggedness is controlled by $K$ (the number of inter-bit associations per position) and $p$ (fitness contribution weights). Using $N = 19$, $K = 9$, $p = 0.9$ (524,288 total strings, comparable to GDB-9 size), the global maximum was ~0.3. Both ACSESS and SGA were initialized with the same diverse subset and ran for 30 iterations across 10 independent runs:</p>
<ul>
<li>ACSESS found the global optimum in 100% of runs (vs. 60% for SGA)</li>
<li>ACSESS discovered ~15 of 19 globally optimal strings on average (vs. ~3 for SGA)</li>
<li>ACSESS solutions had higher average fitness than SGA solutions</li>
</ul>
<h2 id="validation-on-gdb-9-dipole-moments">Validation on GDB-9 Dipole Moments</h2>
<p>The method was tested on all ~300,000 molecules in GDB-9 (up to 9 heavy atoms; allowed atom types: C, N, O, S, Cl). For each molecule, the Boltzmann-averaged dipole moment was computed at the <a href="https://en.wikipedia.org/wiki/Austin_Model_1">AM1 level</a> (Gaussian 09):</p>
<p>$$
D = \frac{\sum_{i \in C} \mu_{i} , e^{-\beta E_{i}}}{\sum_{i \in C} e^{-\beta E_{i}}}
$$</p>
<p>where $\mu_{i}$ and $E_{i}$ are the dipole moment and internal energy of conformation $i$, and $\beta = 1 / (k_{\text{B}} T)$ at $T = 298$ K. Conformations (including stereoisomers) were generated using OpenEye OMEGA. The target was molecules with dipole moments $\geq 5.5$ D (the 90th percentile). ACSESS first generated a maximally diverse seed set, then ran 60 iterations of fitness-biased optimization. All methods were initialized from the same diverse seed and compared over multiple runs.</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Dipole Moment (D)</th>
          <th>Diversity (eq. 4)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GA-Roulette</td>
          <td>5.8 $\pm$ 0.03</td>
          <td>6.5 $\pm$ 0.7</td>
      </tr>
      <tr>
          <td>GA-Tournament</td>
          <td>6.4 $\pm$ 0.08</td>
          <td>3.5 $\pm$ 0.7</td>
      </tr>
      <tr>
          <td>GA-Elitism</td>
          <td>6.74 $\pm$ 0.08</td>
          <td>5.4 $\pm$ 0.4</td>
      </tr>
      <tr>
          <td><strong>ACSESS</strong></td>
          <td><strong>6.05 $\pm$ 0.05</strong></td>
          <td><strong>9.7 $\pm$ 0.6</strong></td>
      </tr>
  </tbody>
</table>
<p>ACSESS achieved nearly double the diversity of the best SGA variant while maintaining competitive fitness. Its diversity (~9.7) approached the diversity of the full enumerated high-fitness subset of GDB-9 (~12). <a href="https://en.wikipedia.org/wiki/Self-organizing_map">Self-organizing map</a> (SOM) visualizations confirmed that ACSESS covered high-activity regions that SGAs missed entirely.</p>
<p>Only ~30,000 fitness evaluations were needed to locate diverse optimal regions in the 300,000-molecule space, a 10x efficiency gain over exhaustive enumeration.</p>
<h2 id="limitations">Limitations</h2>
<ul>
<li>Tested only on relatively small chemical spaces (GDB-9 with ~300k molecules and 19-bit NKp with ~500k strings); scaling to the full SMU ($10^{60}$) remains a research direction</li>
<li>Property evaluation (AM1 dipole moments with conformer generation) is the computational bottleneck, not the ACSESS algorithm itself</li>
<li>The 40-dimensional autocorrelation descriptor space may not capture all relevant structural features for every optimization target</li>
<li>Comparison is limited to simple genetic algorithms; more sophisticated evolutionary strategies were not benchmarked</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>The ACSESS algorithm relies on proprietary software, limiting full reproducibility.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://doi.org/10.1038/sdata.2014.22">GDB-9</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>Publicly available enumerated chemical universe (~300k molecules)</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Code</strong>: No public source code was released. The implementation depends on OpenEye OEChem TK (molecule generation), OpenEye MolProp TK (filtering), and OpenEye OMEGA TK (conformer generation), all of which require commercial licenses.</li>
<li><strong>Property calculations</strong>: Dipole moments were computed at the AM1 level using Gaussian 09, also commercial software.</li>
<li><strong>NKp landscape</strong>: Fully specified by parameters ($N = 19$, $K = 9$, $p = 0.9$) and standard NKp model equations, making this portion independently reproducible.</li>
<li><strong>Hardware</strong>: No specific compute requirements reported.</li>
<li><strong>Reproducibility status</strong>: Partially Reproducible. The algorithm is well-described and the NKp experiments could be reimplemented, but the molecular experiments require OpenEye and Gaussian 09 licenses, and no reference implementation was released.</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<ul>
<li><strong>Journal</strong>: Journal of Chemical Information and Modeling, Vol. 55, No. 3, pp. 529-537</li>
<li><strong>Published</strong>: January 16, 2015</li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rupakheti2015strategy,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Strategy To Discover Diverse Optimal Molecules in the Small Molecule Universe}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Rupakheti, Chetan and Virshup, Aaron M. and Yang, Weitao and Beratan, David N.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{55}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{3}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{529--537}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2015}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/ci500749q}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>T5: Exploring Transfer Learning Limits</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/</guid><description>Raffel et al. systematically study transfer learning for NLP with a text-to-text framework, ablating architectures, objectives, data, and multi-task mixing.</description><content:encoded><![CDATA[<h2 id="a-systematic-study-of-nlp-transfer-learning">A systematic study of NLP transfer learning</h2>
<p>This is a <strong>systematization paper</strong> that provides a comprehensive empirical survey of transfer learning techniques for NLP. Rather than proposing a single new method, T5 introduces a unified text-to-text framework and uses it as a testbed to systematically compare pre-training objectives, architectures, unlabeled data sources, transfer approaches, and multi-task mixing strategies. The scale of the ablation study (covering dozens of configurations) and the release of C4, pre-trained models, and code make it both a reference guide and a resource.</p>
<h2 id="unifying-nlp-tasks-as-text-to-text">Unifying NLP tasks as text-to-text</h2>
<p>The core design decision is to cast every NLP task as a text-to-text problem: both the input and output are text strings, with a task-specific prefix. Classification, regression, summarization, translation, and question answering all use the same model, loss function (cross-entropy on output tokens), and decoding procedure. This simplicity enables fair comparison across tasks and training strategies.</p>
<p>The model architecture is a standard encoder-decoder Transformer. The paper finds that this form outperforms decoder-only (language model) and encoder-only (BERT-style) variants in the text-to-text setting, while having similar computational cost to decoder-only models despite twice the parameters (the encoder processes the input only once, then the decoder attends to it).</p>
<h2 id="multi-task-mixing-strategies-and-findings">Multi-task mixing: strategies and findings</h2>
<p>The most thesis-relevant contribution is the systematic ablation of multi-task mixing strategies (Section 3.5.2). When training on multiple tasks simultaneously (which in the text-to-text framework simply means mixing data from different sources), the central question is how to set the proportion of data from each task.</p>
<h3 id="three-mixing-strategies">Three mixing strategies</h3>
<p><strong>Examples-proportional mixing.</strong> Sample in proportion to each dataset&rsquo;s size, with an artificial cap $K$ on the maximum dataset size. Without the cap, the unsupervised pre-training data (orders of magnitude larger) would dominate all batches. The mixing rate for task $m$ is:</p>
<p>$$
r_{m} = \frac{\min(e_{m}, K)}{\sum_{n} \min(e_{n}, K)}
$$</p>
<p>where $e_{m}$ is the number of examples in task $m$&rsquo;s dataset.</p>
<p><strong>Temperature-scaled mixing.</strong> Raise each mixing rate $r_{m}$ to the power $1/T$ and renormalize. At $T=1$ this equals examples-proportional mixing; as $T$ increases, proportions approach equal mixing. Uses a large cap $K = 2^{21}$.</p>
<p><strong>Equal mixing.</strong> Sample uniformly from all tasks. Included as a negative reference: the model overfits on low-resource tasks and underfits on high-resource tasks.</p>
<h3 id="results">Results</h3>
<table>
  <thead>
      <tr>
          <th>Mixing strategy</th>
          <th>GLUE</th>
          <th>CNN/DM</th>
          <th>SQuAD</th>
          <th>SuperGLUE</th>
          <th>EnDe</th>
          <th>EnFr</th>
          <th>EnRo</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Baseline (pre-train/fine-tune)</td>
          <td>83.28</td>
          <td>19.24</td>
          <td>80.88</td>
          <td>71.36</td>
          <td>26.98</td>
          <td>39.82</td>
          <td>27.65</td>
      </tr>
      <tr>
          <td>Equal</td>
          <td>76.13</td>
          <td>19.02</td>
          <td>76.51</td>
          <td>63.37</td>
          <td>23.89</td>
          <td>34.31</td>
          <td>26.78</td>
      </tr>
      <tr>
          <td>Examples-proportional, $K=2^{18}$</td>
          <td>81.67</td>
          <td>19.07</td>
          <td>78.17</td>
          <td>67.94</td>
          <td>24.57</td>
          <td>35.19</td>
          <td>27.39</td>
      </tr>
      <tr>
          <td>Examples-proportional, $K=2^{19}$</td>
          <td>81.42</td>
          <td>19.24</td>
          <td>79.78</td>
          <td>67.30</td>
          <td>25.21</td>
          <td>36.30</td>
          <td>27.76</td>
      </tr>
      <tr>
          <td>Temperature-scaled, $T=2$</td>
          <td>81.90</td>
          <td>19.28</td>
          <td>79.42</td>
          <td>69.92</td>
          <td>25.42</td>
          <td>36.72</td>
          <td>27.20</td>
      </tr>
  </tbody>
</table>
<p><strong>Key findings on mixing:</strong></p>
<ol>
<li>
<p><strong>Multi-task training underperforms pre-train-then-fine-tune on most tasks.</strong> No mixing strategy matches the baseline of unsupervised pre-training followed by task-specific fine-tuning.</p>
</li>
<li>
<p><strong>Equal mixing is worst.</strong> It dramatically degrades performance, confirming that proportions matter.</p>
</li>
<li>
<p><strong>There exists a task-specific sweet spot for the cap $K$.</strong> Most tasks have an optimal $K$ value; larger or smaller values hurt. The exception is very high-resource tasks (WMT English-French) that always benefit from higher mixing proportions.</p>
</li>
<li>
<p><strong>Temperature scaling at $T=2$ provides the best single compromise.</strong> It achieves reasonable performance across all tasks without requiring per-task tuning of $K$.</p>
</li>
<li>
<p><strong>Multi-task pre-training followed by fine-tuning closes the gap.</strong> When multi-task training is used as pre-training (not as the final training stage), followed by task-specific fine-tuning, performance becomes comparable to unsupervised pre-training alone. This suggests that multi-task exposure during pre-training provides useful early signal without the negative effects of forcing a single model to perform all tasks simultaneously.</p>
</li>
<li>
<p><strong>&ldquo;Leave-one-out&rdquo; training works.</strong> Pre-training on a multi-task mixture that excludes a target task, then fine-tuning on it, produces only slightly worse results. This indicates that multi-task pre-training builds general capabilities that transfer to unseen tasks without dramatic task interference.</p>
</li>
</ol>
<h2 id="data-repetition-degrades-performance">Data repetition degrades performance</h2>
<p>The paper also systematically tests the effect of pre-training data set size by truncating C4 and training over repeated data:</p>
<table>
  <thead>
      <tr>
          <th>Unique tokens</th>
          <th>Repeats</th>
          <th>GLUE</th>
          <th>SQuAD</th>
          <th>SuperGLUE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Full dataset</td>
          <td>0</td>
          <td>83.28</td>
          <td>80.88</td>
          <td>71.36</td>
      </tr>
      <tr>
          <td>$2^{29}$</td>
          <td>64</td>
          <td>82.87</td>
          <td>80.97</td>
          <td>72.03</td>
      </tr>
      <tr>
          <td>$2^{27}$</td>
          <td>256</td>
          <td>82.62</td>
          <td>79.78</td>
          <td>69.97</td>
      </tr>
      <tr>
          <td>$2^{25}$</td>
          <td>1,024</td>
          <td>79.55</td>
          <td>76.27</td>
          <td>64.76</td>
      </tr>
      <tr>
          <td>$2^{23}$</td>
          <td>4,096</td>
          <td>76.34</td>
          <td>70.92</td>
          <td>59.29</td>
      </tr>
  </tbody>
</table>
<p>Performance degrades as data shrinks, with 64 repeats showing limited effects but 1,024+ repeats causing significant degradation. Training loss curves confirm memorization at high repetition counts. The paper recommends using large, diverse pre-training datasets whenever possible.</p>
<h2 id="scaling-and-final-configuration">Scaling and final configuration</h2>
<p>The paper compares scaling strategies: more data, larger models, and ensembles. Training a larger model for fewer steps generally outperforms training a smaller model on more data. Ensembles of independently pre-trained and fine-tuned models provide orthogonal gains.</p>
<p>The final T5-11B model combines the best choices from all ablations: encoder-decoder architecture, span corruption objective, C4 pre-training data, multi-task pre-training followed by fine-tuning, and scaling to 11B parameters trained on over 1 trillion tokens. It achieves state-of-the-art results on GLUE (90.3 average), SuperGLUE (88.9, near human performance of 89.8), SQuAD, and CNN/Daily Mail. It does not achieve state-of-the-art on WMT translation tasks, where methods using backtranslation and cross-lingual pre-training retain the lead.</p>
<h2 id="implications-and-limitations">Implications and limitations</h2>
<p>The T5 paper&rsquo;s multi-task mixing findings are its most enduring contribution beyond the model itself. The core lessons: proportions matter enormously (equal mixing fails), examples-proportional mixing with a cap is a reasonable default, temperature scaling provides a single-knob alternative, and multi-task pre-training followed by fine-tuning can match pure unsupervised pre-training.</p>
<p><strong>Limitations:</strong></p>
<ul>
<li>All ablations use the same encoder-decoder architecture. Findings may not transfer to decoder-only models that dominate current practice.</li>
<li>The multi-task mixing experiments treat each task as a separate &ldquo;domain.&rdquo; Interactions between similar tasks (e.g., multiple classification tasks) are not isolated.</li>
<li>The paper does not provide a principled method for choosing $K$ or $T$; both require empirical search.</li>
<li>C4 has known quality issues (templated text, noisy content) that have been addressed in later datasets.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Status: Highly Reproducible.</strong> Code, pre-trained models, and the C4 dataset are all publicly released.</p>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>C4 (Colossal Clean Crawled Corpus)</td>
          <td>~750 GB</td>
          <td>Heuristically cleaned Common Crawl</td>
      </tr>
      <tr>
          <td>Downstream</td>
          <td>GLUE, SuperGLUE, SQuAD, CNN/DM, WMT (EnDe, EnFr, EnRo)</td>
          <td>Standard splits</td>
          <td>Text-to-text format</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<p>Encoder-decoder Transformer. Sizes: Base (220M), Small (60M), Large (770M), 3B, 11B. Baseline uses Base size. SentencePiece vocabulary with 32K tokens. Pre-trained for $2^{19}$ steps, fine-tuned for $2^{18}$ steps on individual tasks.</p>
<h3 id="algorithms">Algorithms</h3>
<p>Multi-task mixing: examples-proportional with cap $K \in {2^{16}, \ldots, 2^{21}}$, temperature-scaled with $T \in {2, 4, 8}$, and equal mixing. Unsupervised objective: span corruption (mean span length 3, 15% corruption rate). Training with Adafactor optimizer, inverse square root learning rate schedule.</p>
<h3 id="hardware">Hardware</h3>
<p>All models trained using Mesh TensorFlow on TPU slices. T5-11B pre-trained for 1M steps with batch size $2^{11}$ sequences of length 512 (~1 trillion tokens total). Exact TPU pod configurations per experiment not detailed.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/google-research/text-to-text-transfer-transformer">T5 Code</a></td>
          <td>Code</td>
          <td>Apache 2.0</td>
          <td>Official TensorFlow implementation (JAX successor: T5X)</td>
      </tr>
      <tr>
          <td><a href="https://github.com/google-research/text-to-text-transfer-transformer#released-model-checkpoints">T5 Models</a></td>
          <td>Model</td>
          <td>Apache 2.0</td>
          <td>Pre-trained checkpoints (Small through 11B)</td>
      </tr>
      <tr>
          <td><a href="https://www.tensorflow.org/datasets/catalog/c4">C4 Dataset</a></td>
          <td>Dataset</td>
          <td>-</td>
          <td>~750 GB cleaned Common Crawl, via TensorFlow Datasets</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{raffel2020exploring,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Raffel, Colin and Shazeer, Noam and Roberts, Adam and Lee, Katherine and Narang, Sharan and Matena, Michael and Zhou, Yanqi and Li, Wei and Liu, Peter J.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Machine Learning Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{21}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{140}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1--67}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SlimPajama-DC: Data Combinations for LLM Training</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/slimpajama-dc-data-combinations/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/slimpajama-dc-data-combinations/</guid><description>Shen et al. study how global deduplication and domain combinations in SlimPajama affect LLM training, finding diversity after dedup is key.</description><content:encoded><![CDATA[<h2 id="an-empirical-study-of-data-domain-combinations">An empirical study of data domain combinations</h2>
<p>This is a <strong>discovery paper</strong> that empirically investigates how different combinations and proportions of data domains affect language model pretraining. Using the SlimPajama dataset (a globally deduplicated, 627B token refinement of RedPajama), the study trains seven 1.3B model configurations with varying domain mixtures to identify which combinations and deduplication strategies produce the best downstream performance.</p>
<h2 id="why-data-combination-strategy-matters">Why data combination strategy matters</h2>
<p>Multi-source pretraining datasets combine data from web crawls, code repositories, books, academic papers, and other sources. Two underexplored questions drive this work: (1) Does deduplication within each source (local) versus across all sources (global) meaningfully affect model quality? (2) When sources are thoroughly deduplicated, how does the combination and proportion of domains affect downstream performance? Most open-source LLM training datasets (RedPajama, The Pile) perform only local deduplication, leaving cross-source redundancy unaddressed.</p>
<h2 id="global-deduplication-and-the-slimpajama-dataset">Global deduplication and the SlimPajama dataset</h2>
<p>SlimPajama applies global MinHashLSH deduplication (Jaccard similarity threshold 0.8, 13-gram signatures) across all seven data sources simultaneously. This reduces RedPajama&rsquo;s 1.2T tokens to 627B tokens, a roughly 48% reduction. The heaviest deduplication hits CommonCrawl and GitHub, which had the most cross-source overlap.</p>
<p>The key processing steps:</p>
<ol>
<li><strong>Low-length document filtering</strong>: Remove documents below a minimum length threshold.</li>
<li><strong>Global deduplication</strong>: MinHashLSH across all sources simultaneously, requiring 64 CPU cores and 1.4TB peak memory. This removes both within-source and between-source duplicates.</li>
</ol>
<p>The resulting dataset composition:</p>
<table>
  <thead>
      <tr>
          <th>Source</th>
          <th>SlimPajama</th>
          <th>RedPajama</th>
          <th>LLaMA 1</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CommonCrawl</td>
          <td>52.2% (333B)</td>
          <td>72.6% (878B)</td>
          <td>67.0%</td>
      </tr>
      <tr>
          <td>C4</td>
          <td>26.7% (170B)</td>
          <td>14.4% (175B)</td>
          <td>15.0%</td>
      </tr>
      <tr>
          <td>GitHub</td>
          <td>5.2% (33B)</td>
          <td>4.9% (59B)</td>
          <td>4.5%</td>
      </tr>
      <tr>
          <td>Books</td>
          <td>4.2% (27B)</td>
          <td>2.1% (26B)</td>
          <td>4.5%</td>
      </tr>
      <tr>
          <td>ArXiv</td>
          <td>4.6% (29B)</td>
          <td>2.3% (28B)</td>
          <td>2.5%</td>
      </tr>
      <tr>
          <td>Wikipedia</td>
          <td>3.8% (24B)</td>
          <td>2.0% (24B)</td>
          <td>4.5%</td>
      </tr>
      <tr>
          <td>StackExchange</td>
          <td>3.3% (21B)</td>
          <td>1.7% (20B)</td>
          <td>2.0%</td>
      </tr>
  </tbody>
</table>
<h2 id="seven-domain-combination-configurations">Seven domain combination configurations</h2>
<p>All configurations train 1.3B parameter models on 330B tokens with identical architecture and hyperparameters. The configurations systematically vary domain diversity:</p>
<ul>
<li><strong>DC-1</strong>: CommonCrawl only (single source)</li>
<li><strong>DC-2</strong>: CommonCrawl + C4 (two web sources)</li>
<li><strong>DC-3</strong>: CommonCrawl + C4 with adjusted proportions</li>
<li><strong>DC-4</strong>: Wikipedia + Books + GitHub + ArXiv + StackExchange (no web crawl)</li>
<li><strong>DC-5</strong>: CommonCrawl + C4 + Wikipedia + Books (four sources, no code/academic)</li>
<li><strong>DC-6</strong>: All seven SlimPajama sources (maximum diversity)</li>
<li><strong>DC-7</strong>: RefinedWeb CommonCrawl (external single-source baseline)</li>
</ul>
<p>The experimental design probes: incremental diversity (DC-1 to DC-2 to DC-5 to DC-6), proportion sensitivity (DC-2 vs DC-3), source importance (DC-3 vs DC-4), and specialization vs generalization (individual vs combined).</p>
<h2 id="diversity-after-global-deduplication-drives-performance">Diversity after global deduplication drives performance</h2>
<h3 id="hugging-face-leaderboard-results">Hugging Face leaderboard results</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Average</th>
          <th>ARC</th>
          <th>HellaSwag</th>
          <th>MMLU</th>
          <th>TruthfulQA</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RedPajama-1.3B</td>
          <td>38.0</td>
          <td>37.2</td>
          <td>55.8</td>
          <td>24.9</td>
          <td>34.3</td>
      </tr>
      <tr>
          <td>DC-1 (CC only)</td>
          <td>38.5</td>
          <td>36.3</td>
          <td>56.0</td>
          <td>27.0</td>
          <td>34.8</td>
      </tr>
      <tr>
          <td>DC-4 (no web)</td>
          <td>37.6</td>
          <td>33.4</td>
          <td>53.3</td>
          <td>26.0</td>
          <td>37.6</td>
      </tr>
      <tr>
          <td>DC-6 (all sources)</td>
          <td>40.0</td>
          <td>33.7</td>
          <td>61.0</td>
          <td>26.9</td>
          <td>38.4</td>
      </tr>
      <tr>
          <td>DC-7 (RefinedWeb)</td>
          <td>41.0</td>
          <td>35.1</td>
          <td>64.7</td>
          <td>26.2</td>
          <td>37.9</td>
      </tr>
  </tbody>
</table>
<p><strong>Key patterns:</strong></p>
<ol>
<li>
<p><strong>More domain diversity improves average performance.</strong> The progression DC-1 (38.5) to DC-2 (38.4) to DC-5 (38.6) to DC-6 (40.0) shows that adding domains consistently lifts average accuracy once global deduplication has removed cross-source redundancy.</p>
</li>
<li>
<p><strong>Global deduplication enables clean combination.</strong> All SlimPajama configurations except DC-4 outperform RedPajama-1.3B (38.0), which uses local deduplication only. The elimination of cross-source overlap means adding sources contributes genuinely new information.</p>
</li>
<li>
<p><strong>Removing web crawl data hurts.</strong> DC-4 (no CommonCrawl/C4) scores lowest (37.6), demonstrating that web text provides essential breadth even when specialized sources are included.</p>
</li>
<li>
<p><strong>Individual domains excel at specific tasks.</strong> DC-1 (CC only) achieves the highest ARC and MMLU scores. DC-4 leads on Winogrande. DC-5 leads on WSC273. No single combination dominates all tasks, reinforcing that diversity trades specialization for generalization.</p>
</li>
<li>
<p><strong>Findings transfer to 7B scale.</strong> The best 1.3B configuration insights were applied to a 7B model trained with large batch sizes, achieving 63.4 average accuracy across the extended benchmark suite.</p>
</li>
</ol>
<h3 id="training-loss-patterns">Training loss patterns</h3>
<p>DC-6 (all sources) achieves the lowest training loss among SlimPajama configurations, consistent with the downstream results. DC-4 (no web crawl) shows the highest training loss, confirming that the large, diverse web crawl data is the most important single component.</p>
<h2 id="implications-and-limitations">Implications and limitations</h2>
<p>The central finding is that <strong>diversity matters most after deduplication</strong>. When cross-source redundancy is removed, each additional source contributes genuinely new signal. Without global deduplication, adding sources may just increase redundancy without proportional benefit.</p>
<p><strong>Limitations:</strong></p>
<ul>
<li>Only seven fixed configurations are tested. No systematic search over continuous mixture proportions (contrast with <a href="/notes/natural-language-processing/language-models/doremi-data-mixture-optimization/">DoReMi</a> or <a href="/notes/natural-language-processing/language-models/data-mixing-laws-pretraining/">Data Mixing Laws</a>).</li>
<li>The configurations are not independent: DC-6 includes all sources from DC-1 through DC-5, making it difficult to isolate the contribution of any single addition.</li>
<li>Only 1.3B and 7B scales tested. Whether the diversity benefit continues scaling is unverified.</li>
<li>English-only. Cross-lingual diversity effects are not studied.</li>
<li>The paper is a technical report without formal peer review.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Status: Highly Reproducible.</strong> All 1.3B models and datasets are publicly released under MIT license on HuggingFace.</p>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>SlimPajama</td>
          <td>627B tokens</td>
          <td>Globally deduplicated from 1.2T RedPajama</td>
      </tr>
      <tr>
          <td>Training</td>
          <td>RefinedWeb</td>
          <td>600B tokens</td>
          <td>External CC-only baseline</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>HF Leaderboard (ARC, HellaSwag, MMLU, TruthfulQA)</td>
          <td>Standard</td>
          <td>4 benchmarks</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Extended suite</td>
          <td>12 additional benchmarks</td>
          <td>Zero and few-shot</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<p>1.3B parameter Cerebras-GPT architecture with ALiBi positional encoding and SwiGLU activation. All configurations trained on 330B tokens. 7B model trained with large batch-size (LBS) strategy on Cerebras 16x CS-2 cluster (80 PFLOP/s in bf16).</p>
<h3 id="hardware">Hardware</h3>
<p>Cerebras 16x CS-2 cluster, 80 PFLOP/s in bf16 mixed precision.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://huggingface.co/MBZUAI-LLM/SlimPajama-DC">SlimPajama-DC Models</a></td>
          <td>Model</td>
          <td>MIT</td>
          <td>All 1.3B DC configurations (select via revision)</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/datasets/MBZUAI-LLM/SlimPajama-627B-DC">SlimPajama-627B-DC Dataset</a></td>
          <td>Dataset</td>
          <td>-</td>
          <td>Source-split version of SlimPajama-627B</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{shen2023slimpajamadc,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{SlimPajama-DC: Understanding Data Combinations for LLM Training}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Shen, Zhiqiang and Tao, Tianhua and Ma, Liqun and Neiswanger, Willie and Liu, Zhengzhong and Wang, Hongyi and Tan, Bowen and Hestness, Joel and Vassilieva, Natalia and Soboleva, Daria and Xing, Eric}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2309.10818}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Scaling Data-Constrained Language Models</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/scaling-data-constrained-language-models/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/scaling-data-constrained-language-models/</guid><description>Muennighoff et al. extend Chinchilla scaling laws to repeated data, finding up to 4 epochs cause negligible loss and 16 epochs mark diminishing returns.</description><content:encoded><![CDATA[<h2 id="an-empirical-study-of-scaling-under-data-constraints">An empirical study of scaling under data constraints</h2>
<p>This is a <strong>discovery paper</strong> that systematically investigates what happens when language models are trained for multiple epochs on repeated data. It extends the Chinchilla scaling laws to the data-constrained regime by proposing a new scaling formula that accounts for the diminishing value of repeated tokens, validated across 400+ training runs ranging from 10M to 9B parameters and up to 1500 epochs.</p>
<h2 id="running-out-of-unique-training-data">Running out of unique training data</h2>
<p>The Chinchilla scaling laws assume unlimited unique data: for a given compute budget, there exists an optimal balance of model parameters and training tokens. But extrapolating these laws to larger models implies data requirements that exceed what is available. Villalobos et al. estimated that high-quality English text would be exhausted by 2024 under Chinchilla-optimal scaling. Most prior large language models trained for a single epoch, and some work explicitly warned against data reuse. The Galactica models (trained for 4.25 epochs) showed that multi-epoch training could work, but no systematic study had quantified the tradeoff between repeated data and fresh data, or how to allocate compute optimally when data is finite.</p>
<h2 id="effective-data-with-exponential-decay-for-repetition">Effective data with exponential decay for repetition</h2>
<p>The paper generalizes the Chinchilla scaling law by replacing raw token count $D$ with an effective data term $D&rsquo;$ that accounts for the diminishing value of repeated tokens:</p>
<p>$$
L(N, D) = \frac{A}{N&rsquo;^{\alpha}} + \frac{B}{D&rsquo;^{\beta}} + E
$$</p>
<p>where the effective data is:</p>
<p>$$
D&rsquo; = U_{D} + U_{D} R_{D}^{<em>} \left(1 - e^{-R_{D}/R_{D}^{</em>}}\right)
$$</p>
<p>Here $U_{D}$ is the number of unique tokens, $R_{D}$ is the number of repetitions (epochs minus 1), and $R_{D}^{<em>}$ is a learned constant representing the &ldquo;half-life&rdquo; of data repetition. When $R_{D} = 0$ (single epoch), $D&rsquo; = U_{D} = D$ and the formula reduces to standard Chinchilla. When $R_{D} \ll R_{D}^{</em>}$, repeated data is worth almost the same as fresh data. As $R_{D}$ grows large, the value of repeated tokens decays to zero, and $D&rsquo;$ saturates at $U_{D}(1 + R_{D}^{<em>})$, meaning no amount of repetition can substitute for more than $R_{D}^{</em>}$ epochs&rsquo; worth of fresh data.</p>
<p>A symmetric formula handles excess parameters:</p>
<p>$$
N&rsquo; = U_{N} + U_{N} R_{N}^{<em>} \left(1 - e^{-R_{N}/R_{N}^{</em>}}\right)
$$</p>
<p>where $U_{N}$ is the compute-optimal parameter count for $U_{D}$ unique tokens and $R_{N}$ measures how much the model exceeds that count. The fitted values are $R_{D}^{<em>} \approx 15.0$ (data repetition half-life at ~16 epochs) and $R_{N}^{</em>} \approx 5.3$ (excess parameters decay faster than repeated data).</p>
<h2 id="experiments-across-400-models">Experiments across 400+ models</h2>
<p><strong>Scale.</strong> Models from 10M to 9B parameters, trained for up to 1500 epochs. Three experimental protocols: fixed unique data (100M, 400M, 1.5B tokens), fixed FLOPs, and parametric fitting across all runs. Training on C4 (English web text) with GPT-2 architecture decoder-only transformers.</p>
<h3 id="resource-allocation-epochs-scale-faster-than-parameters">Resource allocation: epochs scale faster than parameters</h3>
<p>With fixed unique data, results show that more than 50% loss reduction is possible by training beyond one epoch and increasing model size beyond the single-epoch optimum. The data-constrained efficient frontier recommends allocating most additional compute to more epochs rather than more parameters, because excess parameters decay faster ($R_{N}^{<em>} &lt; R_{D}^{</em>}$). This contrasts with Chinchilla, which recommends scaling both equally.</p>
<p>A concrete validation: training the data-constrained compute-optimal model for $9.3 \times 10^{21}$ FLOPs with 25B unique tokens, the recommended allocation (27% fewer parameters, more epochs) achieves better loss and downstream performance than the Chinchilla-optimal allocation.</p>
<h3 id="resource-return-the-4-epoch-safe-zone-and-16-epoch-half-life">Resource return: the 4-epoch safe zone and 16-epoch half-life</h3>
<table>
  <thead>
      <tr>
          <th>Epochs</th>
          <th>Loss impact</th>
          <th>Downstream impact</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1 (baseline)</td>
          <td>Optimal</td>
          <td>Optimal</td>
      </tr>
      <tr>
          <td>Up to 4</td>
          <td>Negligible (+0.5% loss)</td>
          <td>No significant difference</td>
      </tr>
      <tr>
          <td>~16 ($R_{D}^{*}$)</td>
          <td>Diminishing returns begin sharply</td>
          <td>Measurable degradation</td>
      </tr>
      <tr>
          <td>Beyond 16</td>
          <td>Returns decay to near zero</td>
          <td>Significant degradation</td>
      </tr>
      <tr>
          <td>Extreme (44+)</td>
          <td>Training can diverge</td>
          <td>Failure</td>
      </tr>
  </tbody>
</table>
<p>The 8.7B parameter model trained for 4 epochs ($D_{C} = 44$B unique tokens) finishes with only 0.5% higher validation loss than the single-epoch model ($D_{C} = 178$B unique tokens). Beyond 16 epochs, each repeated token retains only $1 - 1/e \approx 63%$ of the value of a fresh token, meaning roughly 37% of value is lost per repetition cycle at the half-life point.</p>
<h3 id="complementary-strategies-code-augmentation-and-filtering">Complementary strategies: code augmentation and filtering</h3>
<p>When data is limited, two strategies can extend the effective dataset:</p>
<p><strong>Code augmentation.</strong> Mixing Python code from The Stack with natural language data. Up to 50% code (42B tokens) shows no degradation on natural language benchmarks, effectively providing a 2x increase in useful training data. Some tasks (WebNLG generation, bAbI reasoning) actually improve with code, possibly because code trains long-range state-tracking capabilities.</p>
<p><strong>Filtering relaxation.</strong> Perplexity filtering (keeping the 25% lowest-perplexity samples) is effective on noisy datasets, but deduplication filtering does not improve downstream performance (though it may reduce memorization). The recommendation: reserve aggressive filtering for noisy data sources; for clean datasets, more data through reduced filtering is better than less data through strict filtering.</p>
<p><strong>Combined strategy</strong>: doubling available data with code and then repeating for 4 epochs yields 8x more training tokens with performance expected to match 8x more unique data.</p>
<h2 id="key-findings-and-limitations">Key findings and limitations</h2>
<p><strong>Key findings:</strong></p>
<ul>
<li>Multi-epoch training is beneficial, not harmful, up to moderate repetition counts.</li>
<li>The data-constrained scaling law accurately predicts loss under repetition using an exponential decay formulation.</li>
<li>Compute should be allocated to epochs faster than parameters when data is constrained.</li>
<li>Code augmentation and selective filtering extend effective data without quality degradation.</li>
</ul>
<p><strong>Limitations:</strong></p>
<ul>
<li>All experiments use the GPT-2 transformer architecture; applicability to other architectures or modalities is untested.</li>
<li>Only the entire dataset is repeated uniformly. Selectively repeating subsets (e.g., high-value data for more epochs) is not modeled.</li>
<li>Hyperparameter sensitivity (learning rate, dropout) to epoch count is unexplored. Higher learning rates may cause earlier onset of diminishing returns.</li>
<li>Focused on English text. Cross-lingual augmentation effects are not studied.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Status: Highly Reproducible.</strong> Code, models, datasets, and hyperparameters are all publicly released under Apache 2.0.</p>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>C4 (English)</td>
          <td>Varies by experiment</td>
          <td>Fixed unique data: 100M, 400M, 1.5B tokens</td>
      </tr>
      <tr>
          <td>Code augmentation</td>
          <td>The Stack (Python)</td>
          <td>Up to 42B tokens</td>
          <td>Mixed with natural language</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>19 NL tasks</td>
          <td>Standard splits</td>
          <td>Zero to five-shot, 114 scores per model</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Data-constrained scaling law: $D&rsquo; = U_{D} + U_{D} R_{D}^{<em>}(1 - e^{-R_{D}/R_{D}^{</em>}})$ with $R_{D}^{<em>} \approx 15.0$, $R_{N}^{</em>} \approx 5.3$. Fitted using the methodology of Hoffmann et al. (2022) adapted for the repetition terms. 400+ training runs used for fitting.</p>
<h3 id="models">Models</h3>
<p>GPT-2 architecture decoder-only transformers with GPT-2 tokenizer. Sizes: 10M to 8.7B parameters. Cosine learning rate schedule (max 2e-4, decay to 2e-5), Adam optimizer ($\beta_2 = 0.999$), dropout 0.1, weight decay 0.1, gradient clipping at 1.0. bfloat16 precision. Trained using Megatron-DeepSpeed.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Data-Constrained Optimal</th>
          <th>Chinchilla Optimal</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validation loss (9.3e21 FLOPs, 25B unique)</td>
          <td>Lower</td>
          <td>Higher</td>
          <td>27% fewer parameters</td>
      </tr>
      <tr>
          <td>Downstream (4 epochs vs 1)</td>
          <td>No significant difference</td>
          <td>Baseline</td>
          <td>8.7B params, 44B unique tokens</td>
      </tr>
      <tr>
          <td>Code augmentation (50% code)</td>
          <td>No NL degradation</td>
          <td>Baseline</td>
          <td>Some tasks improve</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Trained on the LUMI supercomputer (Finland) using AMD Instinct MI250X GPUs with data, tensor, and pipeline parallelism. Up to 256 GPUs (64 nodes) per run, with up to 2,200 nodes (~8,800 GPUs) used in parallel across all concurrent runs. Total compute: approximately 3 million GPU hours. The cluster runs on 100% renewable hydroelectric energy.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/huggingface/datablations">datablations</a></td>
          <td>Code + Models + Data</td>
          <td>Apache 2.0</td>
          <td>All 400+ models, datasets, and training code</td>
      </tr>
      <tr>
          <td><a href="https://github.com/TurkuNLP/Megatron-DeepSpeed">Megatron-DeepSpeed fork</a></td>
          <td>Code</td>
          <td>-</td>
          <td>Training framework adapted for AMD ROCm</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{muennighoff2023scaling,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Scaling Data-Constrained Language Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Muennighoff, Niklas and Rush, Alexander M. and Barak, Boaz and Le Scao, Teven and Piktus, Aleksandra and Tazi, Nouamane and Pyysalo, Sampo and Wolf, Thomas and Raffel, Colin}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{36}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DoReMi: Optimizing Data Mixtures for LM Pretraining</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/doremi-data-mixture-optimization/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/doremi-data-mixture-optimization/</guid><description>DoReMi uses a small proxy model with distributionally robust optimization to learn domain weights that speed up large-scale language model pretraining by 2.6x.</description><content:encoded><![CDATA[<h2 id="a-method-for-automatic-domain-reweighting">A method for automatic domain reweighting</h2>
<p>This is a <strong>method paper</strong> that introduces Domain Reweighting with Minimax Optimization (DoReMi), an algorithm for automatically tuning the mixture proportions of pretraining data domains. Rather than relying on heuristics or expensive downstream-task-based tuning, DoReMi uses a small proxy model trained with <a href="https://en.wikipedia.org/wiki/Robust_optimization">group distributionally robust optimization (Group DRO)</a> to produce domain weights that transfer to much larger models.</p>
<h2 id="why-data-mixture-proportions-matter">Why data mixture proportions matter</h2>
<p>Language model pretraining datasets combine text from many domains: web crawls, Wikipedia, books, code, academic papers, and others. The mixture proportions (how much of each domain to include) significantly affect downstream performance, but existing approaches either set them by hand (<a href="https://en.wikipedia.org/wiki/The_Pile_(dataset)">The Pile</a> uses heuristic weights) or tune them against downstream tasks (GLaM/PaLM), which is expensive and risks overfitting to a specific evaluation set. No principled, task-agnostic method existed for determining mixture proportions.</p>
<h2 id="minimax-optimization-over-domain-excess-loss">Minimax optimization over domain excess loss</h2>
<p>DoReMi&rsquo;s core insight is to frame data mixture optimization as a minimax problem: find domain weights that minimize the worst-case excess loss across all domains. The algorithm has three steps.</p>
<p><strong>Step 1</strong>: Train a small reference model (280M parameters) on some default domain weights $\alpha_{\text{ref}}$ (e.g., proportional to raw token count).</p>
<p><strong>Step 2</strong>: Train a small proxy model $p_{\theta}$ using Group DRO, which solves the minimax objective:</p>
<p>$$
\min_{\theta} \max_{\alpha \in \Delta^{k}} \sum_{i=1}^{k} \alpha_{i} \cdot \left[ \frac{1}{\sum_{x \in D_{i}} |x|} \sum_{x \in D_{i}} \ell_{\theta}(x) - \ell_{\text{ref}}(x) \right]
$$</p>
<p>where $\ell_{\theta}(x) = -\log p_{\theta}(x)$ and $\ell_{\text{ref}}(x) = -\log p_{\text{ref}}(x)$. The excess loss $\ell_{\theta}(x) - \ell_{\text{ref}}(x)$ measures how much headroom the proxy has to improve on each example relative to the reference. The inner maximization upweights domains with high excess loss via exponentiated gradient ascent, while the outer minimization trains the proxy on those upweighted domains.</p>
<p>At each training step, the domain weights update as:</p>
<p>$$
\alpha_{t}&rsquo; \leftarrow \alpha_{t-1} \exp(\eta \lambda_{t})
$$</p>
<p>where $\lambda_{t}[i]$ is the per-domain excess loss (clipped at zero), followed by renormalization and smoothing with a uniform component: $\alpha_{t} \leftarrow (1-c)\frac{\alpha_{t}&rsquo;}{\sum_{i} \alpha_{t}&rsquo;[i]} + cu$, with $c = 10^{-3}$.</p>
<p>The final domain weights are the average over all training steps: $\bar{\alpha} = \frac{1}{T}\sum_{t=1}^{T} \alpha_{t}$.</p>
<p><strong>Step 3</strong>: Resample data according to $\bar{\alpha}$ and train the full-scale model using standard procedures.</p>
<p><strong>Iterated DoReMi</strong> extends this by running multiple rounds, using the previous round&rsquo;s optimized weights as the next round&rsquo;s reference weights. This converges within 3 rounds on the GLaM dataset.</p>
<h2 id="experiments-across-the-pile-and-glam-datasets">Experiments across The Pile and GLaM datasets</h2>
<p><strong>Datasets.</strong> The Pile (22 domains, 800GB) and the GLaM dataset (8 domains, also used for PaLM). On The Pile, baseline weights come from the dataset defaults. On GLaM, baseline weights are uniform, with downstream-tuned oracle weights available for comparison.</p>
<p><strong>Setup.</strong> Transformer decoder-only LMs trained with next-token prediction. All models use batch size 512 and sequence length 1024. Proxy and reference models are 280M parameters. Main models are 8B parameters (30x larger). Training runs: 200K steps (Pile) or 300K steps (GLaM). The domain weight optimization cost (training two 280M models) is 8% of the compute for the 8B main model.</p>
<p><strong>Evaluation.</strong> Per-domain held-out perplexity and one-shot generative accuracy on five tasks: TriviaQA, NaturalQuestions, WebQuestions, SQuADv2, and LAMBADA.</p>
<h3 id="key-domain-weight-shifts">Key domain weight shifts</h3>
<p>On The Pile, DoReMi (280M) dramatically upweights diverse web text (Pile-CC: 0.112 to 0.606) while downweighting specialized domains like ArXiv (0.105 to 0.004), PubMed Central (0.107 to 0.005), and StackExchange (0.093 to 0.015). Smaller, underrepresented domains like YouTubeSubtitles and PhilPapers receive proportionally large increases.</p>
<h3 id="scaling-behavior">Scaling behavior</h3>
<p>DoReMi was tested with matched proxy/main model sizes (280M through 1B) and with varying proxy sizes (70M through 1B) feeding into an 8B main model.</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>Speedup to baseline accuracy</th>
          <th>Downstream improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>DoReMi (280M to 280M)</td>
          <td>4x</td>
          <td>+2% avg accuracy</td>
      </tr>
      <tr>
          <td>DoReMi (280M to 8B)</td>
          <td>2.6x</td>
          <td>+6.5% avg accuracy</td>
      </tr>
      <tr>
          <td>DoReMi (150M to 8B)</td>
          <td>~2x</td>
          <td>Significant</td>
      </tr>
      <tr>
          <td>DoReMi (1B to 8B)</td>
          <td>~2x</td>
          <td>Significant</td>
      </tr>
  </tbody>
</table>
<p>Improvements are consistent across all tested model scales (280M to 1B matched), with no sign of diminishing returns at larger sizes.</p>
<h2 id="perplexity-improves-everywhere-even-on-downweighted-domains">Perplexity improves everywhere, even on downweighted domains</h2>
<p>The most striking finding is that DoReMi improves perplexity on all 22 domains in The Pile, including domains it downweights. The proposed explanation: the lowest-entropy domains need few samples to learn (they&rsquo;re statistically simple), while the highest-entropy domains have token distributions close to the uniform initialization and also need fewer samples. Reallocating weight to medium-entropy domains generates positive transfer that lifts all domains.</p>
<p>On The Pile, DoReMi reaches the baseline&rsquo;s downstream accuracy in 75K steps versus 200K for the baseline (2.6x speedup) and achieves a 6.5% absolute improvement in average one-shot accuracy at 200K steps.</p>
<p>On the GLaM dataset, iterated DoReMi (round 2) matches the performance of domain weights that were tuned directly on downstream task performance, despite having no knowledge of downstream tasks. Domain weights converge within 3 iterations.</p>
<h3 id="ablations">Ablations</h3>
<p>Using only the proxy model&rsquo;s loss (prefer hardest domains) or only the negative reference loss (prefer easiest domains) both underperform the full excess loss formulation. Both components are necessary: the excess loss identifies domains where the proxy has room to improve relative to what is learnable.</p>
<p>The proxy model itself typically underperforms the main model trained on its weights, and this gap grows at larger proxy scales. A 1B proxy model underperforms the 1B baseline, yet its domain weights still improve 1B main model training by over 2x. This suggests the domain weight signal is robust even when the proxy model itself is not well-trained.</p>
<h3 id="limitations">Limitations</h3>
<p>The domain weight landscape may have multiple local optima: a 280M proxy puts most weight on Pile-CC, while a 1B proxy favors OpenWebText2 instead. Both configurations improve over baseline, but the optimal weights are not unique.</p>
<p>The granularity of &ldquo;domains&rdquo; matters. DoReMi works better with more domains (22 on The Pile versus 8 on GLaM). Domains are defined by data provenance, which is coarse-grained. Fine-grained domain definitions (e.g., via clustering) could improve results but also risk DRO putting all weight on a small set of worst-case examples.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>The Pile</td>
          <td>800 GB, 22 domains</td>
          <td>Default heuristic weights as baseline</td>
      </tr>
      <tr>
          <td>Pretraining</td>
          <td>GLaM dataset</td>
          <td>8 domains</td>
          <td>Uniform weights as baseline; downstream-tuned oracle available</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>TriviaQA, NaturalQuestions, WebQuestions, SQuADv2, LAMBADA</td>
          <td>Standard splits</td>
          <td>One-shot generative evaluation</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Group DRO with exponentiated gradient ascent for domain weight updates. Step size $\eta = 1$, smoothing $c = 10^{-3}$. Per-token excess loss clipped at zero. Domain weights averaged over all training steps. Iterated DoReMi converges when $|\bar{\alpha} - \alpha_{\text{ref}}|_{\infty} &lt; 10^{-3}$.</p>
<h3 id="models">Models</h3>
<p>Vanilla Transformer decoder-only models with 256K vocabulary. Sizes: 70M (3 layers), 150M (6 layers), 280M (12 layers), 510M (12 layers), 760M (12 layers), 1B (16 layers), 8B (32 layers). All use 64-dim attention heads except 8B (128-dim).</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>DoReMi (280M to 8B)</th>
          <th>Baseline (8B)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Avg one-shot accuracy</td>
          <td>+6.5% over baseline</td>
          <td>Reference</td>
          <td>5 generative tasks</td>
      </tr>
      <tr>
          <td>Worst-case log-perplexity</td>
          <td>1.46</td>
          <td>1.71</td>
          <td>Across 22 Pile domains</td>
      </tr>
      <tr>
          <td>Avg log-perplexity</td>
          <td>1.40</td>
          <td>1.64</td>
          <td>Across 22 Pile domains</td>
      </tr>
      <tr>
          <td>Domains beating baseline</td>
          <td>22/22</td>
          <td>0/22</td>
          <td>Per-domain perplexity</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Proxy and reference models (under 1B) trained on TPUv3. Models at 1B and 8B trained on TPUv4. Domain weight optimization (two 280M runs) costs 8% of 8B training FLOPs.</p>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{xie2023doremi,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Xie, Sang Michael and Pham, Hieu and Dong, Xuanyi and Du, Nan and Liu, Hanxiao and Lu, Yifeng and Liang, Percy and Le, Quoc V. and Ma, Tengyu and Yu, Adams Wei}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{36}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Data Mixing Laws for LM Pretraining Optimization</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/data-mixing-laws-pretraining/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/data-mixing-laws-pretraining/</guid><description>Ye et al. discover that LM loss follows an exponential law over domain mixture proportions, enabling cheap prediction and optimization of data mixtures.</description><content:encoded><![CDATA[<h2 id="an-empirical-discovery-of-predictable-mixture-loss-relationships">An empirical discovery of predictable mixture-loss relationships</h2>
<p>This is a <strong>discovery paper</strong> that identifies a quantitative, functional relationship between pretraining data mixture proportions and language model loss. The key finding is that domain-specific validation loss follows an exponential law over the linear combination of training domain proportions, and this law composes with standard scaling laws to enable cheap prediction of large-model performance under arbitrary mixtures.</p>
<h2 id="the-missing-quantitative-link-between-data-mixtures-and-performance">The missing quantitative link between data mixtures and performance</h2>
<p>Pretraining data for large language models combines text from many domains (web, code, academic, books, etc.), and mixture proportions significantly affect model quality. Existing approaches either set proportions by hand without disclosed criteria (LLaMA, Baichuan) or use algorithmic methods like <a href="/notes/natural-language-processing/language-models/doremi-data-mixture-optimization/">DoReMi</a> that optimize qualitatively but cannot predict the quantitative effect of a specific mixture before training. Scaling laws exist for model size and data quantity, but no equivalent existed for mixture proportions. This paper fills that gap.</p>
<h2 id="the-exponential-data-mixing-law">The exponential data mixing law</h2>
<p>The core finding: for a model of fixed size trained for a fixed number of steps, the validation loss on domain $i$ as a function of the training mixture proportions $r_{1 \dots M}$ follows:</p>
<p>$$
L_{i}(r_{1 \dots M}) = c_{i} + k_{i} \exp\left(\sum_{j=1}^{M} t_{ij} r_{j}\right)
$$</p>
<p>where $c_{i}$, $k_{i}$, and $t_{ij}$ are fitted parameters. The constant $c_{i}$ represents the irreducible loss (not affected by mixture changes). The interaction coefficients $t_{ij}$ capture how training domain $j$ affects validation loss on domain $i$: negative $t_{ij}$ means domain $j$ helps domain $i$, positive means it hurts.</p>
<p>This was discovered progressively:</p>
<ol>
<li><strong>Two domains</strong>: Log-reducible-loss is linear in domain proportion (univariate exponential).</li>
<li><strong>Three domains</strong>: The exponential generalizes to a linear combination over all domain proportions (Eq. above), outperforming alternatives with comparable parameter count.</li>
<li><strong>General validation</strong>: For a validation set composed of $K$ domains with proportions $s_{1 \dots K}$, the overall loss is:</li>
</ol>
<p>$$
L(r_{1 \dots M}) = \sum_{i=1}^{K} s_{i} \left[ c_{i} + k_{i} \exp\left(\sum_{j=1}^{M} t_{ij} r_{j}\right) \right]
$$</p>
<p>When the validation set composition is unknown, implicit domain aggregation treats $s_{i}$ as learnable parameters. Setting the number of implicit domains larger than the true number works well and is robust to overestimation.</p>
<h3 id="domain-interaction-patterns">Domain interaction patterns</h3>
<p>Visualizing the fitted $t_{ij}$ coefficients across 5 coarse Pile domains reveals three relationship types: most domain pairs are <strong>unrelated</strong> (sparse interaction matrix where each domain&rsquo;s loss is dominated by its own training proportion), some show <strong>facilitation</strong> (e.g., dialogue data helps internet text), and some show <strong>conflict</strong> (e.g., symbolic data hurts prose). This sparsity explains why the law can be fitted with fewer samples than the quadratic parameter count would suggest.</p>
<h2 id="nested-scaling-pipeline-for-cheap-prediction">Nested scaling pipeline for cheap prediction</h2>
<p>Fitting data mixing laws directly at target scale is too expensive (requires many full training runs at different mixtures). The paper proposes nesting three scaling laws:</p>
<p><strong>Step 1</strong>: For each mixture $r_{i}$ and each small model size $N_{j}$, train for $S_{0}$ steps. Fit a <a href="https://en.wikipedia.org/wiki/Power_law">power law</a> $L(S) = E_{1} + B/S^{\beta}$ over steps to extrapolate to the target step count $S_{\text{target}}$.</p>
<p><strong>Step 2</strong>: With the step-extrapolated losses for each mixture, fit a power law $L(N) = E_{2} + A/N^{\alpha}$ over model sizes to extrapolate to the target model size $N_{\text{target}}$.</p>
<p><strong>Step 3</strong>: With the predicted losses at $(N_{\text{target}}, S_{\text{target}})$ for all sampled mixtures, fit the data mixing law and search for the optimal mixture.</p>
<p>This pipeline requires only training small models (70M to 410M) for short runs (30B tokens) to predict performance of a 1B model trained for 100B tokens.</p>
<h3 id="mixture-sampling-strategy">Mixture sampling strategy</h3>
<p>To get informative samples efficiently, the paper uses double-diminishing proportions: for each domain, enumerate proportions by halving from the maximum available. This distributes losses evenly across the exponential law&rsquo;s range. From 40 candidate mixtures trained at the smallest scale (70M), 20 are selected based on which subset minimizes data mixing law fitting error.</p>
<h2 id="experiments-on-redpajama-and-continual-pretraining">Experiments on RedPajama and continual pretraining</h2>
<p><strong>Main experiment.</strong> Models trained on RedPajama, validated on the Pile (mimicking the common scenario where validation data comes from a different distribution than training). Small models: 70M, 160M, 305M, 410M trained for 30B tokens. Target: 1B model for 100B tokens.</p>
<p>The optimized mixture dramatically redistributes weight compared to RedPajama defaults:</p>
<table>
  <thead>
      <tr>
          <th>Domain</th>
          <th>Default</th>
          <th>Optimized</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CommonCrawl</td>
          <td>0.670</td>
          <td>0.125</td>
      </tr>
      <tr>
          <td>C4</td>
          <td>0.150</td>
          <td>0.250</td>
      </tr>
      <tr>
          <td>GitHub</td>
          <td>0.045</td>
          <td>0.141</td>
      </tr>
      <tr>
          <td>ArXiv</td>
          <td>0.045</td>
          <td>0.250</td>
      </tr>
      <tr>
          <td>Books</td>
          <td>0.045</td>
          <td>0.094</td>
      </tr>
      <tr>
          <td>StackExchange</td>
          <td>0.025</td>
          <td>0.125</td>
      </tr>
      <tr>
          <td>Wikipedia</td>
          <td>0.020</td>
          <td>0.016</td>
      </tr>
  </tbody>
</table>
<p>The optimized mixture reaches the default mixture&rsquo;s final performance in 73% of the training steps and eventually achieves performance equivalent to 48% more training on the default mixture.</p>
<p><strong>Comparison to DoReMi and DoGE.</strong> Data mixing laws outperform both: the predicted-optimal mixture achieves lower validation loss than DoReMi and DoGE (both universal and OOD settings) for 1B models trained for 100B tokens on RedPajama.</p>
<p><strong>Continual pretraining.</strong> The law extends to continual pretraining (Pythia-70M on Pile + Python code). It accurately predicts the critical mixture proportion that avoids <a href="https://en.wikipedia.org/wiki/Catastrophic_interference">catastrophic forgetting</a> on the original domain while improving the target domain. This suggests data mixing laws could guide dynamic data schedules across multi-stage pretraining.</p>
<h2 id="implications-and-limitations">Implications and limitations</h2>
<p>The data mixing law provides a predictive framework rather than just an optimization algorithm. Key implications:</p>
<ul>
<li>The interaction coefficients $t_{ij}$ make domain relationships quantitatively observable before full-scale training, identifying facilitation and conflict pairs.</li>
<li>The nested pipeline&rsquo;s cost is dominated by the small-model training runs (40 mixtures at 70M scale), which is orders of magnitude cheaper than even a single target-scale run.</li>
<li>The continual pretraining application opens the door to optimizing dynamic data schedules, where mixture proportions change across training stages.</li>
</ul>
<p><strong>Limitations</strong>: The &ldquo;domain&rdquo; concept remains loosely defined (provenance-based). The nested scaling laws introduce compounding errors at each step, and predictions tend to slightly underestimate actual loss. The number of required fitting samples, while subquadratic in practice due to sparsity, still scales with the number of domains. No theoretical justification for the exponential form is provided; it is a purely empirical finding.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training (pilot)</td>
          <td>The Pile (GitHub, Pile-CC, Books3)</td>
          <td>30B tokens</td>
          <td>2-domain and 3-domain experiments</td>
      </tr>
      <tr>
          <td>Training (main)</td>
          <td>RedPajama</td>
          <td>100B tokens</td>
          <td>7 domains</td>
      </tr>
      <tr>
          <td>Validation</td>
          <td>The Pile validation set</td>
          <td>Standard split</td>
          <td>Out-of-distribution relative to RedPajama</td>
      </tr>
      <tr>
          <td>Continual pretraining</td>
          <td>Pile + Python code</td>
          <td>10B tokens</td>
          <td>Pythia-70M base model</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Data mixing law: $L_{i}(r_{1 \dots M}) = c_{i} + k_{i} \exp(\sum_{j} t_{ij} r_{j})$. Fitted via AdaBoost Regressor on sampled mixtures. Step scaling law: $L(S) = E_{1} + B/S^{\beta}$. Model size scaling law: $L(N) = E_{2} + A/N^{\alpha}$. Both fitted via Huber loss minimization with LBFGS. Decomposed Chinchilla-style (separate fits for stability). 40 candidate mixtures sampled via double-diminishing proportions, 20 selected for the final pipeline.</p>
<h3 id="models">Models</h3>
<p>Transformer decoder-only LMs. Pilot: 70M, 160M. Main pipeline: 70M, 160M, 305M, 410M (for fitting), 1B (target). Batch size: 1M tokens. Cosine learning rate decay with 2K step warmup, decaying to 0.1x at 100K steps.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Optimized Mixture</th>
          <th>Default Mixture</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Steps to match default final loss</td>
          <td>73K (73%)</td>
          <td>100K (100%)</td>
          <td>27% training reduction</td>
      </tr>
      <tr>
          <td>Equivalent extra training</td>
          <td>+48%</td>
          <td>Baseline</td>
          <td>Estimated via step scaling law</td>
      </tr>
      <tr>
          <td>Validation loss (1B, 100B)</td>
          <td>Lowest</td>
          <td>Higher than optimized</td>
          <td>Also beats DoReMi and DoGE</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>8 A100 GPUs. Training times per 30B-token run: 3.5 hours (70M), 8 hours (160M), 16 hours (305M), 21 hours (410M).</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://pile.eleuther.ai/">The Pile</a></td>
          <td>Dataset</td>
          <td>MIT</td>
          <td>Pilot and validation data</td>
      </tr>
      <tr>
          <td><a href="https://github.com/togethercomputer/RedPajama-Data">RedPajama</a></td>
          <td>Dataset</td>
          <td>Apache 2.0</td>
          <td>Main training data</td>
      </tr>
      <tr>
          <td><a href="https://github.com/EleutherAI/pythia">Pythia Suite</a></td>
          <td>Model</td>
          <td>Apache 2.0</td>
          <td>Model architecture configs; Pythia-70M checkpoint for continual pretraining</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status: Partially Reproducible.</strong> Datasets and base model checkpoints are public. No official code release for the data mixing law fitting pipeline, mixture sampling, or the nested scaling law prediction workflow.</p>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{ye2025datamixinglaws,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Data Mixing Laws: Optimizing Data Mixtures by Predicting Language Modeling Performance}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ye, Jiasheng and Liu, Peiju and Sun, Tianxiang and Zhan, Jun and Zhou, Yunhua and Qiu, Xipeng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Grammar VAE: Generating Valid Molecules via CFGs</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/</guid><description>The Grammar VAE encodes and decodes molecular parse trees from context-free grammars, guaranteeing syntactically valid SMILES outputs during generation.</description><content:encoded><![CDATA[<h2 id="a-grammar-constrained-vae-for-discrete-data-generation">A Grammar-Constrained VAE for Discrete Data Generation</h2>
<p>This is a <strong>Method</strong> paper that introduces the Grammar Variational Autoencoder (GVAE), a variant of the <a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">variational autoencoder</a> that operates directly on parse trees from context-free grammars (CFGs) rather than on raw character sequences. The primary contribution is a decoding mechanism that uses a stack and grammar-derived masks to restrict the output at every timestep to only syntactically valid production rules. This guarantees that every decoded output is a valid string under the grammar, addressing a fundamental limitation of character-level VAEs when applied to structured discrete data such as <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> molecular strings and arithmetic expressions.</p>
<h2 id="why-character-level-vaes-fail-on-structured-discrete-data">Why Character-Level VAEs Fail on Structured Discrete Data</h2>
<p>Generative models for continuous data (images, audio) had achieved impressive results by 2017, but generating structured discrete data remained difficult. The key challenge is that string representations of molecules and mathematical expressions are brittle: small perturbations to a character sequence often produce invalid outputs. <a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">Gomez-Bombarelli et al. (2016)</a> demonstrated a character-level VAE (CVAE) for SMILES strings that could encode molecules into a continuous latent space and decode them back, enabling latent-space optimization for molecular design. However, the CVAE frequently decoded latent points into strings that were not valid SMILES, particularly when exploring regions of latent space far from training data.</p>
<p>The fundamental issue is that character-level decoders must implicitly learn the syntactic rules of the target language from data alone. For SMILES, this includes matching parentheses, valid atom types, proper bonding, and ring closure notation. The GVAE addresses this by giving the decoder explicit knowledge of the grammar, so it can focus entirely on learning the semantic structure of the data.</p>
<h2 id="core-innovation-stack-based-grammar-masking-in-the-decoder">Core Innovation: Stack-Based Grammar Masking in the Decoder</h2>
<p>The GVAE encodes and decodes sequences of production rules from a context-free grammar rather than sequences of characters.</p>
<p><strong>Encoding.</strong> Given an input string (e.g., a SMILES molecule), the encoder first parses it into a parse tree using the CFG, then performs a left-to-right pre-order traversal of the tree to extract an ordered sequence of production rules. Each rule is represented as a one-hot vector of dimension $K$ (total number of production rules in the grammar). The resulting $T(\mathbf{X}) \times K$ matrix is processed by a convolutional neural network to produce the mean and variance of a Gaussian posterior $q_{\phi}(\mathbf{z} \mid \mathbf{X})$.</p>
<p><strong>Decoding with grammar masks.</strong> The decoder maps a latent vector $\mathbf{z}$ through an RNN to produce a matrix of logits $\mathbf{F} \in \mathbb{R}^{T_{max} \times K}$. The key innovation is a last-in first-out (LIFO) stack that tracks the current parsing state. At each timestep $t$, the decoder:</p>
<ol>
<li>Pops the top non-terminal $\alpha$ from the stack</li>
<li>Applies a fixed binary mask $\mathbf{m}_{\alpha} \in {0, 1}^K$ that zeros out all production rules whose left-hand side is not $\alpha$</li>
<li>Samples a production rule from the masked softmax distribution:</li>
</ol>
<p>$$
p(\mathbf{x}_{t} = k \mid \alpha, \mathbf{z}) = \frac{m_{\alpha,k} \exp(f_{tk})}{\sum_{j=1}^{K} m_{\alpha,j} \exp(f_{tj})}
$$</p>
<ol start="4">
<li>Pushes the right-hand-side non-terminals of the selected rule onto the stack (right-to-left, so the leftmost is on top)</li>
</ol>
<p>This process continues until the stack is empty or $T_{max}$ timesteps are reached. Because the mask restricts selection to only those rules applicable to the current non-terminal, every generated sequence of production rules is guaranteed to be a valid derivation under the grammar.</p>
<p><strong>Training.</strong> The model is trained by maximizing the ELBO:</p>
<p>$$
\mathcal{L}(\phi, \theta; \mathbf{X}) = \mathbb{E}_{q(\mathbf{z} \mid \mathbf{X})} \left[ \log p_{\theta}(\mathbf{X}, \mathbf{z}) - \log q_{\phi}(\mathbf{z} \mid \mathbf{X}) \right]
$$</p>
<p>where the likelihood factorizes as:</p>
<p>$$
p(\mathbf{X} \mid \mathbf{z}) = \prod_{t=1}^{T(\mathbf{X})} p(\mathbf{x}_{t} \mid \mathbf{z})
$$</p>
<p>During training, the masks at each timestep are determined by the ground-truth production rule sequence, so no stack simulation is needed. The stack-based decoding is only required at generation time.</p>
<p><strong>Syntactic vs. semantic validity.</strong> The grammar guarantees syntactic validity but not semantic validity. The GVAE can still produce chemically implausible molecules (e.g., an oxygen atom with three bonds) because such constraints are not context-free. SMILES ring-bond digit matching is also not context-free, so the grammar cannot enforce it. Additionally, sequences that have not emptied the stack by $T_{max}$ are marked invalid.</p>
<h2 id="experiments-on-symbolic-regression-and-molecular-optimization">Experiments on Symbolic Regression and Molecular Optimization</h2>
<p>The authors evaluate the GVAE on two domains: arithmetic expressions and molecules. Both use Bayesian optimization (BO) over the learned latent space.</p>
<p><strong>Setup.</strong> After training each VAE, the authors encode training data into latent vectors and train a sparse Gaussian process (SGP) with 500 inducing points to predict properties from latent representations. They then run batch BO with expected improvement, selecting 50 candidates per iteration.</p>
<h3 id="arithmetic-expressions">Arithmetic Expressions</h3>
<ul>
<li><strong>Data</strong>: 100,000 randomly generated univariate expressions from a simple grammar (3 binary operators, 2 unary operators, 3 constants), each with at most 15 production rules</li>
<li><strong>Target</strong>: Find an expression minimizing $\log(1 + \text{MSE})$ against the true function $1/3 + x + \sin(x \cdot x)$</li>
<li><strong>BO iterations</strong>: 5, averaged over 10 repetitions</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Fraction Valid</th>
          <th>Average Score</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GVAE</td>
          <td>0.99 +/- 0.01</td>
          <td>3.47 +/- 0.24</td>
      </tr>
      <tr>
          <td>CVAE</td>
          <td>0.86 +/- 0.06</td>
          <td>4.75 +/- 0.25</td>
      </tr>
  </tbody>
</table>
<p>The GVAE&rsquo;s best expression ($x/1 + \sin(3) + \sin(x \cdot x)$, score 0.04) nearly exactly recovers the true function, while the CVAE&rsquo;s best ($x \cdot 1 + \sin(3) + \sin(3/1)$, score 0.39) misses the sinusoidal component.</p>
<h3 id="molecular-optimization">Molecular Optimization</h3>
<ul>
<li><strong>Data</strong>: 250,000 SMILES strings from the ZINC database</li>
<li><strong>Target</strong>: Maximize penalized logP (water-octanol partition coefficient penalized for ring size and synthetic accessibility)</li>
<li><strong>BO iterations</strong>: 10, averaged over 5 trials</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Fraction Valid</th>
          <th>Average Score</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GVAE</td>
          <td>0.31 +/- 0.07</td>
          <td>-9.57 +/- 1.77</td>
      </tr>
      <tr>
          <td>CVAE</td>
          <td>0.17 +/- 0.05</td>
          <td>-54.66 +/- 2.66</td>
      </tr>
  </tbody>
</table>
<p>The GVAE produces roughly twice as many valid molecules as the CVAE and finds molecules with substantially better penalized logP scores (best: 2.94 vs. 1.98).</p>
<h3 id="latent-space-quality">Latent Space Quality</h3>
<p>Interpolation experiments show that the GVAE produces valid outputs at every intermediate point when linearly interpolating between two encoded expressions, while the CVAE passes through invalid strings. Grid searches around encoded molecules in the GVAE latent space show smooth transitions where neighboring points differ by single atoms.</p>
<h3 id="predictive-performance">Predictive Performance</h3>
<p>Sparse GP models trained on GVAE latent features achieve better test RMSE and log-likelihood than those trained on CVAE features for both expressions and molecules:</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>GVAE (Expressions)</th>
          <th>CVAE (Expressions)</th>
          <th>GVAE (Molecules)</th>
          <th>CVAE (Molecules)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Test LL</td>
          <td>-1.320 +/- 0.001</td>
          <td>-1.397 +/- 0.003</td>
          <td>-1.739 +/- 0.004</td>
          <td>-1.812 +/- 0.004</td>
      </tr>
      <tr>
          <td>Test RMSE</td>
          <td>0.884 +/- 0.002</td>
          <td>0.975 +/- 0.004</td>
          <td>1.404 +/- 0.006</td>
          <td>1.504 +/- 0.006</td>
      </tr>
  </tbody>
</table>
<h3 id="reconstruction-and-prior-sampling">Reconstruction and Prior Sampling</h3>
<p>On held-out molecules, the GVAE achieves 53.7% reconstruction accuracy vs. 44.6% for the CVAE. When sampling from the prior $p(\mathbf{z}) = \mathcal{N}(0, \mathbf{I})$, 7.2% of GVAE samples are valid molecules vs. 0.7% for the CVAE.</p>
<h2 id="key-findings-limitations-and-impact">Key Findings, Limitations, and Impact</h2>
<p><strong>Key findings.</strong> Incorporating grammar structure into the VAE decoder consistently improves validity rates, latent space smoothness, downstream predictive performance, and Bayesian optimization outcomes across both domains. The approach is general: any domain with a context-free grammar can benefit.</p>
<p><strong>Limitations acknowledged by the authors.</strong></p>
<ul>
<li>The GVAE guarantees syntactic but not semantic validity. For molecules, invalid ring-bond patterns and chemically implausible structures can still be generated.</li>
<li>The molecular validity rate during BO (31%) is substantially higher than the CVAE (17%) but still means most decoded molecules are invalid, largely due to non-context-free constraints in SMILES.</li>
<li>The approach requires a context-free grammar for the target domain, which limits applicability to well-defined formal languages.</li>
<li>Sequences that do not complete parsing within $T_{max}$ timesteps are discarded as invalid.</li>
</ul>
<p><strong>Impact.</strong> The GVAE was an influential early contribution to constrained molecular generation. It directly inspired the Syntax-Directed VAE (SD-VAE) by Dai et al. (2018), which uses attribute grammars for tighter semantic constraints, and contributed to the broader movement toward structured molecular generation methods including graph-based approaches. The paper demonstrated that encoding domain knowledge into the decoder architecture is more effective than relying on the model to learn structural constraints from data alone.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training (expressions)</td>
          <td>Generated arithmetic expressions</td>
          <td>100,000</td>
          <td>Up to 15 production rules each</td>
      </tr>
      <tr>
          <td>Training (molecules)</td>
          <td>ZINC database subset</td>
          <td>250,000 SMILES</td>
          <td>Same subset as <a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">Gomez-Bombarelli et al. (2016)</a></td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Encoder: 1D convolutional neural network over one-hot rule sequences</li>
<li>Decoder: RNN with stack-based grammar masking</li>
<li>Latent space: 56 dimensions (molecules), isotropic Gaussian prior</li>
<li>Property predictor: Sparse Gaussian process with 500 inducing points</li>
<li>Optimization: Batch Bayesian optimization with expected improvement, 50 candidates per iteration, Kriging Believer for batch selection</li>
</ul>
<h3 id="models">Models</h3>
<p>Architecture details follow <a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">Gomez-Bombarelli et al. (2016)</a> with modifications for grammar-based encoding/decoding. Specific layer sizes and hyperparameters are described in the supplementary material.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>GVAE</th>
          <th>CVAE</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Fraction valid (expressions)</td>
          <td>0.99</td>
          <td>0.86</td>
          <td>During BO</td>
      </tr>
      <tr>
          <td>Fraction valid (molecules)</td>
          <td>0.31</td>
          <td>0.17</td>
          <td>During BO</td>
      </tr>
      <tr>
          <td>Best penalized logP</td>
          <td>2.94</td>
          <td>1.98</td>
          <td>Best molecule found</td>
      </tr>
      <tr>
          <td>Reconstruction accuracy</td>
          <td>53.7%</td>
          <td>44.6%</td>
          <td>On held-out molecules</td>
      </tr>
      <tr>
          <td>Prior validity</td>
          <td>7.2%</td>
          <td>0.7%</td>
          <td>Sampling from N(0,I)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/mkusner/grammarVAE">grammarVAE</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Kusner, M. J., Paige, B., &amp; Hernández-Lobato, J. M. (2017). Grammar Variational Autoencoder. <em>Proceedings of the 34th International Conference on Machine Learning (ICML)</em>, 1945-1954.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{kusner2017grammar,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Grammar Variational Autoencoder}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Kusner, Matt J. and Paige, Brooks and Hern{\&#39;a}ndez-Lobato, Jos{\&#39;e} Miguel}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 34th International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1945--1954}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2017}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{PMLR}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Coscientist: Autonomous Chemistry with LLM Agents</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/autonomous-chemical-research-coscientist/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/autonomous-chemical-research-coscientist/</guid><description>Coscientist uses GPT-4 to autonomously design, plan, and execute chemical experiments including Pd-catalysed cross-coupling optimization.</description><content:encoded><![CDATA[<h2 id="an-llm-powered-agent-for-autonomous-chemical-experimentation">An LLM-Powered Agent for Autonomous Chemical Experimentation</h2>
<p>This is a <strong>Method</strong> paper that introduces Coscientist, an AI system driven by GPT-4 that autonomously designs, plans, and performs complex chemical experiments. The primary contribution is a modular multi-LLM agent architecture that integrates internet search, documentation retrieval, code execution, and robotic experimentation APIs into a unified system capable of end-to-end experimental chemistry with minimal human intervention.</p>
<h2 id="bridging-llm-capabilities-and-laboratory-automation">Bridging LLM Capabilities and Laboratory Automation</h2>
<p>Transformer-based large language models had demonstrated strong capabilities in natural language processing, biology, chemistry, and code generation by early 2023. Simultaneously, laboratory automation had progressed with autonomous reaction discovery, automated flow systems, and mobile robotic platforms. However, these two threads remained largely separate: LLMs could reason about chemistry in text, but could not act on that reasoning by controlling physical experiments.</p>
<p>The gap this work addresses is the integration of LLM reasoning with laboratory automation in a closed-loop system. Prior automated chemistry systems relied on traditional optimization algorithms or narrow AI components. The question was whether GPT-4&rsquo;s general reasoning capabilities could be combined with tool access to produce a system that autonomously designs experiments, writes instrument code, executes reactions, and interprets results, all from natural language prompts.</p>
<p>This work was developed independently and in parallel with other autonomous agent efforts (AutoGPT, BabyAGI, LangChain), with <a href="/notes/chemistry/llm-applications/chemcrow-augmenting-llms-chemistry-tools/">ChemCrow</a> serving as another chemistry-specific example.</p>
<h2 id="a-modular-multi-llm-architecture-with-tool-access">A Modular Multi-LLM Architecture with Tool Access</h2>
<p>The core innovation is Coscientist&rsquo;s modular architecture, centered on a &ldquo;Planner&rdquo; module (a GPT-4 chat completion instance) that orchestrates four command types:</p>
<ol>
<li><strong>GOOGLE</strong>: A Web Searcher module (itself an LLM) that transforms prompts into search queries, browses results, and funnels answers back to the Planner.</li>
<li><strong>PYTHON</strong>: A Code Execution module running in an isolated Docker container for calculations and data analysis, with no LLM dependency.</li>
<li><strong>DOCUMENTATION</strong>: A Docs Searcher module that retrieves and summarizes technical documentation (e.g., Opentrons Python API, Emerald Cloud Lab Symbolic Lab Language) using ada embeddings and distance-based vector search.</li>
<li><strong>EXPERIMENT</strong>: An Automation module that executes generated code on laboratory hardware or provides synthetic procedures.</li>
</ol>
<p>The system prompts are engineered in a modular fashion, with the Planner receiving initial user input and command outputs as messages. The Planner can iteratively call commands, fix software errors, and refine its approach. This design allows natural language instructions (e.g., &ldquo;perform multiple Suzuki reactions&rdquo;) to be translated into complete experimental protocols.</p>
<p>For documentation retrieval, all sections of the OT-2 API documentation were embedded using OpenAI&rsquo;s ada model, and relevant sections are retrieved via cosine similarity search. For the Emerald Cloud Lab, the system learned to program in a symbolic lab language (SLL) that was completely unknown to GPT-4 at training time, demonstrating effective in-context learning from supplied documentation.</p>
<h2 id="six-tasks-demonstrating-autonomous-chemistry-capabilities">Six Tasks Demonstrating Autonomous Chemistry Capabilities</h2>
<p>The paper evaluates Coscientist across six tasks of increasing complexity.</p>
<h3 id="task-1-chemical-synthesis-planning">Task 1: Chemical Synthesis Planning</h3>
<p>A benchmark of seven compounds was used to compare synthesis planning across models (GPT-4, GPT-3.5, Claude 1.3, Falcon-40B-Instruct) with and without web search. Outputs were scored on a 1-5 scale:</p>
<table>
  <thead>
      <tr>
          <th>Score</th>
          <th>Meaning</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>5</td>
          <td>Very detailed and chemically accurate procedure</td>
      </tr>
      <tr>
          <td>4</td>
          <td>Detailed and accurate but without reagent quantities</td>
      </tr>
      <tr>
          <td>3</td>
          <td>Correct chemistry but no step-by-step procedure</td>
      </tr>
      <tr>
          <td>2</td>
          <td>Extremely vague or unfeasible</td>
      </tr>
      <tr>
          <td>1</td>
          <td>Incorrect or failure to follow instructions</td>
      </tr>
  </tbody>
</table>
<p>The GPT-4-powered Web Searcher achieved maximum scores for acetaminophen, aspirin, nitroaniline, and phenolphthalein. It was the only approach to achieve acceptable scores (3+) for ibuprofen, which all non-browsing models synthesized incorrectly. These results highlight the importance of grounding LLMs to avoid hallucinations.</p>
<h3 id="task-2-documentation-search">Task 2: Documentation Search</h3>
<p>The system correctly identified relevant ECL functions from documentation and generated valid SLL code that was successfully executed at ECL, including an <a href="https://en.wikipedia.org/wiki/High-performance_liquid_chromatography">HPLC</a> experiment on a caffeine standard sample.</p>
<h3 id="task-3-cloud-laboratory-execution">Task 3: Cloud Laboratory Execution</h3>
<p>Using prompt-to-function and prompt-to-SLL pipelines, Coscientist generated executable code for the Emerald Cloud Lab. It also searched a catalogue of 1,110 model samples to identify relevant stock solutions from simple search terms.</p>
<h3 id="task-4-liquid-handler-control">Task 4: Liquid Handler Control</h3>
<p>Using the Opentrons OT-2, Coscientist translated natural language prompts (e.g., &ldquo;colour every other line with one colour of your choice,&rdquo; &ldquo;draw a red cross&rdquo;) into accurate liquid handling protocols.</p>
<h3 id="task-5-integrated-multi-module-experiment">Task 5: Integrated Multi-Module Experiment</h3>
<p>The most complex demonstration combined web search, code execution, documentation retrieval, and hardware control to design and execute <a href="https://en.wikipedia.org/wiki/Suzuki_reaction">Suzuki-Miyaura</a> and <a href="https://en.wikipedia.org/wiki/Sonogashira_coupling">Sonogashira</a> <a href="https://en.wikipedia.org/wiki/Cross-coupling_reaction">cross-coupling</a> reactions. Coscientist:</p>
<ul>
<li>Searched the internet for reaction conditions and stoichiometries</li>
<li>Selected correct coupling partners (never misassigning <a href="https://en.wikipedia.org/wiki/Phenylboronic_acid">phenylboronic acid</a> to Sonogashira)</li>
<li>Calculated reagent volumes and wrote OT-2 protocols</li>
<li>Self-corrected when using an incorrect heater-shaker method by consulting documentation</li>
<li>Successfully produced target products confirmed by <a href="https://en.wikipedia.org/wiki/Gas_chromatography%E2%80%93mass_spectrometry">GC-MS</a> analysis (biphenyl at 9.53 min for Suzuki, diphenylacetylene at 12.92 min for Sonogashira)</li>
</ul>
<h3 id="task-6-reaction-optimization">Task 6: Reaction Optimization</h3>
<p>Coscientist was tested on two fully mapped reaction datasets:</p>
<ol>
<li><strong>Suzuki reaction flow dataset</strong> (Perera et al.): varying ligands, reagents/bases, and solvents</li>
<li><strong><a href="https://en.wikipedia.org/wiki/Buchwald%E2%80%93Hartwig_amination">Buchwald-Hartwig</a> C-N coupling dataset</strong> (Doyle et al.): varying ligands, additives, and bases</li>
</ol>
<p>Performance was evaluated using a normalized advantage metric:</p>
<p>$$\text{Normalized Advantage} = \frac{\text{yield}_i - \overline{\text{yield}}}{\text{yield}_{\max} - \overline{\text{yield}}}$$</p>
<p>A value of 1 indicates maximum yield reached, 0 indicates random performance, and negative values indicate worse than random. The normalized maximum advantage (NMA) tracks the best result achieved up to each iteration.</p>
<p>Key findings from the optimization experiments:</p>
<ul>
<li>GPT-4 with prior information (10 random data points) produced better initial guesses than GPT-4 without prior information</li>
<li>Both GPT-4 approaches converged to similar NMA values at the limit</li>
<li>Both GPT-4 approaches outperformed standard <a href="https://en.wikipedia.org/wiki/Bayesian_optimization">Bayesian optimization</a> in NMA and normalized advantage</li>
<li>GPT-3.5 largely failed due to inability to output correct JSON schemas</li>
<li>On the Buchwald-Hartwig dataset, GPT-4 performed comparably whether given compound names or <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings, and could reason about electronic properties from SMILES representations</li>
</ul>
<p>All experiments used a maximum of 20 iterations (5.2% and 6.9% of the total reaction space for the two datasets).</p>
<h2 id="demonstrated-versatility-with-safety-considerations">Demonstrated Versatility with Safety Considerations</h2>
<p>Coscientist demonstrated that GPT-4, when equipped with appropriate tool access, can autonomously handle the full experimental chemistry workflow from literature search to reaction execution and data interpretation. The system showed chemical reasoning capabilities, including selecting appropriate reagents, providing justifications for choices based on reactivity and selectivity, and using experimental data to guide subsequent iterations.</p>
<p>Several limitations are acknowledged:</p>
<ul>
<li>The experimental setup was not yet fully automated (plates were moved manually between instruments), though no human decision-making was involved</li>
<li>GPT-3.5 consistently underperformed due to inability to follow formatting instructions</li>
<li>The synthesis planning evaluation scale is inherently subjective</li>
<li>It is unclear whether GPT-4&rsquo;s training data contained information from the optimization datasets</li>
<li>The comparison with Bayesian optimization may reflect different exploration/exploitation balances rather than pure capability differences</li>
</ul>
<p>The authors raise safety concerns about dual-use potential and note that full code and prompts were withheld pending development of US AI regulations. A simplified implementation was released for reproducibility purposes.</p>
<p>Future directions include extending the system with reaction databases (Reaxys, SciFinder), implementing advanced prompting strategies (ReAct, Chain of Thought, Tree of Thoughts), and developing automated quality control for cloud laboratory experiments.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Synthesis benchmark</td>
          <td>7 compound set</td>
          <td>7 compounds</td>
          <td>Acetaminophen, aspirin, ibuprofen, nitroaniline, etc.</td>
      </tr>
      <tr>
          <td>Optimization</td>
          <td>Perera et al. Suzuki flow dataset</td>
          <td>Fully mapped condition space</td>
          <td>Varying ligands, bases, solvents</td>
      </tr>
      <tr>
          <td>Optimization</td>
          <td>Doyle Buchwald-Hartwig dataset</td>
          <td>Fully mapped condition space</td>
          <td>Varying ligands, additives, bases</td>
      </tr>
      <tr>
          <td>Reagent selection</td>
          <td><a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> compound database</td>
          <td>Not specified</td>
          <td>Used for computational experiments</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Planner</strong>: GPT-4 chat completion with modular system prompts</li>
<li><strong>Web Searcher</strong>: GPT-4 or GPT-3.5-turbo for query generation and result parsing</li>
<li><strong>Documentation embedding</strong>: OpenAI ada model with distance-based vector search</li>
<li><strong>Code execution</strong>: Isolated Docker container (no LLM dependency)</li>
<li><strong>Baseline</strong>: Bayesian optimization with varying initial sample sizes (1-10)</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>GPT-4 (primary)</li>
<li>GPT-3.5-turbo (baseline)</li>
<li>Claude 1.3 (baseline for synthesis planning)</li>
<li>Falcon-40B-Instruct (baseline for synthesis planning)</li>
<li>OpenAI ada (for documentation embedding)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Context</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Synthesis score (1-5)</td>
          <td>7-compound benchmark</td>
          <td>Subjective expert grading</td>
      </tr>
      <tr>
          <td>Normalized advantage</td>
          <td>Optimization tasks</td>
          <td>Measures improvement over random</td>
      </tr>
      <tr>
          <td>NMA</td>
          <td>Optimization tasks</td>
          <td>Maximum advantage achieved through iteration N</td>
      </tr>
      <tr>
          <td>GC-MS confirmation</td>
          <td>Cross-coupling reactions</td>
          <td>Product formation verified experimentally</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Opentrons OT-2 liquid handler with heater-shaker module</li>
<li>UV-Vis plate reader</li>
<li>Emerald Cloud Lab (cloud-based automation)</li>
<li>Computational requirements not specified (relies on OpenAI API calls)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/gomesgroup/coscientist">gomesgroup/coscientist</a></td>
          <td>Code</td>
          <td>Apache-2.0 with Commons Clause</td>
          <td>Simplified implementation; full code withheld for safety</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Boiko, D. A., MacKnight, R., Kline, B. &amp; Gomes, G. (2023). Autonomous chemical research with large language models. <em>Nature</em>, 624(7992), 570-578. <a href="https://doi.org/10.1038/s41586-023-06792-0">https://doi.org/10.1038/s41586-023-06792-0</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{boiko2023autonomous,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Autonomous chemical research with large language models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Boiko, Daniil A. and MacKnight, Robert and Kline, Ben and Gomes, Gabriel dos Passos}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{624}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{7992}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{570--578}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer Nature}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41586-023-06792-0}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>VAE for Automatic Chemical Design (2018 Seminal)</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/</guid><description>A variational autoencoder maps SMILES strings to a continuous latent space, enabling gradient-based optimization for molecular design and generation.</description><content:encoded><![CDATA[<h2 id="a-foundational-method-for-continuous-molecular-representation">A Foundational Method for Continuous Molecular Representation</h2>
<p>This is a <strong>Method</strong> paper that introduces a variational autoencoder (VAE) framework for mapping discrete molecular representations (<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings) into a continuous latent space. The primary contribution is demonstrating that this continuous representation enables three key capabilities: (1) automatic generation of novel molecules by decoding random or perturbed latent vectors, (2) smooth interpolation between molecules in latent space, and (3) gradient-based optimization of molecular properties using a jointly trained property predictor. This work is widely regarded as one of the earliest and most influential applications of deep generative models to molecular design.</p>
<h2 id="the-challenge-of-searching-discrete-chemical-space">The Challenge of Searching Discrete Chemical Space</h2>
<p>Molecular design is fundamentally an optimization problem: identify molecules that maximize some set of desirable properties. The search space is enormous (estimated $10^{23}$ to $10^{60}$ drug-like molecules) and discrete, making systematic exploration difficult. Prior approaches fell into two categories, each with significant limitations:</p>
<ol>
<li><strong>Virtual screening</strong> over fixed libraries: effective but monolithic, costly to enumerate, and requiring hand-crafted rules to avoid impractical chemistries.</li>
<li><strong>Discrete local search</strong> (e.g., genetic algorithms): requires manual specification of mutation and crossover heuristics, and cannot leverage gradient information to guide the search.</li>
</ol>
<p>The core insight is that mapping molecules into a continuous vector space sidesteps these problems entirely. In a continuous space, new compounds can be generated by vector perturbation (no hand-crafted mutation rules), optimization can follow property gradients (enabling larger and more directed jumps), and large unlabeled chemical databases can be leveraged through unsupervised representation learning.</p>
<h2 id="a-vae-architecture-for-smiles-strings-with-joint-property-prediction">A VAE Architecture for SMILES Strings with Joint Property Prediction</h2>
<p>The architecture consists of three coupled neural networks trained jointly:</p>
<ol>
<li>
<p><strong>Encoder</strong>: Converts SMILES character strings into fixed-dimensional continuous vectors (the latent representation). Uses three 1D convolutional layers followed by a fully connected layer. For ZINC molecules, the latent space has 196 dimensions; for <a href="/notes/chemistry/datasets/qm9/">QM9</a>, 156 dimensions.</p>
</li>
<li>
<p><strong>Decoder</strong>: Converts latent vectors back into SMILES strings character by character using three layers of gated recurrent units (GRUs). The output is stochastic, as each character is sampled from a probability distribution over the SMILES alphabet.</p>
</li>
<li>
<p><strong>Property Predictor</strong>: A multilayer perceptron that predicts molecular properties directly from the latent representation. Joint training with the autoencoder reconstruction loss organizes the latent space so that molecules with similar properties cluster together.</p>
</li>
</ol>
<h3 id="the-vae-objective">The VAE Objective</h3>
<p>The model uses the <a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">variational autoencoder framework of Kingma and Welling</a>. The training objective combines three terms:</p>
<p>$$\mathcal{L} = \mathcal{L}_{recon} + \beta \cdot D_{KL}(q(z|x) | p(z)) + \lambda \cdot \mathcal{L}_{prop}$$</p>
<p>where $\mathcal{L}_{recon}$ is the reconstruction loss (cross-entropy over SMILES characters), $D_{KL}$ is the KL divergence regularizer that encourages the latent distribution $q(z|x)$ to match a standard Gaussian prior $p(z)$, and $\mathcal{L}_{prop}$ is the property prediction regression loss. Both the variational loss and the property prediction loss are annealed in using a sigmoid schedule after 29 epochs over a total of 120 epochs of training.</p>
<p>The KL regularization is critical: it forces the decoder to handle a wider variety of latent points, preventing &ldquo;dead areas&rdquo; in latent space that would decode to invalid molecules.</p>
<h3 id="gradient-based-optimization">Gradient-Based Optimization</h3>
<p>After training, a Gaussian process (GP) surrogate model is fit on top of the latent representations to predict the target property. Optimization proceeds by:</p>
<ol>
<li>Encoding a seed molecule into the latent space</li>
<li>Using the GP model to define a smooth property surface over the latent space</li>
<li>Optimizing the latent vector $z$ to maximize the predicted property via gradient ascent</li>
<li>Decoding the optimized $z$ back into a SMILES string</li>
</ol>
<p>The objective used for demonstration is $5 \times \text{QED} - \text{SAS}$, balancing drug-likeness (QED) against synthetic accessibility (SAS).</p>
<h2 id="experiments-on-zinc-and-qm9-datasets">Experiments on ZINC and QM9 Datasets</h2>
<p>Two autoencoder systems were trained:</p>
<ul>
<li><strong>ZINC</strong>: 250,000 drug-like molecules from the ZINC database, with a 196-dimensional latent space. Properties predicted: logP, QED, SAS.</li>
<li><strong>QM9</strong>: 108,000 molecules with fewer than 9 heavy atoms, with a 156-dimensional latent space. Properties predicted: HOMO energy, LUMO energy, electronic spatial extent ($\langle R^2 \rangle$).</li>
</ul>
<h3 id="latent-space-quality">Latent Space Quality</h3>
<p>The encoded latent dimensions follow approximately normal distributions as enforced by the variational regularizer. Decoding is stochastic: sampling the same latent point multiple times yields different SMILES strings, with the most frequent decoding tending to be closest to the original point in latent space. Decoding validity rates are 73-79% for points near known molecules but only 4% for randomly selected latent points.</p>
<p>Spherical interpolation (slerp) between molecules in latent space produces smooth structural transitions, accounting for the geometry of high-dimensional Gaussian distributions where linear interpolation would pass through low-probability regions.</p>
<h3 id="molecular-generation-comparison">Molecular Generation Comparison</h3>
<table>
  <thead>
      <tr>
          <th>Source</th>
          <th>Dataset</th>
          <th>Samples</th>
          <th>logP</th>
          <th>SAS</th>
          <th>QED</th>
          <th>% in ZINC</th>
          <th>% in eMolecules</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Data</td>
          <td>ZINC</td>
          <td>249k</td>
          <td>2.46 (1.43)</td>
          <td>3.05 (0.83)</td>
          <td>0.73 (0.14)</td>
          <td>100</td>
          <td>12.9</td>
      </tr>
      <tr>
          <td>GA</td>
          <td>ZINC</td>
          <td>5303</td>
          <td>2.84 (1.86)</td>
          <td>3.80 (1.01)</td>
          <td>0.57 (0.20)</td>
          <td>6.5</td>
          <td>4.8</td>
      </tr>
      <tr>
          <td>VAE</td>
          <td>ZINC</td>
          <td>8728</td>
          <td>2.67 (1.46)</td>
          <td>3.18 (0.86)</td>
          <td>0.70 (0.14)</td>
          <td>5.8</td>
          <td>7.0</td>
      </tr>
      <tr>
          <td>Data</td>
          <td>QM9</td>
          <td>134k</td>
          <td>0.30 (1.00)</td>
          <td>4.25 (0.94)</td>
          <td>0.48 (0.07)</td>
          <td>0.0</td>
          <td>8.6</td>
      </tr>
      <tr>
          <td>GA</td>
          <td>QM9</td>
          <td>5470</td>
          <td>0.96 (1.53)</td>
          <td>4.47 (1.01)</td>
          <td>0.53 (0.13)</td>
          <td>0.018</td>
          <td>3.8</td>
      </tr>
      <tr>
          <td>VAE</td>
          <td>QM9</td>
          <td>2839</td>
          <td>0.30 (0.97)</td>
          <td>4.34 (0.98)</td>
          <td>0.47 (0.08)</td>
          <td>0.0</td>
          <td>8.9</td>
      </tr>
  </tbody>
</table>
<p>The VAE generates molecules whose property distributions closely match the training data, outperforming a genetic algorithm baseline that biases toward higher chemical complexity and decreased drug-likeness. Only 5.8% of VAE-generated ZINC molecules were found in the original ZINC database, indicating genuine novelty.</p>
<h3 id="property-prediction">Property Prediction</h3>
<table>
  <thead>
      <tr>
          <th>Dataset/Property</th>
          <th>Mean Baseline</th>
          <th>ECFP</th>
          <th>Graph Conv.</th>
          <th>1-hot SMILES</th>
          <th>Encoder Only</th>
          <th>VAE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ZINC/logP</td>
          <td>1.14</td>
          <td>0.38</td>
          <td>0.05</td>
          <td>0.16</td>
          <td>0.13</td>
          <td>0.15</td>
      </tr>
      <tr>
          <td>ZINC/QED</td>
          <td>0.112</td>
          <td>0.045</td>
          <td>0.017</td>
          <td>0.041</td>
          <td>0.037</td>
          <td>0.054</td>
      </tr>
      <tr>
          <td>QM9/HOMO (eV)</td>
          <td>0.44</td>
          <td>0.20</td>
          <td>0.12</td>
          <td>0.12</td>
          <td>0.13</td>
          <td>0.16</td>
      </tr>
      <tr>
          <td>QM9/LUMO (eV)</td>
          <td>1.05</td>
          <td>0.20</td>
          <td>0.15</td>
          <td>0.11</td>
          <td>0.14</td>
          <td>0.16</td>
      </tr>
      <tr>
          <td>QM9/Gap (eV)</td>
          <td>1.07</td>
          <td>0.30</td>
          <td>0.18</td>
          <td>0.16</td>
          <td>0.18</td>
          <td>0.21</td>
      </tr>
  </tbody>
</table>
<p>The VAE latent representation achieves property prediction accuracy comparable to graph convolutions for some properties, though graph convolutions generally perform best. The primary purpose of joint training is not to maximize prediction accuracy but to organize the latent space for optimization.</p>
<h3 id="optimization-results">Optimization Results</h3>
<p>Bayesian optimization with a GP model on the jointly trained latent space consistently produces molecules with higher percentile scores on the $5 \times \text{QED} - \text{SAS}$ objective compared to both random Gaussian search and genetic algorithm baselines. Starting from molecules in the bottom 10th percentile of the ZINC dataset, the optimizer reliably discovers molecules in regions of high objective value. Training the GP with 1000 molecules (vs. 2000) produces a wider diversity of solutions by optimizing to multiple local optima rather than a single global optimum.</p>
<h2 id="key-findings-limitations-and-legacy">Key Findings, Limitations, and Legacy</h2>
<h3 id="key-findings">Key Findings</h3>
<ul>
<li>A continuous latent representation of molecules enables gradient-based search through chemical space, a qualitatively different approach from discrete enumeration or genetic algorithms.</li>
<li>Joint training with property prediction organizes the latent space by property values, creating smooth gradients that optimization can follow.</li>
<li>The VAE generates novel molecules with realistic property distributions, and the latent space encodes an estimated 7.5 million molecules despite training on only 250,000.</li>
</ul>
<h3 id="acknowledged-limitations">Acknowledged Limitations</h3>
<ul>
<li>The SMILES-based decoder sometimes produces formally valid but chemically undesirable molecules (acid chlorides, anhydrides, cyclopentadienes, aziridines, etc.) because the grammar of valid SMILES does not capture all synthetic or stability constraints.</li>
<li>Character-level SMILES generation is fragile: the decoder must implicitly learn which strings are valid SMILES, making the learning problem harder than necessary.</li>
<li>Decoding validity drops to only 4% for random latent points far from training data, limiting the ability to explore truly novel regions of chemical space.</li>
</ul>
<h3 id="directions-identified">Directions Identified</h3>
<p>The authors point to several extensions that were already underway at the time of publication:</p>
<ul>
<li><strong><a href="/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/">Grammar VAE</a></strong>: Using an explicitly defined SMILES grammar instead of forcing the model to learn one (Kusner et al., 2017).</li>
<li><strong>Graph-based decoders</strong>: Directly outputting molecular graphs to avoid the SMILES validity problem.</li>
<li><strong>Adversarial training</strong>: Using GANs for molecular generation (<a href="/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/">ORGAN, ORGANIC</a>).</li>
<li><strong>LSTM/RNN generators</strong>: Applying recurrent networks directly to SMILES for generation and reaction prediction.</li>
</ul>
<p>This paper has been cited over 2,900 times and launched a large body of follow-up work in VAE-based, GAN-based, and reinforcement learning-based molecular generation.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>ZINC (drug-like subset)</td>
          <td>250,000 molecules</td>
          <td>Randomly sampled from ZINC database</td>
      </tr>
      <tr>
          <td>Training</td>
          <td>QM9</td>
          <td>108,000 molecules</td>
          <td>Molecules with fewer than 9 heavy atoms</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>ZINC held-out set</td>
          <td>5,000 molecules</td>
          <td>For latent space analysis</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Encoder</strong>: 3 x 1D convolutional layers (ZINC: filters 9,9,10 with kernels 9,9,11; QM9: filters 2,2,1 with kernels 5,5,4), followed by a fully connected layer</li>
<li><strong>Decoder</strong>: 3 x GRU layers (ZINC: hidden dim 488; QM9: hidden dim 500), trained with teacher forcing</li>
<li><strong>Property Predictor</strong>: 2 fully connected layers of 1000 neurons (dropout 0.20) for prediction; smaller 3-layer MLP of 67 neurons (dropout 0.15) for latent space shaping</li>
<li><strong>Variational loss annealing</strong>: Sigmoid schedule after 29 epochs, total 120 epochs</li>
<li><strong>SMILES validation</strong>: Post-hoc filtering with RDKit; invalid outputs discarded</li>
<li><strong>Optimization</strong>: Gaussian process surrogate model trained on 2000 maximally diverse molecules from latent space</li>
</ul>
<h3 id="models">Models</h3>
<p>Built with Keras and TensorFlow. Latent dimensions: 196 (ZINC), 156 (QM9). SMILES alphabet: 35 characters (ZINC), 22 characters (QM9). Maximum string length: 120 (ZINC), 34 (QM9). Only canonicalized SMILES used for training.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>logP</td>
          <td>Water-octanol partition coefficient</td>
      </tr>
      <tr>
          <td>QED</td>
          <td>Quantitative Estimation of Drug-likeness (0-1)</td>
      </tr>
      <tr>
          <td>SAS</td>
          <td>Synthetic Accessibility Score</td>
      </tr>
      <tr>
          <td>HOMO/LUMO (eV)</td>
          <td>Frontier orbital energies (QM9)</td>
      </tr>
      <tr>
          <td>Decoding validity</td>
          <td>Fraction of latent points producing valid SMILES</td>
      </tr>
      <tr>
          <td>Novelty</td>
          <td>Fraction of generated molecules not in training set</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Training was performed on the Harvard FAS Odyssey Cluster. Specific GPU types and training times are not reported. The Gaussian process optimization requires only minutes to train on a few thousand molecules.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/aspuru-guzik-group/chemical_vae">chemical_vae</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official implementation with training scripts and pre-trained models</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Gómez-Bombarelli, R., Wei, J. N., Duvenaud, D., Hernández-Lobato, J. M., Sánchez-Lengeling, B., Sheberla, D., Aguilera-Iparraguirre, J., Hirzel, T. D., Adams, R. P., &amp; Aspuru-Guzik, A. (2018). Automatic Chemical Design Using a Data-Driven Continuous Representation of Molecules. <em>ACS Central Science</em>, 4(2), 268-276. <a href="https://doi.org/10.1021/acscentsci.7b00572">https://doi.org/10.1021/acscentsci.7b00572</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{gomez2018automatic,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Automatic Chemical Design Using a Data-Driven Continuous Representation of Molecules}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{G{\&#39;o}mez-Bombarelli, Rafael and Wei, Jennifer N. and Duvenaud, David and Hern{\&#39;a}ndez-Lobato, Jos{\&#39;e} Miguel and S{\&#39;a}nchez-Lengeling, Benjam{\&#39;i}n and Sheberla, Dennis and Aguilera-Iparraguirre, Jorge and Hirzel, Timothy D. and Adams, Ryan P. and Aspuru-Guzik, Al{\&#39;a}n}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{ACS Central Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{268--276}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acscentsci.7b00572}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SMI+AIS: Hybridizing SMILES with Environment Tokens</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/smi-ais-hybrid-molecular-representation/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/smi-ais-hybrid-molecular-representation/</guid><description>SMI+AIS hybridizes SMILES with Atom-In-SMILES tokens encoding local chemical environments, improving molecular generation binding affinity and synthesizability.</description><content:encoded><![CDATA[<h2 id="a-hybrid-molecular-representation-combining-smiles-and-chemical-environment-tokens">A Hybrid Molecular Representation Combining SMILES and Chemical-Environment Tokens</h2>
<p>This is a <strong>Method</strong> paper that introduces SMI+AIS(N), a hybrid molecular string representation combining standard <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> tokens with <a href="/notes/chemistry/molecular-representations/notations/atom-in-smiles-tokenization/">Atom-In-SMILES (AIS)</a> tokens. AIS tokens encode local chemical environment information (central atom, ring membership, and neighboring atoms) into a single token. The key contribution is a systematic hybridization strategy that selectively replaces the most frequent SMILES tokens with AIS equivalents, preserving SMILES grammar compatibility while enriching token diversity. The method is validated on molecular structure generation via latent space optimization for drug design.</p>
<h2 id="limitations-of-standard-smiles-for-machine-learning">Limitations of Standard SMILES for Machine Learning</h2>
<p>SMILES is the most widely adopted string-based molecular representation, used in major databases like ZINC and PubChem. Despite this ubiquity, SMILES has several well-known limitations for machine learning applications:</p>
<ol>
<li><strong>Non-unique representations</strong>: The same molecule can be encoded as multiple distinct SMILES strings.</li>
<li><strong>Invalid string generation</strong>: Generative models can produce syntactically invalid SMILES that do not correspond to any molecule.</li>
<li><strong>Limited token diversity</strong>: SMILES tokens map one-to-one to atoms or bonds, so the token vocabulary is restricted to the available atom and bond types.</li>
<li><strong>Insufficient chemical context</strong>: Individual SMILES tokens carry no information about the local chemical environment of an atom.</li>
</ol>
<p>Alternative representations like <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> (guaranteeing validity) and <a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a> (guaranteeing uniqueness) address some of these issues but share the same fundamental limitation of low token diversity. The Atom-In-SMILES (AIS) representation (Ucak et al., 2023) enriches tokens with neighboring atom and ring information, but using AIS exclusively produces a large vocabulary with many infrequent tokens that can cause data sparsity problems. The authors aim to find a middle ground: adding chemical context to the most common tokens while keeping the vocabulary manageable.</p>
<h2 id="core-innovation-selective-token-hybridization-with-ais">Core Innovation: Selective Token Hybridization with AIS</h2>
<p>The SMI+AIS(N) representation hybridizes standard SMILES with AIS tokens through a frequency-based selection process:</p>
<h3 id="ais-token-structure">AIS Token Structure</h3>
<p>Each AIS token encodes three pieces of information about an atom, delimited by semicolons:</p>
<p>$$
\lbrack \text{central atom} ; \text{ring info} ; \text{neighbor atoms} \rbrack
$$</p>
<p>For example, the oxygen in a carboxyl group of benzoic acid is represented as <code>[O;!R;C]</code>, meaning: oxygen atom, not in a ring, bonded to carbon. In standard SMILES, this would simply be <code>O</code>.</p>
<h3 id="hybridization-procedure">Hybridization Procedure</h3>
<ol>
<li>Convert all SMILES strings in the <a href="/notes/chemistry/datasets/zinc-22/">ZINC database</a> to their full AIS representations.</li>
<li>Count the frequency of each AIS token across the database.</li>
<li>Select the top-N most frequent AIS tokens to form the hybrid vocabulary.</li>
<li>In the hybrid representation, atoms matching these top-N AIS tokens are written in AIS notation; all other atoms use standard SMILES notation.</li>
</ol>
<p>For benzoic acid, the hybridization produces:</p>
<p>$$
\text{SMI}: \texttt{O=C(O)c1ccccc1}
$$</p>
<p>$$
\text{SMI+AIS}: \texttt{\lbrack O;!R;C\rbrack=\lbrack C;!R;COO\rbrack(\lbrack OH;!R;C\rbrack)c1ccccc1}
$$</p>
<p>The parameter N controls vocabulary size. The authors test N = 50, 100, 150, and 200, finding that N = 100-150 provides the best balance for the ZINC database.</p>
<h3 id="token-frequency-rebalancing">Token Frequency Rebalancing</h3>
<p>A key benefit of hybridization is mitigating the severe token frequency imbalance in standard SMILES. Carbon (C), the most frequent element with ~184 million occurrences in ZINC, is represented by only 16 token types in SMILES. With SMI+AIS(200), carbon is distinguished into 145 token types based on chemical environment, with 74% of carbon occurrences represented by AIS tokens. Less common elements like halogens see minimal change (only 2% AIS representation), which avoids introducing unnecessarily rare tokens.</p>
<table>
  <thead>
      <tr>
          <th>Element</th>
          <th>Frequency</th>
          <th>SMILES Types</th>
          <th>SMI+AIS(100) Types (AIS %)</th>
          <th>SMI+AIS(200) Types (AIS %)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>C</td>
          <td>183,860,954</td>
          <td>16</td>
          <td>78 (73%)</td>
          <td>145 (74%)</td>
      </tr>
      <tr>
          <td>O</td>
          <td>27,270,229</td>
          <td>8</td>
          <td>16 (11%)</td>
          <td>24 (11%)</td>
      </tr>
      <tr>
          <td>N</td>
          <td>26,022,928</td>
          <td>11</td>
          <td>32 (1%)</td>
          <td>46 (10%)</td>
      </tr>
      <tr>
          <td>X (halogens)</td>
          <td>6,137,030</td>
          <td>7</td>
          <td>10 (2%)</td>
          <td>11 (2%)</td>
      </tr>
      <tr>
          <td>S</td>
          <td>4,581,307</td>
          <td>12</td>
          <td>17 (2%)</td>
          <td>24 (2%)</td>
      </tr>
  </tbody>
</table>
<h2 id="latent-space-optimization-for-molecular-generation">Latent Space Optimization for Molecular Generation</h2>
<h3 id="model-architecture">Model Architecture</h3>
<p>The evaluation uses a <a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">conditional variational autoencoder (CVAE)</a> with:</p>
<ul>
<li><strong>Encoder</strong>: BERT-style architecture with entity and positional embeddings, 4 multi-head attention layers (8 heads each), producing mean and standard deviation vectors in latent space.</li>
<li><strong>Decoder</strong>: 4 stacked gated recurrent unit (GRU) layers that transform sampled latent vectors (conditioned) back into token sequences.</li>
<li>Training: 20 epochs on 9 million compounds from the ZINC database (8:1:1 train/valid/test split) under identical conditions for all representations.</li>
</ul>
<h3 id="optimization-setup">Optimization Setup</h3>
<p><a href="https://en.wikipedia.org/wiki/Bayesian_optimization">Bayesian Optimization</a> (BO) via BoTorch is applied to the CVAE <a href="/notes/chemistry/molecular-design/generation/latent-space/">latent space</a>, maximizing a multi-objective function:</p>
<p>$$
\text{Obj} = -\text{BA} - 0.5 \times \text{SA}^2
$$</p>
<p>where BA is binding affinity (docking score from QuickVina 2, lower is stronger) and SA is synthetic accessibility score (from RDKit, lower is more synthesizable). Each BO iteration generates 800 candidate latent vectors. Invalid strings receive a penalty objective value of -100.</p>
<h3 id="protein-targets">Protein Targets</h3>
<p>Four diverse targets were used to assess generalizability:</p>
<ul>
<li><strong>PDK4</strong> (<a href="https://en.wikipedia.org/wiki/Pyruvate_dehydrogenase_kinase">Pyruvate Dehydrogenase Kinase</a> 4): narrow, deep binding pocket</li>
<li><strong>5-HT1B</strong> (<a href="https://en.wikipedia.org/wiki/5-HT1B_receptor">Serotonin Receptor 1B</a>): shallow, open <a href="https://en.wikipedia.org/wiki/G_protein-coupled_receptor">GPCR</a> conformation</li>
<li><strong>PARP1</strong> (<a href="https://en.wikipedia.org/wiki/PARP1">Poly ADP-ribose Polymerase 1</a>): small, flexible molecule binding site</li>
<li><strong>CK1d</strong> (<a href="https://en.wikipedia.org/wiki/Casein_kinase_1">Casein Kinase I</a> Delta): broad, accessible conformation</li>
</ul>
<p>Protein structures were obtained from the <a href="https://en.wikipedia.org/wiki/Protein_Data_Bank">Protein Data Bank</a> (PDB IDs: 4V26, 4IAQ, 6I8M, 4TN6). Each optimization was run 10 times independently from the same 5 initial compounds selected from BindingDB.</p>
<h3 id="key-results">Key Results</h3>
<p>SMI+AIS(100) consistently achieved the highest objective values across protein targets.</p>
<p><strong>PDK4 Optimization</strong> (Top-1 results over 10 independent runs):</p>
<ul>
<li>SMI+AIS(100) achieved approximately 12% improvement over standard SMILES and 28% improvement over SELFIES based on median Top-1 objective values.</li>
<li>Generated structures exhibited BA scores between -10 and -9 and SA scores between 2.0 and 2.3.</li>
<li>Molecular weights clustered around 400 amu, consistent with the CVAE conditioning.</li>
</ul>
<p><strong>Validity Ratios</strong>: Standard SMILES produced approximately 40% valid structures. SMI+AIS representations showed significant improvement as N increased, though SMI+AIS(200) showed slight saturation, likely from insufficiently trained infrequent tokens.</p>
<p><strong>SELFIES</strong>: Despite achieving the highest validity ratio, SELFIES failed to generate chemically meaningful structures with desirable BA and SA scores. The authors attribute this to SELFIES grammar where token meaning is highly context-dependent, causing minor latent space variations to produce large structural changes.</p>
<p><strong>Cross-target consistency</strong>: Improvements were observed across all four protein targets, with slight variation (5-HT1B showed smaller differences between SMI and SMI+AIS(100) for Top-1, while other targets showed significant improvements).</p>
<h2 id="improved-molecular-generation-through-chemical-context-enrichment">Improved Molecular Generation Through Chemical Context Enrichment</h2>
<p>The SMI+AIS(N) representation achieves consistent improvements in molecular generation quality compared to both standard SMILES and SELFIES. The core findings are:</p>
<ol>
<li><strong>Binding affinity improvement</strong>: Approximately 7% improvement over standard SMILES for the PDK4 target.</li>
<li><strong>Synthesizability improvement</strong>: Approximately 6% increase in synthetic accessibility scores.</li>
<li><strong>Target independence</strong>: Performance gains transfer across four structurally diverse protein targets.</li>
<li><strong>Preserved structural motifs</strong>: The generative model retains chemically meaningful fragments (e.g., acetamide and <a href="https://en.wikipedia.org/wiki/Piperidine">piperidine</a>) from initial compounds without explicit fragment constraints.</li>
</ol>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge several limitations:</p>
<ul>
<li><strong>Stereochemistry</strong>: SMI+AIS inherits the limited stereochemistry handling of standard SMILES.</li>
<li><strong>Evaluation scope</strong>: Only molecular generation was tested; property prediction and other ML tasks remain unexplored.</li>
<li><strong>Compute constraints</strong>: The study was limited to molecular generation due to computing power and time.</li>
<li><strong>Single optimization strategy</strong>: Only latent space optimization with Bayesian optimization was evaluated; other generative approaches were not compared.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The authors suggest extending SMI+AIS to diverse benchmarking tests including molecular property prediction, experimental validation, and broader applications of chemical language models.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Vocab</td>
          <td>ZINC Database</td>
          <td>9M compounds</td>
          <td>Canonicalized, deduplicated, split 8:1:1</td>
      </tr>
      <tr>
          <td>Binding targets</td>
          <td>BindingDB</td>
          <td>5 initial compounds per target</td>
          <td>Selected for each protein target</td>
      </tr>
      <tr>
          <td>Protein structures</td>
          <td>PDB</td>
          <td>4 structures</td>
          <td>IDs: 4V26, 4IAQ, 6I8M, 4TN6</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Tokenization</strong>: AIS token frequency counting on full ZINC database, top-N selection</li>
<li><strong>Generative model</strong>: Conditional VAE with BERT encoder (4 layers, 8 heads) and GRU decoder (4 layers)</li>
<li><strong>Optimization</strong>: Bayesian Optimization via BoTorch (800 candidates per iteration)</li>
<li><strong>Docking</strong>: QuickVina 2 with 25 A pocket size, 10 docking simulations per ligand</li>
<li><strong>SA scoring</strong>: RDKit SA score</li>
<li>Training: 20 epochs for all representations under identical conditions</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>CVAE architecture details in supplementary (Fig. S9, Tables S2, S4)</li>
<li>No pre-trained weights released</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>SMI+AIS(100) vs SMILES</th>
          <th>SMI+AIS(100) vs SELFIES</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Median Top-1 Obj. Value</td>
          <td>+12%</td>
          <td>+28%</td>
          <td>PDK4 target</td>
      </tr>
      <tr>
          <td>Validity Ratio</td>
          <td>Higher than ~40% (SMILES)</td>
          <td>Lower than SELFIES</td>
          <td>SMI+AIS improves with N</td>
      </tr>
      <tr>
          <td>BA (binding affinity)</td>
          <td>~7% improvement</td>
          <td>Substantial</td>
          <td>Lower (more negative) is better</td>
      </tr>
      <tr>
          <td>SA (synthesizability)</td>
          <td>~6% improvement</td>
          <td>Substantial</td>
          <td>Lower is more synthesizable</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Hardware details are not specified in the main text. Optimization wall times are reported in supplementary Table S5.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/herim-han/AIS-Drug-Opt">AIS-Drug-Opt</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Source code and datasets for reproduction</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility Status</strong>: Partially Reproducible. Code and processed data are publicly available on GitHub, but no pre-trained model weights are released, the license is unspecified, and hardware requirements are not documented in the main text.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Han, H., Yeom, M. S., &amp; Choi, S. (2025). Hybridization of SMILES and chemical-environment-aware tokens to improve performance of molecular structure generation. <em>Scientific Reports</em>, 15, 16892. <a href="https://doi.org/10.1038/s41598-025-01890-7">https://doi.org/10.1038/s41598-025-01890-7</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{han2025hybridization,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Hybridization of SMILES and chemical-environment-aware tokens to improve performance of molecular structure generation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Han, Herim and Yeom, Min Sun and Choi, Sunghwan}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Scientific Reports}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{16892}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer Nature}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41598-025-01890-7}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>PASITHEA: Gradient-Based Molecular Design via Dreaming</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/deep-molecular-dreaming-pasithea/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/deep-molecular-dreaming-pasithea/</guid><description>PASITHEA applies inceptionism to molecular design, using gradient-based optimization on SELFIES representations to generate molecules with target properties.</description><content:encoded><![CDATA[<h2 id="inceptionism-applied-to-molecular-inverse-design">Inceptionism Applied to Molecular Inverse Design</h2>
<p>This is a <strong>Method</strong> paper that introduces PASITHEA, a gradient-based approach to de-novo molecular design inspired by inceptionism (deep dreaming) techniques from computer vision. The core contribution is a direct optimization framework that modifies molecular structures by backpropagating through a trained property-prediction network, with the molecular input (rather than weights) serving as the optimizable variable. PASITHEA is enabled by SELFIES, a surjective molecular string representation that guarantees 100% validity of generated molecules.</p>
<h2 id="the-need-for-direct-gradient-based-molecular-optimization">The Need for Direct Gradient-Based Molecular Optimization</h2>
<p>Existing inverse molecular design methods, including variational autoencoders (VAEs), generative adversarial networks (GANs), reinforcement learning (RL), and genetic algorithms (GAs), share a common characteristic: they optimize molecules indirectly. VAEs and GANs learn distributions and scan latent spaces. RL agents learn policies from environmental rewards. GAs iteratively apply mutations and selections. None of these approaches directly maximize an objective function in a gradient-based manner with respect to the molecular representation itself.</p>
<p>This indirection has several consequences. VAE-based methods require learning a latent space, and the optimization happens in that space rather than directly on molecular structures. RL and GA methods require expensive function evaluations for each candidate molecule. The authors identify an opportunity to exploit gradients more directly by reversing the learning process of a neural network trained to predict molecular properties, thereby sidestepping latent spaces, policies, and population-based search entirely.</p>
<p>A second motivation is interpretability. By operating directly on the molecular representation (rather than a learned latent space), PASITHEA can reveal what a regression network has learned about structure-property relationships, a capability the authors frame as analogous to how deep dreaming reveals what image classifiers have learned about visual features.</p>
<h2 id="core-innovation-inverting-regression-networks-on-selfies">Core Innovation: Inverting Regression Networks on SELFIES</h2>
<p>PASITHEA&rsquo;s key insight is a two-phase training procedure that repurposes the standard neural network training loop for molecule generation.</p>
<p><strong>Phase 1: Prediction training.</strong> A fully connected neural network is trained to predict a real-valued chemical property (logP) from one-hot encoded SELFIES strings. The standard feedforward and backpropagation process updates the network weights to minimize mean squared error between predicted and ground-truth property values:</p>
<p>$$
\min_{\theta} \frac{1}{N} \sum_{i=1}^{N} (f_{\theta}(\mathbf{x}_i) - y_i)^2
$$</p>
<p>where $f_{\theta}$ is the neural network with parameters $\theta$, $\mathbf{x}_i$ is the one-hot encoded SELFIES input, and $y_i$ is the target logP value.</p>
<p><strong>Phase 2: Inverse training (deep dreaming).</strong> The network weights $\theta$ are frozen. For a given input molecule $\mathbf{x}$ and a desired target property value $y_{\text{target}}$, the gradients are computed with respect to the input representation rather than the weights:</p>
<p>$$
\mathbf{x} \leftarrow \mathbf{x} - \eta \nabla_{\mathbf{x}} \mathcal{L}(f_{\theta}(\mathbf{x}), y_{\text{target}})
$$</p>
<p>This gradient descent on the input incrementally modifies the one-hot encoding of the molecular string, transforming it toward a structure whose predicted property matches the target value. At each step, the argmax function converts the continuous one-hot encoding back to a discrete SELFIES string, which always maps to a valid molecular graph due to the surjective property of SELFIES.</p>
<p><strong>The role of SELFIES.</strong> The surjective mapping from strings to molecular graphs is essential. With SMILES, intermediate strings during optimization can become syntactically invalid (e.g., an unclosed ring like &ldquo;CCCC1CCCCC&rdquo;), producing no valid molecule. SELFIES enforces constraints that guarantee every string maps to a valid molecular graph, making the continuous gradient-based optimization feasible.</p>
<p><strong>Input noise injection.</strong> Because inverse training transforms a one-hot encoding from binary values to real numbers, the discrete-to-continuous transition can cause convergence problems. The authors address this by initializing the input with noise: every zero in the one-hot encoding is replaced by a random number in $[0, k]$, where $k$ is a hyperparameter between 0.5 and 0.95. This smooths the optimization landscape and enables incremental molecular modifications rather than abrupt changes.</p>
<h2 id="experimental-setup-on-qm9-with-logp-optimization">Experimental Setup on QM9 with LogP Optimization</h2>
<h3 id="dataset-and-property">Dataset and Property</h3>
<p>The experiments use a random subset of 10,000 molecules from the <a href="/notes/chemistry/datasets/qm9/">QM9</a> dataset. The target property is the logarithm of the partition coefficient (logP), computed using RDKit. LogP measures lipophilicity, an important drug-likeness indicator that follows an approximately normal distribution in QM9 and has a nearly continuous range, making it suitable for gradient-based optimization.</p>
<h3 id="network-architecture">Network Architecture</h3>
<p>PASITHEA uses a fully connected neural network with four layers, each containing 500 nodes with ReLU activation. The loss function is mean squared error. Data is split 85%/15% for training/testing. The prediction model trains for approximately 1,500 epochs with an Adam optimizer and a learning rate of $1 \times 10^{-6}$.</p>
<p>For inverse training, the authors select a noise upper-bound of 0.9 and a learning rate of 0.01, chosen from hyperparameter tuning experiments that evaluate the percentage of molecules optimized toward the target property.</p>
<h3 id="optimization-targets">Optimization Targets</h3>
<p>Two extreme logP targets are used: $+6$ (high lipophilicity) and $-6$ (low lipophilicity). These values exceed the range of logP values in the QM9 dataset (minimum: $-2.19$, maximum: $3.08$), testing whether the model can extrapolate beyond the training distribution.</p>
<h2 id="distribution-shifts-and-interpretable-molecular-transformations">Distribution Shifts and Interpretable Molecular Transformations</h2>
<h3 id="distribution-level-results">Distribution-Level Results</h3>
<p>Applying deep dreaming to the full set of 10,000 molecules produces a clear shift in the logP distribution:</p>
<table>
  <thead>
      <tr>
          <th>Statistic</th>
          <th>QM9 Original</th>
          <th>Optimized (target +6)</th>
          <th>Optimized (target -6)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Mean logP</td>
          <td>0.3909</td>
          <td>1.8172</td>
          <td>-0.3360</td>
      </tr>
      <tr>
          <td>Min logP</td>
          <td>-2.1903</td>
          <td>-0.8240</td>
          <td>-2.452</td>
      </tr>
      <tr>
          <td>Max logP</td>
          <td>3.0786</td>
          <td>4.2442</td>
          <td>0.9018</td>
      </tr>
  </tbody>
</table>
<p>The optimized distributions extend beyond the original dataset&rsquo;s property range. The right-shifted distribution (target +6) produces molecules with logP values up to 4.24, exceeding the original maximum of 3.08. The left-shifted distribution (target -6) reaches -2.45, below the original minimum. This indicates that PASITHEA can generate molecules with properties outside the training data bounds.</p>
<p>Additionally, 97.2% of the generated molecules do not exist in the original training set, indicating that the network is not memorizing data but rather using structural features to guide optimization. Some generated molecules contain more heavy atoms than the QM9 maximum of 9, since the SELFIES string length allows for larger structures.</p>
<h3 id="molecule-level-interpretability">Molecule-Level Interpretability</h3>
<p>The stepwise molecular transformations reveal interpretable &ldquo;strategies&rdquo; the network employs:</p>
<ol>
<li>
<p><strong>Nitrogen appendage</strong>: When optimizing for lower logP, the network repeatedly appends nitrogen atoms to the molecule. The authors observe this as a consistent pattern across multiple test molecules, reflecting the known relationship between nitrogen content and reduced lipophilicity.</p>
</li>
<li>
<p><strong>Length modulation</strong>: When optimizing for higher logP, the network tends to increase molecular chain length (e.g., extending a carbon chain). When optimizing for lower logP, it shortens chains. This captures the intuition that larger, more carbon-heavy molecules tend to be more lipophilic.</p>
</li>
<li>
<p><strong>Bond order changes</strong>: The network replaces single bonds with double or triple bonds during optimization, demonstrating an understanding of the relationship between bonding patterns and logP.</p>
</li>
<li>
<p><strong>Consistency across trials</strong>: Because the input initialization includes random noise, repeated trials with the same molecule produce different transformation sequences. Despite this stochasticity, the network applies consistent strategies across trials (e.g., always shortening chains for negative optimization), validating that it has learned genuine structure-property relationships.</p>
</li>
</ol>
<h3 id="thermodynamic-stability">Thermodynamic Stability</h3>
<p>The authors assess synthesizability by computing heats of formation using MOPAC2016 at the PM7 level of theory. Some optimization trajectories move toward thermodynamically stable molecules (negative heats of formation), while others produce less stable structures. The authors acknowledge this limitation and propose multi-objective optimization incorporating stability as a future direction.</p>
<h3 id="comparison-to-vaes">Comparison to VAEs</h3>
<p>The key distinction from VAEs is where gradient computation occurs. In VAEs, a latent space is learned through encoding and decoding, and property optimization happens in that latent space. In PASITHEA, gradients are computed directly with respect to the molecular representation (SELFIES one-hot encoding). The authors argue this makes the approach more interpretable, since we can probe what the network learned about molecular structure without the &ldquo;detour&rdquo; through a latent space.</p>
<h3 id="limitations">Limitations</h3>
<p>The authors are forthright about the preliminary nature of these results:</p>
<ul>
<li>The method is demonstrated only on a small subset of QM9 with a single, computationally inexpensive property (logP).</li>
<li>The simple four-layer architecture may not scale to larger molecular spaces or more complex properties.</li>
<li>Generated molecules are not always thermodynamically stable, requiring additional optimization objectives.</li>
<li>The approach has not been benchmarked against established methods (VAEs, GANs, RL) on standard generative benchmarks.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Evaluation</td>
          <td>QM9 (random subset)</td>
          <td>10,000 molecules</td>
          <td>logP values computed via RDKit</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Prediction training</strong>: 4-layer fully connected NN, 500 nodes/layer, ReLU activation, MSE loss, Adam optimizer, LR $1 \times 10^{-6}$, ~1,500 epochs, 85/15 train/test split</li>
<li><strong>Inverse training</strong>: Frozen weights, Adam optimizer, LR 0.01, noise upper-bound 0.9, logP targets of +6 and -6</li>
<li><strong>Heats of formation</strong>: MOPAC2016, PM7 level, geometry optimization with eigenvector following (EF)</li>
</ul>
<h3 id="models">Models</h3>
<p>The architecture is a simple 4-layer MLP. No pre-trained weights are distributed, but the full code is available.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Novel molecules</td>
          <td>97.2%</td>
          <td>Generated molecules not in training set</td>
      </tr>
      <tr>
          <td>Max logP (target +6)</td>
          <td>4.2442</td>
          <td>Exceeds QM9 max of 3.0786</td>
      </tr>
      <tr>
          <td>Min logP (target -6)</td>
          <td>-2.452</td>
          <td>Below QM9 min of -2.1903</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/aspuru-guzik-group/Pasithea">Pasithea</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Shen, C., Krenn, M., Eppel, S., &amp; Aspuru-Guzik, A. (2021). Deep molecular dreaming: inverse machine learning for de-novo molecular design and interpretability with surjective representations. <em>Machine Learning: Science and Technology</em>, 2(3), 03LT02. <a href="https://doi.org/10.1088/2632-2153/ac09d6">https://doi.org/10.1088/2632-2153/ac09d6</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{shen2021deep,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Deep molecular dreaming: inverse machine learning for de-novo molecular design and interpretability with surjective representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Shen, Cynthia and Krenn, Mario and Eppel, Sagi and Aspuru-Guzik, Al{\&#39;a}n}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Machine Learning: Science and Technology}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{3}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{03LT02}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{IOP Publishing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1088/2632-2153/ac09d6}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Link-INVENT: RL-Driven Molecular Linker Generation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/link-invent-generative-linker-design/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/link-invent-generative-linker-design/</guid><description>Link-INVENT extends REINVENT for molecular linker design using RNN-based generation and reinforcement learning with flexible multi-parameter scoring.</description><content:encoded><![CDATA[<h2 id="a-method-for-generative-linker-design-with-reinforcement-learning">A Method for Generative Linker Design with Reinforcement Learning</h2>
<p>Link-INVENT is a <strong>Method</strong> paper that introduces a generative model for molecular linker design built on the <a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a> de novo design platform. The primary contribution is an encoder-decoder recurrent neural network (RNN) architecture that generates SMILES-based linkers connecting two molecular subunits, combined with a flexible multi-parameter optimization (MPO) scoring function and reinforcement learning (RL) to steer generation toward desired properties. Link-INVENT targets three practical drug discovery tasks: fragment linking, scaffold hopping, and <a href="https://en.wikipedia.org/wiki/Proteolysis_targeting_chimera">proteolysis targeting chimera</a> (PROTAC) design.</p>
<h2 id="why-linker-design-needs-flexible-multi-parameter-optimization">Why Linker Design Needs Flexible Multi-Parameter Optimization</h2>
<p>Generating suitable chemical linkers between molecular subunits is a central challenge in <a href="https://en.wikipedia.org/wiki/Fragment-based_lead_discovery">fragment-based drug discovery</a> (FBDD), scaffold hopping, and PROTAC design. Traditional computational approaches rely on database searches, inherently limiting the generalizability of proposed linkers to the pre-defined collection. Recent deep learning methods (DeLinker, SyntaLinker, 3DLinker, DiffLinker) can generate novel linkers but offer limited support for optimizing specific physicochemical properties. Users can typically control only linker length and a few properties like hydrogen-bond donor count.</p>
<p>The key gaps that Link-INVENT addresses are:</p>
<ol>
<li><strong>Conditioning on both subunits</strong>: Prior RNN-based approaches (SAMOA) generate linkers conditioned only on the SMILES sequence seen so far, which may not account for the second molecular subunit. Link-INVENT conditions on both warheads simultaneously.</li>
<li><strong>Flexible scoring</strong>: Existing DL-based linker design tools lack the ability to define tailored MPO objectives. Link-INVENT inherits <a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent4-generative-molecule-design/">REINVENT 4&rsquo;s</a> full scoring infrastructure and adds linker-specific properties.</li>
<li><strong>Generalizability</strong>: A single trained prior handles fragment linking, scaffold hopping, and PROTAC tasks without retraining.</li>
</ol>
<h2 id="core-innovation-conditional-linker-generation-with-augmented-likelihood-rl">Core Innovation: Conditional Linker Generation with Augmented Likelihood RL</h2>
<p>Link-INVENT&rsquo;s architecture is an encoder-decoder RNN adapted from the Lib-INVENT library design model. The encoder processes a pair of warheads (molecular subunits with defined exit vectors), and the decoder generates a linker token by token, yielding a connected molecule in SMILES format. The model uses three hidden layers of 512 LSTM cells with an embedding size of 256.</p>
<h3 id="training">Training</h3>
<p>The prior is trained on ChEMBL v27 data processed through reaction-based slicing to generate (linker, warheads pair, full molecule) tuples. <a href="/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/">SMILES randomization</a> augments the training data at each epoch, improving chemical space generalizability. The prior is trained by maximizing the likelihood of generating a linker conditioned on the input warhead pair, with teacher forcing for stability.</p>
<h3 id="multi-parameter-optimization-via-rl">Multi-Parameter Optimization via RL</h3>
<p>The scoring function $S(x)$ is a weighted geometric mean of individual component scores:</p>
<p>$$
S(x) = \left(\prod_{i=1}^{n} C_{i}(x)^{w_{i}}\right)^{\frac{1}{\sum_{i=1}^{n} w_{i}}}
$$</p>
<p>where $x$ is a sampled linked molecule, $C_{i}(x)$ is the score for the $i$-th component, and $w_{i}$ is its weight.</p>
<p>The agent (initialized as a copy of the prior) is updated via the Difference of Augmented and Posterior likelihoods (DAP) loss. The <a href="/notes/chemistry/molecular-design/generation/rl-tuned/augmented-hill-climb-rl-molecule-generation/">augmented log likelihood</a> is:</p>
<p>$$
\log \pi_{\text{augmented}} = \log \pi_{\text{prior}} + \sigma \cdot S(x)
$$</p>
<p>where $\pi$ denotes a policy (token sampling probabilities conditioned on the sequence so far) and $\sigma$ is a scalar factor. The loss function is:</p>
<p>$$
J(\theta) = \left(\log \pi_{\text{augmented}} - \log \pi_{\text{agent}}\right)^{2}
$$</p>
<p>Minimizing $J(\theta)$ steers the agent to generate molecules that satisfy the scoring function while remaining anchored to the prior&rsquo;s chemical space.</p>
<h3 id="diversity-filters">Diversity Filters</h3>
<p>Link-INVENT uses Diversity Filters (DFs) to balance exploration and exploitation. Buckets of limited size track unique <a href="/notes/chemistry/molecular-design/generation/rl-tuned/memory-assisted-rl-diverse-molecular-design/">Bemis-Murcko scaffolds</a>. When a bucket is full, further sampling of that scaffold receives a score of zero, encouraging the agent to explore diverse chemical space regions.</p>
<h3 id="linker-specific-scoring-components">Linker-Specific Scoring Components</h3>
<p>New scoring components provide direct control over linker properties:</p>
<ul>
<li><strong>Linker effective length</strong>: number of bonds between attachment atoms</li>
<li><strong>Linker maximum graph length</strong>: bonds in the longest graph traversal path</li>
<li><strong>Linker length ratio</strong>: effective length divided by maximum graph length (controls branching)</li>
<li><strong>Linker ratio of rotatable bonds</strong>: rotatable bonds over total bonds (controls flexibility)</li>
<li><strong>Linker number of rings</strong>: controls linearity vs. cyclicity</li>
<li><strong>Linker number of HBDs</strong>: hydrogen-bond donors in the linker itself</li>
</ul>
<h2 id="experimental-evaluation-across-three-drug-discovery-tasks">Experimental Evaluation Across Three Drug Discovery Tasks</h2>
<p>Link-INVENT was evaluated through four experiments across three drug discovery applications, all using the same pre-trained prior.</p>
<h3 id="illustrative-example-two-benzene-rings">Illustrative Example: Two Benzene Rings</h3>
<p>A simple experiment linked two benzene rings with the objectives of limiting HBDs and requiring exactly one ring in the linker. Over 20 epochs, the agent learned to satisfy both objectives, demonstrating the basic RL-guided generation process.</p>
<h3 id="experiment-1a-fragment-linking-ck2-alpha-inhibitors">Experiment 1a: Fragment Linking (CK2 alpha Inhibitors)</h3>
<p>Based on the <a href="https://en.wikipedia.org/wiki/Casein_kinase_2">casein kinase 2</a> (CK2 alpha) fragment linking campaign by Fusco and Brear et al., Link-INVENT was tasked with linking two fragment hits while retaining the Lys68 hydrogen-bond interaction via a DockStream docking constraint (Glide/LigPrep backend). The scoring function also enforced linker length ratio &gt;= 70 and linker MW &lt;= 200 Da.</p>
<p>Over 100 epochs in triplicate, the agent generated molecules with gradually improving docking scores. Key results:</p>
<ul>
<li>Docking score distributions across triplicates were nearly identical, demonstrating reproducibility</li>
<li>Some generated molecules achieved more favorable docking scores than the reference ligand CAM4066 (-15.20 kcal/mol)</li>
<li>More than 5000 unique Bemis-Murcko scaffolds were generated, with minimal overlap across replicates</li>
<li>Binding pose analysis showed the generated linker closely resembled the ground-truth linker, retaining the Lys68 interaction</li>
</ul>
<h3 id="experiment-1b-comparison-fragment-linking-impdh-inhibitors">Experiment 1b: Comparison Fragment Linking (IMPDH Inhibitors)</h3>
<p>Using the IMPDH inhibitor fragment linking case study from Trapero et al., this experiment applied core constrained docking (fragment pose within 0.3 A of reference) and compared results to DeLinker and SyntaLinker. The scoring function enforced linker effective length in [3, 5], length ratio &gt;= 70, and linker MW &lt;= 150 Da.</p>
<p>Link-INVENT generated 8960 SMILES across 70 epochs (comparable to DeLinker&rsquo;s 9000 molecular graphs). Results:</p>
<ul>
<li>Link-INVENT generated molecules with more favorable docking scores than the reference ligand across triplicate runs</li>
<li>Of 20 DeLinker and 3 SyntaLinker example molecules, none and one (the recovered reference) docked better than or equal to the reference</li>
<li>Approximately 3000 unique Bemis-Murcko scaffolds were generated from 5000 total molecules</li>
<li>Link-INVENT&rsquo;s advantage comes from including docking explicitly as a learning objective rather than applying it post hoc</li>
</ul>
<h3 id="experiment-2-scaffold-hopping-dlk-inhibitor-cns-optimization">Experiment 2: Scaffold Hopping (DLK Inhibitor CNS Optimization)</h3>
<p>Based on Patel et al.&rsquo;s <a href="https://en.wikipedia.org/wiki/MAP3K12">dual leucine zipper kinase</a> (DLK) inhibitor campaign, Link-INVENT generated new scaffold ideas to improve CNS penetration while retaining potency. The scoring function included a Cys193 docking constraint plus CNS-compatible properties (HBDs &lt; 2, tPSA &lt;= 90 A squared, 3 &lt;= SlogP &lt;= 4, MW &lt;= 450 Da, 1-2 aromatic rings in linker).</p>
<p>The solution space was significantly narrower than fragment linking. The agent still generated diverse scaffolds with favorable docking scores, though fewer exceeded the reference ligand&rsquo;s score. Binding pose analysis confirmed retained Cys193 interactions and predicted additional Gln195 hydrogen bonds.</p>
<h3 id="experiment-3-protac-design-bcl-2mcl-1-dual-degradation">Experiment 3: PROTAC Design (Bcl-2/Mcl-1 Dual Degradation)</h3>
<p>Three sub-experiments demonstrated linker-specific scoring components for PROTAC design based on Wang et al.&rsquo;s Bcl-2/Mcl-1 dual degradation strategy:</p>
<table>
  <thead>
      <tr>
          <th>Sub-Experiment</th>
          <th>Objective</th>
          <th>Key Finding</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Sub-Exp 1: Linker length</td>
          <td>Generate linkers within specified length intervals [4,6], [7,9], [10,12], [13,15]</td>
          <td>Clear enrichment within target intervals vs. baseline broad distribution</td>
      </tr>
      <tr>
          <td>Sub-Exp 2: Linearity</td>
          <td>Control linear vs. cyclic linkers at fixed length [7,9]</td>
          <td>Baseline ratio ~1:2 linear:cyclic; enforcing linearity or cyclicity achieved strong enrichment</td>
      </tr>
      <tr>
          <td>Sub-Exp 3: Flexibility</td>
          <td>Generate linkers with Low [0,30], Moderate [40,60], or High [70,100] rotatable bond ratios</td>
          <td>Agent learned that rings and sp2 atoms yield rigidity; linear sp3 chains yield flexibility</td>
      </tr>
  </tbody>
</table>
<h2 id="key-findings-and-practical-implications-for-drug-discovery">Key Findings and Practical Implications for Drug Discovery</h2>
<p>Link-INVENT demonstrates several practical advantages for molecular linker design:</p>
<ol>
<li><strong>Single prior, multiple tasks</strong>: The same pre-trained model handles fragment linking, scaffold hopping, and PROTAC design without retraining.</li>
<li><strong>Docking as a learning signal</strong>: Including molecular docking explicitly in the scoring function (via DockStream) during RL yields molecules with more favorable docking scores than approaches that apply docking post hoc.</li>
<li><strong>Implicit 3D awareness</strong>: The docking constraint guides the agent toward 3D structural awareness without explicit 3D coordinate inputs, as demonstrated by the overlap between generated and reference binding poses.</li>
<li><strong>Diverse and reproducible output</strong>: Diversity filters ensure exploration of multiple chemical space regions, and triplicate experiments show consistent docking score distributions with minimal scaffold overlap.</li>
</ol>
<p>Limitations acknowledged by the authors include:</p>
<ul>
<li>The linker flexibility metric (ratio of rotatable bonds) is agnostic to intra-molecular hydrogen bonds and does not account for all rigidity factors</li>
<li>Molecular docking is an approximation that can be exploited (e.g., excessive HBDs achieving favorable scores at the expense of permeability)</li>
<li>Experiments 1a and 1b require a proprietary Schrodinger license for Glide/LigPrep docking</li>
<li>No direct experimental (wet-lab) validation was performed in this study</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Prior training</td>
          <td>ChEMBL v27 (reaction-sliced)</td>
          <td>Not specified</td>
          <td>Filtered for drug-like compounds, then reaction-based slicing with SMIRKS</td>
      </tr>
      <tr>
          <td>Validation</td>
          <td>Held-out Bemis-Murcko scaffolds</td>
          <td>287 scaffolds</td>
          <td>Held out from training set</td>
      </tr>
      <tr>
          <td>SMILES augmentation</td>
          <td>Randomized SMILES per epoch</td>
          <td>Same tuples, different representations</td>
          <td>Improves generalizability</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Architecture</strong>: Encoder-decoder RNN with 3 hidden layers of 512 LSTM cells, embedding size 256</li>
<li><strong>RL loss</strong>: DAP (Difference of Augmented and Posterior likelihoods)</li>
<li><strong>Batch size</strong>: 128 molecules per epoch</li>
<li><strong>Diversity filter</strong>: Bemis-Murcko scaffold buckets of size 25</li>
<li><strong>Score threshold</strong>: 0 (to store all molecules for analysis)</li>
<li><strong>Scoring function</strong>: Weighted geometric mean of component scores</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Single pre-trained prior used across all experiments</li>
<li>Agent initialized as copy of prior, updated via RL</li>
<li>Pre-trained prior available at GitHub repository</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>Molecular docking via DockStream with Glide/LigPrep backend</li>
<li>Triplicate runs for all experiments</li>
<li>Metrics: docking scores, unique Bemis-Murcko scaffold counts, binding pose overlap</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Hardware specifications are not reported in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/MolecularAI/Reinvent">REINVENT (Link-INVENT code)</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Main codebase for Link-INVENT</td>
      </tr>
      <tr>
          <td><a href="https://github.com/MolecularAI/ReinventCommunity">ReinventCommunity (data + tutorial)</a></td>
          <td>Code + Data</td>
          <td>MIT</td>
          <td>Training/validation data, reaction SMIRKS, pre-trained prior, Jupyter tutorial</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status</strong>: Partially Reproducible. Code, training data, and pre-trained prior are publicly available. However, reproducing the docking-based experiments (1a, 1b, and 2) requires a proprietary Schrodinger license for Glide and LigPrep. The PROTAC experiments (Experiment 3) that use only physicochemical scoring are fully reproducible with the open-source code.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Guo, J., Knuth, F., Margreitter, C., Janet, J. P., Papadopoulos, K., Engkvist, O., &amp; Patronov, A. (2023). Link-INVENT: generative linker design with reinforcement learning. <em>Digital Discovery</em>, 2, 392-408. <a href="https://doi.org/10.1039/D2DD00115B">https://doi.org/10.1039/D2DD00115B</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{guo2023link,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Link-INVENT: generative linker design with reinforcement learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Guo, Jeff and Knuth, Franziska and Margreitter, Christian and Janet, Jon Paul and Papadopoulos, Kostas and Engkvist, Ola and Patronov, Atanas}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Digital Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{392--408}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/D2DD00115B}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Evolutionary Molecular Design via Deep Learning + GA</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/target-aware/evolutionary-design-deep-learning-genetic-algorithm/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/target-aware/evolutionary-design-deep-learning-genetic-algorithm/</guid><description>Kwon et al. combine an RNN decoder for SMILES reconstruction with a genetic algorithm operating on ECFP fingerprints for goal-directed molecular design.</description><content:encoded><![CDATA[<h2 id="fingerprint-based-evolutionary-molecular-design">Fingerprint-Based Evolutionary Molecular Design</h2>
<p>This is a <strong>Method</strong> paper that introduces an evolutionary design methodology (EDM) for goal-directed molecular optimization. The primary contribution is a three-component framework where (1) molecules are encoded as <a href="https://en.wikipedia.org/wiki/Chemical_similarity">extended-connectivity fingerprint</a> (ECFP) vectors, (2) a genetic algorithm evolves these fingerprint vectors through mutation and crossover, (3) a recurrent neural network (RNN) decodes the evolved fingerprints back into valid SMILES strings, and (4) a deep neural network (DNN) evaluates molecular fitness. The key advantage over prior evolutionary approaches is that no hand-crafted chemical rules or fragment libraries are needed, as the RNN learns valid molecular reconstruction from data.</p>
<h2 id="challenges-in-evolutionary-molecular-optimization">Challenges in Evolutionary Molecular Optimization</h2>
<p>Evolutionary algorithms for molecular design face two core challenges. First, maintaining chemical validity of evolved molecules is difficult when operating on graph or string representations directly. Prior methods rely on predefined chemical rules and fragment libraries to constrain structural modifications (atom/bond additions, deletions, substitutions), but these introduce bias and risk convergence to local optima. Each new application domain requires specifying new chemical rules, which may not exist for emerging areas. Second, fitness evaluation must be both efficient and accurate. Simple evaluation methods like structural similarity indices or semi-empirical quantum chemistry calculations reduce computational cost but may not capture complex property relationships.</p>
<p>High-throughput computational screening (HTCS) is a common alternative, but it depends on the quality of predefined virtual chemical libraries and often requires multiple iterative enumerations, limiting its ability to explore novel chemical space.</p>
<h2 id="core-innovation-evolving-fingerprints-with-neural-decoding">Core Innovation: Evolving Fingerprints with Neural Decoding</h2>
<p>The key insight is to perform genetic operations in fingerprint space rather than in molecular graph or SMILES string space. The framework comprises three learned functions:</p>
<p><strong>Encoding function</strong> $e(\cdot)$: Converts a SMILES string $\mathbf{m}$ into a 5000-dimensional ECFP vector $\mathbf{x}$ using Morgan fingerprints with a neighborhood radius of 6. This is a deterministic hash-based encoding (not learned).</p>
<p><strong>Decoding function</strong> $d(\cdot)$: An RNN with three hidden layers of 500 LSTM units that reconstructs a SMILES string from an ECFP vector. The RNN generates SMILES as a sequence of three-character substrings, conditioning each prediction on the current substring and the input ECFP vector:</p>
<p>$$d(\mathbf{x}) = \mathbf{m}, \quad \text{where } p(\mathbf{m}_{t+1} | \mathbf{m}_{t}, \mathbf{x})$$</p>
<p>The three-character substring approach reduces the ratio of invalid SMILES by imposing additional constraints on subsequent characters.</p>
<p><strong>Property prediction function</strong> $f(\cdot)$: A five-layer DNN with 250 hidden units per layer that predicts molecular properties from ECFP vectors:</p>
<p>$$\mathbf{t} = f(e(\mathbf{m}))$$</p>
<p>The RNN is trained by minimizing cross-entropy loss between the softmax output and the target SMILES string $\mathbf{m}_{i}$, learning the relationship $d(e(\mathbf{m}_{i})) = \mathbf{m}_{i}$. The DNN is trained by minimizing mean squared error between predicted and computed property values. Both use the Adam optimizer with mini-batch size 100, 500 training epochs, and dropout rate 0.5.</p>
<h3 id="genetic-algorithm-operations">Genetic Algorithm Operations</h3>
<p>The GA evolves ECFP vectors using the DEAP library with the following parameters:</p>
<ul>
<li><strong>Population size</strong>: 50</li>
<li><strong>Crossover rate</strong>: 0.7 (uniform crossover, mixing ratio 0.2)</li>
<li><strong>Mutation rate</strong>: 0.3 (Gaussian mutation, $N(0, 0.2^{2})$, applied to 1% of elements)</li>
<li><strong>Selection</strong>: Tournament selection with size 3, top 3 individuals as parents</li>
<li><strong>Termination</strong>: 500 generations or 30 consecutive generations without fitness improvement</li>
</ul>
<p>The evolutionary loop proceeds as follows: a seed molecule $\mathbf{m}_{0}$ is encoded to $\mathbf{x}_{0}$, mutated to generate a population $\mathbf{P}^{0} = {\mathbf{z}_{1}, \mathbf{z}_{2}, \ldots, \mathbf{z}_{L}}$, each vector is decoded via the RNN, validity is checked with RDKit, fitness is evaluated via the DNN, and the top parents produce the next generation through crossover and mutation.</p>
<h2 id="experimental-setup-light-absorbing-wavelength-optimization">Experimental Setup: Light-Absorbing Wavelength Optimization</h2>
<h3 id="training-data-and-deep-learning-performance">Training Data and Deep Learning Performance</h3>
<p>The models were trained on 10,000 to 100,000 molecules randomly sampled from PubChem (molecular weight 200-600 g/mol). Each molecule was labeled with DFT-computed excitation energy ($S_{1}$), <a href="https://en.wikipedia.org/wiki/HOMO_and_LUMO">HOMO, and LUMO</a> energies using B3LYP/6-31G.</p>
<table>
  <thead>
      <tr>
          <th>Training Data</th>
          <th>Validity (%)</th>
          <th>Reconstructability (%)</th>
          <th>$S_{1}$ (R, MAE)</th>
          <th>HOMO (R, MAE)</th>
          <th>LUMO (R, MAE)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>100,000</td>
          <td>88.8</td>
          <td>62.4</td>
          <td>0.977, 0.185 eV</td>
          <td>0.948, 0.168 eV</td>
          <td>0.960, 0.195 eV</td>
      </tr>
      <tr>
          <td>50,000</td>
          <td>86.7</td>
          <td>60.1</td>
          <td>0.973, 0.198 eV</td>
          <td>0.945, 0.172 eV</td>
          <td>0.955, 0.209 eV</td>
      </tr>
      <tr>
          <td>30,000</td>
          <td>85.3</td>
          <td>59.8</td>
          <td>0.930, 0.228 eV</td>
          <td>0.934, 0.191 eV</td>
          <td>0.945, 0.224 eV</td>
      </tr>
      <tr>
          <td>10,000</td>
          <td>83.2</td>
          <td>55.7</td>
          <td>0.913, 0.278 eV</td>
          <td>0.885, 0.244 eV</td>
          <td>0.917, 0.287 eV</td>
      </tr>
  </tbody>
</table>
<p>Validity refers to the proportion of chemically valid SMILES after RDKit inspection. Reconstructability measures how often the RNN can reproduce the original molecule from its ECFP (62.4% at 100k training samples by matching canonical SMILES among 10,000 generated strings).</p>
<h3 id="design-task-1-unconstrained-s1-modification">Design Task 1: Unconstrained S1 Modification</h3>
<p>Fifty seed molecules with $S_{1}$ values between 3.8 eV and 4.2 eV were evolved in both increasing and decreasing directions. With 50,000 training samples, $S_{1}$ increased by approximately 60% on average in the increasing direction and showed slightly lower rates of change in the decreasing direction. The asymmetry is attributed to the skewed $S_{1}$ distribution of training data (average $S_{1}$ of 4.3-4.4 eV, higher than the seed median of 4.0 eV). Performance saturated at approximately 50,000 training samples.</p>
<h3 id="design-task-2-s1-modification-with-homolumo-constraints">Design Task 2: S1 Modification with HOMO/LUMO Constraints</h3>
<p>The same 50 seeds were evolved with constraints: $-7.0 \text{ eV} &lt; \text{HOMO} &lt; -5.0 \text{ eV}$ and $\text{LUMO} &lt; 0.0 \text{ eV}$. In the increasing $S_{1}$ direction, constraints suppressed the rate of change because both HOMO and LUMO bounds limit the achievable HOMO-LUMO gap. In the decreasing direction, constraints had minimal effect because LUMO could freely decrease while HOMO had sufficient room to rise within the allowed range.</p>
<h3 id="design-task-3-extrapolation-beyond-training-data">Design Task 3: Extrapolation Beyond Training Data</h3>
<p>To generate molecules with $S_{1}$ values below 1.77 eV (outside the training distribution, which had mean $S_{1}$ of 4.91 eV), the authors introduced iterative &ldquo;phases&rdquo;: generate molecules, compute their properties via DFT, retrain the models, and repeat. Starting from the 30 lowest-$S_{1}$ seed molecules with 300 generation runs per phase:</p>
<ul>
<li>Phase 1: Average $S_{1}$ = 2.20 eV, 12 molecules below 1.77 eV</li>
<li>Phase 2: Average $S_{1}$ = 2.22 eV, 37 molecules below 1.77 eV</li>
<li>Phase 3: Average $S_{1}$ = 2.31 eV, 58 molecules below 1.77 eV</li>
</ul>
<p>While the average $S_{1}$ rose slightly across phases, variance decreased (from 1.40 to 1.36), indicating the model concentrated its outputs closer to the target range. This active-learning-like loop demonstrates the framework can extend beyond the training distribution.</p>
<h3 id="design-task-4-guacamol-benchmarks">Design Task 4: GuacaMol Benchmarks</h3>
<p>The method was evaluated on the <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> goal-directed benchmark suite using the ChEMBL25 training dataset. The RNN model was retrained with three-character substrings.</p>
<table>
  <thead>
      <tr>
          <th>Benchmark</th>
          <th>Best of Dataset</th>
          <th><a href="/notes/chemistry/molecular-design/generation/autoregressive/lstm-drug-like-molecule-generation/">SMILES LSTM</a></th>
          <th>SMILES GA</th>
          <th><a href="/notes/chemistry/molecular-design/generation/search-based/graph-based-genetic-algorithm-chemical-space/">Graph GA</a></th>
          <th><a href="/notes/chemistry/molecular-design/generation/search-based/graph-based-genetic-algorithm-chemical-space/">Graph MCTS</a></th>
          <th>cRNN</th>
          <th>EDM (ours)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Celecoxib rediscovery</td>
          <td>0.505</td>
          <td>1.000</td>
          <td>0.607</td>
          <td>1.000</td>
          <td>0.378</td>
          <td>1.000</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>Troglitazone rediscovery</td>
          <td>0.419</td>
          <td>1.000</td>
          <td>0.558</td>
          <td>1.000</td>
          <td>0.312</td>
          <td>1.000</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>Thiothixene rediscovery</td>
          <td>0.456</td>
          <td>1.000</td>
          <td>0.495</td>
          <td>1.000</td>
          <td>0.308</td>
          <td>1.000</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>LogP(-1.0)</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>0.980</td>
          <td>1.000</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>LogP(8.0)</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>0.979</td>
          <td>1.000</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>TPSA(150.0)</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>CNS MPO</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>QED</td>
          <td>0.948</td>
          <td>0.948</td>
          <td>0.948</td>
          <td>0.948</td>
          <td>0.944</td>
          <td>0.948</td>
          <td>0.948</td>
      </tr>
  </tbody>
</table>
<p>The EDM achieves maximum scores on all eight tasks, matching the cRNN baseline. The 256 highest-scoring molecules from the ChEMBL25 test set were used as seeds, with 500 SMILES strings generated per seed.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="results">Results</h3>
<p>The evolutionary design framework successfully evolved seed molecules toward target properties across all four design tasks. The RNN decoder maintained 88.8% chemical validity at 100k training samples, and the DNN property predictor achieved correlation coefficients above 0.94 for $S_{1}$, HOMO, and LUMO prediction. The iterative retraining procedure enabled exploration outside the training data distribution, generating 58 molecules with $S_{1}$ below 1.77 eV after three phases. On GuacaMol benchmarks, the method achieved maximum scores on all eight tasks, matching <a href="/notes/chemistry/molecular-design/generation/autoregressive/lstm-drug-like-molecule-generation/">SMILES LSTM</a>, <a href="/notes/chemistry/molecular-design/generation/search-based/graph-based-genetic-algorithm-chemical-space/">Graph GA</a>, and cRNN baselines.</p>
<h3 id="limitations">Limitations</h3>
<p>Several limitations are worth noting:</p>
<ol>
<li><strong>Reconstructability ceiling</strong>: Only 62.4% of molecules could be reconstructed from their ECFP vectors, meaning the RNN decoder fails to recover the original molecule approximately 38% of the time. This information loss in the ECFP encoding is a fundamental bottleneck.</li>
<li><strong>Data dependence</strong>: Performance is sensitive to the training data distribution. The asymmetric evolution rates for increasing vs. decreasing $S_{1}$ directly reflect the skewed training data.</li>
<li><strong>Structural constraints</strong>: Three heuristic constraints (fused ring sizes, number of fused rings, alkyl chain lengths) were still needed to maintain reasonable molecular structures, partially undermining the claim of a fully data-driven approach.</li>
<li><strong>DFT reliance</strong>: The extrapolation experiment requires DFT calculations in the loop, which are computationally expensive and may limit scalability.</li>
<li><strong>Limited benchmark scope</strong>: Only 8 GuacaMol tasks were tested, and all achieved perfect scores, making it difficult to differentiate from competing methods. The paper does not report on harder multi-objective benchmarks.</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Evaluation</td>
          <td>PubChem random sample</td>
          <td>10,000-100,000 molecules</td>
          <td>MW 200-600 g/mol, labeled with DFT-computed $S_{1}$, HOMO, LUMO</td>
      </tr>
      <tr>
          <td>GuacaMol Benchmark</td>
          <td>ChEMBL25</td>
          <td>Standard split</td>
          <td>Used for retraining RNN; 256 top-scoring seeds</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Genetic algorithm</strong>: DEAP library; population 50, crossover rate 0.7, mutation rate 0.3, tournament size 3</li>
<li><strong>RNN decoder</strong>: 3 hidden layers, 500 LSTM units each, three-character substring generation</li>
<li><strong>DNN predictor</strong>: 5 layers, 250 hidden units, sigmoid activations, linear output</li>
<li><strong>Training</strong>: Adam optimizer, mini-batch 100, 500 epochs, dropout 0.5</li>
</ul>
<h3 id="models">Models</h3>
<p>All neural networks were implemented using Keras with the Theano backend (GPU-accelerated). No pre-trained model weights are publicly available.</p>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>RNN validity</strong>: Proportion of chemically valid SMILES (RDKit check)</li>
<li><strong>Reconstructability</strong>: Fraction of seed molecules recoverable from ECFP (canonical SMILES match in 10,000 generated strings)</li>
<li><strong>DNN accuracy</strong>: Correlation coefficient (R) and MAE via 10-fold cross-validation</li>
<li><strong>Evolutionary performance</strong>: Average rate of $S_{1}$ change across 50 seeds; molecule count in target range</li>
<li><strong>GuacaMol</strong>: Standard rediscovery and property satisfaction benchmarks</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify GPU models, training times, or computational requirements for the evolutionary runs. DFT calculations used the Gaussian 09 program suite with B3LYP/6-31G.</p>
<h3 id="artifacts">Artifacts</h3>
<p>No public code repository or pre-trained models are available. The paper is published under a CC-BY 4.0 license as open access in Scientific Reports.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://www.nature.com/articles/s41598-021-96812-8">Paper (Nature)</a></td>
          <td>Paper</td>
          <td>CC-BY 4.0</td>
          <td>Open access</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility classification</strong>: Partially Reproducible. The method is described in sufficient detail for reimplementation, but no code, trained models, or preprocessed datasets are released. The DFT calculations require Gaussian 09, a commercial software package.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Kwon, Y., Kang, S., Choi, Y.-S., &amp; Kim, I. (2021). Evolutionary design of molecules based on deep learning and a genetic algorithm. <em>Scientific Reports</em>, 11, 17304. <a href="https://doi.org/10.1038/s41598-021-96812-8">https://doi.org/10.1038/s41598-021-96812-8</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{kwon2021evolutionary,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Evolutionary design of molecules based on deep learning and a genetic algorithm}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Kwon, Youngchun and Kang, Seokho and Choi, Youn-Suk and Kim, Inkoo}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Scientific Reports}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{11}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{17304}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41598-021-96812-8}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Avoiding Failure Modes in Goal-Directed Generation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/avoiding-failure-modes-goal-directed-generation/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/avoiding-failure-modes-goal-directed-generation/</guid><description>Langevin et al. show that apparent failure modes in goal-directed molecular generation stem from QSAR model disagreement, not algorithmic flaws.</description><content:encoded><![CDATA[<h2 id="reinterpreting-goal-directed-generation-failures-as-qsar-model-issues">Reinterpreting Goal-Directed Generation Failures as QSAR Model Issues</h2>
<p>This is an <strong>Empirical</strong> study that challenges a widely cited finding about failure modes in goal-directed molecular generation. <a href="/notes/chemistry/molecular-design/generation/evaluation/failure-modes-molecule-generation/">Renz et al. (2019)</a> had shown that when molecules are optimized against a machine learning scoring function, control models trained on the same data distribution assign much lower scores to the generated molecules. This was interpreted as evidence that generation algorithms exploit model-specific biases. Langevin et al. demonstrate that this divergence is already present in the original data distribution and is attributable to disagreement among the QSAR classifiers, not to flaws in the generation algorithms themselves.</p>
<h2 id="why-qsar-model-agreement-matters-for-molecular-generation">Why QSAR Model Agreement Matters for Molecular Generation</h2>
<p>Goal-directed generation uses a scoring function (typically a <a href="https://en.wikipedia.org/wiki/Quantitative_structure%E2%80%93activity_relationship">QSAR</a> model) to guide the design of molecules that maximize predicted activity. In the experimental framework from Renz et al., three Random Forest classifiers are trained: an optimization model $C_{opt}$ on Split 1, a model control $C_{mc}$ on Split 1 with a different random seed, and a data control $C_{dc}$ on Split 2. Each returns a confidence score ($S_{opt}$, $S_{mc}$, $S_{dc}$). The expectation is that molecules with high $S_{opt}$ should also score highly under $S_{mc}$ and $S_{dc}$, since all three models are trained on the same data distribution for the same target.</p>
<p>Renz et al. observed that during optimization, $S_{mc}$ and $S_{dc}$ diverge from $S_{opt}$, reaching substantially lower values. This was interpreted as goal-directed generation exploiting biases unique to the optimization model. The recommendation was to halt generation when control scores stop increasing, requiring a held-out dataset for a control model, which may not be feasible in low-data regimes.</p>
<p>The key insight of Langevin et al. is that nobody had checked whether this score disagreement existed before generation even began. If the classifiers already disagree on high-scoring molecules in the original dataset, the divergence during generation is expected behavior, not evidence of algorithmic failure.</p>
<h2 id="pre-existing-classifier-disagreement-explains-the-divergence">Pre-Existing Classifier Disagreement Explains the Divergence</h2>
<p>The core contribution is showing that the gap between optimization and control scores is a property of the QSAR models, not of the generation algorithms.</p>
<p>The authors introduce a held-out test set (10% of the data, used for neither training split) and augment it via Topliss tree enumeration to produce structural analogs for smoother statistical estimates. On this held-out set, they compute the Mean Average Difference (MAD) between $S_{opt}$ and control scores as a function of $S_{opt}$:</p>
<p>$$
\text{MAD}(x) = \frac{1}{|\{i : S_{opt}(x_i) \geq x\}|} \sum_{S_{opt}(x_i) \geq x} |S_{opt}(x_i) - S_{dc}(x_i)|
$$</p>
<p>On the three original datasets (DRD2, EGFR, JAK2), the MAD between $S_{opt}$ and $S_{dc}$ grows substantially with $S_{opt}$, reaching approximately 0.3 for the highest-scoring molecules. For EGFR, even the top molecules (with $S_{opt}$ between 0.5 and 0.6) have $S_{dc}$ below 0.2. This disagreement exists entirely within the original data distribution, before any generative algorithm is applied.</p>
<p>The authors formalize this with tolerance intervals. At each generation time step $t$, the distribution of optimization scores is $P_t[S_{opt}(x)]$. From the held-out set, the conditional distributions $P[S_{dc}(x) | S_{opt}(x)]$ and $P[S_{mc}(x) | S_{opt}(x)]$ are estimated empirically. The expected control scores at time $t$ are then:</p>
<p>$$
\mathbb{E}[S_{dc}] = \int P[S_{dc}(x) | S_{opt}(x)] \cdot P_t[S_{opt}(x)] , dS_{opt}
$$</p>
<p>By sampling from these distributions, the authors construct 95% tolerance intervals for the expected control scores at each time step. The observed trajectories of $S_{mc}$ and $S_{dc}$ during generation fall within these intervals, demonstrating that the divergence is fully explained by pre-existing classifier disagreement.</p>
<h2 id="experimental-setup-original-reproduction-and-corrected-experiments">Experimental Setup: Original Reproduction and Corrected Experiments</h2>
<h3 id="reproduction-of-renz-et-al">Reproduction of Renz et al.</h3>
<p>The original experimental framework uses three datasets from ChEMBL: <a href="https://en.wikipedia.org/wiki/Dopamine_receptor_D2">DRD2</a> (842 molecules, 59 actives), <a href="https://en.wikipedia.org/wiki/Epidermal_growth_factor_receptor">EGFR</a> (842 molecules, 40 actives), and <a href="https://en.wikipedia.org/wiki/Janus_kinase_2">JAK2</a> (667 molecules, 140 actives). These are small, noisy, and chemically diverse. Three goal-directed generation algorithms are tested:</p>
<table>
  <thead>
      <tr>
          <th>Algorithm</th>
          <th>Type</th>
          <th>Mechanism</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Graph GA</td>
          <td>Genetic algorithm on molecular graphs</td>
          <td>Mutation and crossover of molecular graphs</td>
      </tr>
      <tr>
          <td>SMILES-LSTM</td>
          <td>Recurrent neural network</td>
          <td>Hill-climbing fine-tuning on best molecules</td>
      </tr>
      <tr>
          <td>MSO</td>
          <td>Particle swarm in CDDD latent space</td>
          <td>Multiple swarm optimization</td>
      </tr>
  </tbody>
</table>
<p>All algorithms are run for 151 epochs with 10 runs each. The reproduction confirms the findings of Renz et al.: $S_{mc}$ and $S_{dc}$ diverge from $S_{opt}$ during optimization.</p>
<h3 id="tolerance-interval-analysis">Tolerance interval analysis</h3>
<p>The held-out set is augmented using Topliss tree enumeration on phenyl rings, providing structural analogs that are reasonable from a medicinal chemistry perspective. The optimization score range is divided into 25 equal bins, and for each molecule at each time step, 10 samples from the conditional control score distribution are drawn to construct empirical tolerance intervals.</p>
<h3 id="corrected-experiments-with-adequate-models">Corrected experiments with adequate models</h3>
<p>To test whether generation algorithms actually exploit biases when the classifiers agree, the authors construct two tasks where optimization and control models correlate well:</p>
<ol>
<li><strong>ALDH1 dataset</strong>: 464 molecules from LIT-PCBA, split using similarity-based pairing to maximize intra-pair chemical similarity. This ensures both splits sample similar chemistry.</li>
<li><strong>Modified JAK2</strong>: The same JAK2 dataset but with Random Forest hyperparameters adjusted (200 trees instead of 100, minimum 3 samples per leaf instead of 1) to reduce overfitting to spurious correlations.</li>
</ol>
<p>In both cases, $S_{opt}$, $S_{mc}$, and $S_{dc}$ agree well on the held-out test set. The starting population for generation is set to the held-out test set (rather than random ChEMBL molecules) to avoid building in a distribution shift.</p>
<h2 id="findings-no-algorithmic-failure-when-models-agree">Findings: No Algorithmic Failure When Models Agree</h2>
<p>On the corrected experimental setups (ALDH1 and modified JAK2), there is no major divergence between optimization and control scores during generation. The three algorithms produce molecules that score similarly under all three classifiers.</p>
<p>Key findings:</p>
<ol>
<li>
<p><strong>Pre-existing disagreement explains divergence</strong>: On all three original datasets, the divergence between $S_{opt}$ and control scores during generation falls within the tolerance intervals predicted from the initial data distribution alone. The generation algorithms are not exploiting model-specific biases beyond what already exists in the data.</p>
</li>
<li>
<p><strong>Split similarity bias is also pre-existing</strong>: Renz et al. observed that generated molecules are more similar to Split 1 (used to train $C_{opt}$) than Split 2. The authors show this bias is already present in the top-5 percentile of the held-out set: on EGFR and DRD2, high-scoring molecules are inherently more similar to Split 1.</p>
</li>
<li>
<p><strong>Appropriate model design resolves the issue</strong>: When Random Forest hyperparameters are chosen to avoid overfitting (more trees, higher minimum samples per leaf), or when data splits are constructed to be chemically balanced, the classifiers agree and the generation algorithms behave as expected.</p>
</li>
<li>
<p><strong>Quality problems remain independent</strong>: Even when optimization and control scores align, the generated molecules can still be poor drug candidates (unreactive, unsynthesizable, containing unusual fragments). The score divergence issue and the chemical quality issue are separate problems.</p>
</li>
</ol>
<h3 id="limitations-acknowledged-by-the-authors">Limitations acknowledged by the authors</h3>
<ul>
<li>The study focuses on Random Forest classifiers with ECFP fingerprints. The behavior of other model types (e.g., graph neural networks) and descriptor types is not fully explored, though supplementary results show similar patterns with physico-chemical descriptors and Atom-Pair fingerprints.</li>
<li>The corrected ALDH1 task uses a relatively small dataset (464 molecules) with careful split construction. Scaling this approach to larger, more heterogeneous datasets is not demonstrated.</li>
<li>The authors note that their results do not prove generation algorithms never exploit biases; they show that the specific evidence from Renz et al. can be explained without invoking algorithmic failure.</li>
<li>The problem of low-quality generated molecules (poor synthesizability, unusual fragments) remains unresolved and is acknowledged as an open question.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Original tasks</td>
          <td>DRD2, EGFR, JAK2</td>
          <td>842, 842, 667 molecules</td>
          <td>Extracted from ChEMBL; small with few actives</td>
      </tr>
      <tr>
          <td>New task</td>
          <td>ALDH1</td>
          <td>464 molecules (173 with purine substructure)</td>
          <td>Extracted from LIT-PCBA; similarity-based split</td>
      </tr>
      <tr>
          <td>Augmentation</td>
          <td>Topliss tree analogs</td>
          <td>~10x augmentation of held-out set</td>
          <td>Structural analogs via phenyl ring enumeration</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Three goal-directed generation algorithms from the original Renz et al. study:</p>
<ul>
<li><strong>Graph GA</strong>: Genetic algorithm on molecular graphs (Jensen, 2019)</li>
<li><strong>SMILES-LSTM</strong>: Hill-climbing on LSTM-generated SMILES (Segler et al., 2018)</li>
<li><strong>MSO</strong>: Multi-Swarm Optimization in CDDD latent space (Winter et al., 2019)</li>
</ul>
<p>All run for 151 epochs, 10 runs each.</p>
<h3 id="models">Models</h3>
<p>Random Forest classifiers (scikit-learn) with:</p>
<ul>
<li>ECFP fingerprints (radius 2, 1024 bits, RDKit)</li>
<li>Default parameters for original tasks</li>
<li>Modified parameters for JAK2 correction: 200 trees, min 3 samples per leaf</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Purpose</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Mean Average Difference (MAD)</td>
          <td>Measures disagreement between optimization and control scores</td>
          <td>Computed as function of $S_{opt}$ on held-out set</td>
      </tr>
      <tr>
          <td>95% tolerance intervals</td>
          <td>Expected range of control scores given optimization scores</td>
          <td>Empirical, constructed from held-out set</td>
      </tr>
      <tr>
          <td>Tanimoto similarity</td>
          <td>Split bias assessment</td>
          <td>Morgan fingerprints, radius 2, 1024 bits</td>
      </tr>
      <tr>
          <td>ROC-AUC</td>
          <td>Classifier predictive performance</td>
          <td>Used to verify models have comparable accuracy</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Sanofi-Public/IDD-papers-avoiding_failure_modes">Code and datasets</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Fork of Renz et al. codebase with modifications</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Langevin, M., Vuilleumier, R., &amp; Bianciotto, M. (2022). Explaining and avoiding failure modes in goal-directed generation of small molecules. <em>Journal of Cheminformatics</em>, 14, 20. <a href="https://doi.org/10.1186/s13321-022-00601-y">https://doi.org/10.1186/s13321-022-00601-y</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{langevin2022explaining,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Explaining and avoiding failure modes in goal-directed generation of small molecules}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Langevin, Maxime and Vuilleumier, Rodolphe and Bianciotto, Marc}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{14}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{20}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-022-00601-y}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Augmented Hill-Climb for RL-Based Molecule Design</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/augmented-hill-climb-rl-molecule-generation/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/augmented-hill-climb-rl-molecule-generation/</guid><description>Augmented Hill-Climb combines REINVENT and Hill-Climb RL strategies to improve sample efficiency ~45-fold for SMILES-based de novo molecule generation.</description><content:encoded><![CDATA[<h2 id="a-hybrid-rl-strategy-for-de-novo-molecule-generation">A Hybrid RL Strategy for De Novo Molecule Generation</h2>
<p>This is a <strong>Method</strong> paper that proposes Augmented Hill-Climb (AHC), a reinforcement learning strategy for conditioning SMILES-based language models during de novo molecule generation. The primary contribution is a simple hybrid between the <a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a> and Hill-Climb (HC) RL strategies that computes the REINVENT loss function only on the top-k highest-scoring molecules per batch (as in HC), thereby removing the counterproductive regularization effect of low-scoring molecules. The authors demonstrate that AHC improves optimization ability ~1.5-fold and sample efficiency ~45-fold compared to REINVENT across docking tasks against four <a href="https://en.wikipedia.org/wiki/G_protein-coupled_receptor">GPCR</a> targets, and that the approach generalizes to transformer architectures.</p>
<h2 id="sample-efficiency-bottleneck-in-rl-guided-molecular-generation">Sample Efficiency Bottleneck in RL-Guided Molecular Generation</h2>
<p>Recurrent neural networks trained on SMILES have become a standard approach for de novo molecule generation, with RL strategies like REINVENT and Hill-Climb achieving top performance on benchmarks such as <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> and <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a>. However, RL-guided generation can be highly <a href="/notes/chemistry/molecular-design/generation/evaluation/sample-efficiency-de-novo-generation/">sample-inefficient</a>, often requiring $10^5$ or more molecules to optimize complex objectives. This is acceptable for cheap scoring functions (e.g., QSAR models, property calculators) but becomes a practical bottleneck when using computationally expensive scoring functions like molecular docking or computer-aided synthesis planning.</p>
<p>The REINVENT strategy regularizes the agent by computing a loss based on the difference between the agent&rsquo;s policy and an &ldquo;augmented likelihood&rdquo; that combines the prior policy with a scaled reward. When low-scoring molecules are sampled ($R_T \approx 0$), the augmented likelihood reduces to the prior likelihood, causing the agent to trend back toward the prior policy. This negates useful learnings, especially early in training or when the objective is difficult. Meanwhile, Hill-Climb simply fine-tunes the RNN on the top-k molecules per batch, which is sample-efficient but lacks explicit regularization, leading to mode collapse and generation of invalid SMILES.</p>
<p>Previous work by Neil et al. compared RL strategies but did not clearly quantify sample-efficiency differences, and modifications to the REINVENT loss function by Fialkova et al. showed no significant improvement. The best agent reminder (BAR) mechanism offered modest gains but was originally tested on graph-based models.</p>
<h2 id="core-innovation-filtering-low-scoring-molecules-from-the-reinvent-loss">Core Innovation: Filtering Low-Scoring Molecules from the REINVENT Loss</h2>
<p>Augmented Hill-Climb combines the loss formulation of REINVENT with the top-k selection mechanism of Hill-Climb. The agent samples a batch of molecules, ranks them by reward, and computes the REINVENT loss only on the top-k molecules. This removes the counterproductive regularization caused by low-scoring molecules while retaining the prior-based regularization for high-scoring molecules.</p>
<p>The REINVENT loss defines an augmented likelihood:</p>
<p>$$
\log P_{\mathbb{U}}(A) = \log P_{prior}(A) + \sigma R_T
$$</p>
<p>where $\sigma$ is a scaling coefficient controlling the reward contribution. The agent loss is the squared difference between the augmented likelihood and the agent&rsquo;s log-likelihood:</p>
<p>$$
L(\theta) = \left[\log P_{\mathbb{U}}(A) - \log P_{agent}(A)\right]^2
$$</p>
<p>In standard REINVENT, this loss is computed over all molecules in the batch. When $R_T \approx 0$, the augmented likelihood collapses to the prior likelihood, pushing the agent back toward the prior. AHC avoids this by computing the loss only on the top-k molecules ranked by reward, exactly as Hill-Climb selects molecules for fine-tuning.</p>
<p>The key insight is that high-scoring molecules are still regularized by the prior component of the augmented likelihood ($\log P_{prior}(A)$), preventing catastrophic forgetting. Low-scoring molecules, which would otherwise pull the agent back toward the prior, are simply excluded from the loss computation.</p>
<h3 id="diversity-filters-to-prevent-mode-collapse">Diversity Filters to Prevent Mode Collapse</h3>
<p>AHC is more susceptible to mode collapse than REINVENT because it focuses learning on high-scoring molecules. The authors address this with diversity filters (DFs) that penalize the reward of molecules similar to previously generated ones. Through a hyperparameter search over 825 configurations on three GuacaMol tasks, they identify an optimal DF configuration (DF2) with:</p>
<ul>
<li>Minimum score threshold of 0.5 (lower than DF1&rsquo;s 0.8)</li>
<li>Linear penalization output mode (softer than binary)</li>
<li>Bin size of 50 (larger than DF1&rsquo;s 25)</li>
<li>Scaffold similarity based on ECFP4 fingerprints</li>
</ul>
<p>The authors find that stricter DFs (lower thresholds, smaller bins) better prevent mode collapse but reduce optimization performance, while more lenient DFs enable better learning of chemotype-reward associations. DF2 represents a compromise.</p>
<h2 id="experimental-setup-docking-tasks-and-benchmark-comparisons">Experimental Setup: Docking Tasks and Benchmark Comparisons</h2>
<p>The evaluation spans five experiments:</p>
<p><strong>Experiment 1</strong>: AHC vs. REINVENT on DRD2 docking over 100 RL updates (6,400 samples), varying $\sigma$ from 30 to 240. RNN trained on the MOSESn dataset (MOSES with neutralized charges, 2.45M molecules).</p>
<p><strong>Experiment 2</strong>: AHC + DF2 vs. REINVENT on four GPCR targets (DRD2, OPRM1, AGTR1, OX1R) over 500 RL updates. Docking performed with Glide-SP after ligand preparation with LigPrep.</p>
<p><strong>Experiment 3</strong>: Diversity filter hyperparameter search (825 configurations) on three GuacaMol tasks (<a href="https://en.wikipedia.org/wiki/Aripiprazole">Aripiprazole</a> similarity, C11H24 isomers, <a href="https://en.wikipedia.org/wiki/Osimertinib">Osimertinib</a> MPO) using the GuacaMol training set (1.27M molecules from ChEMBL24).</p>
<p><strong>Experiment 4</strong>: Benchmark of AHC against REINFORCE, REINVENT (v1 and v2), BAR, and Hill-Climb (with and without KL regularization) on six tasks of varying difficulty:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Difficulty</th>
          <th>Objective</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Heavy atoms</td>
          <td>Easy</td>
          <td>Maximize number of heavy atoms</td>
      </tr>
      <tr>
          <td><a href="https://en.wikipedia.org/wiki/Risperidone">Risperidone</a> similarity</td>
          <td>Easy</td>
          <td>Maximize Tanimoto similarity to Risperidone</td>
      </tr>
      <tr>
          <td>DRD2 activity</td>
          <td>Medium</td>
          <td>Maximize QSAR-predicted DRD2 activity</td>
      </tr>
      <tr>
          <td>DRD2 docking</td>
          <td>Medium</td>
          <td>Minimize Glide-SP docking score</td>
      </tr>
      <tr>
          <td>DRD2-DRD3 dual</td>
          <td>Hard</td>
          <td>Maximize predicted activity against both targets</td>
      </tr>
      <tr>
          <td>DRD2/DRD3 selective</td>
          <td>Hard</td>
          <td>Maximize selective DRD2 activity over DRD3</td>
      </tr>
  </tbody>
</table>
<p><strong>Experiment 5</strong>: AHC vs. REINVENT on transformer (Tr) and gated transformer (GTr) architectures on the same six benchmark tasks. The GTr implements a GRU-style gate in place of residual connections to stabilize RL training.</p>
<h3 id="rnn-and-transformer-architectures">RNN and Transformer Architectures</h3>
<p>Three RNN configurations were used: (1) embedding 128 + 3 GRU layers of 512 (REINVENT v1), (2) embedding 256 + 3 LSTM layers of 512 (REINVENT 2.0), (3) 3 LSTM layers of 512 with dropout 0.2 (GuacaMol). Transformers used 4 encoder layers with hidden dimension 512, 8 attention heads, and feed-forward dimension 1024.</p>
<p>QSAR models for DRD2 and DRD3 activity were random forest classifiers trained on ExCAPE-DB data with GHOST threshold identification for handling class imbalance.</p>
<h2 id="key-findings-45-fold-sample-efficiency-improvement">Key Findings: 45-Fold Sample Efficiency Improvement</h2>
<h3 id="experiment-1-ahc-consistently-outperforms-reinvent">Experiment 1: AHC Consistently Outperforms REINVENT</h3>
<p>AHC improved optimization ability by 1.39-fold over REINVENT averaged across all $\sigma$ values, with maximum optimization of 205% at $\sigma = 240$ (compared to 128% for REINVENT). AHC required ~80 fewer RL steps to match REINVENT&rsquo;s mean docking score at 100 steps. With DF1 applied, the improvement was 1.45-fold.</p>
<p>AHC showed greater sensitivity to $\sigma$, giving practitioners more control over the regularization-optimization trade-off. At $\sigma = 60$ (heavily regularized), AHC still improved 1.47-fold over REINVENT while maintaining property space defined by the MOSESn training set. At higher $\sigma$ values, AHC extrapolated further outside the training distribution, which can be favorable (novel chemical space) or unfavorable (scoring function exploitation, e.g., larger molecules getting better docking scores due to the additive nature of scoring functions).</p>
<h3 id="experiment-2-improvement-across-four-gpcr-targets">Experiment 2: Improvement Across Four GPCR Targets</h3>
<p>Across DRD2, OPRM1, AGTR1, and OX1R, AHC + DF2 required on average 7.4-fold fewer training steps and 45.5-fold fewer samples to reach optimization thresholds. The improvement was largest early in training: 19.8-fold fewer steps to reach 120% optimization, and 71.8-fold fewer samples to first produce a molecule exceeding 160% optimization.</p>
<p>AHC + DF2 surpassed the 80% retrospective precision threshold within 100 RL updates for all targets except the challenging OX1R. DF2 successfully stabilized learning, avoiding the convergence-to-threshold failure mode observed with DF1.</p>
<p>Scaffold analysis showed AHC generates similar chemistry to REINVENT. The top 500 scaffolds produced by REINVENT were also generated by AHC, but typically much sooner.</p>
<h3 id="experiment-4-benchmark-against-all-rl-strategies">Experiment 4: Benchmark Against All RL Strategies</h3>
<p>AHC outperformed all other RL strategies on all six benchmark tasks except maximizing heavy atoms (an extrapolation task of limited practical relevance). AHC was particularly superior during early-stage optimization and for harder objectives (dual activity, selective activity).</p>
<p>Hill-Climb with a smaller batch size (HC*) showed improved early-stage sample efficiency similar to AHC, but rapidly underwent mode collapse. KL regularization did not rescue mode collapse in any case and sometimes worsened performance. BAR performed poorly in most tasks, possibly because the best-agent memory acts as a second regularizer that inhibits learning.</p>
<p>In terms of wall time for the DRD2 docking task, AHC reached 140% optimization in 16 CPU hours vs. 202 CPU hours for <a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent4-generative-molecule-design/">REINVENT 2.0</a>. AHC was the only strategy to reach 200% optimization within the allotted time (216 CPU hours). Parallelized over 10 CPUs, this corresponds to ~21.6 hours, making docking-guided generation feasible on local machines.</p>
<h3 id="experiment-5-generalization-to-transformers">Experiment 5: Generalization to Transformers</h3>
<p>AHC outperformed REINVENT on both the standard transformer and the gated transformer architectures. The standard transformer was unstable under RL, readily undergoing mode collapse. The gated transformer (with GRU-style gating replacing residual connections) stabilized RL training. AHC&rsquo;s efficiency gains generalized to both architectures.</p>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge several limitations:</p>
<ul>
<li>Chemistry quality evaluation is complicated by the interaction between RL strategy and scoring function suitability. Greater optimization may lead to unreasonable chemistry due to scoring function exploitation rather than the RL strategy itself.</li>
<li>The diversity filter hyperparameter search was conducted on GuacaMol toy tasks, which may not fully transfer to docking-based objectives.</li>
<li>The docking scoring function was system-dependent: DRD2 and OPRM1 were optimized effectively, while AGTR1 and OX1R proved more challenging (especially AGTR1, where the docking algorithm targeted the wrong sub-pocket).</li>
<li>KL regularization proved ineffective for HC and REINFORCE, suggesting it is not a sufficient regularization method in this context.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RNN pretraining</td>
          <td>MOSESn (MOSES neutralized)</td>
          <td>2,454,087 molecules</td>
          <td>ZINC15 clean leads with neutralized charges</td>
      </tr>
      <tr>
          <td>RNN pretraining</td>
          <td>GuacaMol train</td>
          <td>1,273,104 molecules</td>
          <td>ChEMBL24 with property filters</td>
      </tr>
      <tr>
          <td>QSAR training</td>
          <td>ExCAPE-DB (DRD2)</td>
          <td>4,609 actives / 343,026 inactives</td>
          <td>Random forest with GHOST thresholds</td>
      </tr>
      <tr>
          <td>QSAR training</td>
          <td>ExCAPE-DB (DRD3)</td>
          <td>2,758 actives / 402,524 inactives</td>
          <td>Unique subsets for dual/selective tasks</td>
      </tr>
      <tr>
          <td>DF parameter search</td>
          <td>GuacaMol benchmark tasks</td>
          <td>3 tasks</td>
          <td>825 configurations tested</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>AHC</strong>: REINVENT loss computed on top-k molecules per batch, ranked by reward</li>
<li><strong>Baselines</strong>: REINFORCE, REINVENT (v1, v2), BAR, Hill-Climb, Hill-Climb + KL regularization</li>
<li><strong>Hyperparameters</strong>: Default values from each original publication (listed in Supplementary Table S3)</li>
<li><strong>Docking</strong>: Glide-SP with Schrodinger Protein Preparation Wizard, LigPrep for ligand preparation</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>RNNs</strong>: 3 configurations (GRU/LSTM, 512 hidden units, trained 5-10 epochs)</li>
<li><strong>Transformer</strong>: 4 encoder layers, 512 hidden dim, 8 heads, 1024 FFN dim</li>
<li><strong>Gated Transformer</strong>: Same architecture with GRU-style gating replacing residual connections</li>
<li><strong>QSAR</strong>: Random forest classifiers (100 estimators, max depth 15, min leaf 2)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>AHC + DF2</th>
          <th>REINVENT</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Optimization fold-improvement</td>
          <td>1.45x</td>
          <td>baseline</td>
          <td>DRD2 docking, averaged across sigma values</td>
      </tr>
      <tr>
          <td>Sample efficiency</td>
          <td>45.5x fewer samples</td>
          <td>baseline</td>
          <td>Averaged across 4 GPCR targets</td>
      </tr>
      <tr>
          <td>Step efficiency</td>
          <td>7.4x fewer steps</td>
          <td>baseline</td>
          <td>Averaged across 4 GPCR targets</td>
      </tr>
      <tr>
          <td>CPU hours to 140% (DRD2 docking)</td>
          <td>16h</td>
          <td>202h (REINVENT 2.0)</td>
          <td>AMD Threadripper 1920 + RTX 2060 Super</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>AMD Threadripper 1920 CPU</li>
<li>Nvidia GeForce RTX 2060 Super GPU</li>
<li>DRD2 docking benchmark: 216 CPU hours for AHC to reach 200% optimization (~21.6h parallelized over 10 CPUs)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/MorganCThomas/SMILES-RNN">SMILES-RNN</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>RNN and transformer generative model code</td>
      </tr>
      <tr>
          <td><a href="https://github.com/MorganCThomas/MolScore">MolScore</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td><a href="/notes/chemistry/molecular-design/generation/evaluation/molscore-scoring-benchmarking-framework/">Scoring function platform</a></td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.6084/m9.figshare.19591024.v1">Figshare datasets</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>Supporting data (published under same license as paper)</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Thomas, M., O&rsquo;Boyle, N. M., Bender, A., &amp; de Graaf, C. (2022). Augmented Hill-Climb increases reinforcement learning efficiency for language-based de novo molecule generation. <em>Journal of Cheminformatics</em>, 14, 68.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{thomas2022augmented,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Augmented Hill-Climb increases reinforcement learning efficiency for language-based de novo molecule generation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Thomas, Morgan and O&#39;Boyle, Noel M. and Bender, Andreas and de Graaf, Chris}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{14}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{68}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-022-00646-z}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>PMO: Benchmarking Sample-Efficient Molecular Design</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/pmo-sample-efficient-molecular-optimization/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/pmo-sample-efficient-molecular-optimization/</guid><description>PMO benchmarks 25 molecular optimization algorithms across 23 tasks under a 10K oracle budget, finding older methods like REINVENT still lead.</description><content:encoded><![CDATA[<h2 id="a-standardized-benchmark-for-molecular-optimization">A Standardized Benchmark for Molecular Optimization</h2>
<p>This is a <strong>Resource</strong> paper that introduces PMO (Practical Molecular Optimization), an open-source benchmark for evaluating molecular optimization algorithms with a focus on sample efficiency. The primary contribution is not a new algorithm but a comprehensive evaluation framework that exposes blind spots in how the field measures progress. By benchmarking 25 methods across 23 oracle functions under a fixed budget of 10,000 oracle calls, the authors provide a standardized protocol for transparent and reproducible comparison of molecular design methods.</p>
<h2 id="the-missing-dimension-oracle-budget-in-molecular-design">The Missing Dimension: Oracle Budget in Molecular Design</h2>
<p>Molecular optimization is central to drug and materials discovery, and the field has seen rapid growth in computational methods. Despite this progress, the authors identify three persistent problems with how methods are evaluated:</p>
<ol>
<li>
<p><strong>Lack of oracle budget control</strong>: Most papers do not report how many candidate molecules were evaluated by the oracle to achieve their results, despite this number spanning orders of magnitude. In practice, the most valuable oracles (wet-lab experiments, high-accuracy simulations) are expensive, making sample efficiency critical.</p>
</li>
<li>
<p><strong>Trivial or self-designed oracles</strong>: Many papers only report on easy objectives like QED or penalized LogP, or introduce custom tasks that make cross-method comparison impossible.</p>
</li>
<li>
<p><strong>Insufficient handling of randomness</strong>: Many algorithms are stochastic, yet existing benchmarks examined no more than five methods and rarely reported variance across independent runs.</p>
</li>
</ol>
<p>Prior benchmarks such as <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a>, Therapeutics Data Commons (TDC), and Tripp et al.&rsquo;s analysis each suffer from at least one of these issues. PMO addresses all three simultaneously.</p>
<h2 id="the-pmo-benchmark-design">The PMO Benchmark Design</h2>
<p>The core innovation of PMO is its evaluation protocol rather than any single algorithmic contribution. The benchmark enforces three design principles:</p>
<p><strong>Oracle budget constraint</strong>: All methods are limited to 10,000 oracle calls. This is deliberately much smaller than the unconstrained budgets typical in the literature, reflecting the practical reality that experimental evaluations are costly.</p>
<p><strong>AUC-based metric</strong>: Instead of reporting only the final top-K score, PMO uses the area under the curve (AUC) of top-K average property value versus oracle calls:</p>
<p>$$
\text{AUC Top-}K = \int_{0}^{N} \bar{f}_{K}(n) , dn
$$</p>
<p>where $\bar{f}_{K}(n)$ is the average property value of the top $K$ molecules found after $n$ oracle calls, and $N = 10{,}000$. The paper uses $K = 10$. This metric rewards methods that reach high property values quickly, not just those that eventually converge given enough budget. All AUC values are min-max scaled to [0, 1].</p>
<p><strong>Standardized data</strong>: All methods use only the ZINC 250K dataset (approximately 250,000 molecules) whenever a database is required, ensuring a level playing field.</p>
<p>The benchmark includes 23 oracle functions: QED, <a href="https://en.wikipedia.org/wiki/Dopamine_receptor_D2">DRD2</a>, <a href="https://en.wikipedia.org/wiki/GSK-3">GSK3</a>-beta, <a href="https://en.wikipedia.org/wiki/C-Jun_N-terminal_kinase">JNK3</a>, and 19 oracles from <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> covering multi-property objectives (MPOs) based on similarity, molecular weight, CLogP, and other pharmaceutically relevant criteria. All oracle scores are normalized to [0, 1].</p>
<h2 id="25-methods-across-nine-algorithm-families">25 Methods Across Nine Algorithm Families</h2>
<p>The benchmark evaluates 25 molecular optimization methods organized along two dimensions: molecular assembly strategy (SMILES, SELFIES, atom-level graphs, fragment-level graphs, synthesis-based) and optimization algorithm (GA, MCTS, BO, VAE, GAN, score-based modeling, hill climbing, RL, gradient ascent). Each method was hyperparameter-tuned on two held-out tasks (zaleplon_mpo and perindopril_mpo) and then evaluated across all 23 oracles for 5 independent runs.</p>
<p>The following table summarizes the top 10 methods by sum of mean AUC Top-10 across all 23 tasks:</p>
<table>
  <thead>
      <tr>
          <th>Rank</th>
          <th>Method</th>
          <th>Assembly</th>
          <th>Sum AUC Top-10</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1</td>
          <td><a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a></td>
          <td>SMILES</td>
          <td>14.196</td>
      </tr>
      <tr>
          <td>2</td>
          <td>Graph GA</td>
          <td>Fragments</td>
          <td>13.751</td>
      </tr>
      <tr>
          <td>3</td>
          <td>SELFIES-REINVENT</td>
          <td>SELFIES</td>
          <td>13.471</td>
      </tr>
      <tr>
          <td>4</td>
          <td>GP BO</td>
          <td>Fragments</td>
          <td>13.156</td>
      </tr>
      <tr>
          <td>5</td>
          <td><a href="/notes/chemistry/molecular-design/generation/search-based/stoned-selfies-chemical-space-exploration/">STONED</a></td>
          <td>SELFIES</td>
          <td>13.024</td>
      </tr>
      <tr>
          <td>6</td>
          <td>LSTM HC</td>
          <td>SMILES</td>
          <td>12.223</td>
      </tr>
      <tr>
          <td>7</td>
          <td>SMILES GA</td>
          <td>SMILES</td>
          <td>12.054</td>
      </tr>
      <tr>
          <td>8</td>
          <td>SynNet</td>
          <td>Synthesis</td>
          <td>11.498</td>
      </tr>
      <tr>
          <td>9</td>
          <td>DoG-Gen</td>
          <td>Synthesis</td>
          <td>11.456</td>
      </tr>
      <tr>
          <td>10</td>
          <td>DST</td>
          <td>Fragments</td>
          <td>10.989</td>
      </tr>
  </tbody>
</table>
<p>The bottom five methods by overall ranking were GFlowNet-AL, Pasithea, JT-VAE, Graph MCTS, and MolDQN.</p>
<p>REINVENT is ranked first across all six metrics considered (AUC Top-1, AUC Top-10, AUC Top-100, Top-1, Top-10, Top-100). Graph GA is consistently second. Both methods were released several years before many of the methods they outperform, yet they are rarely used as baselines in newer work.</p>
<h2 id="key-findings-older-methods-win-and-selfies-offers-limited-advantage">Key Findings: Older Methods Win and SELFIES Offers Limited Advantage</h2>
<p>The benchmark yields several findings with practical implications:</p>
<p><strong>No method solves optimization within realistic budgets.</strong> None of the 25 methods can optimize the included objectives within hundreds of oracle calls (the scale at which experimental evaluations would be feasible), except for trivially easy oracles like QED, DRD2, and osimertinib_mpo.</p>
<p><strong>Older algorithms remain competitive.</strong> REINVENT (2017) and Graph GA (2019) outperform all newer methods tested, including those published at top AI conferences. The absence of standardized benchmarking had obscured this fact.</p>
<p><strong>SMILES versus SELFIES.</strong> <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> was designed to guarantee syntactically valid molecular strings, but head-to-head comparisons show that SELFIES-based variants of language model methods (REINVENT, LSTM HC, VAE) generally do not outperform their <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> counterparts. Modern language models learn SMILES grammar well enough that syntactic invalidity is no longer a practical issue. The one exception is genetic algorithms, where SELFIES-based GAs (<a href="/notes/chemistry/molecular-design/generation/search-based/stoned-selfies-chemical-space-exploration/">STONED</a>) outperform SMILES-based GAs, likely because SELFIES provides more intuitive mutation operations.</p>
<p><strong>Model-based methods need careful design.</strong> Model-based variants (GP BO relative to Graph GA, GFlowNet-AL relative to GFlowNet) do not consistently outperform their model-free counterparts. GP BO outperformed Graph GA in 12 of 23 tasks but underperformed on sum, and GFlowNet-AL underperformed GFlowNet in nearly every task. The bottleneck is the quality of the predictive surrogate model, and naive surrogate integration can actually hurt performance.</p>
<p><strong>Oracle landscape determines method suitability.</strong> Clustering analysis of relative AUC Top-10 scores reveals clear patterns. String-based GAs excel on isomer-type oracles (which are sums of atomic contributions), while RL-based and fragment-based methods perform better on similarity-based MPOs. This suggests there is no single best algorithm, and method selection should be informed by the optimization landscape.</p>
<p><strong>Hyperparameter tuning and multiple runs are essential.</strong> Optimal hyperparameters differed substantially from default values in original papers. For example, REINVENT&rsquo;s performance is highly sensitive to its sigma parameter, and the best value under the constrained-budget setting is much larger than originally suggested. Methods like Graph GA and GP BO also show high variance across runs, underscoring the importance of reporting distributional outcomes rather than single-run results.</p>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge several limitations: they cannot exhaustively tune every hyperparameter or include every variant of each method; the conclusion may be biased toward similarity-based oracles (which dominate the 23 tasks); important quantities like synthesizability and diversity are not thoroughly evaluated; and oracle calls from pre-training data in model-based methods are counted against the budget, which may disadvantage methods that could leverage prior data collection. For a follow-up study that adds property filters and diversity requirements to the PMO evaluation, see <a href="/notes/chemistry/molecular-design/generation/evaluation/sample-efficiency-de-novo-generation/">Re-evaluating Sample Efficiency</a>.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Molecule library</td>
          <td>ZINC 250K</td>
          <td>~250,000 molecules</td>
          <td>Used for screening, pre-training generative models, and fragment extraction</td>
      </tr>
      <tr>
          <td>Oracle functions</td>
          <td>TDC / GuacaMol</td>
          <td>23 tasks</td>
          <td>All scores normalized to [0, 1]</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>25 molecular optimization methods spanning 9 algorithm families and 5 molecular assembly strategies. Each method was hyperparameter-tuned on 2 held-out tasks (zaleplon_mpo, perindopril_mpo) using 3 independent runs, then evaluated on all 23 tasks with 5 independent runs each.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>AUC Top-K</td>
          <td>Area under curve of top-K average vs. oracle calls</td>
          <td>Primary metric; K=10; min-max scaled to [0, 1]</td>
      </tr>
      <tr>
          <td>Top-K</td>
          <td>Final top-K average property value at 10K calls</td>
          <td>Secondary metric</td>
      </tr>
      <tr>
          <td>Sum rank</td>
          <td>Sum of AUC Top-10 across all 23 tasks</td>
          <td>Used for overall ranking</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper states hardware details are in Appendix C.2. The benchmark runs on standard compute infrastructure and does not require GPUs for most methods. Specific compute requirements vary by method.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/wenhao-gao/mol_opt">mol_opt</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Full benchmark implementation with all 25 methods</td>
      </tr>
      <tr>
          <td><a href="https://figshare.com/articles/dataset/Results_for_practival_molecular_optimization_PMO_benchmark/20123453">Benchmark results</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>All experimental results from the paper</td>
      </tr>
      <tr>
          <td><a href="https://tdcommons.ai">TDC</a></td>
          <td>Dataset</td>
          <td>MIT</td>
          <td>Oracle functions and evaluation infrastructure</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{gao2022sample,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Sample Efficiency Matters: A Benchmark for Practical Molecular Optimization}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Gao, Wenhao and Fu, Tianfan and Sun, Jimeng and Coley, Connor W.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{35}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{21342--21357}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Gao, W., Fu, T., Sun, J., &amp; Coley, C. W. (2022). Sample Efficiency Matters: A Benchmark for Practical Molecular Optimization. <em>Advances in Neural Information Processing Systems</em>, 35, 21342-21357. <a href="https://arxiv.org/abs/2206.12411">https://arxiv.org/abs/2206.12411</a></p>
<p><strong>Publication</strong>: NeurIPS 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/wenhao-gao/mol_opt">PMO Benchmark Code (GitHub)</a></li>
<li><a href="https://figshare.com/articles/dataset/Results_for_practival_molecular_optimization_PMO_benchmark/20123453">Benchmark Results (Figshare)</a></li>
<li><a href="https://tdcommons.ai">Therapeutics Data Commons</a></li>
</ul>
]]></content:encoded></item><item><title>MolScore: Scoring and Benchmarking for Drug Design</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/molscore-scoring-benchmarking-framework/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/molscore-scoring-benchmarking-framework/</guid><description>MolScore provides a unified, open-source Python framework for scoring, evaluating, and benchmarking generative models applied to de novo drug design.</description><content:encoded><![CDATA[<h2 id="a-unified-resource-for-generative-molecular-design">A Unified Resource for Generative Molecular Design</h2>
<p>MolScore is a <strong>Resource</strong> paper that introduces an open-source Python framework for scoring, evaluating, and benchmarking generative models in de novo drug design. The primary contribution is the software itself: a modular, configurable platform that consolidates functionality previously scattered across multiple tools (GuacaMol, MOSES, MolOpt, REINVENT, TDC) into a single package. MolScore provides scoring functions for molecular optimization, evaluation metrics for assessing the quality of generated molecules, and a benchmark mode for standardized comparison of generative models.</p>
<h2 id="the-fragmented-landscape-of-generative-model-evaluation">The Fragmented Landscape of Generative Model Evaluation</h2>
<p>Generative models for molecular design have proliferated rapidly, but evaluating and comparing them remains difficult. Existing benchmarks each address only part of the problem:</p>
<ul>
<li><strong><a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a></strong> provides 20 fixed optimization objectives but cannot separate top-performing models on most tasks, and custom objectives require code modification.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a></strong> focuses on distribution-learning metrics but does not support molecular optimization.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/evaluation/pmo-sample-efficient-molecular-optimization/">MolOpt</a></strong> extends benchmark evaluation to 25 generative approaches but lacks evaluation of the quality of generated chemistry.</li>
<li><strong>Docking benchmarks</strong> (<a href="/notes/chemistry/molecular-design/generation/evaluation/smina-docking-benchmark/">smina-docking-benchmark</a>, <a href="/notes/chemistry/molecular-design/generation/evaluation/dockstring-docking-benchmarks-ligand-design/">DOCKSTRING</a>, TDC) test structure-based scoring but often lack proper ligand preparation, leading generative models to exploit non-holistic objectives by generating large or greasy molecules.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a></strong> provides configurable scoring functions but is tightly coupled to its own generative model architecture.</li>
</ul>
<p>No single tool offered configurable objectives, comprehensive evaluation metrics, generative-model-agnostic design, and graphical user interfaces together. This fragmentation forces practitioners to write custom glue code and makes reproducible comparison across methods difficult.</p>
<h2 id="modular-architecture-for-scoring-evaluation-and-benchmarking">Modular Architecture for Scoring, Evaluation, and Benchmarking</h2>
<p>MolScore is split into two sub-packages:</p>
<h3 id="molscore-molecule-scoring">molscore: Molecule Scoring</h3>
<p>The <code>molscore</code> sub-package handles iterative scoring of <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> generated by any generative model. The workflow for each iteration:</p>
<ol>
<li>Parse and validate SMILES via RDKit, canonicalize, and check intra-batch uniqueness.</li>
<li>Cross-reference against previously generated molecules to reuse cached scores (saving compute for expensive scoring functions like docking).</li>
<li>Run user-specified scoring functions on valid, unique molecules (invalid molecules receive a score of 0).</li>
<li>Transform each score to a 0-1 range using configurable transformation functions (normalize, linear threshold, Gaussian threshold, step threshold).</li>
<li>Aggregate transformed scores into a single desirability score using configurable aggregation (weighted sum, product, geometric mean, arithmetic mean, <a href="https://en.wikipedia.org/wiki/Pareto_front">Pareto front</a>, or auto-weighted variants).</li>
<li>Optionally apply diversity filters to penalize non-diverse molecules, or use any scoring function as a multiplicative filter.</li>
</ol>
<p>The full objective is specified in a single JSON configuration file, with a Streamlit GUI provided for interactive configuration writing. The available scoring functions span:</p>
<table>
  <thead>
      <tr>
          <th>Category</th>
          <th>Examples</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Descriptors</td>
          <td>RDKit descriptors, linker descriptors, penalized logP</td>
      </tr>
      <tr>
          <td>Similarity</td>
          <td>Fingerprint similarity, ROCS, Open3DAlign, substructure matching</td>
      </tr>
      <tr>
          <td>Predictive models</td>
          <td>Scikit-learn models, PIDGINv5 (2,337 ChEMBL31 targets), ChemProp, ADMET-AI</td>
      </tr>
      <tr>
          <td>Docking</td>
          <td>Glide, PLANTS, GOLD, OEDock, Smina, Gnina, Vina, rDock</td>
      </tr>
      <tr>
          <td>Synthesizability</td>
          <td>SA score, RA Score, AiZynthFinder, reaction filters</td>
      </tr>
  </tbody>
</table>
<p>Most scoring functions support multiprocessing, and computationally expensive functions (docking, ligand preparation) can be distributed across compute clusters via Dask.</p>
<h3 id="moleval-molecule-evaluation">moleval: Molecule Evaluation</h3>
<p>The <code>moleval</code> sub-package computes performance metrics on generated molecules relative to reference datasets. It extends the MOSES metric suite with additional intrinsic metrics (sphere exclusion diversity, scaffold uniqueness, functional group and ring system diversity, ZINC20 purchasability via molbloom) and extrinsic metrics (analogue similarity/coverage, functional group and ring system similarity, outlier bits or &ldquo;Silliness&rdquo;).</p>
<h3 id="benchmark-mode">Benchmark Mode</h3>
<p>A <code>MolScoreBenchmark</code> class iterates over a list of JSON configuration files, providing standardized comparison. Pre-built presets reimplement GuacaMol and MolOpt benchmarks, and users can define custom benchmark suites without writing code.</p>
<h2 id="case-studies-5-ht2a-ligand-design-and-fine-tuning-evaluation">Case Studies: 5-HT2A Ligand Design and Fine-Tuning Evaluation</h2>
<p>The authors demonstrate MolScore with a SMILES-based RNN generative model using <a href="/notes/chemistry/molecular-design/generation/rl-tuned/augmented-hill-climb-rl-molecule-generation/">Augmented Hill-Climb</a> for optimization, designing serotonin <a href="https://en.wikipedia.org/wiki/5-HT2A_receptor">5-HT2A</a> receptor ligands across three objective sets of increasing complexity.</p>
<h3 id="first-objective-set-basic-drug-properties">First Objective Set: Basic Drug Properties</h3>
<p>Four objectives combine predicted 5-HT2A activity (via PIDGINv5 random forest models at 1 uM) with synthesizability (RAscore) and/or <a href="https://en.wikipedia.org/wiki/Blood%E2%80%93brain_barrier">BBB</a> permeability property ranges (<a href="https://en.wikipedia.org/wiki/Polar_surface_area">TPSA</a> &lt; 70, HBD &lt; 2, logP 2-4, MW &lt; 400). All objectives were optimized successfully, with diversity filters preventing mode collapse. The most difficult single objective (5-HT2A activity alone) was hardest primarily because the diversity filter more heavily penalized similar molecules for this relatively easy task.</p>
<h3 id="second-objective-set-selectivity">Second Objective Set: Selectivity</h3>
<p>Six objectives incorporate selectivity proxies using PIDGINv5 models for off-target prediction against <a href="https://en.wikipedia.org/wiki/G_protein-coupled_receptor">Class A GPCR</a> membrane receptors (266 models), the <a href="https://en.wikipedia.org/wiki/Dopamine_receptor_D2">D2 dopamine receptor</a>, dopamine receptor family, serotonin receptor subtypes, and combinations. These proved substantially harder: selectivity against dopamine and serotonin receptor families combined was barely improved during optimization. Even with imperfect predictive models, the PIDGINv5 ensemble correctly identified 95 of 126 known selective 5-HT2A ligands. Nearest-neighbor analysis of de novo molecules (<a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> 0.3-0.6) showed they tended to be structurally simpler versions of known selective ligands.</p>
<h3 id="third-objective-set-structure-based-docking">Third Objective Set: Structure-Based Docking</h3>
<p>Two objectives use molecular docking via GlideSP into 5-HT2A (PDB: 6A93) and D2 (PDB: 6CM4) crystal structures with full ligand preparation (LigPrep for stereoisomer/tautomer/protonation state enumeration). Multi-parameter optimization includes docking score, D155 polar interaction constraint, formal charge, and consecutive rotatable bond limits. Single-target docking scores reached the mean of known ligands within 200 steps, but optimizing for divergent 5-HT2A vs D2 docking scores was much harder due to binding pocket similarity. Protein-ligand interaction fingerprint analysis (ProLIF) revealed that molecules optimized for selectivity avoided specific binding pocket regions shared between the two receptors.</p>
<h3 id="evaluation-case-study-fine-tuning-epochs">Evaluation Case Study: Fine-Tuning Epochs</h3>
<p>The moleval sub-package was used to track metrics across fine-tuning epochs of a SMILES RNN on A2A receptor ligands, showing that just one or two epochs sufficed to increase similarity to the fine-tuning set, while further epochs reduced novelty and diversity.</p>
<h2 id="configurable-benchmarking-with-practical-drug-design-relevance">Configurable Benchmarking with Practical Drug Design Relevance</h2>
<p>MolScore provides a more comprehensive platform than any single existing tool. Compared to prior work:</p>
<table>
  <thead>
      <tr>
          <th>Feature</th>
          <th>GuacaMol</th>
          <th>MOSES</th>
          <th>MolOpt</th>
          <th>TDC</th>
          <th>REINVENT</th>
          <th>MolScore</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Configurable objectives</td>
          <td>No</td>
          <td>N/A</td>
          <td>No</td>
          <td>No</td>
          <td>Yes</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>Optimization objectives</td>
          <td>Yes</td>
          <td>No</td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>Evaluation metrics</td>
          <td>Yes</td>
          <td>Yes</td>
          <td>No</td>
          <td>No</td>
          <td>No</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>Model-agnostic</td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
          <td>No</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>GUI</td>
          <td>No</td>
          <td>No</td>
          <td>No</td>
          <td>No</td>
          <td>Yes</td>
          <td>Yes</td>
      </tr>
  </tbody>
</table>
<p>The framework integrates into any Python-based generative model in three lines of code. Dependency conflicts between scoring function libraries are handled by running conflicting components as local servers from isolated conda environments.</p>
<p>Key limitations acknowledged by the authors include: the assumption of conda for environment management, the inherent difficulty of designing non-exploitable objectives, and the fact that ligand-based predictive models may have limited applicability domains for out-of-distribution de novo molecules.</p>
<p>Future directions include accepting 3D molecular conformations as inputs, structure interaction fingerprint rescoring, and dynamic configuration files for curriculum learning.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ChEMBL compounds</td>
          <td>Not specified</td>
          <td>Standard ChEMBL training set for SMILES RNN</td>
      </tr>
      <tr>
          <td>Evaluation reference</td>
          <td>5-HT2A ligands from ChEMBL31</td>
          <td>3,771 compounds</td>
          <td>Extracted for score distribution comparison</td>
      </tr>
      <tr>
          <td>Activity models</td>
          <td>PIDGINv5 on ChEMBL31</td>
          <td>2,337 target models</td>
          <td>Random forest classifiers at various concentration thresholds</td>
      </tr>
      <tr>
          <td>Fine-tuning</td>
          <td>A2A receptor ligands</td>
          <td>Not specified</td>
          <td>Used for moleval case study</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The generative model used in case studies is a SMILES-based RNN with Augmented Hill-Climb reinforcement learning. Diversity filters penalize non-diverse molecules during optimization. Score transformation functions (normalize, linear threshold, Gaussian threshold, step threshold) map raw scores to 0-1 range. Aggregation functions (arithmetic mean, weighted sum, product, geometric mean, Pareto front) combine multi-parameter objectives.</p>
<h3 id="models">Models</h3>
<p>PIDGINv5 provides 2,337 pre-trained random forest classifiers on ChEMBL31 targets. RAscore provides pre-trained synthesizability prediction. ADMET-AI and ChemProp models are supported via isolated environments. Docking uses GlideSP with LigPrep for ligand preparation in the structure-based case study.</p>
<h3 id="evaluation">Evaluation</h3>
<p>Intrinsic metrics: validity, uniqueness, scaffold uniqueness, internal diversity, sphere exclusion diversity, Solow-Polasky diversity, scaffold diversity, functional group diversity, ring system diversity, MCF and <a href="https://en.wikipedia.org/wiki/Pan-assay_interference_compounds">PAINS</a> filters, ZINC20 purchasability.</p>
<p>Extrinsic metrics: novelty, <a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">FCD</a>, analogue similarity/coverage, functional group similarity, ring system similarity, SNN similarity, fragment similarity, scaffold similarity, outlier bits, Wasserstein distance on LogP/SA/NP/QED/MW.</p>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper. Docking-based objectives can be distributed across compute clusters via Dask.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/MorganCThomas/MolScore">MolScore</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Main framework, installable via pip</td>
      </tr>
      <tr>
          <td><a href="https://github.com/MorganCThomas/MolScore_examples">MolScore Examples</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Integration examples with SMILES-RNN, CReM, GraphGA</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Thomas, M., O&rsquo;Boyle, N. M., Bender, A., &amp; de Graaf, C. (2024). MolScore: a scoring, evaluation and benchmarking framework for generative models in de novo drug design. <em>Journal of Cheminformatics</em>, 16(1), 64. <a href="https://doi.org/10.1186/s13321-024-00861-w">https://doi.org/10.1186/s13321-024-00861-w</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{thomas2024molscore,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{MolScore: a scoring, evaluation and benchmarking framework for generative models in de novo drug design}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Thomas, Morgan and O&#39;Boyle, Noel M. and Bender, Andreas and de Graaf, Chris}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{16}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{64}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{BioMed Central}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-024-00861-w}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>GuacaMol: Benchmarking Models for De Novo Molecular Design</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/</guid><description>GuacaMol introduces a standardized benchmark suite for evaluating de novo molecular design models across distribution learning and goal-directed optimization.</description><content:encoded><![CDATA[<h2 id="a-standardized-benchmark-for-molecular-design">A Standardized Benchmark for Molecular Design</h2>
<p>GuacaMol is a <strong>Resource</strong> paper. Its primary contribution is a standardized, open-source benchmarking framework for evaluating models for de novo molecular design. The framework defines 5 distribution-learning benchmarks and 20 goal-directed optimization benchmarks, implemented as a Python package. The authors also provide baseline results for several classical and neural generative models, establishing reference performance levels for future comparisons.</p>
<h2 id="the-need-for-consistent-evaluation-in-generative-chemistry">The Need for Consistent Evaluation in Generative Chemistry</h2>
<p>By 2018, deep generative models for molecular design (<a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">VAEs</a>, RNNs, <a href="/posts/what-is-a-gan/">GANs</a>) had shown promising results, but the field lacked consistent evaluation standards. Different papers used different tasks, different datasets, and different metrics, making it difficult to compare models or assess real progress. Comparative studies between neural approaches and well-established algorithms like genetic algorithms were rare.</p>
<p>In other areas of machine learning, standardized benchmarks (ImageNet for vision, GLUE for NLP) had driven rapid progress by enabling fair comparisons. The de novo design community lacked an equivalent. Additionally, many existing evaluations focused on easily optimizable properties (logP, QED) that could not differentiate between models, since even simple baselines achieved near-perfect scores on those tasks.</p>
<h2 id="benchmark-design-distribution-learning-and-goal-directed-optimization">Benchmark Design: Distribution Learning and Goal-Directed Optimization</h2>
<p>GuacaMol separates evaluation into two independent dimensions, reflecting the two main use cases of generative models.</p>
<h3 id="distribution-learning-benchmarks">Distribution-Learning Benchmarks</h3>
<p>These five benchmarks assess how well a model learns to generate molecules similar to a training set (a standardized subset of ChEMBL 24):</p>
<ol>
<li><strong>Validity</strong>: Fraction of generated molecules that are chemically valid (parseable by RDKit), measured over 10,000 generated samples.</li>
<li><strong>Uniqueness</strong>: Fraction of unique canonical <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> among 10,000 valid generated molecules.</li>
<li><strong>Novelty</strong>: Fraction of generated molecules not present in the training set, measured over 10,000 unique samples.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">Fréchet ChemNet Distance</a> (FCD)</strong>: Measures distributional similarity between generated and reference molecules using hidden representations from ChemNet (trained on biological activity prediction). The FCD score is transformed as:</li>
</ol>
<p>$$S = \exp(-0.2 \cdot \text{FCD})$$</p>
<ol start="5">
<li><strong>KL Divergence</strong>: Compares distributions of nine physicochemical descriptors (BertzCT, MolLogP, MolWt, TPSA, NumHAcceptors, NumHDonors, NumRotatableBonds, NumAliphaticRings, NumAromaticRings) plus maximum nearest-neighbor ECFP4 similarity. The final score aggregates per-descriptor KL divergences:</li>
</ol>
<p>$$S = \frac{1}{k} \sum_{i}^{k} \exp(-D_{\text{KL}, i})$$</p>
<p>where $k = 9$ is the number of descriptors.</p>
<h3 id="goal-directed-benchmarks">Goal-Directed Benchmarks</h3>
<p>The 20 goal-directed benchmarks evaluate a model&rsquo;s ability to generate molecules that maximize a given scoring function. These span several categories:</p>
<ul>
<li><strong>Rediscovery</strong> (3 tasks): Regenerate a specific target molecule (Celecoxib, Troglitazone, Thiothixene) using Tanimoto similarity on ECFP4 fingerprints.</li>
<li><strong>Similarity</strong> (3 tasks): Generate many molecules similar to a target (Aripiprazole, Albuterol, Mestranol) above a threshold of 0.75.</li>
<li><strong>Isomers</strong> (2 tasks): Generate molecules matching a target molecular formula ($\text{C}_{11}\text{H}_{24}$ and $\text{C}_9\text{H}_{10}\text{N}_2\text{O}_2\text{PF}_2\text{Cl}$).</li>
<li><strong>Median molecules</strong> (2 tasks): Maximize similarity to two reference molecules simultaneously (camphor/menthol and tadalafil/sildenafil).</li>
<li><strong>Multi-property optimization</strong> (7 tasks): Optimize combinations of similarity, physicochemical properties, and structural features for drug-relevant molecules (Osimertinib, Fexofenadine, Ranolazine, Perindopril, Amlodipine, Sitagliptin, Zaleplon).</li>
<li><strong>SMARTS-based</strong> (1 task): Target molecules containing specific substructure patterns with constrained physicochemical properties (Valsartan SMARTS).</li>
<li><strong>Scaffold/decorator hop</strong> (2 tasks): Modify molecular scaffolds while preserving substituent patterns, or vice versa.</li>
</ul>
<p>The benchmark score for most goal-directed tasks combines top-1, top-10, and top-100 molecule scores:</p>
<p>$$S = \frac{1}{3}\left(s_1 + \frac{1}{10}\sum_{i=1}^{10} s_i + \frac{1}{100}\sum_{i=1}^{100} s_i\right)$$</p>
<p>where $s_i$ are molecule scores sorted in decreasing order.</p>
<h3 id="score-modifiers">Score Modifiers</h3>
<p>Raw molecular properties are transformed via modifier functions to restrict scores to [0, 1]:</p>
<ul>
<li><strong>Gaussian($\mu$, $\sigma$)</strong>: Targets a specific property value</li>
<li><strong>MinGaussian($\mu$, $\sigma$)</strong>: Full score below $\mu$, decreasing above</li>
<li><strong>MaxGaussian($\mu$, $\sigma$)</strong>: Full score above $\mu$, decreasing below</li>
<li><strong>Thresholded($t$)</strong>: Full score above threshold $t$, linear decrease below</li>
</ul>
<p>Multi-property objectives use either arithmetic or geometric means to combine individual scores.</p>
<h2 id="baseline-models-and-experimental-setup">Baseline Models and Experimental Setup</h2>
<p>The authors evaluate six baseline models spanning different paradigms:</p>
<p><strong>Distribution-learning baselines:</strong></p>
<ul>
<li><strong>Random sampler</strong>: Samples molecules directly from the dataset (provides upper/lower bounds).</li>
<li><strong>SMILES LSTM</strong>: 3-layer LSTM (hidden size 1024) trained to predict next SMILES characters.</li>
<li><strong>Graph MCTS</strong>: Monte Carlo Tree Search building molecules atom-by-atom.</li>
<li><strong>VAE</strong>: Variational autoencoder on SMILES representations.</li>
<li><strong>AAE</strong>: Adversarial autoencoder.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/">ORGAN</a></strong>: Objective-reinforced generative adversarial network.</li>
</ul>
<p><strong>Goal-directed baselines:</strong></p>
<ul>
<li><strong>Best of dataset</strong>: Scores all training molecules and returns the best (virtual screening baseline).</li>
<li><strong>SMILES LSTM</strong>: Same model with 20 iterations of hill-climbing (8192 samples per iteration, top 1024 for fine-tuning).</li>
<li><strong>SMILES GA</strong>: Genetic algorithm operating on SMILES strings with grammar-based mutations.</li>
<li><strong>Graph GA</strong>: Genetic algorithm operating on molecular graphs with crossover and mutation.</li>
<li><strong>Graph MCTS</strong>: Monte Carlo Tree Search with 40 simulations per molecule.</li>
</ul>
<p>The training dataset is ChEMBL 24 after filtering: salt removal, charge neutralization, SMILES length cap of 100, element restrictions, and removal of molecules similar (ECFP4 &gt; 0.323) to 10 held-out drug molecules used in benchmarks.</p>
<h3 id="distribution-learning-results">Distribution-Learning Results</h3>
<table>
  <thead>
      <tr>
          <th>Benchmark</th>
          <th>Random</th>
          <th>SMILES LSTM</th>
          <th>Graph MCTS</th>
          <th>AAE</th>
          <th>ORGAN</th>
          <th>VAE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity</td>
          <td>1.000</td>
          <td>0.959</td>
          <td>1.000</td>
          <td>0.822</td>
          <td>0.379</td>
          <td>0.870</td>
      </tr>
      <tr>
          <td>Uniqueness</td>
          <td>0.997</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>0.841</td>
          <td>0.999</td>
      </tr>
      <tr>
          <td>Novelty</td>
          <td>0.000</td>
          <td>0.912</td>
          <td>0.994</td>
          <td>0.998</td>
          <td>0.687</td>
          <td>0.974</td>
      </tr>
      <tr>
          <td>KL divergence</td>
          <td>0.998</td>
          <td>0.991</td>
          <td>0.522</td>
          <td>0.886</td>
          <td>0.267</td>
          <td>0.982</td>
      </tr>
      <tr>
          <td>FCD</td>
          <td>0.929</td>
          <td>0.913</td>
          <td>0.015</td>
          <td>0.529</td>
          <td>0.000</td>
          <td>0.863</td>
      </tr>
  </tbody>
</table>
<h3 id="goal-directed-results-selected">Goal-Directed Results (Selected)</h3>
<table>
  <thead>
      <tr>
          <th>Benchmark</th>
          <th>Best of Dataset</th>
          <th>SMILES LSTM</th>
          <th>SMILES GA</th>
          <th>Graph GA</th>
          <th>Graph MCTS</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Celecoxib rediscovery</td>
          <td>0.505</td>
          <td>1.000</td>
          <td>0.732</td>
          <td>1.000</td>
          <td>0.355</td>
      </tr>
      <tr>
          <td>Osimertinib MPO</td>
          <td>0.839</td>
          <td>0.907</td>
          <td>0.886</td>
          <td>0.953</td>
          <td>0.784</td>
      </tr>
      <tr>
          <td>Sitagliptin MPO</td>
          <td>0.509</td>
          <td>0.545</td>
          <td>0.689</td>
          <td>0.891</td>
          <td>0.458</td>
      </tr>
      <tr>
          <td>Scaffold Hop</td>
          <td>0.738</td>
          <td>0.998</td>
          <td>0.885</td>
          <td>1.000</td>
          <td>0.478</td>
      </tr>
      <tr>
          <td><strong>Total (20 tasks)</strong></td>
          <td><strong>12.144</strong></td>
          <td><strong>17.340</strong></td>
          <td><strong>14.396</strong></td>
          <td><strong>17.983</strong></td>
          <td><strong>9.009</strong></td>
      </tr>
  </tbody>
</table>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="main-findings">Main Findings</h3>
<p>The Graph GA achieves the highest total score across goal-directed benchmarks (17.983), followed closely by the SMILES LSTM (17.340). This result is notable because genetic algorithms are well-established methods, and the LSTM-based neural approach nearly matches their optimization performance.</p>
<p>However, compound quality tells a different story. When examining the top 100 molecules per task through chemical quality filters (SureChEMBL, Glaxo, PAINS rules), 77% of LSTM-generated molecules pass, matching the Best of ChEMBL baseline. In contrast, Graph GA produces only 40% passing molecules, and Graph MCTS only 22%. This suggests that neural models benefit from pre-training on real molecular distributions, which encodes implicit knowledge about what constitutes a &ldquo;reasonable&rdquo; molecule.</p>
<p><a href="/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/">ORGAN</a> performs poorly across all distribution-learning tasks, with more than half its generated molecules being invalid. This is consistent with mode collapse, a known problem in GAN training.</p>
<p>Simpler generative models (LSTM, VAE) outperform more complex architectures (ORGAN, AAE) on distribution learning. Graph MCTS struggles with both distribution learning and goal-directed optimization, suggesting that single-molecule search trees are less effective than population-based approaches.</p>
<h3 id="limitations">Limitations</h3>
<p>The authors explicitly identify several issues:</p>
<ul>
<li><strong>Compound quality is hard to quantify</strong>: The rule-based filters used are acknowledged as &ldquo;high precision, low recall&rdquo; surrogates. They catch some problematic molecules but cannot encode the full breadth of medicinal chemistry expertise.</li>
<li><strong>Some benchmarks are too easy</strong>: The trivially optimizable tasks (logP, QED, CNS MPO) cannot differentiate between models. All baselines achieve near-perfect scores on these.</li>
<li><strong>Sample efficiency and runtime are not benchmarked</strong>: The framework does not penalize models for requiring excessive scoring function calls.</li>
<li><strong>Synthesis accessibility is not addressed</strong>: Generated molecules may be valid but impractical to synthesize.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The authors call for harder benchmark tasks, better compound quality metrics, attention to sample efficiency and runtime constraints, and further development of graph-based neural generative models.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>ChEMBL 24 (post-processed)</td>
          <td>~1.6M molecules</td>
          <td>Salt removal, neutralization, SMILES length cap, element restrictions</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>10 held-out drug molecules</td>
          <td>10</td>
          <td>Removed from training set via ECFP4 similarity threshold</td>
      </tr>
      <tr>
          <td>Quality filters</td>
          <td>SureChEMBL, Glaxo, PAINS, in-house rules</td>
          <td>N/A</td>
          <td>Applied via rd_filters</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>SMILES LSTM</strong>: 3-layer LSTM, hidden size 1024; hill-climbing with 20 iterations, 8192 samples per iteration, top 1024 for fine-tuning</li>
<li><strong>Graph GA</strong>: Population of 100, mating pool of 200, crossover + mutation (probability 0.5), 1000 epochs max</li>
<li><strong>SMILES GA</strong>: Population of 300, offspring of 600, SMILES grammar-based mutations, 1000 epochs max</li>
<li><strong>Graph MCTS</strong>: 40 simulations per molecule, 25 children per step, rollout to 60 atoms, starting from CC</li>
</ul>
<h3 id="models">Models</h3>
<p>All baseline implementations are released as open-source code. VAE, AAE, and ORGAN implementations are from the <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a> repository.</p>
<h3 id="evaluation">Evaluation</h3>
<p>All distribution-learning benchmarks sample 10,000 molecules. Goal-directed benchmarks use combinations of top-1, top-10, and top-100 scores. Compound quality is assessed via the percentage of top-100 molecules passing chemical filters.</p>
<h3 id="hardware">Hardware</h3>
<p>Hardware requirements are not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/BenevolentAI/guacamol">GuacaMol</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Benchmarking framework and scoring functions</td>
      </tr>
      <tr>
          <td><a href="https://github.com/BenevolentAI/guacamol_baselines">GuacaMol Baselines</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Baseline model implementations</td>
      </tr>
      <tr>
          <td><a href="https://figshare.com/projects/GuacaMol/56639">ChEMBL dataset</a></td>
          <td>Dataset</td>
          <td>CC-BY-SA 3.0</td>
          <td>Post-processed ChEMBL 24 for benchmarks</td>
      </tr>
      <tr>
          <td><a href="https://github.com/bioinf-jku/FCD">FCD package</a></td>
          <td>Code</td>
          <td>LGPL-3.0</td>
          <td>Fréchet ChemNet Distance implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Brown, N., Fiscato, M., Segler, M. H. S., &amp; Vaucher, A. C. (2019). GuacaMol: Benchmarking Models for De Novo Molecular Design. <em>Journal of Chemical Information and Modeling</em>, 59(3), 1096-1108. <a href="https://doi.org/10.1021/acs.jcim.8b00839">https://doi.org/10.1021/acs.jcim.8b00839</a></p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/BenevolentAI/guacamol">GuacaMol Python package</a></li>
<li><a href="https://github.com/BenevolentAI/guacamol_baselines">GuacaMol baselines</a></li>
<li><a href="https://figshare.com/projects/GuacaMol/56639">Post-processed ChEMBL datasets</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{brown2019guacamol,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{GuacaMol: Benchmarking Models for de Novo Molecular Design}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Brown, Nathan and Fiscato, Marco and Segler, Marwin H. S. and Vaucher, Alain C.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{59}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{3}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1096--1108}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2019}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.8b00839}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Graph-Based GA and MCTS Generative Model for Molecules</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/search-based/graph-based-genetic-algorithm-chemical-space/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/search-based/graph-based-genetic-algorithm-chemical-space/</guid><description>Jensen introduces a graph-based genetic algorithm and generative model with MCTS that outperforms ML methods for penalized logP optimization.</description><content:encoded><![CDATA[<h2 id="a-graph-based-approach-to-molecular-optimization">A Graph-Based Approach to Molecular Optimization</h2>
<p>This is a <strong>Method</strong> paper that introduces two graph-based approaches for exploring chemical space: a genetic algorithm (GB-GA) and a generative model combined with <a href="https://en.wikipedia.org/wiki/Monte_Carlo_tree_search">Monte Carlo tree search</a> (GB-GM-MCTS). The primary contribution is demonstrating that these non-ML, graph-based methods can match or exceed the performance of contemporary ML-based generative models for molecular property optimization, while being several orders of magnitude faster. The paper provides open-source implementations built on the RDKit cheminformatics package. The two approaches explore <a href="https://en.wikipedia.org/wiki/Chemical_space">chemical space</a> using direct graph manipulations rather than string-based representations like <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>.</p>
<h2 id="why-compare-simple-baselines-to-ml-generative-models">Why Compare Simple Baselines to ML Generative Models?</h2>
<p>By 2018, several ML-based generative models for molecules had been published, including VAEs, RNNs, and graph convolutional policy networks. However, these models were rarely compared against traditional optimization approaches such as genetic algorithms. Jensen identifies this gap explicitly: while ML generative model performance had been impressive, the lack of comparison to simpler baselines made it difficult to assess whether the complexity of ML approaches was justified.</p>
<p>A practical barrier to such comparisons was the absence of free, open-source GA implementations for molecular optimization (the existing ACSESS algorithm required proprietary OpenEye toolkits). This paper fills that gap by providing RDKit-based implementations of both the GB-GA and GB-GM-MCTS.</p>
<h2 id="graph-based-crossovers-mutations-and-monte-carlo-tree-search">Graph-Based Crossovers, Mutations, and Monte Carlo Tree Search</h2>
<h3 id="gb-ga-crossovers-and-mutations-on-molecular-graphs">GB-GA: Crossovers and Mutations on Molecular Graphs</h3>
<p>The GB-GA operates directly on molecular graph representations (not string representations like SMILES). It combines ideas from Brown et al. (2004) and the ACSESS algorithm of Virshup et al. (2013).</p>
<p><strong>Crossovers</strong> can occur at two types of positions with equal probability:</p>
<ul>
<li>Non-ring bonds: a molecule is cut at a non-ring bond, and fragments from two parent molecules are recombined</li>
<li>Ring bonds: adjacent bonds or bonds separated by one bond are cut, and fragments are mated using single or double bonds</li>
</ul>
<p><strong>Mutations</strong> include seven operation types, each with specified probabilities:</p>
<ul>
<li>Append atom (15%): adds an atom with a single, double, or triple bond</li>
<li>Insert atom (15%): inserts an atom into an existing bond</li>
<li>Delete atom (14%): removes an atom, reconnecting neighbors</li>
<li>Change atom type (14%): swaps element identity (C, N, O, F, S, Cl, Br)</li>
<li>Change bond order (14%): toggles between single, double, and triple bonds</li>
<li>Delete ring bond (14%): opens a ring</li>
<li>Add ring bond (14%): closes a new ring</li>
</ul>
<p>Molecules with macrocycles (seven or more atoms), allene centers in rings, fewer than five heavy atoms, incorrect valences, or more non-H atoms than the target size are discarded. The target size is sampled from a normal distribution with mean 39.15 and standard deviation 3.50 non-H atoms, calibrated to match the molecules found by Yang et al. (2017).</p>
<h3 id="gb-gm-mcts-a-probabilistic-growth-model-with-tree-search">GB-GM-MCTS: A Probabilistic Growth Model with Tree Search</h3>
<p>The GB-GM grows molecules one atom at a time, with the choice of bond order and atom type determined probabilistically from a bonding analysis of a reference dataset (the first 1000 molecules from ZINC). Since 63% of atoms in the reference set are ring atoms, ring-creation or ring-insertion mutations are chosen 63% of the time.</p>
<p>The generative model is combined with a <a href="https://en.wikipedia.org/wiki/Monte_Carlo_tree_search">Monte Carlo tree search</a> where:</p>
<ul>
<li>Each node corresponds to an atom addition step</li>
<li>Leaf parallelization uses a maximum of 25 leaf nodes</li>
<li>The exploration factor is $1 / \sqrt{2}$</li>
<li>Rollout terminates if the molecule exceeds the target size</li>
<li>The reward function returns 1 if the predicted $J(\mathbf{m})$ value exceeds the largest value found so far, and 0 otherwise</li>
</ul>
<h3 id="the-penalized-logp-objective">The Penalized logP Objective</h3>
<p>Both methods optimize the penalized logP score $J(\mathbf{m})$:</p>
<p>$$
J(\mathbf{m}) = \log P(\mathbf{m}) - \text{SA}(\mathbf{m}) - \text{RingPenalty}(\mathbf{m})
$$</p>
<p>where $\log P(\mathbf{m})$ is the <a href="https://en.wikipedia.org/wiki/Partition_coefficient">octanol-water partition coefficient</a> predicted by RDKit, $\text{SA}(\mathbf{m})$ is a synthetic accessibility score, and $\text{RingPenalty}(\mathbf{m})$ penalizes unrealistically large rings by reducing the score by $\text{RingSize} - 6$ for each oversized ring. Each property is normalized to zero mean and unit standard deviation across the ZINC dataset.</p>
<h2 id="experimental-setup-and-comparisons-to-ml-methods">Experimental Setup and Comparisons to ML Methods</h2>
<h3 id="gb-ga-experiments">GB-GA Experiments</h3>
<p>Ten GA simulations were performed with a population size of 20 over 50 generations (1000 $J(\mathbf{m})$ evaluations per run). The initial mating pool was 20 random molecules from the first 1000 molecules in ZINC. Two mutation rates were tested: 50% and 1%.</p>
<h3 id="gb-gm-mcts-experiments">GB-GM-MCTS Experiments</h3>
<p>Ten simulations used ethane as a seed molecule with 1000 tree traversals per run. Additional experiments used 5000 traversals and an adjusted probability of generating $\text{C}=\text{C}-\text{C}$ ring patterns (increased from 62% to 80%).</p>
<h3 id="baselines">Baselines</h3>
<p>Results were compared to those compiled by Yang et al. (2017):</p>
<ul>
<li>ChemTS (RNN + MCTS)</li>
<li>RNN with and without Bayesian optimization</li>
<li><a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">Continuous VAE (CVAE)</a></li>
<li><a href="/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/">Grammar VAE (GVAE)</a></li>
<li>Graph convolutional policy network (GCPN, from You et al. 2018)</li>
</ul>
<h3 id="key-results">Key Results</h3>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Average $J(\mathbf{m})$</th>
          <th>Molecules Evaluated</th>
          <th>CPU Time</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GB-GA (50% mutation)</td>
          <td>6.8 +/- 0.7</td>
          <td>1000</td>
          <td>30 seconds</td>
      </tr>
      <tr>
          <td>GB-GA (1% mutation)</td>
          <td>7.4 +/- 0.9</td>
          <td>1000</td>
          <td>30 seconds</td>
      </tr>
      <tr>
          <td>GB-GM-MCTS (62%)</td>
          <td>2.6 +/- 0.6</td>
          <td>1000</td>
          <td>90 seconds</td>
      </tr>
      <tr>
          <td>GB-GM-MCTS (80%)</td>
          <td>3.4 +/- 0.6</td>
          <td>1000</td>
          <td>90 seconds</td>
      </tr>
      <tr>
          <td>GB-GM-MCTS (80%)</td>
          <td>4.3 +/- 0.6</td>
          <td>5000</td>
          <td>9 minutes</td>
      </tr>
      <tr>
          <td>ChemTS</td>
          <td>4.9 +/- 0.5</td>
          <td>~5000</td>
          <td>2 hours</td>
      </tr>
      <tr>
          <td>ChemTS</td>
          <td>5.6 +/- 0.5</td>
          <td>~20000</td>
          <td>8 hours</td>
      </tr>
      <tr>
          <td>RNN + BO</td>
          <td>4.5 +/- 0.2</td>
          <td>~4000</td>
          <td>8 hours</td>
      </tr>
      <tr>
          <td>Only RNN</td>
          <td>4.8 +/- 0.2</td>
          <td>~20000</td>
          <td>8 hours</td>
      </tr>
      <tr>
          <td>CVAE + BO</td>
          <td>0.0 +/- 0.9</td>
          <td>~100</td>
          <td>8 hours</td>
      </tr>
      <tr>
          <td>GVAE + BO</td>
          <td>0.2 +/- 1.3</td>
          <td>~1000</td>
          <td>8 hours</td>
      </tr>
  </tbody>
</table>
<p>The GB-GA with 1% mutation rate achieved an average maximum $J(\mathbf{m})$ of 7.4, which is 1.8 units higher than the best ML result (ChemTS at 5.6) while using 20x fewer evaluations and completing in 30 seconds versus 8 hours. The two highest-scoring individual molecules found by GB-GA had $J(\mathbf{m})$ scores of 8.8 and 8.5, exceeding the 7.8-8.0 range found by the GCPN approach. These molecules bore little resemblance to the initial mating pool (<a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarities</a> of 0.27 and 0.12 to the most similar ZINC molecules), indicating that the GA traversed a large distance in chemical space in just 50 generations.</p>
<p>The GB-GM-MCTS performed below ChemTS at equal evaluations (4.3 vs. 4.9 at 5000 evaluations) but was several orders of magnitude faster (9 minutes vs. 2 hours). The MCTS approach tended to extract the dominant hydrophobic structural motif (benzene rings) from the training set, making it more dependent on training set composition than the GA.</p>
<h2 id="simple-methods-set-a-high-bar-for-molecular-optimization">Simple Methods Set a High Bar for Molecular Optimization</h2>
<p>The central finding is that a simple graph-based genetic algorithm outperforms all tested ML-based generative models on penalized logP optimization, both in terms of solution quality and computational efficiency. The GB-GA achieves higher $J(\mathbf{m})$ scores with 1000 evaluations in 30 seconds than ML methods achieve with 20,000 evaluations over 8 hours.</p>
<p>Several additional observations emerge:</p>
<ol>
<li><strong>Chemical space traversal</strong>: The GB-GA can reach high-scoring molecules that are structurally distant from the starting population, with Tanimoto similarity as low as 0.12 to the nearest ZINC molecule.</li>
<li><strong>Mutation rate matters</strong>: A 1% mutation rate outperformed a 50% rate (7.4 vs. 6.8), suggesting that preserving more parental structure during crossover is beneficial for this objective.</li>
<li><strong>Training set dependence</strong>: The GB-GM-MCTS is more sensitive to training set composition than the GA. Its preference for benzene-ring-containing molecules (the dominant ZINC motif) limits its ability to discover alternative structural solutions like the long aliphatic chains favored by the GA.</li>
<li><strong>Generalizability caveat</strong>: Jensen explicitly notes that these comparisons cover only one property (penalized logP) and that similar comparisons for other properties are needed before drawing general conclusions.</li>
</ol>
<p>The paper&rsquo;s influence has been substantial: it helped establish the expectation that new molecular generative models should be benchmarked against genetic algorithm baselines, a position subsequently reinforced by Brown et al. (2019) in <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> and by <a href="/notes/chemistry/molecular-design/generation/search-based/genetic-algorithms-molecule-generation-baselines/">Tripp and Hernandez-Lobato (2023)</a>.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Initial mating pool / reference set</td>
          <td><a href="/notes/chemistry/datasets/zinc-22/">ZINC</a> (subset)</td>
          <td>First 1000 molecules</td>
          <td>Same subset used in previous studies (Gomez-Bombarelli et al., Yang et al.)</td>
      </tr>
      <tr>
          <td>Target molecule size</td>
          <td>Derived from Yang et al. results</td>
          <td>20 molecules</td>
          <td>Mean 39.15, SD 3.50 non-H atoms</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>GB-GA</strong>: Population size 20, 50 generations, mutation rates of 1% and 50% tested. Crossovers at ring and non-ring bonds with equal probability. Seven mutation types with specified probabilities. Molecules selected from mating pool based on normalized logP scores.</li>
<li><strong>GB-GM</strong>: Atom-by-atom growth using probabilistic rules derived from ZINC bonding analysis. Ring creation probability 63% (matching ZINC), with 80% variant also tested. Seed molecule: ethane.</li>
<li><strong>MCTS</strong>: Modified from haroldsultan/MCTS Python implementation. Leaf parallelization with max 25 leaf nodes. Exploration factor $1/\sqrt{2}$. Binary reward function (1 if new best, 0 otherwise).</li>
<li><strong>Property calculation</strong>: logP, SA score, and ring penalty all computed via RDKit. Each property normalized to zero mean and unit standard deviation across ZINC.</li>
</ul>
<h3 id="models">Models</h3>
<p>No neural network models are used. The GB-GA and GB-GM are purely algorithmic approaches parameterized by bonding statistics from the ZINC dataset.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>GB-GA (1%)</th>
          <th>Best ML (ChemTS)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Average max $J(\mathbf{m})$</td>
          <td>7.4 +/- 0.9</td>
          <td>5.6 +/- 0.5</td>
          <td>Over 10 runs</td>
      </tr>
      <tr>
          <td>Single best $J(\mathbf{m})$</td>
          <td>8.8</td>
          <td>~8.0 (GCPN)</td>
          <td>GB-GA vs. You et al.</td>
      </tr>
      <tr>
          <td>Evaluations per run</td>
          <td>1000</td>
          <td>~20,000</td>
          <td>20x fewer for GB-GA</td>
      </tr>
      <tr>
          <td>CPU time per run</td>
          <td>30 seconds</td>
          <td>8 hours</td>
          <td>~960x faster</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>All GB-GA and GB-GM experiments were run on a laptop. No GPU required. The GB-GA completes in 30 seconds per run and the GB-GM-MCTS in 90 seconds (1000 traversals) to 9 minutes (5000 traversals).</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/jensengroup/GB-GA/tree/v0.0">GB-GA (v0.0)</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Graph-based genetic algorithm, RDKit dependency only</td>
      </tr>
      <tr>
          <td><a href="https://github.com/jensengroup/GB-GM/tree/v0.0">GB-GM (v0.0)</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Graph-based generative model + MCTS, RDKit dependency only</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Jensen, J. H. (2019). A graph-based genetic algorithm and generative model/Monte Carlo tree search for the exploration of chemical space. <em>Chemical Science</em>, 10(12), 3567-3572. <a href="https://doi.org/10.1039/c8sc05372c">https://doi.org/10.1039/c8sc05372c</a></p>
<p><strong>Publication</strong>: Chemical Science (Royal Society of Chemistry), 2019</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/jensengroup/GB-GA">GB-GA Code (GitHub)</a></li>
<li><a href="https://github.com/jensengroup/GB-GM">GB-GM Code (GitHub)</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{jensen2019graph,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{A graph-based genetic algorithm and generative model/Monte Carlo tree search for the exploration of chemical space}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Jensen, Jan H.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Chemical Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{10}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{3567--3572}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2019}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/c8sc05372c}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Failure Modes in Molecule Generation &amp; Optimization</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/failure-modes-molecule-generation/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/failure-modes-molecule-generation/</guid><description>Renz et al. show trivial models fool distribution-learning metrics and ML scoring functions introduce exploitable biases in goal-directed molecule generation.</description><content:encoded><![CDATA[<h2 id="an-empirical-critique-of-molecular-generation-evaluation">An Empirical Critique of Molecular Generation Evaluation</h2>
<p>This is an <strong>Empirical</strong> paper that critically examines evaluation practices for molecular generative models. Rather than proposing a new generative method, the paper exposes systematic weaknesses in both distribution-learning metrics and goal-directed optimization scoring functions. The primary contributions are: (1) demonstrating that a trivially simple &ldquo;AddCarbon&rdquo; model can achieve near-perfect scores on widely used distribution-learning benchmarks, and (2) introducing an experimental framework with optimization scores and control scores that reveals model-specific and data-specific biases when ML models serve as scoring functions for goal-directed generation.</p>
<h2 id="evaluation-gaps-in-de-novo-molecular-design">Evaluation Gaps in De Novo Molecular Design</h2>
<p>The rapid growth of deep learning methods for molecular generation (RNN-based SMILES generators, VAEs, GANs, graph neural networks) created a need for standardized evaluation. Benchmarking suites like <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> and <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a> introduced metrics for validity, uniqueness, novelty, KL divergence over molecular properties, and <a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">Frechet ChemNet Distance (FCD)</a>. For goal-directed generation, penalized logP became a common optimization target.</p>
<p>However, these metrics leave significant blind spots. Distribution-learning metrics do not detect whether a model merely copies training molecules with minimal modifications. Goal-directed benchmarks often use scoring functions that fail to capture the full requirements of drug discovery (synthetic feasibility, drug-likeness, absence of reactive substructures). When ML models serve as scoring functions, the problem worsens because generated molecules can exploit artifacts of the learned model rather than exhibiting genuinely desirable properties.</p>
<p>At the time of writing, wet-lab validations of generative models remained scarce, with only a handful of studies (Merk et al., Zhavoronkov et al.) demonstrating in vitro activity for generated compounds. The lack of rigorous evaluation left the field unable to distinguish meaningfully innovative methods from those that simply exploit metric weaknesses.</p>
<h2 id="the-copy-problem-and-control-score-framework">The Copy Problem and Control Score Framework</h2>
<p>The paper introduces two key conceptual contributions.</p>
<h3 id="the-addcarbon-model-for-distribution-learning">The AddCarbon Model for Distribution-Learning</h3>
<p>The AddCarbon model is deliberately trivial: it samples a molecule from the training set, inserts a single carbon atom at a random position in its SMILES string, and returns the result if it produces a valid, novel molecule. This model achieves near-perfect scores across most <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> distribution-learning benchmarks:</p>
<table>
  <thead>
      <tr>
          <th>Benchmark</th>
          <th>RS</th>
          <th>LSTM</th>
          <th>GraphMCTS</th>
          <th>AAE</th>
          <th>ORGAN</th>
          <th>VAE</th>
          <th>AddCarbon</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity</td>
          <td>1.000</td>
          <td>0.959</td>
          <td>1.000</td>
          <td>0.822</td>
          <td>0.379</td>
          <td>0.870</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>Uniqueness</td>
          <td>0.997</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>0.841</td>
          <td>0.999</td>
          <td>0.999</td>
      </tr>
      <tr>
          <td>Novelty</td>
          <td>0.000</td>
          <td>0.912</td>
          <td>0.994</td>
          <td>0.998</td>
          <td>0.687</td>
          <td>0.974</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>KL divergence</td>
          <td>0.998</td>
          <td>0.991</td>
          <td>0.522</td>
          <td>0.886</td>
          <td>0.267</td>
          <td>0.982</td>
          <td>0.982</td>
      </tr>
      <tr>
          <td>FCD</td>
          <td>0.929</td>
          <td>0.913</td>
          <td>0.015</td>
          <td>0.529</td>
          <td>0.000</td>
          <td>0.863</td>
          <td>0.871</td>
      </tr>
  </tbody>
</table>
<p>The AddCarbon model beats all baselines except the LSTM on the FCD metric, despite being practically useless. This exposes what the authors call the &ldquo;copy problem&rdquo;: current metrics check only for exact matches to training molecules, so minimal edits evade novelty detection. The authors argue that likelihood-based evaluation on hold-out test sets, analogous to standard practice in NLP, would provide a more comprehensive metric.</p>
<h3 id="control-scores-for-goal-directed-generation">Control Scores for Goal-Directed Generation</h3>
<p>For goal-directed generation, the authors introduce a three-score experimental design:</p>
<ul>
<li><strong>Optimization Score (OS)</strong>: Output of a classifier trained on data split 1, used to guide the molecular optimizer.</li>
<li><strong>Model Control Score (MCS)</strong>: Output of a second classifier trained on split 1 with a different random seed. Divergence between OS and MCS quantifies model-specific biases.</li>
<li><strong>Data Control Score (DCS)</strong>: Output of a classifier trained on data split 2. Divergence between OS and DCS quantifies data-specific biases.</li>
</ul>
<p>This mirrors the training/test split paradigm in supervised learning. If a generator truly produces molecules with the desired bioactivity, the control scores should track the optimization score. Divergence between them indicates the optimizer is exploiting artifacts of the specific model or training data rather than learning generalizable chemical properties.</p>
<h2 id="experimental-setup-three-targets-three-generators">Experimental Setup: Three Targets, Three Generators</h2>
<h3 id="targets-and-data">Targets and Data</h3>
<p>The authors selected three biological targets from ChEMBL: <a href="https://en.wikipedia.org/wiki/Janus_kinase_2">Janus kinase 2</a> (JAK2), <a href="https://en.wikipedia.org/wiki/Epidermal_growth_factor_receptor">epidermal growth factor receptor</a> (EGFR), and <a href="https://en.wikipedia.org/wiki/Dopamine_receptor_D2">dopamine receptor D2</a> (DRD2). For each target, the data was split into two halves (split 1 and split 2) with balanced active/inactive ratios. Random forest classifiers using binary folded ECFP4 fingerprints (radius 2, size 1024) were trained to produce three scoring functions per target: the OS and MCS on split 1 (different random seeds), and the DCS on split 2.</p>
<h3 id="generators">Generators</h3>
<p>Three molecular generators were evaluated:</p>
<ol>
<li><strong>Graph-based Genetic Algorithm (GA)</strong>: Iteratively applies random mutations and crossovers to a population of molecules, retaining the best in each generation. One of the top performers in GuacaMol.</li>
<li><strong>SMILES-LSTM</strong>: An autoregressive model that generates SMILES character by character, optimized via hill climbing (iteratively sampling, keeping top molecules, fine-tuning). Also a top GuacaMol performer.</li>
<li><strong><a href="https://en.wikipedia.org/wiki/Particle_swarm_optimization">Particle Swarm Optimization</a> (PS)</strong>: Optimizes molecules in the continuous latent space of a SMILES-based sequence-to-sequence model.</li>
</ol>
<p>Each optimizer was run 10 times per target dataset.</p>
<h2 id="score-divergence-and-exploitable-biases">Score Divergence and Exploitable Biases</h2>
<h3 id="optimization-vs-control-score-divergence">Optimization vs. Control Score Divergence</h3>
<p>Across all three targets and all three generators, the OS consistently outpaced both control scores during optimization. The DCS sometimes stagnated or even decreased while the OS continued to climb. This divergence demonstrates that the generators exploit biases in the scoring function rather than discovering genuinely active compounds.</p>
<p>The MCS also diverged from the OS despite being trained on exactly the same data, confirming model-specific biases: the optimization exploits features unique to the particular random forest instance. The larger gap between OS and DCS (compared to OS and MCS) indicates that data-specific biases contribute more to the divergence than model-specific biases.</p>
<h3 id="chemical-space-migration">Chemical Space Migration</h3>
<p>Optimized molecules migrated toward the region of split 1 actives (used to train the OS), as shown by t-SNE embeddings and nearest-neighbor Tanimoto similarity analysis. Optimized molecules had more similar neighbors in split 1 than in split 2, confirming data-specific bias. By the end of optimization, generated molecules occupied different regions of chemical space than known actives when measured by logP and molecular weight, with compounds from the same optimization run forming distinct clusters.</p>
<h3 id="quality-of-generated-molecules">Quality of Generated Molecules</h3>
<p>High-scoring generated molecules frequently contained problematic substructures: reactive dienes, nitrogen-fluorine bonds, long heteroatom chains that are synthetically infeasible, and highly uncommon functional groups. The LSTM optimizer showed a bias toward high molecular weight, low diversity, and high logP values. These molecules would be rejected by medicinal chemists despite their high optimization scores.</p>
<h3 id="key-takeaways">Key Takeaways</h3>
<p>The authors emphasize several practical implications:</p>
<ol>
<li><strong>Early stopping</strong>: Control scores can indicate when further optimization is exploiting biases rather than finding better molecules. Optimization should stop when control scores plateau.</li>
<li><strong>Scoring function iteration</strong>: In practice, generative models are &ldquo;highly adept at exploiting&rdquo; incomplete scoring functions, necessitating several iterations of generation and scoring function refinement.</li>
<li><strong>Synthetic accessibility</strong>: Even high-scoring molecules are useless if they cannot be synthesized. The authors consider this a major challenge for practical adoption.</li>
<li><strong>Likelihood-based evaluation</strong>: For distribution-learning, the authors recommend reporting test-set likelihoods for likelihood-based models, following standard NLP practice.</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Bioactivity data</td>
          <td>ChEMBL (JAK2, EGFR, DRD2)</td>
          <td>See Table S1</td>
          <td>Binary classification tasks, split 50/50</td>
      </tr>
      <tr>
          <td>Distribution-learning</td>
          <td>GuacaMol training set</td>
          <td>Subset of ChEMBL</td>
          <td>Used as starting population for GA and PS</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Scoring function</strong>: Random forest classifier (scikit-learn) on binary ECFP4 fingerprints (size 1024, radius 2, RDKit)</li>
<li><strong>GA</strong>: Graph-based genetic algorithm from Jensen (2019)</li>
<li><strong>LSTM</strong>: SMILES-LSTM with hill climbing, pretrained model from GuacaMol</li>
<li><strong>PS</strong>: Particle swarm optimization in latent space of a sequence-to-sequence model (Winter et al. 2019)</li>
<li>Each optimizer run 10 times per target</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Optimization Score (OS)</td>
          <td>RF classifier on split 1</td>
          <td>Guides optimization</td>
      </tr>
      <tr>
          <td>Model Control Score (MCS)</td>
          <td>RF on split 1, different seed</td>
          <td>Detects model-specific bias</td>
      </tr>
      <tr>
          <td>Data Control Score (DCS)</td>
          <td>RF on split 2</td>
          <td>Detects data-specific bias</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> metrics</td>
          <td>Validity, uniqueness, novelty, KL div, FCD</td>
          <td>For distribution-learning</td>
      </tr>
  </tbody>
</table>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/ml-jku/mgenerators-failure-modes">ml-jku/mgenerators-failure-modes</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Data, code, and results</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{renz2019failure,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{On failure modes in molecule generation and optimization}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Renz, Philipp and Van Rompaey, Dries and Wegner, J{\&#34;o}rg Kurt and Hochreiter, Sepp and Klambauer, G{\&#34;u}nter}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Drug Discovery Today: Technologies}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{32-33}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{55--63}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2019}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Elsevier}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1016/j.ddtec.2020.09.003}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Renz, P., Van Rompaey, D., Wegner, J. K., Hochreiter, S., &amp; Klambauer, G. (2019). On failure modes in molecule generation and optimization. <em>Drug Discovery Today: Technologies</em>, 32-33, 55-63. <a href="https://doi.org/10.1016/j.ddtec.2020.09.003">https://doi.org/10.1016/j.ddtec.2020.09.003</a></p>
<p><strong>Publication</strong>: Drug Discovery Today: Technologies, Volume 32-33, 2019</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/ml-jku/mgenerators-failure-modes">Code and data (GitHub)</a></li>
</ul>
]]></content:encoded></item><item><title>DOCKSTRING: Docking-Based Benchmarks for Drug Design</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/dockstring-docking-benchmarks-ligand-design/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/dockstring-docking-benchmarks-ligand-design/</guid><description>DOCKSTRING provides an open-source Python docking package, 15M+ score dataset across 58 targets, and benchmark tasks for ML-driven drug design.</description><content:encoded><![CDATA[<h2 id="a-three-part-resource-for-docking-based-ml-benchmarks">A Three-Part Resource for Docking-Based ML Benchmarks</h2>
<p>DOCKSTRING is a <strong>Resource</strong> paper that delivers three integrated components for benchmarking machine learning models in drug discovery using molecular docking. The primary contributions are: (1) an open-source Python package wrapping <a href="https://en.wikipedia.org/wiki/AutoDock">AutoDock Vina</a> for deterministic docking from <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings, (2) a dataset of over 15 million docking scores and poses covering 260,000+ molecules docked against 58 medically relevant protein targets, and (3) a suite of benchmark tasks spanning regression, <a href="https://en.wikipedia.org/wiki/Virtual_screening">virtual screening</a>, and de novo molecular design. The paper additionally provides baseline results across classical and deep learning methods.</p>
<h2 id="why-existing-molecular-benchmarks-fall-short">Why Existing Molecular Benchmarks Fall Short</h2>
<p>ML methods for drug discovery are frequently evaluated using simple physicochemical properties such as penalized logP or QED (quantitative estimate of druglikeness). These properties are computationally cheap and easy to optimize, but they do not depend on the interaction between a candidate compound and a protein target. As a result, strong performance on logP or QED benchmarks does not necessarily translate to strong performance on real drug design tasks.</p>
<p><a href="https://en.wikipedia.org/wiki/Docking_(molecular)">Molecular docking</a> offers a more realistic evaluation objective because docking scores depend on the 3D structure of the ligand-target complex. Docking is routinely used by medicinal chemists to estimate binding affinities during hit discovery and lead optimization. Several prior efforts attempted to bring docking into ML benchmarking, but each had limitations:</p>
<ul>
<li><strong>VirtualFlow and DockStream</strong> require manually prepared target files and domain expertise.</li>
<li><strong>TDC and Cieplinski et al.</strong> provide SMILES-to-score wrappers but lack proper ligand protonation and randomness control, and cover very few targets (one and four, respectively).</li>
<li><strong>DUD-E</strong> is easily overfit by ML models that memorize actives vs. decoys.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> and <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a></strong> rely on physicochemical properties or similarity functions that miss 3D structural subtleties.</li>
<li><strong><a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a></strong> compiles experimental datasets but does not support on-the-fly label computation needed for transfer learning or de novo design.</li>
</ul>
<p>DOCKSTRING addresses all of these gaps: it standardizes the docking procedure, automates ligand and target preparation, controls randomness for reproducibility, and provides a large, diverse target set.</p>
<h2 id="core-innovation-standardized-end-to-end-docking-pipeline">Core Innovation: Standardized End-to-End Docking Pipeline</h2>
<p>The key innovation is a fully automated, deterministic docking pipeline that produces reproducible scores from a SMILES string in four lines of Python code. The pipeline consists of three stages:</p>
<p><strong>Target Preparation.</strong> 57 of the 58 protein targets originate from the Directory of Useful Decoys Enhanced (DUD-E). PDB files are standardized with <a href="https://en.wikipedia.org/wiki/Open_Babel">Open Babel</a>, polar hydrogens are added, and conversion to PDBQT format is performed with AutoDock Tools. Search boxes are derived from crystallographic ligands with 12.5 A padding and a minimum side length of 30 A. The 58th target (DRD2, <a href="https://en.wikipedia.org/wiki/Dopamine_receptor_D2">dopamine receptor D2</a>) was prepared separately following the same protocol.</p>
<p><strong>Ligand Preparation.</strong> Ligands are protonated at pH 7.4 with Open Babel, embedded into 3D conformations using the ETKDG algorithm in RDKit, refined with the <a href="https://en.wikipedia.org/wiki/Merck_molecular_force_field">MMFF94 force field</a>, and assigned Gasteiger partial charges. Stereochemistry of determined stereocenters is maintained, while undetermined stereocenters are assigned randomly but consistently across runs.</p>
<p><strong>Docking.</strong> AutoDock Vina runs with default exhaustiveness (8), up to 9 binding modes, and an energy range of 3 kcal/mol. The authors verified that fixing the random seed yields docking score variance of less than 0.1 kcal/mol across runs, making the pipeline fully deterministic.</p>
<p>The three de novo design objective functions incorporate a QED penalty to enforce druglikeness:</p>
<p>$$
f_{\text{F2}}(l) = s(l, \text{F2}) + 10(1 - \text{QED}(l))
$$</p>
<p>$$
f_{\text{PPAR}}(l) = \max_{t \in \text{PPAR}} s(l, t) + 10(1 - \text{QED}(l))
$$</p>
<p>$$
f_{\text{JAK2}}(l) = s(l, \text{JAK2}) - \min(s(l, \text{LCK}), -8.1) + 10(1 - \text{QED}(l))
$$</p>
<p>The F2 task optimizes binding to a single protease. The Promiscuous <a href="https://en.wikipedia.org/wiki/Peroxisome_proliferator-activated_receptor">PPAR</a> task requires strong binding to three nuclear receptors simultaneously. The Selective <a href="https://en.wikipedia.org/wiki/Janus_kinase_2">JAK2</a> task is adversarial, requiring strong JAK2 binding while avoiding <a href="https://en.wikipedia.org/wiki/Tyrosin-protein_kinase_Lck">LCK</a> binding (two kinases with a score correlation of 0.80).</p>
<h2 id="experimental-setup-regression-virtual-screening-and-de-novo-design">Experimental Setup: Regression, Virtual Screening, and De Novo Design</h2>
<h3 id="dataset-construction">Dataset Construction</h3>
<p>The dataset combines molecules from ExCAPE-DB (which curates PubChem and ChEMBL bioactivity assays). The authors selected all molecules with active labels against targets having at least 1,000 experimental actives, plus 150,000 inactive-only molecules. After discarding 1.8% of molecules that failed ligand preparation, the final dataset contains 260,155 compounds docked against 58 targets, producing over 15 million docking scores and poses. The dataset required over 500,000 CPU hours to generate.</p>
<p>Cluster analysis using <a href="https://en.wikipedia.org/wiki/DBSCAN">DBSCAN</a> (<a href="https://en.wikipedia.org/wiki/Jaccard_index">Jaccard distance</a> threshold of 0.25 on RDKit fingerprints) found 52,000 clusters, and Bemis-Murcko scaffold decomposition identified 102,000 scaffolds, confirming high molecular diversity. Train/test splitting follows cluster labels to prevent data leakage.</p>
<h3 id="regression-baselines">Regression Baselines</h3>
<p>Five targets of varying difficulty were selected: <a href="https://en.wikipedia.org/wiki/Poly_(ADP-ribose)_polymerase">PARP1</a> (easy), F2 (easy-medium), KIT (medium), ESR2 (hard), and PGR (hard). Baselines include Ridge, Lasso, XGBoost, exact GP, sparse GP, MPNN, and Attentive FP.</p>
<table>
  <thead>
      <tr>
          <th>Target</th>
          <th>Ridge</th>
          <th>Lasso</th>
          <th>XGBoost</th>
          <th>GP (exact)</th>
          <th>GP (sparse)</th>
          <th>MPNN</th>
          <th>Attentive FP</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>logP</td>
          <td>0.640</td>
          <td>0.640</td>
          <td>0.734</td>
          <td>0.707</td>
          <td>0.716</td>
          <td>0.953</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>QED</td>
          <td>0.519</td>
          <td>0.483</td>
          <td>0.660</td>
          <td>0.640</td>
          <td>0.598</td>
          <td>0.901</td>
          <td>0.981</td>
      </tr>
      <tr>
          <td>ESR2</td>
          <td>0.421</td>
          <td>0.416</td>
          <td>0.497</td>
          <td>0.441</td>
          <td>0.508</td>
          <td>0.506</td>
          <td>0.627</td>
      </tr>
      <tr>
          <td>F2</td>
          <td>0.672</td>
          <td>0.663</td>
          <td>0.688</td>
          <td>0.705</td>
          <td>0.744</td>
          <td>0.798</td>
          <td>0.880</td>
      </tr>
      <tr>
          <td>KIT</td>
          <td>0.604</td>
          <td>0.594</td>
          <td>0.674</td>
          <td>0.637</td>
          <td>0.684</td>
          <td>0.755</td>
          <td>0.806</td>
      </tr>
      <tr>
          <td>PARP1</td>
          <td>0.706</td>
          <td>0.700</td>
          <td>0.723</td>
          <td>0.743</td>
          <td>0.772</td>
          <td>0.815</td>
          <td>0.910</td>
      </tr>
      <tr>
          <td>PGR</td>
          <td>0.242</td>
          <td>0.245</td>
          <td>0.345</td>
          <td>0.291</td>
          <td>0.387</td>
          <td>0.324</td>
          <td>0.678</td>
      </tr>
  </tbody>
</table>
<p>Values are mean $R^2$ over three runs. Attentive FP achieves the best performance on every target but remains well below perfect prediction on the harder targets, confirming that docking score regression is a meaningful benchmark.</p>
<h3 id="virtual-screening-baselines">Virtual Screening Baselines</h3>
<p>Models trained on PARP1, KIT, and PGR docking scores rank all molecules in <a href="/notes/chemistry/datasets/zinc-22/">ZINC20</a> (~1 billion compounds). The top 5,000 predictions are docked, and the enrichment factor (EF) is computed relative to a 0.1 percentile activity threshold.</p>
<table>
  <thead>
      <tr>
          <th>Target</th>
          <th>Threshold</th>
          <th>FSS</th>
          <th>Ridge</th>
          <th>Attentive FP</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>KIT</td>
          <td>-10.7</td>
          <td>239.2</td>
          <td>451.6</td>
          <td>766.5</td>
      </tr>
      <tr>
          <td>PARP1</td>
          <td>-12.1</td>
          <td>313.1</td>
          <td>325.9</td>
          <td>472.2</td>
      </tr>
      <tr>
          <td>PGR</td>
          <td>-10.1</td>
          <td>161.4</td>
          <td>120.5</td>
          <td>461.3</td>
      </tr>
  </tbody>
</table>
<p>The maximum possible EF is 1,000. Attentive FP substantially outperforms fingerprint similarity search (FSS) and Ridge regression across all targets.</p>
<h3 id="de-novo-design-baselines">De Novo Design Baselines</h3>
<p>Four optimization methods were tested: <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> GA, <a href="/notes/chemistry/molecular-design/generation/search-based/graph-based-genetic-algorithm-chemical-space/">Graph GA</a>, GP-BO with UCB acquisition ($\beta = 10$), and GP-BO with expected improvement (EI), each with a budget of 5,000 objective function evaluations. Without QED penalties, all methods easily surpass the best training set molecules but produce large, lipophilic, undrug-like compounds. With QED penalties, the tasks become substantially harder: GP-BO with EI is the only method that finds 25 molecules better than the training set across all three tasks.</p>
<p>The Selective JAK2 task proved hardest due to the high correlation between JAK2 and LCK scores. Pose analysis of the top de novo molecule revealed a dual binding mode: type V inhibitor behavior in JAK2 (binding distant N- and C-terminal lobe regions) and type I behavior in LCK (hinge-binding), suggesting a plausible selectivity mechanism.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<p><strong>Key findings:</strong></p>
<ol>
<li>Docking scores are substantially harder to predict than logP or QED, making them more suitable for benchmarking high-performing ML models. Graph neural networks (Attentive FP) achieve near-perfect $R^2$ on logP but only 0.63-0.91 on docking targets.</li>
<li>In-distribution regression difficulty does not necessarily predict out-of-distribution virtual screening difficulty. PARP1 is easiest for regression, but KIT is easiest for virtual screening.</li>
<li>Adding a QED penalty to de novo design objectives transforms trivially solvable tasks into meaningful benchmarks. The adversarial Selective JAK2 objective, which exploits correlated docking scores, may be an effective way to avoid docking score biases toward large and lipophilic molecules.</li>
<li>Docking scores from related protein targets are highly correlated, supporting the biological meaningfulness of the dataset and enabling multiobjective and transfer learning tasks.</li>
</ol>
<p><strong>Limitations acknowledged by the authors:</strong></p>
<ul>
<li>Docking scores are approximate heuristics. They use static binding sites and force fields with limited calibration for certain metal ions. DOCKSTRING benchmarks should not substitute for rational drug design and experimental validation.</li>
<li>The pipeline relies on AutoDock Vina specifically; other docking programs may produce different rankings.</li>
<li>Top de novo molecules for F2 and Promiscuous PPAR contain conjugated ring structures uncommon in successful drugs.</li>
<li>Platform support is primarily Linux, with noted scoring inconsistencies on macOS.</li>
</ul>
<p><strong>Future directions</strong> mentioned include multiobjective tasks (transfer learning, few-shot learning), improved objective functions for better pharmacokinetic properties and synthetic feasibility, and multifidelity optimization tasks combining docking with more expensive computational methods.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Ligand source</td>
          <td>ExCAPE-DB (PubChem + ChEMBL)</td>
          <td>260,155 molecules</td>
          <td>Actives against 58 targets + 150K inactive-only</td>
      </tr>
      <tr>
          <td>Docking scores</td>
          <td>DOCKSTRING dataset</td>
          <td>15M+ scores and poses</td>
          <td>Full matrix across all molecule-target pairs</td>
      </tr>
      <tr>
          <td>Virtual screening library</td>
          <td>ZINC20</td>
          <td>~1 billion molecules</td>
          <td>Used for out-of-distribution evaluation</td>
      </tr>
      <tr>
          <td>Target structures</td>
          <td>DUD-E + PDB 6CM4 (DRD2)</td>
          <td>58 targets</td>
          <td>Kinases (22), enzymes (12), nuclear receptors (9), proteases (7), GPCRs (5), cytochromes (2), chaperone (1)</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Docking engine</strong>: AutoDock Vina with default exhaustiveness (8), up to 9 binding modes, energy range of 3 kcal/mol</li>
<li><strong>Ligand preparation</strong>: Open Babel (protonation at pH 7.4), RDKit ETKDG (3D embedding), MMFF94 (force field refinement), Gasteiger charges</li>
<li><strong>Regression models</strong>: Ridge, Lasso, XGBoost (hyperparameters via 20-configuration random search with 5-fold CV), exact GP and sparse GP (Tanimoto kernel on fingerprints), MPNN, Attentive FP (DeepChem defaults, 10 epochs)</li>
<li><strong>Optimization</strong>: Graph GA (population 250, offspring 25, mutation rate 0.01), SELFIES GA (same population/offspring settings), GP-BO with UCB ($\beta = 10$) or EI (batch size 5, 1000 offspring, 25 generations per iteration)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Setting</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>$R^2$ (coefficient of determination)</td>
          <td>Regression</td>
          <td>Cluster-split train/test</td>
      </tr>
      <tr>
          <td>EF (enrichment factor)</td>
          <td>Virtual screening</td>
          <td>Top 5,000 from ZINC20, 0.1 percentile threshold</td>
      </tr>
      <tr>
          <td>Objective value trajectory</td>
          <td>De novo design</td>
          <td>5,000 function evaluation budget</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The dataset required over 500,000 CPU hours to compute, using the University of Cambridge Research Computing Service (EPSRC and DiRAC funded). Per-target docking takes approximately 15 seconds on 8 CPUs.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/dockstring/dockstring">DOCKSTRING Python package</a></td>
          <td>Code</td>
          <td>Apache 2.0</td>
          <td>Wraps AutoDock Vina; available via conda-forge and PyPI</td>
      </tr>
      <tr>
          <td><a href="https://dockstring.github.io">DOCKSTRING dataset</a></td>
          <td>Dataset</td>
          <td>Apache 2.0</td>
          <td>15M+ docking scores and poses for 260K molecules x 58 targets</td>
      </tr>
      <tr>
          <td><a href="https://github.com/dockstring/dockstring">Benchmark baselines</a></td>
          <td>Code</td>
          <td>Apache 2.0</td>
          <td>Regression, virtual screening, and de novo design baseline implementations</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: García-Ortegón, M., Simm, G. N. C., Tripp, A. J., Hernández-Lobato, J. M., Bender, A., &amp; Bacallado, S. (2022). DOCKSTRING: Easy Molecular Docking Yields Better Benchmarks for Ligand Design. <em>Journal of Chemical Information and Modeling</em>, 62(15), 3486-3502. <a href="https://doi.org/10.1021/acs.jcim.1c01334">https://doi.org/10.1021/acs.jcim.1c01334</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling, 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://dockstring.github.io">DOCKSTRING Project Page</a></li>
<li><a href="https://github.com/dockstring/dockstring">GitHub Repository</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{garciaortegon2022dockstring,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{{DOCKSTRING}: Easy Molecular Docking Yields Better Benchmarks for Ligand Design}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Garc{\&#39;\i}a-Orteg{\&#39;o}n, Miguel and Simm, Gregor N. C. and Tripp, Austin J. and Hern{\&#39;a}ndez-Lobato, Jos{\&#39;e} Miguel and Bender, Andreas and Bacallado, Sergio}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{62}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{3486--3502}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.1c01334}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Neural Scaling of Deep Chemical Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/neural-scaling-of-deep-chemical-models/</link><pubDate>Tue, 24 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/neural-scaling-of-deep-chemical-models/</guid><description>Frey et al. discover neural scaling laws for chemical LLMs and GNN interatomic potentials, showing power-law loss improvements with scale.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>discovery paper</strong> that identifies empirical neural scaling laws in two distinct domains of chemical deep learning: large language models (LLMs) for generative chemistry and graph neural networks (GNNs) for machine-learned interatomic potentials. The paper also introduces training performance estimation (TPE) as a practical tool for accelerating hyperparameter optimization in these domains.</p>
<h2 id="why-scaling-laws-matter-for-chemistry">Why scaling laws matter for chemistry</h2>
<p>Neural scaling laws, first characterized for NLP models by Kaplan et al. (2020), describe how model loss decreases as a power law with increasing model size, dataset size, or compute:</p>
<p>$$
L(R) = \alpha R^{-\beta}
$$</p>
<p>where $\alpha$ is a coefficient, $\beta$ is the scaling exponent, and $R$ is the resource being scaled (parameters, data, or compute). These relationships have guided resource allocation decisions in NLP and computer vision, but their applicability to scientific deep learning was unknown.</p>
<p>Chemical deep learning differs from standard NLP and vision tasks in several key ways. Physics-based priors (like symmetry constraints) may reduce the need for massive scale. The heterogeneity of chemical space and molecular tasks makes general pre-training more challenging. There are no established default architectures, datasets, or training recipes at large scale for chemistry.</p>
<p>This paper asks: do the same scaling behaviors hold for chemical models, and how do physical priors affect them?</p>
<h2 id="training-performance-estimation-for-efficient-scaling">Training performance estimation for efficient scaling</h2>
<p>Before running expensive scaling experiments, the authors needed a way to efficiently select hyperparameters. They introduced TPE, a generalization of training speed estimation (TSE) to new domains. TSE computes the cumulative training loss over the first $T$ epochs:</p>
<p>$$
\text{TSE} = \sum_{t=1}^{T} \left( \frac{1}{B} \sum_{i=1}^{B} \mathcal{L}\left(f_{\theta(t,i)}(\mathbf{X}_i), \mathbf{y}_i\right) \right)
$$</p>
<p>where $B$ is the number of training steps per epoch, $\mathcal{L}$ is the loss function, and $f_{\theta(t,i)}$ is the network at epoch $t$ and mini-batch $i$. A linear regression then predicts converged loss from early-training TSE:</p>
<p>$$
L = m \times \text{TSE} + b
$$</p>
<p>Using only 20% of the total training budget, TPE achieves $R^2 = 0.98$ and Spearman&rsquo;s $\rho = 1.0$ for ChemGPT on the MOSES dataset. For GNNs, it achieves $R^2 \geq 0.86$ and $\rho \geq 0.92$ across SchNet, PaiNN, and SpookyNet. This enables discarding suboptimal configurations early, saving up to 90% of compute.</p>
<h2 id="chemgpt-scaling-chemical-language-models">ChemGPT: scaling chemical language models</h2>
<p>ChemGPT is a GPT-3-style autoregressive transformer for molecular generation. It uses GPT-Neo as its backbone with a SELFIES tokenizer, factorizing the probability of a molecular sequence as:</p>
<p>$$
p(x) = \prod_{i=1}^{n} p\left(s_i \mid s_1, \dots, s_{i-1}\right)
$$</p>
<p>The authors trained ChemGPT models ranging from ~78K to over 1 billion non-embedding parameters on subsets of PubChem10M (up to ~10 million molecules, or ~300 million tokens). Key findings from the scaling experiments:</p>
<ul>
<li><strong>Pre-training loss monotonically improves</strong> with increasing dataset size up to nearly 10 million molecules, with no saturation observed.</li>
<li><strong>For a fixed data budget</strong>, increasing model size provides monotonic improvements until models reach ~1 billion parameters.</li>
<li><strong>The scaling exponent</strong> $\beta = 0.17 \pm 0.01$ for the largest dataset (after excluding the three largest models from the power-law fit), and $\beta = 0.30 \pm 0.01$ for the next largest dataset.</li>
<li><strong>Resolution-limited regimes</strong> appear where the power-law behavior breaks down, indicating either insufficient data for a given model size or vice versa. These regimes shift depending on the data budget.</li>
</ul>
<p>An interesting observation: for small datasets, large models ($10^7$ parameters and above) still provide notable loss improvements, suggesting that scaling up model size helps even when data is limited.</p>
<h2 id="neural-force-field-scaling-with-gnns">Neural force field scaling with GNNs</h2>
<p>For tasks requiring three-dimensional molecular geometry, the authors studied GNN-based neural force fields (NFFs). These models predict energies $\hat{E} = f_\theta(X)$ and derive forces by differentiation:</p>
<p>$$
\hat{F}_{ij} = -\frac{\partial \hat{E}}{\partial r_{ij}}
$$</p>
<p>Training uses an L1 loss over energies and forces:</p>
<p>$$
\mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} \left[ \alpha_E | E_i - \hat{E}_i | + \alpha_F | \mathbf{F}_i - \hat{\mathbf{F}}_i | \right]
$$</p>
<p>Four NFF architectures were studied, spanning a range of physical priors:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Type</th>
          <th>Key Characteristic</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SchNet</td>
          <td>E(3) invariant</td>
          <td>Continuous filter convolutions</td>
      </tr>
      <tr>
          <td>PaiNN</td>
          <td>E(3) equivariant</td>
          <td>Equivariant message passing</td>
      </tr>
      <tr>
          <td>Allegro</td>
          <td>E(3) equivariant</td>
          <td>Local, learned many-body functions</td>
      </tr>
      <tr>
          <td>SpookyNet</td>
          <td>E(3) equivariant</td>
          <td>Non-local interactions, empirical corrections</td>
      </tr>
  </tbody>
</table>
<p>Model capacity is parameterized as $c = d \times w$ (depth times width). Models were trained on subsets of the ANI-1x dataset (up to 100,000 geometries, corresponding to ~4.5 million force labels).</p>
<p>Key GNN scaling findings:</p>
<ul>
<li><strong>PaiNN shows monotonic loss improvement</strong> with increasing dataset size and strong correlation between converged loss and model capacity (Spearman&rsquo;s $\rho \geq 0.88$).</li>
<li><strong>Equivariant GNNs (PaiNN, Allegro) show better scaling efficiency</strong> than invariant GNNs (SchNet), with larger $\beta$ values.</li>
<li><strong>The scaling exponent for equivariant GNNs</strong> is $\beta = 0.26$, indicating that physics-based equivariance priors provide greater sample efficiency that persists to much larger and more chemically diverse datasets than previously studied.</li>
<li><strong>A transition at $10^4$ datapoints</strong> shows nearly perfect rank correlation between model capacity and converged loss ($\rho \geq 0.93$), suggesting this may be a threshold where models move from memorization to generalization.</li>
</ul>
<h2 id="results-and-practical-implications">Results and practical implications</h2>
<p>The scaling results provide actionable guidance for resource allocation:</p>
<ul>
<li>For <strong>chemical LLMs with large data budgets</strong>, the greatest loss improvements come from scaling up small models (around $10^5$ parameters).</li>
<li>For <strong>small data budgets</strong>, rapid improvements come from scaling medium-sized models ($10^7$ parameters).</li>
<li>For <strong>NFFs</strong>, low-capacity models show diminishing returns with more data, while high-capacity models show rapid improvements with increasing dataset size.</li>
<li><strong>Neither model type has saturated</strong> with respect to model size, dataset size, or compute, suggesting substantial room for improvement with further scaling.</li>
</ul>
<p>The 300-million-parameter ChemGPT trained on 300 million tokens and the PaiNN model with capacity ~1,000 trained on $10^5$ frames achieved the minimum losses in their respective scaling plots, providing concrete targets for practitioners.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Data:</strong></p>
<ul>
<li>PubChem10M (10M SMILES strings, via DeepChem)</li>
<li>MOSES (2M molecules, for TPE validation)</li>
<li>ANI-1x (5M DFT calculations, via Figshare)</li>
<li>Revised MD-17 (10 small organic molecules, 10,000 frames for TPE)</li>
</ul>
<p><strong>Models:</strong></p>
<ul>
<li>ChemGPT: GPT-Neo backbone, 24 layers, widths from 16 to 2,048, sizes from ~78K to ~1.2B non-embedding parameters</li>
<li>SchNet, PaiNN, Allegro, SpookyNet: widths of 16, 64, 256; depths of 2, 3, 4; 5 Angstrom cutoff</li>
</ul>
<p><strong>Training:</strong></p>
<ul>
<li>ChemGPT: AdamW optimizer, learning rate $2 \times 10^{-5}$, batch size 8 per GPU, 10 epochs, cross-entropy loss</li>
<li>GNNs: Adam optimizer, learning rate scheduler (halved after 30 epochs without improvement), early stopping after 50 stagnant epochs, max 1,000 epochs, L1 loss (force-only training)</li>
</ul>
<p><strong>Hardware:</strong></p>
<ul>
<li>NVIDIA Volta V100 GPUs (32 GB), 2 GPUs per node</li>
<li>PyTorch with distributed data parallel (DDP), PyTorch Lightning, LitMatter</li>
</ul>
<p><strong>Code:</strong> <a href="https://github.com/ncfrey/litmatter">LitMatter repository</a></p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation:</strong> Frey, N.C., Soklaski, R., Axelrod, S. et al. Neural scaling of deep chemical models. <em>Nat Mach Intell</em> <strong>5</strong>, 1297-1305 (2023).</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{frey2023neural,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Neural scaling of deep chemical models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Frey, Nathan C. and Soklaski, Ryan and Axelrod, Simon and Samsi, Siddharth and G{\&#39;o}mez-Bombarelli, Rafael and Coley, Connor W. and Gadepally, Vijay}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Machine Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{11}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1297--1305}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s42256-023-00740-3}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Tartarus: Realistic Inverse Molecular Design Benchmarks</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/tartarus-inverse-molecular-design/</link><pubDate>Mon, 23 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/tartarus-inverse-molecular-design/</guid><description>Tartarus provides physics-based benchmark tasks for inverse molecular design spanning materials, drugs, and reactions with algorithm-domain dependencies.</description><content:encoded><![CDATA[<h2 id="a-resource-for-realistic-molecular-design-evaluation">A Resource for Realistic Molecular Design Evaluation</h2>
<p>This is a <strong>Resource</strong> paper. Its primary contribution is Tartarus, a modular benchmarking platform for inverse molecular design that provides physically grounded evaluation tasks across four application domains: organic photovoltaics, organic emitters, protein ligands, and chemical reaction substrates. Each task pairs a curated reference dataset with a computational simulation workflow that evaluates proposed molecular structures using established methods from computational chemistry (<a href="https://en.wikipedia.org/wiki/Force_field_(chemistry)">force fields</a>, semi-empirical quantum chemistry, <a href="https://en.wikipedia.org/wiki/Density_functional_theory">density functional theory</a>, and <a href="https://en.wikipedia.org/wiki/Docking_(molecular)">molecular docking</a>).</p>
<h2 id="the-problem-with-existing-molecular-design-benchmarks">The Problem with Existing Molecular Design Benchmarks</h2>
<p>Inverse molecular design, the challenge of crafting molecules with specific optimal properties, is central to drug, catalyst, and materials discovery. Many algorithms have been proposed for this task, but the benchmarks used to evaluate them have significant limitations:</p>
<ul>
<li><strong>Penalized logP</strong>, one of the most common benchmarks, depends heavily on molecule size and chain composition, limiting its informativeness.</li>
<li><strong>QED maximization</strong> has reached saturation, with numerous models achieving near-perfect scores.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a></strong> often yields near-perfect scores across models, obscuring meaningful performance differences. <a href="/notes/chemistry/molecular-design/generation/evaluation/pmo-sample-efficient-molecular-optimization/">Gao et al. (2022)</a> traced this to unlimited property evaluations, with imposed limits revealing much larger disparities.</li>
<li><strong>MOSES</strong> evaluates distribution-matching ability, but the emergence of <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> and simple algorithms has made these tasks relatively straightforward.</li>
<li><strong>Molecular docking</strong> benchmarks are gaining popularity, but tend to favor reactive or unstable molecules and typically cover only drug design.</li>
</ul>
<p>These benchmarks share a common weakness: they rely on cheap, approximate property estimators (often QSAR models or simple heuristics) rather than physics-based simulations. This makes them poor proxies for real molecular design campaigns, where properties must be validated through computational or experimental workflows. Tartarus addresses this by providing benchmark tasks grounded in established simulation methods.</p>
<h2 id="physics-based-simulation-workflows-as-benchmark-oracles">Physics-Based Simulation Workflows as Benchmark Oracles</h2>
<p>The core innovation in Tartarus is the use of computational chemistry simulation pipelines as objective functions for benchmarking. Rather than relying on learned property predictors, each benchmark task runs a full simulation workflow to evaluate proposed molecules:</p>
<ol>
<li><strong>Organic Photovoltaics (OPV)</strong>: Starting from a <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> string, the workflow generates 3D coordinates with Open Babel, performs conformer search with CREST at the GFN-FF level, optimizes geometry at GFN2-xTB, and computes <a href="https://en.wikipedia.org/wiki/HOMO_and_LUMO">HOMO/LUMO</a> energies. Power conversion efficiency (PCE) is estimated via the Scharber model for single-junction <a href="https://en.wikipedia.org/wiki/Organic_solar_cell">organic solar cells</a>. HOMO and LUMO energies are calibrated against DFT results from the Harvard Clean Energy Project Database using <a href="https://en.wikipedia.org/wiki/Theil%E2%80%93Sen_estimator">Theil-Sen regression</a>:</li>
</ol>
<p>$$
E_{\text{HOMO, calibrated}} = E_{\text{HOMO, GFN2-xTB}} \cdot 0.8051 + 2.5377 \text{ eV}
$$</p>
<p>$$
E_{\text{LUMO, calibrated}} = E_{\text{LUMO, GFN2-xTB}} \cdot 0.8788 + 3.7913 \text{ eV}
$$</p>
<ol start="2">
<li>
<p><strong>Organic Emitters (OLED)</strong>: The workflow uses conformer search via CREST, geometry optimization at GFN0-xTB, and TD-DFT single-point calculations at the B3LYP/6-31G* level with PySCF to extract singlet-triplet gaps, <a href="https://en.wikipedia.org/wiki/Oscillator_strength">oscillator strengths</a>, and vertical excitation energies.</p>
</li>
<li>
<p><strong>Protein Ligands</strong>: The workflow generates 3D coordinates, applies structural filters (<a href="https://en.wikipedia.org/wiki/Lipinski%27s_rule_of_five">Lipinski&rsquo;s Rule of Five</a>, reactive moiety checks), and performs molecular docking using QuickVina2 with re-scoring via smina against three protein targets: 1SYH (ionotropic glutamate receptor), 6Y2F (<a href="https://en.wikipedia.org/wiki/3C-like_protease">SARS-CoV-2 main protease</a>), and 4LDE (beta-2 adrenoceptor).</p>
</li>
<li>
<p><strong>Chemical Reaction Substrates</strong>: The workflow models the intramolecular double hydrogen transfer in syn-sesquinorbornenes using the SEAM force field approach at the GFN-FF/GFN2-xTB level to compute activation and reaction energies.</p>
</li>
</ol>
<p>Each benchmark also includes a curated reference dataset for training generative models and a standardized evaluation protocol: train on 80% of the dataset, use 20% for hyperparameter optimization, then optimize structures starting from the best reference molecule with a constrained budget of 5,000 proposed compounds, a 24-hour runtime cap, and five independent repetitions.</p>
<h2 id="benchmark-tasks-datasets-and-model-comparisons">Benchmark Tasks, Datasets, and Model Comparisons</h2>
<h3 id="models-evaluated">Models Evaluated</h3>
<p>Eight generative models spanning major algorithm families were tested:</p>
<ul>
<li><strong>VAEs</strong>: SMILES-VAE and SELFIES-VAE</li>
<li><strong>Flow models</strong>: MoFlow</li>
<li><strong>Reinforcement learning</strong>: <a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a></li>
<li><strong>LSTM-based hill climbing</strong>: SMILES-LSTM-HC and SELFIES-LSTM-HC</li>
<li><strong>Genetic algorithms</strong>: <a href="/notes/chemistry/molecular-design/generation/search-based/graph-based-genetic-algorithm-chemical-space/">GB-GA</a> and JANUS</li>
</ul>
<h3 id="organic-photovoltaics-results">Organic Photovoltaics Results</h3>
<p>The reference dataset (CEP_SUB) contains approximately 25,000 molecules from the Harvard Clean Energy Project Database. Two objectives combine PCE with synthetic accessibility (SAscore):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>PCE_PCBM - SAscore</th>
          <th>PCE_PCDTBT - SAscore</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Dataset</td>
          <td>7.57</td>
          <td>31.71</td>
      </tr>
      <tr>
          <td>SMILES-VAE</td>
          <td>7.44 +/- 0.28</td>
          <td>10.23 +/- 11.14</td>
      </tr>
      <tr>
          <td>SELFIES-VAE</td>
          <td>7.05 +/- 0.66</td>
          <td>29.24 +/- 0.65</td>
      </tr>
      <tr>
          <td>MoFlow</td>
          <td>7.08 +/- 0.31</td>
          <td>29.81 +/- 0.37</td>
      </tr>
      <tr>
          <td>SMILES-LSTM-HC</td>
          <td>6.69 +/- 0.40</td>
          <td>31.79 +/- 0.15</td>
      </tr>
      <tr>
          <td>SELFIES-LSTM-HC</td>
          <td>7.40 +/- 0.41</td>
          <td>30.71 +/- 1.20</td>
      </tr>
      <tr>
          <td>REINVENT</td>
          <td>7.48 +/- 0.11</td>
          <td>30.47 +/- 0.44</td>
      </tr>
      <tr>
          <td>GB-GA</td>
          <td>7.78 +/- 0.02</td>
          <td>30.24 +/- 0.80</td>
      </tr>
      <tr>
          <td>JANUS</td>
          <td>7.59 +/- 0.14</td>
          <td>31.34 +/- 0.74</td>
      </tr>
  </tbody>
</table>
<p>GB-GA achieves the best score on the first task (7.78), while SMILES-LSTM-HC leads on the second (31.79). Most models can marginally improve PCE but struggle to simultaneously improve PCE and reduce SAscore.</p>
<h3 id="organic-emitters-results">Organic Emitters Results</h3>
<p>The reference dataset (GDB-13_SUB) contains approximately 380,000 molecules filtered for conjugated pi-systems from <a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a>. Three objectives target singlet-triplet gap minimization, oscillator strength maximization, and a combined multi-objective:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Delta E(S1-T1)</th>
          <th>f12</th>
          <th>Multi-objective</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Dataset</td>
          <td>0.020</td>
          <td>2.97</td>
          <td>-0.04</td>
      </tr>
      <tr>
          <td>SMILES-VAE</td>
          <td>0.071 +/- 0.003</td>
          <td>0.50 +/- 0.27</td>
          <td>-0.57 +/- 0.33</td>
      </tr>
      <tr>
          <td>SELFIES-VAE</td>
          <td>0.016 +/- 0.001</td>
          <td>0.36 +/- 0.31</td>
          <td>0.17 +/- 0.10</td>
      </tr>
      <tr>
          <td>MoFlow</td>
          <td>0.013 +/- 0.001</td>
          <td>0.81 +/- 0.11</td>
          <td>-0.04 +/- 0.06</td>
      </tr>
      <tr>
          <td>GB-GA</td>
          <td>0.012 +/- 0.002</td>
          <td>2.14 +/- 0.45</td>
          <td>0.07 +/- 0.03</td>
      </tr>
      <tr>
          <td>JANUS</td>
          <td>0.008 +/- 0.001</td>
          <td>2.07 +/- 0.16</td>
          <td>0.02 +/- 0.05</td>
      </tr>
  </tbody>
</table>
<p>Only JANUS, GB-GA, and SELFIES-VAE generate compounds comparable to or improving upon the best training molecules. JANUS achieves the lowest singlet-triplet gap (0.008 eV), while SELFIES-VAE achieves the highest multi-objective fitness (0.17). Some proposed structures contain reactive moieties, likely because stability is not explicitly penalized in the objective functions.</p>
<h3 id="protein-ligand-results">Protein Ligand Results</h3>
<p>The reference dataset contains approximately 152,000 molecules from the DTP Open Compound Collection, filtered for drug-likeness. Docking is performed against three protein targets using both QuickVina2 and smina re-scoring:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>1SYH (smina)</th>
          <th>6Y2F (smina)</th>
          <th>4LDE (smina)</th>
          <th>SR (1SYH)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Dataset</td>
          <td>-10.2</td>
          <td>-8.2</td>
          <td>-13.1</td>
          <td>100.0%</td>
      </tr>
      <tr>
          <td>SMILES-VAE</td>
          <td>-10.4 +/- 0.6</td>
          <td>-8.9 +/- 0.8</td>
          <td>-11.1 +/- 0.4</td>
          <td>12.3%</td>
      </tr>
      <tr>
          <td>SELFIES-VAE</td>
          <td>-10.9 +/- 0.3</td>
          <td>-10.1 +/- 0.4</td>
          <td>-11.9 +/- 0.2</td>
          <td>34.8%</td>
      </tr>
      <tr>
          <td>REINVENT</td>
          <td>-12.1 +/- 0.2</td>
          <td>-11.4 +/- 0.3</td>
          <td>-13.7 +/- 0.5</td>
          <td>77.8%</td>
      </tr>
      <tr>
          <td>GB-GA</td>
          <td>-12.0 +/- 0.2</td>
          <td>-11.0 +/- 0.2</td>
          <td>-13.8 +/- 0.4</td>
          <td>72.6%</td>
      </tr>
      <tr>
          <td>JANUS</td>
          <td>-11.9 +/- 0.2</td>
          <td>-11.9 +/- 0.4</td>
          <td>-13.6 +/- 0.5</td>
          <td>68.4%</td>
      </tr>
  </tbody>
</table>
<p>No single model consistently achieves the best docking score across all three targets. REINVENT leads on 1SYH, JANUS on 6Y2F, and GB-GA on 4LDE. Both VAE models show low success rates for structural filter compliance (12-39%), while REINVENT, GAs, and LSTMs achieve 68-78%.</p>
<h3 id="chemical-reaction-substrates-results">Chemical Reaction Substrates Results</h3>
<p>The reference dataset (SNB-60K) contains approximately 60,000 syn-sesquinorbornene derivatives generated via <a href="/notes/chemistry/molecular-design/generation/search-based/stoned-selfies-chemical-space-exploration/">STONED-SELFIES</a> mutations. Four objectives target activation energy, reaction energy, and two combined metrics:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Delta E(activation)</th>
          <th>Delta E(reaction)</th>
          <th>Delta E(act) + Delta E(rxn)</th>
          <th>-Delta E(act) + Delta E(rxn)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Dataset</td>
          <td>64.94</td>
          <td>-34.39</td>
          <td>56.48</td>
          <td>-95.25</td>
      </tr>
      <tr>
          <td>SMILES-VAE</td>
          <td>76.81 +/- 0.25</td>
          <td>-10.96 +/- 0.71</td>
          <td>71.01 +/- 0.62</td>
          <td>-90.94 +/- 1.04</td>
      </tr>
      <tr>
          <td>MoFlow</td>
          <td>70.12 +/- 2.13</td>
          <td>-20.21 +/- 4.13</td>
          <td>63.21 +/- 0.69</td>
          <td>-92.82 +/- 3.06</td>
      </tr>
      <tr>
          <td>GB-GA</td>
          <td>56.04 +/- 3.07</td>
          <td>-41.39 +/- 5.76</td>
          <td>45.20 +/- 6.78</td>
          <td>-100.07 +/- 1.35</td>
      </tr>
      <tr>
          <td>JANUS</td>
          <td>47.56 +/- 2.19</td>
          <td>-45.37 +/- 7.90</td>
          <td>39.22 +/- 3.99</td>
          <td>-97.14 +/- 1.13</td>
      </tr>
  </tbody>
</table>
<p>Only JANUS and GB-GA consistently outperform the best reference compounds. Both VAE models fail to surpass the dataset baseline on any objective. JANUS achieves the best single-objective scores for activation energy (47.56) and reaction energy (-45.37), and the best combined score (39.22).</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="central-finding-algorithm-performance-is-domain-dependent">Central Finding: Algorithm Performance is Domain-Dependent</h3>
<p>The most important result from Tartarus is that no single generative model consistently outperforms the others across all benchmark domains. This has several implications:</p>
<ul>
<li><strong>Genetic algorithms (GB-GA and JANUS) show the most consistently strong performance</strong> across benchmarks, despite being among the simplest approaches and requiring minimal pre-conditioning time (seconds vs. hours for deep models).</li>
<li><strong>VAE-based models (SMILES-VAE and SELFIES-VAE) show the weakest overall performance</strong>, often failing to surpass the best molecules in the reference datasets. Their reliance on the available training data appears to limit their effectiveness.</li>
<li><strong>REINVENT performs competitively on protein ligand tasks</strong> but shows weaker performance on other benchmarks.</li>
<li><strong>Representation matters</strong>: SELFIES-based models generally outperform their SMILES-based counterparts (e.g., SELFIES-VAE vs. SMILES-VAE), consistent with SELFIES providing 100% validity guarantees.</li>
</ul>
<h3 id="timing-analysis">Timing Analysis</h3>
<p>Training time varies dramatically across models. Both VAEs require over 9 hours of GPU training, with estimated CPU-only training times of approximately 25 days. REINVENT and MoFlow train in under 1 hour. Both GAs complete pre-conditioning in seconds and require no GPU.</p>
<h3 id="limitations-acknowledged-by-the-authors">Limitations Acknowledged by the Authors</h3>
<ul>
<li>Benchmark domains covered are not comprehensive and need expansion.</li>
<li>3D generative models are not well supported, as proposed conformers are ignored in favor of simulation-derived geometries.</li>
<li>The chemical reaction substrate benchmark requires specialized geometries (reactant, product, transition state) that most 3D generative models cannot produce.</li>
<li>Results depend heavily on both model hyperparameters and benchmark settings (compute budget, number of evaluations).</li>
<li>Objective functions may need revision when undesired structures are promoted.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>OPV Training</td>
          <td>CEP_SUB (Harvard Clean Energy Project subset)</td>
          <td>~25,000 molecules</td>
          <td>From HIPS/neural-fingerprint repository</td>
      </tr>
      <tr>
          <td>Emitter Training</td>
          <td>GDB-13_SUB (filtered GDB-13)</td>
          <td>~380,000 molecules</td>
          <td>Conjugated pi-system filter applied</td>
      </tr>
      <tr>
          <td>Ligand Training</td>
          <td>DTP Open Compound Collection (filtered)</td>
          <td>~152,000 molecules</td>
          <td>Drug-likeness and structural filters applied</td>
      </tr>
      <tr>
          <td>Reaction Training</td>
          <td>SNB-60K (STONED-SELFIES mutations)</td>
          <td>~60,000 molecules</td>
          <td>Generated from syn-sesquinorbornene core</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>All eight algorithms are implemented in the Tartarus repository with configuration files and installation instructions. The evaluation protocol specifies: 80/20 train/validation split, population size of 5,000, 24-hour runtime cap, five independent runs per model.</p>
<h3 id="models">Models</h3>
<p>Pre-trained model checkpoints are not provided. Training must be performed from scratch using the provided reference datasets and hyperparameter configurations documented in the Supporting Information.</p>
<h3 id="evaluation">Evaluation</h3>
<p>Properties are evaluated through physics-based simulation workflows (not learned surrogates). Each workflow accepts a SMILES string and returns computed properties. Key software dependencies include: Open Babel, CREST, xTB, PySCF, QuickVina2, smina, and RDKit.</p>
<h3 id="hardware">Hardware</h3>
<p>Training and sampling benchmarks were conducted using 24 CPU cores (AMD Rome 7532 @ 2.40 GHz) and a single Tesla A100 GPU. Simulations were run on the Beluga, Narval, Niagara, Cedar, and Sherlock supercomputing clusters.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/aspuru-guzik-group/Tartarus">Tartarus GitHub</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Benchmark tasks, simulation workflows, model configs</td>
      </tr>
      <tr>
          <td><a href="https://zenodo.org/badge/latestdoi/444879123">Zenodo Archive</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>Reference datasets for all four benchmark domains</td>
      </tr>
      <tr>
          <td><a href="https://discord.gg/KypwPXTY2s">Discord Community</a></td>
          <td>Other</td>
          <td>N/A</td>
          <td>Discussion and collaboration channel</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Nigam, A., Pollice, R., Tom, G., Jorner, K., Willes, J., Thiede, L. A., Kundaje, A., &amp; Aspuru-Guzik, A. (2023). Tartarus: A Benchmarking Platform for Realistic And Practical Inverse Molecular Design. <em>Advances in Neural Information Processing Systems 36</em>, 3263-3306.</p>
<p><strong>Publication</strong>: NeurIPS 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/aspuru-guzik-group/Tartarus">Tartarus GitHub Repository</a></li>
<li><a href="https://zenodo.org/badge/latestdoi/444879123">Zenodo Dataset Archive</a></li>
<li><a href="https://discord.gg/KypwPXTY2s">Discord Community</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{nigam2023tartarus,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Tartarus: A Benchmarking Platform for Realistic And Practical Inverse Molecular Design}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Nigam, AkshatKumar and Pollice, Robert and Tom, Gary and Jorner, Kjell and Willes, John and Thiede, Luca A. and Kundaje, Anshul and Aspuru-Guzik, Al{\&#39;a}n}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{36}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{3263--3306}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SMINA Docking Benchmark for De Novo Drug Design Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/smina-docking-benchmark/</link><pubDate>Mon, 23 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/smina-docking-benchmark/</guid><description>A docking-based benchmark for evaluating de novo drug design generative models, using SMINA scoring across eight protein targets from ChEMBL.</description><content:encoded><![CDATA[<h2 id="a-docking-based-benchmark-for-de-novo-drug-design">A Docking-Based Benchmark for De Novo Drug Design</h2>
<p>This is a <strong>Resource</strong> paper. Its primary contribution is a standardized benchmark for evaluating generative models in de novo drug design. Rather than introducing a new generative method, the paper provides a reusable evaluation framework built around molecular docking, a widely used computational proxy for predicting protein-ligand binding. The benchmark uses SMINA (a fork of <a href="https://en.wikipedia.org/wiki/AutoDock">AutoDock Vina</a>) to score generated molecules against eight protein targets, offering a more realistic evaluation than commonly used proxy metrics like logP or QED.</p>
<h2 id="why-existing-benchmarks-fall-short">Why Existing Benchmarks Fall Short</h2>
<p>De novo drug design methods are typically evaluated using simple proxy tasks that do not reflect the complexity of real drug discovery. The octanol-water partition coefficient (logP) can be trivially optimized by producing unrealistic molecules. The QED drug-likeness score suffers from the same issue. Neural network-based bioactivity predictors are similarly exploitable.</p>
<p>As Coley et al. (2020) note: &ldquo;The current evaluations for generative models do not reflect the complexity of real discovery problems.&rdquo;</p>
<p>More realistic evaluation approaches exist in adjacent domains (photovoltaics, excitation energies), where physical calculations are used to both train and evaluate models. Yet de novo drug design has largely relied on the same simplistic proxies. This gap between proxy task performance and real-world utility motivates the development of a docking-based benchmark that, while still a proxy, captures more of the structural complexity involved in protein-ligand interactions.</p>
<h2 id="benchmark-design-smina-docking-with-the-vinardo-scoring-function">Benchmark Design: SMINA Docking with the Vinardo Scoring Function</h2>
<p>The benchmark is defined by three components: (1) docking software that computes a ligand&rsquo;s pose in the binding site, (2) a scoring function that evaluates the pose, and (3) a training set of compounds with precomputed docking scores.</p>
<p>The concrete instantiation uses SMINA v. 2017.11.9 with the Vinardo scoring function:</p>
<p>$$S = -0.045 \cdot G + 0.8 \cdot R - 0.035 \cdot H - 0.6 \cdot B$$</p>
<p>where $S$ is the docking score, $G$ is the gauss term, $R$ is repulsion, $H$ is the hydrophobic term, and $B$ is the non-directional hydrogen bond term. The gauss and repulsion terms measure steric interactions between the ligand and the protein, while the hydrophobic and hydrogen bond terms capture favorable non-covalent contacts.</p>
<p>The benchmark includes three task variants:</p>
<ol>
<li><strong>Docking Score Function</strong>: Optimize the full Vinardo docking score (lower is better).</li>
<li><strong>Repulsion</strong>: Minimize only the repulsion component, defined as:</li>
</ol>
<p>$$
R(a_1, a_2) = \begin{cases}
d(a_1, a_2)^2 &amp; d(a_1, a_2) &lt; 0 \\
0 &amp; \text{otherwise}
\end{cases}
$$</p>
<p>where $d(a_1, a_2)$ is the inter-atomic distance minus the sum of <a href="https://en.wikipedia.org/wiki/Van_der_Waals_radius">van der Waals radii</a>.</p>
<ol start="3">
<li><strong>Hydrogen Bonding</strong>: Maximize the hydrogen bond term:</li>
</ol>
<p>$$
B(a_1, a_2) = \begin{cases}
0 &amp; (a_1, a_2) \text{ do not form H-bond} \\
1 &amp; d(a_1, a_2) &lt; -0.6 \\
0 &amp; d(a_1, a_2) \geq 0 \\
\frac{d(a_1, a_2)}{-0.6} &amp; \text{otherwise}
\end{cases}
$$</p>
<p>Scores are averaged over the top 5 binding poses for stability. Generated compounds are filtered by <a href="https://en.wikipedia.org/wiki/Lipinski%27s_rule_of_five">Lipinski&rsquo;s Rule of Five</a> and a minimum molecular weight of 100. Each model must generate 250 unique molecules per target.</p>
<p>Training data comes from <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a>, covering eight drug targets: 5-HT1B, 5-HT2B, ACM2, CYP2D6, ADRB1, MOR, A2A, and D2. Dataset sizes range from 1,082 (ADRB1) to 10,225 (MOR) molecules.</p>
<h2 id="experimental-evaluation-of-three-generative-models">Experimental Evaluation of Three Generative Models</h2>
<h3 id="models-tested">Models Tested</h3>
<p>Three popular generative models were evaluated:</p>
<ul>
<li><strong><a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">CVAE</a></strong> (Chemical Variational Autoencoder): A VAE operating on <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/">GVAE</a></strong> (Grammar Variational Autoencoder): Extends CVAE by enforcing grammatical correctness of generated SMILES.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a></strong>: A recurrent neural network trained first on ChEMBL in a supervised manner, then fine-tuned with reinforcement learning using docking scores as rewards.</li>
</ul>
<p>For CVAE and GVAE, molecules are generated by sampling from the latent space and taking 50 gradient steps to optimize an MLP that predicts the docking score. For REINVENT, a random forest model predicts docking scores from ECFP fingerprints, and the reward combines this prediction with the QED score.</p>
<h3 id="baselines">Baselines</h3>
<p>Two baselines provide context:</p>
<ul>
<li><strong>Training set</strong>: The top 50%, 10%, and 1% of docking scores from the ChEMBL training set.</li>
<li><strong><a href="/notes/chemistry/datasets/zinc-22/">ZINC</a> subset</strong>: A random sample of ~9.2 million drug-like molecules from ZINC, with the same percentile breakdowns.</li>
</ul>
<p>Diversity is measured as the mean <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto distance</a> (using 1024-bit ECFP with radius 2) between all pairs of generated molecules.</p>
<h3 id="key-results">Key Results</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Model</th>
          <th>5-HT1B Score</th>
          <th>5-HT1B Diversity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Docking Score</td>
          <td>CVAE</td>
          <td>-4.647</td>
          <td>0.907</td>
      </tr>
      <tr>
          <td>Docking Score</td>
          <td>GVAE</td>
          <td>-4.955</td>
          <td>0.901</td>
      </tr>
      <tr>
          <td>Docking Score</td>
          <td>REINVENT</td>
          <td>-9.774</td>
          <td>0.506</td>
      </tr>
      <tr>
          <td>Docking Score</td>
          <td>ZINC (10%)</td>
          <td>-9.894</td>
          <td>0.862</td>
      </tr>
      <tr>
          <td>Docking Score</td>
          <td>ZINC (1%)</td>
          <td>-10.496</td>
          <td>0.861</td>
      </tr>
      <tr>
          <td>Docking Score</td>
          <td>Train (10%)</td>
          <td>-10.837</td>
          <td>0.749</td>
      </tr>
  </tbody>
</table>
<p>On the full docking score task, CVAE and GVAE fail to match even the mean ZINC docking score. REINVENT performs substantially better (e.g., -9.774 on 5-HT1B) but still falls short of the top 10% ZINC scores (-9.894) in most cases. The exception is ACM2, where REINVENT&rsquo;s score (-9.775) exceeds the ZINC 10% threshold (-8.282).</p>
<p>On the repulsion task, all three models fail to outperform the top 10% ZINC scores. On the hydrogen bonding task (the easiest), GVAE and REINVENT nearly match the top 1% ZINC scores, suggesting that optimizing individual scoring components is more tractable than the full docking score.</p>
<p>A consistent finding across all experiments is that REINVENT generates substantially less diverse molecules than the training set (e.g., 0.506 vs. 0.787 mean Tanimoto distance on 5-HT1B). The t-SNE visualizations show generated molecules clustering in a single dense region, separate from the training data, regardless of optimization target.</p>
<p>The paper also notes a moderately strong correlation between docking scores and molecular weight or the number of rotatable bonds. Generated compounds achieve better docking scores at the same molecular weight after optimization, suggesting the models learn some structural preferences rather than simply exploiting molecular size.</p>
<h2 id="limitations-of-current-generative-models-for-drug-design">Limitations of Current Generative Models for Drug Design</h2>
<p>The main finding is negative: popular generative models for de novo drug design struggle to generate molecules that dock well when trained on realistically sized datasets (1,000 to 10,000 compounds). Even the best-performing model (REINVENT) generally cannot outperform the top 10% of a random ZINC subset on the full docking score task.</p>
<p>The authors acknowledge several limitations:</p>
<ul>
<li><strong>Docking is itself a proxy</strong>: The SMINA docking score is only an approximation of true binding affinity. The fact that even this simpler proxy is challenging should raise concerns about these models&rsquo; readiness for real drug discovery pipelines.</li>
<li><strong>Limited model selection</strong>: Only three models were tested (CVAE, GVAE, REINVENT). The authors note that CVAE and GVAE were not designed for small training sets, and REINVENT may not represent the state of the art in all respects.</li>
<li><strong>ML-based scoring surrogate</strong>: All models use an ML model (MLP or random forest) to predict docking scores during generation, rather than running SMINA directly. This introduces an additional approximation layer.</li>
<li><strong>No similarity constraints</strong>: The benchmark does not impose constraints on the distance between generated and training molecules. A trivial baseline is to simply return the training set.</li>
</ul>
<p>On a more positive note, the tested models perform well on the simplest subtask (hydrogen bonding), suggesting that optimizing docking scores from limited data is attainable but challenging. The benchmark has already been adopted by other groups, notably Nigam et al. (2021) for evaluating their JANUS genetic algorithm.</p>
<p>Future directions include adding similarity constraints, extending to additional protein targets, and using the benchmark to evaluate newer structure-based generative models that employ equivariant neural networks.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Evaluation</td>
          <td>ChEMBL (8 targets)</td>
          <td>1,082-10,225 molecules per target</td>
          <td>90/10 train/test split</td>
      </tr>
      <tr>
          <td>Baseline</td>
          <td>ZINC 15 subset</td>
          <td>~9.2M drug-like molecules</td>
          <td>In-stock, standard reactivity, drug-like</td>
      </tr>
      <tr>
          <td>Protein structures</td>
          <td><a href="https://en.wikipedia.org/wiki/Protein_Data_Bank">Protein Data Bank</a></td>
          <td>8 structures</td>
          <td>Cleaned with Schrodinger modeling package</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>CVAE/GVAE: Fine-tuned 5 epochs on target data, then 50 gradient steps in latent space to optimize MLP-predicted score</li>
<li>REINVENT: Pretrained on ChEMBL, fine-tuned with RL; reward = random forest prediction * QED score</li>
<li>All docking performed with SMINA v. 2017.11.9 using Vinardo scoring function in score_only mode</li>
<li>Scores averaged over top 5 binding poses</li>
<li>Filtering: Lipinski Rule of Five, minimum molecular weight 100</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Mean docking score</td>
          <td>Average over 250 generated molecules</td>
          <td>Lower is better for docking score and repulsion</td>
      </tr>
      <tr>
          <td>Diversity</td>
          <td>Mean Tanimoto distance (ECFP, r=2)</td>
          <td>Higher is more diverse</td>
      </tr>
      <tr>
          <td>ZINC percentile baselines</td>
          <td>Top 50%, 10%, 1% from random ZINC subset</td>
          <td>Task considered &ldquo;solved&rdquo; if generated score exceeds ZINC 1%</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/cieplinski-tobiasz/smina-docking-benchmark">smina-docking-benchmark</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Benchmark code, data, evaluation notebooks</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Cieplinski, T., Danel, T., Podlewska, S., &amp; Jastrzebski, S. (2023). Generative Models Should at Least Be Able to Design Molecules That Dock Well: A New Benchmark. <em>Journal of Chemical Information and Modeling</em>, 63(11), 3238-3247. <a href="https://doi.org/10.1021/acs.jcim.2c01355">https://doi.org/10.1021/acs.jcim.2c01355</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/cieplinski-tobiasz/smina-docking-benchmark">GitHub Repository</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{cieplinski2023generative,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Generative Models Should at Least Be Able to Design Molecules That Dock Well: A New Benchmark}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Cieplinski, Tobiasz and Danel, Tomasz and Podlewska, Sabina and Jastrzebski, Stanislaw}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{63}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{11}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{3238--3247}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.2c01355}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Genetic Algorithms as Baselines for Molecule Generation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/search-based/genetic-algorithms-molecule-generation-baselines/</link><pubDate>Mon, 23 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/search-based/genetic-algorithms-molecule-generation-baselines/</guid><description>Genetic algorithms outperform many deep learning methods for molecule generation. Tripp and Hernández-Lobato propose the GA criterion.</description><content:encoded><![CDATA[<h2 id="a-position-paper-on-molecular-generation-baselines">A Position Paper on Molecular Generation Baselines</h2>
<p>This is a <strong>Position</strong> paper that argues genetic algorithms (GAs) are underused and underappreciated as baselines in the molecular generation community. The primary contribution is empirical evidence that a simple GA implementation (MOL_GA) matches or outperforms many sophisticated deep learning methods on standard benchmarks. The authors propose the &ldquo;GA criterion&rdquo; as a minimum bar for evaluating new molecular generation algorithms.</p>
<h2 id="why-molecular-generation-may-be-easier-than-assumed">Why Molecular Generation May Be Easier Than Assumed</h2>
<p>Drug discovery is fundamentally a molecular generation task, and many machine learning methods have been proposed for it (Du et al., 2022). The problem has many variants, from unconditional generation of novel molecules to directed optimization of specific molecular properties.</p>
<p>The authors observe that generating valid molecules is, in some respects, straightforward. The rules governing molecular validity are well-defined bond constraints that can be checked using standard cheminformatics software like <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a>. This means new molecules can be generated simply by adding, removing, or substituting fragments of known molecules. When applied iteratively, this is exactly what a genetic algorithm does. Despite this, many papers in the field propose complex deep learning methods without adequately comparing to simple GA baselines.</p>
<h2 id="the-ga-criterion-for-evaluating-new-methods">The GA Criterion for Evaluating New Methods</h2>
<p>The core proposal is the <strong>GA criterion</strong>: new methods in molecular generation should offer some clear advantage over genetic algorithms. This advantage can be:</p>
<ul>
<li><strong>Empirical</strong>: outperforming GAs on relevant benchmarks</li>
<li><strong>Conceptual</strong>: identifying and overcoming a specific limitation of randomly modifying known molecules</li>
</ul>
<p>The authors argue that the current state of molecular generation research reflects poor empirical practices, where comprehensive baseline evaluation is treated as optional rather than essential.</p>
<h2 id="genetic-algorithm-framework-and-benchmark-experiments">Genetic Algorithm Framework and Benchmark Experiments</h2>
<h3 id="how-genetic-algorithms-work-for-molecules">How Genetic Algorithms Work for Molecules</h3>
<p>GAs operate through the following iterative procedure:</p>
<ol>
<li>Start with an initial population $P$ of molecules</li>
<li>Sample a subset $S \subseteq P$ from the population (possibly biased toward better molecules)</li>
<li>Generate new molecules $N$ from $S$ via mutation and crossover operations</li>
<li>Select a new population $P&rsquo;$ from $P \cup N$ (e.g., keep the highest-scoring molecules)</li>
<li>Set $P \leftarrow P&rsquo;$ and repeat from step 2</li>
</ol>
<p>The MOL_GA implementation uses:</p>
<ul>
<li><strong>Quantile-based sampling</strong> (step 2): molecules are sampled from the top quantiles of the population using a log-uniform distribution over quantile thresholds:</li>
</ul>
<p>$$
u \sim \mathcal{U}[-3, 0], \quad \epsilon = 10^{u}
$$</p>
<p>A molecule is drawn uniformly from the top $\epsilon$ fraction of the population.</p>
<ul>
<li><strong>Mutation and crossover</strong> (step 3): graph-based operations from <a href="/notes/chemistry/molecular-design/generation/search-based/graph-based-genetic-algorithm-chemical-space/">Jensen (2019)</a>, as implemented in the <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol benchmark (Brown et al., 2019)</a></li>
<li><strong>Greedy population selection</strong> (step 4): molecules with the highest scores are retained</li>
</ul>
<h3 id="unconditional-generation-on-zinc-250k">Unconditional Generation on ZINC 250K</h3>
<p>The first experiment evaluates unconditional molecule generation, where the task is to produce novel, valid, and unique molecules distinct from a reference set (ZINC 250K). Success is measured by validity, novelty (at 10,000 generated molecules), and uniqueness.</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Paper</th>
          <th>Validity</th>
          <th>Novelty@10k</th>
          <th>Uniqueness</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>JT-VAE</td>
          <td>Jin et al. (2018)</td>
          <td>99.8%</td>
          <td>100%</td>
          <td>100%</td>
      </tr>
      <tr>
          <td>GCPN</td>
          <td>You et al. (2018)</td>
          <td>100%</td>
          <td>100%</td>
          <td>99.97%</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/generation/rl-tuned/molecularrnn-graph-generation-optimized-properties/">MolecularRNN</a></td>
          <td>Popova et al. (2019)</td>
          <td>100%</td>
          <td>100%</td>
          <td>99.89%</td>
      </tr>
      <tr>
          <td>Graph NVP</td>
          <td>Madhawa et al. (2019)</td>
          <td>100%</td>
          <td>100%</td>
          <td>94.80%</td>
      </tr>
      <tr>
          <td>Graph AF</td>
          <td>Shi et al. (2020)</td>
          <td>100%</td>
          <td>100%</td>
          <td>99.10%</td>
      </tr>
      <tr>
          <td>MoFlow</td>
          <td>Zang and Wang (2020)</td>
          <td>100%</td>
          <td>100%</td>
          <td>99.99%</td>
      </tr>
      <tr>
          <td>GraphCNF</td>
          <td>Lippe and Gavves (2020)</td>
          <td>96.35%</td>
          <td>99.98%</td>
          <td>99.98%</td>
      </tr>
      <tr>
          <td>Graph DF</td>
          <td>Luo et al. (2021)</td>
          <td>100%</td>
          <td>100%</td>
          <td>99.16%</td>
      </tr>
      <tr>
          <td>ModFlow</td>
          <td>Verma et al. (2022)</td>
          <td>98.1%</td>
          <td>100%</td>
          <td>99.3%</td>
      </tr>
      <tr>
          <td>GraphEBM</td>
          <td>Liu et al. (2021)</td>
          <td>99.96%</td>
          <td>100%</td>
          <td>98.79%</td>
      </tr>
      <tr>
          <td>AddCarbon</td>
          <td>Renz et al. (2019)</td>
          <td>100%</td>
          <td>99.94%</td>
          <td>99.86%</td>
      </tr>
      <tr>
          <td>MOL_GA</td>
          <td>(this paper)</td>
          <td>99.76%</td>
          <td>99.94%</td>
          <td>98.60%</td>
      </tr>
  </tbody>
</table>
<p>All methods perform near 100% on all metrics, demonstrating that unconditional molecule generation is not a particularly discriminative benchmark. The authors note that generation speed (molecules per second) is an important missing dimension from these comparisons, where simple methods like GAs have a clear advantage.</p>
<h3 id="molecule-optimization-on-the-pmo-benchmark">Molecule Optimization on the PMO Benchmark</h3>
<p>The second experiment evaluates directed molecule optimization on the <a href="/notes/chemistry/molecular-design/generation/evaluation/pmo-sample-efficient-molecular-optimization/">Practical Molecular Optimization (PMO) benchmark (Gao et al., 2022)</a>, which measures the ability to find molecules optimizing a scalar objective function $f: \mathcal{M} \mapsto \mathbb{R}$ with a budget of 10,000 evaluations.</p>
<p>A key insight is that previous GA implementations in PMO used large generation sizes ($\approx 100$), which limits the number of improvement iterations. The authors set the generation size to 5, allowing approximately 2,000 iterations of improvement within the same evaluation budget.</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th><a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a></th>
          <th>Graph GA</th>
          <th>MOL_GA</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>albuterol_similarity</td>
          <td>0.882 +/- 0.006</td>
          <td>0.838 +/- 0.016</td>
          <td><strong>0.896 +/- 0.035</strong></td>
      </tr>
      <tr>
          <td>amlodipine_mpo</td>
          <td>0.635 +/- 0.035</td>
          <td>0.661 +/- 0.020</td>
          <td><strong>0.688 +/- 0.039</strong></td>
      </tr>
      <tr>
          <td>celecoxib_rediscovery</td>
          <td><strong>0.713 +/- 0.067</strong></td>
          <td>0.630 +/- 0.097</td>
          <td>0.567 +/- 0.083</td>
      </tr>
      <tr>
          <td>drd2</td>
          <td>0.945 +/- 0.007</td>
          <td><strong>0.964 +/- 0.012</strong></td>
          <td>0.936 +/- 0.016</td>
      </tr>
      <tr>
          <td>fexofenadine_mpo</td>
          <td>0.784 +/- 0.006</td>
          <td>0.760 +/- 0.011</td>
          <td><strong>0.825 +/- 0.019</strong></td>
      </tr>
      <tr>
          <td>isomers_c9h10n2o2pf2cl</td>
          <td>0.642 +/- 0.054</td>
          <td>0.719 +/- 0.047</td>
          <td><strong>0.865 +/- 0.012</strong></td>
      </tr>
      <tr>
          <td>sitagliptin_mpo</td>
          <td>0.021 +/- 0.003</td>
          <td>0.433 +/- 0.075</td>
          <td><strong>0.582 +/- 0.040</strong></td>
      </tr>
      <tr>
          <td>zaleplon_mpo</td>
          <td>0.358 +/- 0.062</td>
          <td>0.346 +/- 0.032</td>
          <td><strong>0.519 +/- 0.029</strong></td>
      </tr>
      <tr>
          <td><strong>Sum (23 tasks)</strong></td>
          <td>14.196</td>
          <td>13.751</td>
          <td><strong>14.708</strong></td>
      </tr>
      <tr>
          <td><strong>Rank</strong></td>
          <td>2</td>
          <td>3</td>
          <td><strong>1</strong></td>
      </tr>
  </tbody>
</table>
<p>MOL_GA achieves the highest aggregate score across all 23 PMO tasks, outperforming both the previous best GA (Graph GA) and the previous best overall method (REINVENT). The authors attribute this partly to the tuning of the baselines in PMO rather than MOL_GA being an especially strong method, since MOL_GA is essentially the same algorithm as Graph GA with different hyperparameters.</p>
<h2 id="implications-for-molecular-generation-research">Implications for Molecular Generation Research</h2>
<p>The key findings and arguments are:</p>
<ol>
<li>
<p><strong>GAs match or outperform deep learning methods</strong> on standard molecular generation benchmarks, both for unconditional generation and directed optimization.</p>
</li>
<li>
<p><strong>Hyperparameter choices matter significantly</strong>: MOL_GA&rsquo;s strong performance on PMO comes partly from using a smaller generation size (5 vs. ~100), which allows more iterations of refinement within the same evaluation budget.</p>
</li>
<li>
<p><strong>The GA criterion should be enforced in peer review</strong>: new molecular generation methods should demonstrate a clear advantage over GAs, whether empirical or conceptual.</p>
</li>
<li>
<p><strong>Deep learning methods may implicitly do what GAs do explicitly</strong>: many generative models are trained on datasets of known molecules, so the novel molecules they produce may simply be variants of their training data. The authors consider this an important direction for future investigation.</p>
</li>
<li>
<p><strong>Poor empirical practices are widespread</strong>: the paper argues that many experiments in molecule generation are conducted with an explicit desired outcome (that the novel algorithm is the best), leading to inadequate baseline comparisons.</p>
</li>
</ol>
<p>The authors are careful to note that this result should not be interpreted as GAs being exceptional algorithms. Rather, it is an indication that more complex methods have made surprisingly little progress beyond what simple heuristic search can achieve.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Unconditional generation</td>
          <td>ZINC 250K</td>
          <td>250,000 molecules</td>
          <td>Reference set for novelty evaluation</td>
      </tr>
      <tr>
          <td>Directed optimization</td>
          <td>PMO benchmark</td>
          <td>23 tasks</td>
          <td>10,000 evaluation budget per task</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>GA implementation</strong>: MOL_GA package, using graph-based mutation and crossover from Jensen (2019) via the GuacaMol implementation</li>
<li><strong>Generation size</strong>: 5 molecules per iteration (allowing ~2,000 iterations with 10,000 evaluations)</li>
<li><strong>Population selection</strong>: Greedy (highest-scoring molecules retained)</li>
<li><strong>Sampling</strong>: Quantile-based with log-uniform distribution over quantile thresholds</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Benchmark</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity, Novelty@10k, Uniqueness</td>
          <td>ZINC 250K unconditional</td>
          <td>Calculated using <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES package</a></td>
      </tr>
      <tr>
          <td>AUC top-10 scores</td>
          <td>PMO benchmark</td>
          <td>23 optimization tasks with 10,000 evaluation budget</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify hardware requirements. Given that GAs are computationally lightweight compared to deep learning methods, standard CPU hardware is likely sufficient.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/AustinT/mol_ga">MOL_GA</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Python package for molecular genetic algorithms</td>
      </tr>
      <tr>
          <td><a href="https://pypi.org/project/mol-ga/">MOL_GA on PyPI</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>pip-installable package</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Tripp, A., &amp; Hernández-Lobato, J. M. (2023). Genetic algorithms are strong baselines for molecule generation. <em>arXiv preprint arXiv:2310.09267</em>. <a href="https://arxiv.org/abs/2310.09267">https://arxiv.org/abs/2310.09267</a></p>
<p><strong>Publication</strong>: arXiv preprint, 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/AustinT/mol_ga">MOL_GA Python Package (GitHub)</a></li>
<li><a href="https://pypi.org/project/mol-ga/">MOL_GA on PyPI</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{tripp2023genetic,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Genetic algorithms are strong baselines for molecule generation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Tripp, Austin and Hern{\&#39;a}ndez-Lobato, Jos{\&#39;e} Miguel}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2310.09267}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>RetMol: Retrieval-Based Controllable Molecule Generation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/retmol-retrieval-molecule-generation/</link><pubDate>Sun, 22 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/retmol-retrieval-molecule-generation/</guid><description>RetMol uses retrieval-augmented generation to steer a pre-trained molecular model toward desired properties using only a handful of exemplar molecules.</description><content:encoded><![CDATA[<h2 id="retrieval-augmented-generation-for-molecules">Retrieval-Augmented Generation for Molecules</h2>
<p>This is a <strong>Method</strong> paper that introduces RetMol, a retrieval-based framework for controllable molecule generation. The key idea is to guide a pre-trained generative model using a small set of exemplar molecules that partially satisfy the desired design criteria, retrieved from a task-specific database. The approach requires no task-specific fine-tuning of the generative backbone and works effectively with very few exemplar molecules (as few as 23).</p>
<h2 id="limitations-of-existing-controllable-generation">Limitations of Existing Controllable Generation</h2>
<p>Existing approaches to controllable molecule generation fall into three categories, each with drawbacks:</p>
<ol>
<li><strong>Reinforcement learning (RL)-based methods</strong> require task-specific fine-tuning of the generative model for each new objective</li>
<li><strong>Supervised learning (SL)-based methods</strong> need molecules with desired properties as training data, which may be scarce</li>
<li><strong>Latent optimization-based methods</strong> require training property predictors in the latent space, which is challenging with limited active molecules and incompatible with variable-length latent spaces like those in transformers</li>
</ol>
<p>RetMol addresses all three issues by keeping the generative backbone frozen and using a lightweight, task-agnostic retrieval module that can be applied to new tasks simply by swapping the retrieval database.</p>
<h2 id="the-retmol-framework">The RetMol Framework</h2>
<p>RetMol consists of four components built around a pre-trained encoder-decoder backbone (<a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">Chemformer</a>, a BART variant trained on ZINC):</p>
<h3 id="retrieval-database">Retrieval Database</h3>
<p>A task-specific collection of exemplar molecules that at least partially satisfy the design criteria. The database can be very small (e.g., 23 known inhibitors for the SARS-CoV-2 task) and is dynamically updated during inference with newly generated molecules.</p>
<h3 id="molecule-retriever">Molecule Retriever</h3>
<p>A heuristic-based module that selects the $K$ most relevant exemplar molecules (default $K = 10$). It first constructs a feasible set of molecules satisfying all constraints, then selects those with the best property scores. If too few molecules satisfy all constraints, it progressively relaxes constraints until enough candidates are available.</p>
<h3 id="information-fusion-via-cross-attention">Information Fusion via Cross-Attention</h3>
<p>The core trainable component. Retrieved exemplar embeddings are fused with the input molecule embedding using cross-attention:</p>
<p>$$\boldsymbol{e} = f_{\text{CA}}(\boldsymbol{e}_{\text{in}}, \boldsymbol{E}_r; \theta) = \text{Attn}(\text{Query}(\boldsymbol{e}_{\text{in}}), \text{Key}(\boldsymbol{E}_r)) \cdot \text{Value}(\boldsymbol{E}_r)$$</p>
<p>where $\boldsymbol{e}_{\text{in}} = \text{Enc}(x_{\text{in}}) \in \mathbb{R}^{L \times D}$ is the input embedding and $\boldsymbol{E}_r = [\boldsymbol{e}_r^1, \ldots, \boldsymbol{e}_r^K]$ are the retrieved exemplar embeddings. This module adds less than 5% parameter overhead (460K parameters over the 10M base model).</p>
<h3 id="self-supervised-training-nearest-neighbor-prediction">Self-Supervised Training: Nearest Neighbor Prediction</h3>
<p>Rather than reconstructing the input molecule (which would make the retrieval module unnecessary), RetMol trains the fusion module to predict the nearest neighbor of the input:</p>
<p>$$\mathcal{L}(\theta) = \sum_{i=1}^{B} \text{CE}\left(\text{Dec}\left(f_{\text{CA}}(\boldsymbol{e}_{\text{in}}^{(i)}, \boldsymbol{E}_r^{(i)}; \theta)\right), x_{\text{1NN}}^{(i)}\right)$$</p>
<p>The remaining $K - 1$ nearest neighbors serve as the retrieved exemplar molecules. This forces the fusion module to learn how to use exemplar molecules to transform the input toward a related target. Only the fusion module parameters are updated; the encoder and decoder remain frozen.</p>
<h2 id="iterative-refinement-at-inference">Iterative Refinement at Inference</h2>
<p>During inference, RetMol uses an iterative process:</p>
<ol>
<li>Encode the input molecule and retrieved exemplars</li>
<li>Fuse embeddings via cross-attention</li>
<li>Perturb the fused embedding $M$ times with Gaussian noise</li>
<li>Greedily decode $M$ candidate molecules</li>
<li>Replace the input with the best candidate if it improves upon the current score</li>
<li>Add remaining good candidates to the retrieval database</li>
<li>Repeat until convergence or a maximum number of iterations</li>
</ol>
<p>The dynamic update of the retrieval database is critical for extrapolating beyond the initial set of exemplar molecules.</p>
<h2 id="experiments-and-results">Experiments and Results</h2>
<p>RetMol is evaluated on four tasks of increasing difficulty:</p>
<h3 id="qed-optimization-under-similarity-constraint">QED Optimization Under Similarity Constraint</h3>
<p>Goal: generate molecules with QED $\geq$ 0.9 while maintaining <a href="https://en.wikipedia.org/wiki/Tanimoto_coefficient">Tanimoto similarity</a> $\geq$ 0.4 to the input. RetMol achieves 94.5% success rate, compared to 92.8% for the previous best (QMO).</p>
<h3 id="penalized-logp-optimization">Penalized LogP Optimization</h3>
<p>Goal: maximize penalized <a href="https://en.wikipedia.org/wiki/Octanol-water_partition_coefficient">LogP</a> while maintaining structural similarity. At $\delta = 0.4$, RetMol achieves 11.55 average improvement, compared to 7.71 for QMO.</p>
<h3 id="gsk3beta--jnk3-dual-inhibitor-design"><a href="https://en.wikipedia.org/wiki/GSK-3">GSK3</a>$\beta$ + <a href="https://en.wikipedia.org/wiki/C-Jun_N-terminal_kinase">JNK3</a> Dual Inhibitor Design</h3>
<p>Goal: simultaneously satisfy four constraints (GSK3$\beta$ inhibition $\geq$ 0.5, JNK3 inhibition $\geq$ 0.5, QED $\geq$ 0.6, SA $\leq$ 4). Results:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Success %</th>
          <th>Novelty</th>
          <th>Diversity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a></td>
          <td>47.9</td>
          <td>0.561</td>
          <td>0.621</td>
      </tr>
      <tr>
          <td>RationaleRL</td>
          <td>74.8</td>
          <td>0.568</td>
          <td>0.701</td>
      </tr>
      <tr>
          <td>MARS</td>
          <td>92.3</td>
          <td>0.824</td>
          <td>0.719</td>
      </tr>
      <tr>
          <td>MolEvol</td>
          <td>93.0</td>
          <td>0.757</td>
          <td>0.681</td>
      </tr>
      <tr>
          <td>RetMol</td>
          <td>96.9</td>
          <td>0.862</td>
          <td>0.732</td>
      </tr>
  </tbody>
</table>
<p>RetMol achieves this without task-specific fine-tuning and requires only 80 iterations compared to MARS&rsquo;s 550.</p>
<h3 id="sars-cov-2-main-protease-inhibitor-optimization"><a href="https://en.wikipedia.org/wiki/3C-like_protease">SARS-CoV-2 Main Protease</a> Inhibitor Optimization</h3>
<p>A real-world task using only 23 known inhibitors as the retrieval database and optimizing 8 weakly-binding drugs. Under the milder similarity constraint ($\delta = 0.4$), RetMol achieves 2.84 kcal/mol average binding affinity improvement versus 1.67 for Graph GA. Under the stricter constraint ($\delta = 0.6$), RetMol succeeds on 5/8 molecules versus 3/8 for Graph GA.</p>
<h2 id="key-analysis-findings">Key Analysis Findings</h2>
<ul>
<li><strong>Database size</strong>: Strong performance even with 100 molecules, already outperforming baselines on success rate</li>
<li><strong>Database quality</strong>: Molecules satisfying all four constraints give the best results (96.9%), but partial satisfaction still works reasonably (84.7% with two properties)</li>
<li><strong>Training objective</strong>: The nearest neighbor prediction objective outperforms conventional reconstruction on validity (0.902 vs. 0.834) and uniqueness (0.922 vs. 0.665)</li>
<li><strong>Dynamic database update</strong>: Essential for extrapolating beyond the initial retrieval database, generating molecules with property values exceeding the best in the original database</li>
</ul>
<h2 id="limitations">Limitations</h2>
<p>RetMol requires exemplar molecules that at least partially satisfy the design criteria. When such molecules are entirely unavailable, the framework cannot be applied. The method also relies on property predictors (for scoring and retrieval), whose accuracy directly affects generation quality. The iterative refinement process adds computational overhead at inference time, and the results depend on the Chemformer backbone&rsquo;s generation capabilities.</p>
<h2 id="reproducibility">Reproducibility</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/NVlabs/RetMol">NVlabs/RetMol</a></td>
          <td>Code</td>
          <td>NVIDIA Source Code License-NC</td>
          <td>Full training and inference code</td>
      </tr>
      <tr>
          <td><a href="https://github.com/NVlabs/RetMol">NVlabs/RetMol (checkpoints)</a></td>
          <td>Model</td>
          <td>CC BY-NC-SA 4.0</td>
          <td>Pre-trained model checkpoints</td>
      </tr>
  </tbody>
</table>
<p><strong>Data</strong>: ZINC250k and ChEMBL datasets for training. Task-specific retrieval databases constructed from these datasets. COVID-19 task uses 23 known SARS-CoV-2 Mpro inhibitors.</p>
<p><strong>Training</strong>: Information fusion module trained on 4x V100 GPUs (16GB each) for approximately 2 hours. Batch size of 256 per GPU, 50K iterations.</p>
<p><strong>Inference</strong>: Single V100 GPU. Greedy decoding with Gaussian perturbation ($\sigma = 1$) for sampling multiple candidates per iteration.</p>
<p><strong>Backbone</strong>: Chemformer (BART variant) pre-trained on ZINC. Frozen during RetMol training and inference.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Wang, Z., Nie, W., Qiao, Z., Xiao, C., Baraniuk, R. G., &amp; Anandkumar, A. (2023). Retrieval-based Controllable Molecule Generation. <em>Proceedings of the Eleventh International Conference on Learning Representations (ICLR 2023)</em>.</p>
<p><strong>Publication</strong>: International Conference on Learning Representations (ICLR) 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/NVlabs/RetMol">GitHub: NVlabs/RetMol</a></li>
<li><a href="https://openreview.net/forum?id=vDFA1tpuLvk">OpenReview</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{wang2023retrieval,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Retrieval-based Controllable Molecule Generation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Wang, Zichao and Nie, Weili and Qiao, Zhuoran and Xiao, Chaowei and Baraniuk, Richard G. and Anandkumar, Anima}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://openreview.net/forum?id=vDFA1tpuLvk}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>LIMO: Latent Inceptionism for Targeted Molecule Generation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/limo-latent-inceptionism/</link><pubDate>Sun, 22 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/limo-latent-inceptionism/</guid><description>LIMO uses gradient-based optimization through a VAE latent space and stacked property predictor to generate drug-like molecules with high binding affinity.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Eckmann, P., Sun, K., Zhao, B., Feng, M., Gilson, M. K., &amp; Yu, R. (2022). LIMO: Latent Inceptionism for Targeted Molecule Generation. <em>Proceedings of the 39th International Conference on Machine Learning (ICML 2022)</em>, PMLR 162, 5777&ndash;5792.</p>
<p><strong>Publication</strong>: ICML 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/Rose-STL-Lab/LIMO">GitHub: Rose-STL-Lab/LIMO</a></li>
<li><a href="https://arxiv.org/abs/2206.09010">arXiv: 2206.09010</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{eckmann2022limo,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{LIMO: Latent Inceptionism for Targeted Molecule Generation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Eckmann, Peter and Sun, Kunyang and Zhao, Bo and Feng, Mudong and Gilson, Michael K and Yu, Rose}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{5777--5792}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">organization</span>=<span style="color:#e6db74">{PMLR}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><h2 id="gradient-based-reverse-optimization-in-molecular-latent-space">Gradient-Based Reverse Optimization in Molecular Latent Space</h2>
<p>This is a <strong>Method</strong> paper that introduces LIMO, a framework for generating molecules with desired properties using gradient-based optimization on a VAE latent space. The key innovation is a stacked architecture where a property predictor operates on the decoded molecular representation rather than directly on the latent space, combined with an inceptionism-like technique that backpropagates through the frozen decoder and predictor to optimize the latent code. This approach is 6-8x faster than RL baselines and 12x faster than sampling-based approaches while producing molecules with higher binding affinities.</p>
<h2 id="slow-property-optimization-in-existing-methods">Slow Property Optimization in Existing Methods</h2>
<p>Generating molecules with high binding affinity to target proteins is a central goal of early drug discovery, but existing computational approaches are slow when optimizing for properties that are expensive to evaluate (such as docking-based binding affinity). RL-based methods require many calls to the property function during training. Sampling-based approaches like MARS need hundreds of iterations. Latent optimization methods that predict properties directly from the latent space suffer from poor prediction accuracy because the mapping from latent space to molecular properties is difficult to learn.</p>
<h2 id="the-limo-framework">The LIMO Framework</h2>
<p>LIMO consists of three components: a VAE for learning a molecular latent space, a property predictor with a novel stacked architecture, and a gradient-based reverse optimization procedure.</p>
<h3 id="selfies-based-vae">SELFIES-Based VAE</h3>
<p>The VAE encodes molecules represented as SELFIES strings into a 1024-dimensional latent space $\mathbf{z} \in \mathbb{R}^m$ and decodes to probability distributions over SELFIES symbols. Since all SELFIES strings correspond to valid molecules, this guarantees 100% chemical validity. The output molecule is obtained by taking the argmax at each position:</p>
<p>$$\hat{x}_i = s_{d_i^*}, \quad d_i^* = \operatorname{argmax}_{d} \{y_{i,1}, \ldots, y_{i,d}\}$$</p>
<p>The VAE uses fully-connected layers (not recurrent), with a 64-dimensional embedding layer, four batch-normalized linear layers (2000-dimensional first layer, 1000-dimensional for the rest) with ReLU activation, and is trained with ELBO loss (0.9 weight on reconstruction, 0.1 on KL divergence).</p>
<h3 id="stacked-property-predictor">Stacked Property Predictor</h3>
<p>The critical architectural choice: the property predictor $g_\theta$ takes the decoded molecular representation $\hat{\mathbf{x}}$ as input rather than the latent code $\mathbf{z}$. The predictor is trained after the VAE is frozen by minimizing MSE on VAE-generated molecules:</p>
<p>$$\ell_0(\theta) = \left\| g_\theta\left(f_{\text{dec}}(\mathbf{z})\right) - \pi\left(f_{\text{dec}}(\mathbf{z})\right) \right\|^2$$</p>
<p>where $\pi$ is the ground-truth property function. This stacking improves prediction accuracy from $r^2 = 0.04$ (predicting from $\mathbf{z}$) to $r^2 = 0.38$ (predicting from $\hat{\mathbf{x}}$) on an unseen test set. The improvement comes because the mapping from molecular space to property is easier to learn than the mapping from latent space to property.</p>
<h3 id="reverse-optimization-inceptionism">Reverse Optimization (Inceptionism)</h3>
<p>After training, the decoder and predictor weights are frozen and $\mathbf{z}$ becomes the trainable parameter. For multiple properties with weights $(w_1, \ldots, w_k)$, the optimization minimizes:</p>
<p>$$\ell_1(\mathbf{z}) = -\sum_{i=1}^{k} w_i \cdot g^i\left(f_{\text{dec}}(\mathbf{z})\right)$$</p>
<p>Since both the decoder and predictor are neural networks, gradients flow through the entire chain, enabling efficient optimization with Adam. This is analogous to the &ldquo;inceptionism&rdquo; (DeepDream) technique from computer vision, where network inputs are optimized to maximize specific outputs.</p>
<h3 id="substructure-constrained-optimization">Substructure-Constrained Optimization</h3>
<p>For lead optimization, LIMO can fix a molecular substructure during optimization by adding a regularization term:</p>
<p>$$\ell_2(\mathbf{z}) = \lambda \sum_{i=1}^{n} \sum_{j=1}^{d} \left(M_{i,j} \cdot \left(f_{\text{dec}}(\mathbf{z})_{i,j} - (\hat{\mathbf{x}}_{\text{start}})_{i,j}\right)\right)^2$$</p>
<p>where $M$ is a binary mask specifying which SELFIES positions must remain unchanged and $\lambda = 1000$. This capability is enabled by the intermediate decoded representation, which most VAE-based methods lack.</p>
<h2 id="experiments-and-results">Experiments and Results</h2>
<h3 id="benchmark-tasks-qed-and-penalized-logp">Benchmark Tasks (QED and Penalized LogP)</h3>
<p>LIMO achieves competitive results with deep generative and RL-based models in 1 hour, compared to 8-24 hours for baselines. Top QED score: 0.947 (maximum possible: 0.948). Top penalized LogP: 10.5 (among length-limited models, comparable to MolDQN&rsquo;s 11.8).</p>
<p>The ablation study (&ldquo;LIMO on z&rdquo;) confirms the stacked predictor architecture: predicting from $\hat{\mathbf{x}}$ yields top p-logP of 10.5 versus 6.52 when predicting directly from $\mathbf{z}$.</p>
<h3 id="binding-affinity-maximization">Binding Affinity Maximization</h3>
<p>The primary contribution. LIMO generates molecules with substantially higher computed binding affinities (lower $K_D$) than baselines against two protein targets:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>ESR1 best $K_D$ (nM)</th>
          <th>ACAA1 best $K_D$ (nM)</th>
          <th>Time (hrs)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GCPN</td>
          <td>6.4</td>
          <td>75</td>
          <td>6</td>
      </tr>
      <tr>
          <td>MolDQN</td>
          <td>373</td>
          <td>240</td>
          <td>6</td>
      </tr>
      <tr>
          <td>MARS</td>
          <td>17</td>
          <td>163</td>
          <td>6</td>
      </tr>
      <tr>
          <td>GraphDF</td>
          <td>25</td>
          <td>370</td>
          <td>12</td>
      </tr>
      <tr>
          <td>LIMO</td>
          <td>0.72</td>
          <td>37</td>
          <td>1</td>
      </tr>
  </tbody>
</table>
<p>For ESR1, LIMO&rsquo;s best molecule has a $K_D$ of 0.72 nM from docking, nearly 10x better than the next method (GCPN at 6.4 nM). When corroborated with more rigorous absolute binding free energy (ABFE) calculations, one LIMO compound achieved a predicted $K_D$ of $6 \times 10^{-14}$ M (0.00006 nM), far exceeding the affinities of approved drugs tamoxifen ($K_D$ = 1.5 nM) and raloxifene ($K_D$ = 0.03 nM).</p>
<h3 id="multi-objective-optimization">Multi-Objective Optimization</h3>
<p>Single-objective optimization produces molecules with high affinity but problematic structures (polyenes, large rings). Multi-objective optimization simultaneously targeting binding affinity, QED ($&gt;$ 0.4), and SA ($&lt;$ 5.5) produces drug-like, synthesizable molecules that still have nanomolar binding affinities. Generated molecules satisfy Lipinski&rsquo;s rule of 5 with zero PAINS alerts.</p>
<h2 id="limitations">Limitations</h2>
<p>The LIMO property predictor achieves only moderate prediction accuracy ($r^2$ = 0.38), meaning the optimization relies on gradient direction being correct rather than absolute predictions being accurate. AutoDock-GPU docking scores do not correlate well with the more accurate ABFE results, a known limitation of docking. The fully-connected VAE architecture limits the molecular diversity compared to recurrent or attention-based alternatives (LSTM decoder produced max QED of only 0.3). The greedy fine-tuning step (replacing carbons with heteroatoms) is a heuristic rather than a learned procedure.</p>
<h2 id="reproducibility">Reproducibility</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Rose-STL-Lab/LIMO">Rose-STL-Lab/LIMO</a></td>
          <td>Code</td>
          <td>UC San Diego Custom (non-commercial)</td>
          <td>Full training, optimization, and evaluation code</td>
      </tr>
  </tbody>
</table>
<p><strong>Data</strong>: ZINC250k dataset for optimization tasks. MOSES dataset for random generation evaluation. Binding affinities computed with AutoDock-GPU.</p>
<p><strong>Hardware</strong>: Two GTX 1080 Ti GPUs (one for PyTorch, one for AutoDock-GPU), 4 CPU cores, 32 GB memory.</p>
<p><strong>Training</strong>: VAE trained for 18 epochs with learning rate 0.0001. Property predictor uses 3 layers of 1000 units, trained for 5 epochs. Reverse optimization uses learning rate 0.1 for 10 epochs.</p>
<p><strong>Targets</strong>: Human estrogen receptor (ESR1, PDB 1ERR) and human peroxisomal acetyl-CoA acyl transferase 1 (ACAA1, PDB 2IIK).</p>
]]></content:encoded></item><item><title>MolGen: Molecular Generation with Chemical Feedback</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/molgen-molecular-generation-chemical-feedback/</link><pubDate>Fri, 20 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/molgen-molecular-generation-chemical-feedback/</guid><description>MolGen pre-trains on SELFIES molecules and uses chemical feedback to align generated molecules with real-world chemical preferences across domains.</description><content:encoded><![CDATA[<h2 id="a-selfies-based-method-for-molecular-generation">A SELFIES-Based Method for Molecular Generation</h2>
<p>This is a <strong>Method</strong> paper that introduces MolGen, a pre-trained molecular language model for generating molecules with desired chemical properties. The primary contribution is a three-part framework: (1) pre-training on 100M+ molecular SELFIES to learn structural and grammatical knowledge, (2) domain-agnostic molecular prefix tuning for cross-domain knowledge transfer, and (3) a chemical feedback paradigm that aligns the model&rsquo;s generative probabilities with real-world chemical preferences. MolGen is the first language model pre-trained on SELFIES rather than SMILES, which guarantees 100% syntactic validity of generated molecules.</p>
<h2 id="challenges-in-language-model-based-molecule-generation">Challenges in Language Model-Based Molecule Generation</h2>
<p>Generating novel molecules with desirable properties is a central task in drug discovery and chemical design. The molecular space is estimated at $10^{33}$ possible structures, making exhaustive search impractical. Prior deep generative approaches face several limitations:</p>
<ol>
<li><strong>Syntactic invalidity</strong>: <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>-based language models frequently generate strings that do not correspond to valid molecular graphs. A single random mutation of a SMILES string has only a 9.9% chance of remaining valid.</li>
<li><strong>Narrow domain focus</strong>: Most existing models focus exclusively on synthetic molecules and neglect <a href="https://en.wikipedia.org/wiki/Natural_product">natural products</a>, which have distinct structural complexity and scaffold diversity.</li>
<li><strong>Molecular hallucinations</strong>: Generated molecules may satisfy chemical structural rules yet fail to exhibit anticipated chemical activity in practical applications. The authors formally define this as molecules that &ldquo;comply with chemical structural rules, yet fail to exhibit practical utility or the anticipated properties.&rdquo;</li>
<li><strong>Limited optimization signals</strong>: Existing approaches rely on reinforcement learning (high variance), fixed-dimensional latent spaces, or expert-provided generation rules, all of which impede efficient exploration of chemical space.</li>
</ol>
<h2 id="core-innovation-pre-training-with-selfies-and-chemical-feedback">Core Innovation: Pre-training with SELFIES and Chemical Feedback</h2>
<p>MolGen&rsquo;s novelty rests on three interconnected components.</p>
<h3 id="selfies-based-pre-training">SELFIES-Based Pre-training</h3>
<p>MolGen uses <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> (Self-Referencing Embedded Strings) instead of SMILES. SELFIES guarantees that every possible combination of symbols in the alphabet corresponds to a chemically valid molecular graph. The model uses a compact vocabulary of 185 tokens.</p>
<p>The first pre-training stage uses a BART-style encoder-decoder. Tokens from a SELFIES string $S = {s_1, \ldots, s_l}$ are randomly replaced with [MASK], then the corrupted input is encoded bidirectionally and decoded left-to-right. The reconstruction loss is:</p>
<p>$$
\mathcal{L}_{\text{ce}}(S) = -\sum_{j=1}^{l} \sum_{s} p_{\text{true}}(s \mid S, S_{&lt; j}) \log p_{\theta}(s \mid S, S_{&lt; j}; \theta)
$$</p>
<p>where $S_{&lt; j}$ denotes the partial sequence ${s_0, \ldots, s_{j-1}}$ and $p_{\text{true}}$ is the one-hot distribution under standard maximum likelihood estimation.</p>
<h3 id="domain-agnostic-molecular-prefix-tuning">Domain-Agnostic Molecular Prefix Tuning</h3>
<p>The second pre-training stage introduces shared prefix vectors $P_k, P_v \in \mathbb{R}^{m \times d}$ prepended to the keys and values of multi-head attention at each layer. Unlike conventional prefix tuning that freezes model parameters, MolGen updates the entire model. The attention output becomes:</p>
<p>$$
\text{head} = \text{Attn}\left(xW_q, [P_k, XW_k], [P_v, XW_v]\right)
$$</p>
<p>This decomposes into a linear interpolation between prefix attention and standard attention:</p>
<p>$$
\text{head} = \lambda(x) \cdot \text{Attn}(xW_q, P_k, P_v) + (1 - \lambda(x)) \cdot \text{Attn}(xW_q, XW_k, XW_v)
$$</p>
<p>where $\lambda(x)$ is a scalar representing the sum of normalized attention weights on the prefixes. The prefixes are trained simultaneously across synthetic and natural product domains, acting as a domain instructor.</p>
<h3 id="chemical-feedback-paradigm">Chemical Feedback Paradigm</h3>
<p>To address molecular hallucinations, MolGen aligns the model&rsquo;s probabilistic rankings with chemical preference rankings. Given a molecule $S$ and a set of candidate outputs $\mathcal{S}^*$ with distinct property scores $\text{Ps}(\cdot)$, the model should satisfy:</p>
<p>$$
p_{\text{true}}(S_i \mid S) &gt; p_{\text{true}}(S_j \mid S), \quad \forall S_i, S_j \in \mathcal{S}^*, \text{Ps}(S_i) &gt; \text{Ps}(S_j)
$$</p>
<p>This is enforced via a rank loss:</p>
<p>$$
\mathcal{L}_{\text{rank}}(S) = \sum_{i} \sum_{j &gt; i} \max\left(0, f(S_j) - f(S_i) + \gamma_{ij}\right)
$$</p>
<p>where $\gamma_{ij} = (j - i) \cdot \gamma$ is a margin scaled by rank difference and $f(S) = \sum_{t=1}^{l} \log p_{\theta}(s_t \mid S, S_{&lt; t}; \theta)$ is the estimated log-probability. The overall training objective combines cross-entropy and rank loss:</p>
<p>$$
\mathcal{L} = \mathcal{L}_{\text{ce}} + \alpha \mathcal{L}_{\text{rank}}
$$</p>
<p>Label smoothing is applied to the target distribution in $\mathcal{L}_{\text{ce}}$, allocating probability mass $\beta$ to non-target tokens to maintain generative diversity.</p>
<h2 id="experiments-across-distribution-learning-and-property-optimization">Experiments Across Distribution Learning and Property Optimization</h2>
<h3 id="datasets">Datasets</h3>
<ul>
<li><strong>Stage 1 pre-training</strong>: 100M+ unlabeled molecules from ZINC-15 (molecular weight $\leq$ 500 Da, LogP $\leq$ 5)</li>
<li><strong>Stage 2 pre-training</strong>: 2.22M molecules spanning synthetic (ZINC, MOSES) and natural product (NPASS, 30,926 compounds) domains</li>
<li><strong>Downstream evaluation</strong>: MOSES synthetic dataset, ZINC250K, and natural product molecules</li>
</ul>
<h3 id="molecular-distribution-learning">Molecular Distribution Learning</h3>
<p>MolGen generates 10,000 synthetic and 80,000 natural product molecules, evaluated on seven metrics (Validity, Fragment similarity, Scaffold similarity, SNN, Internal Diversity, <a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">FCD</a>, and Novelty). Baselines include AAE, <a href="/notes/chemistry/molecular-design/generation/latent-space/latentgan-de-novo-molecular-generation/">LatentGAN</a>, CharRNN, VAE, JT-VAE, LIMO, and <a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">Chemformer</a>.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Validity</th>
          <th>Frag</th>
          <th>Scaf</th>
          <th>SNN</th>
          <th>IntDiv</th>
          <th>FCD</th>
          <th>Novelty</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Chemformer</td>
          <td>.9843</td>
          <td>.9889</td>
          <td>.9248</td>
          <td>.5622</td>
          <td>.8553</td>
          <td>.0061</td>
          <td>.9581</td>
      </tr>
      <tr>
          <td>MolGen</td>
          <td>1.000</td>
          <td>.9999</td>
          <td>.9999</td>
          <td>.9996</td>
          <td>.8567</td>
          <td>.0015</td>
          <td>1.000</td>
      </tr>
  </tbody>
</table>
<p>On synthetic molecules, MolGen achieves 100% validity, near-perfect fragment and scaffold similarity, and the lowest FCD (0.0015). For natural products, MolGen achieves FCD of 0.6519 compared to Chemformer&rsquo;s 0.8346.</p>
<h3 id="targeted-molecule-discovery">Targeted Molecule Discovery</h3>
<p>For penalized logP maximization (top-3 scores):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>1st</th>
          <th>2nd</th>
          <th>3rd</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MARS (no length limit)</td>
          <td>44.99</td>
          <td>44.32</td>
          <td>43.81</td>
      </tr>
      <tr>
          <td>MolGen (no length limit)</td>
          <td>80.30</td>
          <td>74.70</td>
          <td>69.85</td>
      </tr>
      <tr>
          <td>MolGen (length-limited)</td>
          <td>30.51</td>
          <td>28.98</td>
          <td>28.95</td>
      </tr>
  </tbody>
</table>
<p>For QED maximization, MolGen achieves the maximum score of 0.948 across the top-3.</p>
<h3 id="molecular-docking">Molecular Docking</h3>
<p>MolGen optimizes binding affinity for two protein targets (<a href="https://en.wikipedia.org/wiki/Estrogen_receptor_alpha">ESR1</a> and ACAA1), measured by <a href="https://en.wikipedia.org/wiki/Dissociation_constant">dissociation constant</a> $K_D$ (lower is better):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>ESR1 1st</th>
          <th>ESR1 2nd</th>
          <th>ESR1 3rd</th>
          <th>ACAA1 1st</th>
          <th>ACAA1 2nd</th>
          <th>ACAA1 3rd</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>LIMO</td>
          <td>0.72</td>
          <td>0.89</td>
          <td>1.4</td>
          <td>37</td>
          <td>37</td>
          <td>41</td>
      </tr>
      <tr>
          <td>MolGen</td>
          <td>0.13</td>
          <td>0.35</td>
          <td>0.47</td>
          <td>3.36</td>
          <td>3.98</td>
          <td>8.50</td>
      </tr>
  </tbody>
</table>
<p>MolGen achieves the lowest dissociation constants across both targets. Optimization of the 1,000 worst-affinity molecules yields 96.7% relative improvement for ESR1 and 70.4% for ACAA1.</p>
<h3 id="constrained-molecular-optimization">Constrained Molecular Optimization</h3>
<p>Optimizing 800 molecules from ZINC250K with lowest p-logP scores under Tanimoto similarity constraints:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>$\delta = 0.6$</th>
          <th>$\delta = 0.4$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/generation/autoregressive/retmol-retrieval-molecule-generation/">RetMol</a></td>
          <td>3.78 (3.29)</td>
          <td>11.55 (11.27)</td>
      </tr>
      <tr>
          <td>MolGen</td>
          <td>12.08 (0.82)</td>
          <td>12.35 (1.21)</td>
      </tr>
  </tbody>
</table>
<p>MolGen achieves the highest mean improvement with the lowest standard deviation under both constraints.</p>
<h3 id="ablation-studies">Ablation Studies</h3>
<ul>
<li><strong>Chemical feedback</strong>: Without it, the model generates molecules with property scores similar to initial molecules. With it ($\alpha = 3$), property scores increase progressively across generation rounds.</li>
<li><strong>Prefix tuning</strong>: Removing prefix tuning reduces constrained optimization improvement by 0.45 at $\delta = 0.6$ and 2.12 at $\delta = 0.4$.</li>
<li><strong>Label smoothing</strong>: Enhances diversity of generated molecules as measured by Internal Diversity.</li>
<li><strong>Substructure attention</strong>: MolGen focuses attention on chemically meaningful functional groups (fluoro, phenyl, hydroxyl), while SMILES-based PLMs scatter attention across syntactic tokens. The Substructure Attention Level (SAL) metric confirms MolGen&rsquo;s superior focus.</li>
</ul>
<h2 id="key-findings-limitations-and-future-directions">Key Findings, Limitations, and Future Directions</h2>
<h3 id="key-findings">Key Findings</h3>
<ol>
<li>SELFIES pre-training guarantees 100% molecular validity, eliminating the need for external valency checks.</li>
<li>Domain-agnostic prefix tuning enables effective knowledge transfer between synthetic and natural product domains.</li>
<li>The chemical feedback paradigm aligns model outputs with chemical preferences without requiring external annotated data or reference databases.</li>
<li>MolGen achieves the best or competitive results across all evaluated tasks: distribution learning, targeted molecule discovery, constrained optimization, and molecular docking.</li>
</ol>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge several limitations:</p>
<ul>
<li><strong>Computational cost</strong>: Training and fine-tuning on large datasets is computationally intensive.</li>
<li><strong>Model interpretability</strong>: The transformer architecture makes it difficult to understand explicit rationale behind decisions.</li>
<li><strong>Single-target optimization only</strong>: The chemical feedback paradigm handles single-target optimization; multiple conflicting objectives could create ambiguous optimization trajectories.</li>
<li><strong>Task specificity</strong>: MolGen is designed for 2D molecular generation; 3D conformation information is not incorporated.</li>
<li><strong>Reaction prediction</strong>: When applied to reaction prediction (an off-target task), MolGen achieves only 71.4% accuracy on 39,990 reaction samples.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The authors suggest applying MolGen to retrosynthesis and reaction prediction, exploring multimodal pre-training, and incorporating additional knowledge sources.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Stage 1 pre-training</td>
          <td>ZINC-15</td>
          <td>100M+ molecules</td>
          <td>MW $\leq$ 500 Da, LogP $\leq$ 5</td>
      </tr>
      <tr>
          <td>Stage 2 pre-training</td>
          <td>ZINC + <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a> + NPASS</td>
          <td>2.22M molecules</td>
          <td>Synthetic and natural product domains</td>
      </tr>
      <tr>
          <td>Distribution learning (synthetic)</td>
          <td><a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a></td>
          <td>~1.9M molecules</td>
          <td>Standard benchmark split</td>
      </tr>
      <tr>
          <td>Distribution learning (natural)</td>
          <td>NPASS</td>
          <td>30,926 compounds</td>
          <td>30,126 train / 800 test</td>
      </tr>
      <tr>
          <td>Constrained optimization</td>
          <td>ZINC250K</td>
          <td>800 molecules</td>
          <td>Lowest p-logP scores</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Architecture</strong>: BART-based encoder-decoder with SELFIES vocabulary (185 tokens)</li>
<li><strong>Prefix length</strong>: 5 tunable vectors per layer</li>
<li><strong>Optimizer</strong>: LAMB (pre-training), AdamW (fine-tuning)</li>
<li><strong>Pre-training</strong>: 600M steps with linear warm-up (180,000 steps) followed by linear decay</li>
<li><strong>Rank loss weight</strong> ($\alpha$): Recommended values of 3 or 5</li>
<li><strong>Candidate generation</strong>: 30 candidates per molecule (synthetic), 8 candidates (natural products)</li>
</ul>
<h3 id="models">Models</h3>
<p>MolGen is publicly available on Hugging Face. The model uses a vocabulary of 185 SELFIES tokens and is comparable in size to Chemformer-large.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Domain</th>
          <th>MolGen</th>
          <th>Best Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">FCD</a> (lower is better)</td>
          <td>Synthetic</td>
          <td>0.0015</td>
          <td>0.0061 (<a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">Chemformer</a>)</td>
          <td>Distribution learning</td>
      </tr>
      <tr>
          <td>p-logP top-1 (no limit)</td>
          <td>Synthetic</td>
          <td>80.30</td>
          <td>44.99 (MARS)</td>
          <td>Targeted discovery</td>
      </tr>
      <tr>
          <td>QED top-1</td>
          <td>Synthetic</td>
          <td>0.948</td>
          <td>0.948 (several)</td>
          <td>Tied at maximum</td>
      </tr>
      <tr>
          <td>ESR1 $K_D$ top-1</td>
          <td>Docking</td>
          <td>0.13</td>
          <td>0.72 (LIMO)</td>
          <td>Binding affinity</td>
      </tr>
      <tr>
          <td>p-logP improvement ($\delta=0.4$)</td>
          <td>Synthetic</td>
          <td>12.35 (1.21)</td>
          <td>11.55 (11.27) (RetMol)</td>
          <td>Constrained optimization</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>6 NVIDIA V100 GPUs</li>
<li>Pre-training batch size: 256 molecules per GPU</li>
<li>Fine-tuning batch size: 6 (synthetic and natural product)</li>
<li>Training: 100 epochs for fine-tuning tasks</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/zjunlp/MolGen">zjunlp/MolGen</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Official PyTorch implementation</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/zjunlp">zjunlp/MolGen-large</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>Pre-trained weights on Hugging Face</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Fang, Y., Zhang, N., Chen, Z., Guo, L., Fan, X., &amp; Chen, H. (2024). Domain-Agnostic Molecular Generation with Chemical Feedback. <em>Proceedings of the Twelfth International Conference on Learning Representations (ICLR 2024)</em>.</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/zjunlp/MolGen">GitHub: zjunlp/MolGen</a></li>
<li><a href="https://huggingface.co/zjunlp">Hugging Face Models</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{fang2024domain,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Domain-Agnostic Molecular Generation with Chemical Feedback}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Fang, Yin and Zhang, Ningyu and Chen, Zhuo and Guo, Lingbing and Fan, Xiaohui and Chen, Huajun}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{The Twelfth International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://openreview.net/forum?id=9rPyHyjfwP}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Kabsch-Horn Cookbook: Differentiable Alignment</title><link>https://hunterheidenreich.com/projects/kabsch-horn-cookbook/</link><pubDate>Fri, 20 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/projects/kabsch-horn-cookbook/</guid><description>Differentiable Kabsch (SVD) and Horn (quaternion) alignment for NumPy, PyTorch, JAX, TensorFlow, and MLX with gradient-safe SVD.</description><content:encoded><![CDATA[<h2 id="overview">Overview</h2>
<p>Aligning two sets of corresponding points, finding the optimal rotation (and optionally translation and scale) that maps one onto the other, is a fundamental operation across scientific computing. It appears in molecular dynamics (superimposing protein conformations), robotics (sensor registration), and computer vision (shape matching). The two dominant algorithm families are the Kabsch (SVD-based) method and the Horn (quaternion-based) method.</p>
<p>The <strong>Kabsch-Horn Cookbook</strong> is a Python library that implements both algorithm families across five numerical frameworks: NumPy, PyTorch, JAX, TensorFlow, and MLX. Every backend shares the same API, supports N-dimensional point sets, per-point weights, and arbitrary batch dimensions. The PyTorch, JAX, TensorFlow, and MLX backends are fully differentiable, with custom autograd rules that bypass the numerically unstable gradient of the standard SVD near degenerate singular values.</p>
<h2 id="features">Features</h2>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Kabsch</strong>: SVD-based optimal rotation for rigid alignment</li>
<li><strong>Kabsch-Umeyama</strong>: Kabsch with an additional optimal scaling factor $c$, solving $Q \approx cRP + t$</li>
<li><strong>Horn</strong>: Quaternion-based optimal rotation via the eigendecomposition of a $4 \times 4$ key matrix</li>
<li><strong>Horn + Scale</strong>: Horn&rsquo;s method extended with optimal isotropic scaling</li>
<li><strong>RMSD Wrappers</strong>: Convenience functions that return RMSD directly alongside the alignment parameters</li>
</ul>
<h3 id="framework-support">Framework Support</h3>
<table>
  <thead>
      <tr>
          <th>Framework</th>
          <th style="text-align: center">Differentiable</th>
          <th style="text-align: center">Compile/JIT</th>
          <th>Versions</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>NumPy</td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td>1.24+</td>
      </tr>
      <tr>
          <td>PyTorch</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center"><code>torch.compile</code></td>
          <td>2.0+</td>
      </tr>
      <tr>
          <td>JAX</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center"><code>jax.jit</code></td>
          <td>0.4+</td>
      </tr>
      <tr>
          <td>TensorFlow</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center"></td>
          <td>2.13+</td>
      </tr>
      <tr>
          <td>MLX</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center"></td>
          <td>0.1+</td>
      </tr>
  </tbody>
</table>
<p><code>torch.compile</code> and <code>jax.jit</code> are the tested compile/JIT paths. MLX supports 3D inputs only; the Kabsch (SVD) path is N-dimensional on the other four backends.</p>
<h3 id="numerical-robustness">Numerical Robustness</h3>
<p>Standard SVD and eigendecomposition backward passes produce <code>NaN</code> gradients when singular values collide or are near-zero. The library provides custom autograd primitives to handle these cases:</p>
<ul>
<li><strong>SafeSVD</strong> (PyTorch, JAX, TF, MLX): Custom backward pass that clamps the singular value gap, preventing division-by-zero in the gradient</li>
<li><strong>SafeEigh</strong> (PyTorch, JAX, TF, MLX): Analogous safe backward for the symmetric eigendecomposition used in Horn&rsquo;s method</li>
<li><strong>Per-point weights</strong>: Weighted centroids and weighted cross-covariance for mass-weighted or confidence-weighted alignment</li>
<li><strong>Batch dimensions</strong>: All functions broadcast over leading batch dimensions without explicit loops</li>
<li><strong>Mixed-dtype promotion</strong>: Inputs are promoted to a common floating-point dtype automatically</li>
</ul>
<h3 id="testing">Testing</h3>
<p>The test suite uses Hypothesis-based property testing across 13 modules covering:</p>
<ul>
<li>Round-trip correctness (align then compare)</li>
<li>Gradient finiteness and correctness (finite-difference checks)</li>
<li>Reflection handling (proper vs. improper rotations)</li>
<li>Weighted alignment consistency</li>
<li>Batch broadcasting</li>
<li>4 differentiable backends $\times$ 4 precisions (float32, float64, and where supported, float16, bfloat16)</li>
</ul>
<h2 id="usage">Usage</h2>
<p>This is a reference cookbook, so you can copy the framework folder you need from <code>src/kabsch_horn/&lt;framework&gt;/</code> directly into your project (the code has no runtime dependencies beyond the framework itself). To depend on it instead, install a pinned version from GitHub:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bash" data-lang="bash"><span style="display:flex;"><span>pip install <span style="color:#e6db74">&#34;git+https://github.com/hunter-heidenreich/Kabsch-Cookbook.git@v0.4.1&#34;</span>
</span></span></code></pre></div><p>Basic alignment with NumPy:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> numpy <span style="color:#66d9ef">as</span> np
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> kabsch_horn <span style="color:#f92672">import</span> numpy <span style="color:#66d9ef">as</span> kh
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Two sets of corresponding 3D points</span>
</span></span><span style="display:flex;"><span>P <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>randn(<span style="color:#ae81ff">100</span>, <span style="color:#ae81ff">3</span>)
</span></span><span style="display:flex;"><span>R_true <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>qr(np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>randn(<span style="color:#ae81ff">3</span>, <span style="color:#ae81ff">3</span>))[<span style="color:#ae81ff">0</span>]  <span style="color:#75715e"># random rotation matrix</span>
</span></span><span style="display:flex;"><span>Q <span style="color:#f92672">=</span> (P <span style="color:#f92672">@</span> R_true<span style="color:#f92672">.</span>T) <span style="color:#f92672">+</span> np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>randn(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">3</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>R, t, rmsd <span style="color:#f92672">=</span> kh<span style="color:#f92672">.</span>kabsch(P, Q)
</span></span><span style="display:flex;"><span>aligned <span style="color:#f92672">=</span> P <span style="color:#f92672">@</span> R<span style="color:#f92672">.</span>T <span style="color:#f92672">+</span> t
</span></span></code></pre></div><p>RMSD loss for training in PyTorch:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> kabsch_horn <span style="color:#f92672">import</span> pytorch <span style="color:#66d9ef">as</span> kh
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>pred_coords <span style="color:#f92672">=</span> model(input_features)   <span style="color:#75715e"># (B, N, 3), requires_grad=True</span>
</span></span><span style="display:flex;"><span>target_coords <span style="color:#f92672">=</span> batch[<span style="color:#e6db74">&#34;target&#34;</span>]       <span style="color:#75715e"># (B, N, 3)</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>rmsd <span style="color:#f92672">=</span> kh<span style="color:#f92672">.</span>kabsch_rmsd(pred_coords, target_coords)  <span style="color:#75715e"># (B,)</span>
</span></span><span style="display:flex;"><span>loss <span style="color:#f92672">=</span> rmsd<span style="color:#f92672">.</span>mean()
</span></span><span style="display:flex;"><span>loss<span style="color:#f92672">.</span>backward()  <span style="color:#75715e"># safe gradients via SafeSVD</span>
</span></span></code></pre></div><p>For the full API reference and additional examples, see the <a href="https://hunter-heidenreich.github.io/Kabsch-Cookbook/">documentation site</a>.</p>
<h2 id="results">Results</h2>
<h3 id="gradient-stability">Gradient Stability</h3>
<p>The standard SVD backward pass computes terms of the form $\frac{1}{\sigma_i^2 - \sigma_j^2}$, which diverges when two singular values are close. In molecular alignment this happens frequently: planar molecules, symmetric structures, and noisy coordinates can all produce near-degenerate singular values. The SafeSVD primitive floors the magnitude of that denominator at the dtype&rsquo;s machine epsilon (<code>finfo(dtype).eps</code>), producing finite (if slightly biased) gradients in these edge cases. Property-based tests confirm that gradients remain finite across thousands of random rotations, scales, and noise levels for all four differentiable backends.</p>
<h3 id="framework-parity">Framework Parity</h3>
<p>All five backends produce numerically equivalent results (up to floating-point tolerance) on the same inputs. The shared API means switching from NumPy prototyping to PyTorch training requires changing only the import path.</p>
<h2 id="related-work">Related Work</h2>
<p>This project builds on the foundational alignment algorithms described in these papers:</p>
<ul>
<li><a href="/notes/biology/computational-biology/kabsch-algorithm/">Kabsch (1976)</a>: the original SVD-based rotation alignment</li>
<li><a href="/notes/biology/computational-biology/arun-svd-point-fitting/">Arun et al. (1987)</a>: SVD formulation for 3D point set fitting</li>
<li><a href="/notes/biology/computational-biology/horn-absolute-orientation/">Horn (1987)</a>: quaternion-based closed-form absolute orientation</li>
<li><a href="/notes/biology/computational-biology/horn-orthonormal-matrices/">Horn et al. (1988)</a>: orthonormal matrix (polar decomposition) approach</li>
<li><a href="/notes/biology/computational-biology/umeyama-similarity-transformation/">Umeyama (1991)</a>: extension to include optimal scaling</li>
</ul>
<p>For a detailed walkthrough of the Kabsch algorithm with code examples, see the companion blog post: <a href="/posts/kabsch-algorithm/">The Kabsch Algorithm</a>.</p>
]]></content:encoded></item><item><title>Umeyama's Method: Corrected SVD for Point Alignment</title><link>https://hunterheidenreich.com/notes/biology/computational-biology/umeyama-similarity-transformation/</link><pubDate>Mon, 16 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/biology/computational-biology/umeyama-similarity-transformation/</guid><description>Umeyama (1991) fixes the SVD-based point set alignment method to always produce proper rotations, jointly solving for rotation, translation, and scale.</description><content:encoded><![CDATA[<h2 id="fixing-the-reflection-problem-in-svd-based-alignment">Fixing the Reflection Problem in SVD-Based Alignment</h2>
<p>This <strong>Method</strong> paper addresses a specific failure mode in prior SVD-based solutions to the point set registration problem. Both <a href="/notes/biology/computational-biology/arun-svd-point-fitting/">Arun et al. (1987)</a> and <a href="/notes/biology/computational-biology/horn-orthonormal-matrices/">Horn, Hilden, and Negahdaripour (1988)</a> presented SVD-based methods for finding the optimal rotation between two point patterns. (Note: this is a different paper from <a href="/notes/biology/computational-biology/horn-absolute-orientation/">Horn&rsquo;s 1987 quaternion method</a>, which does not suffer from this issue.) These SVD-based methods can produce a reflection ($\det(R) = -1$) instead of a proper rotation when the data is severely corrupted. Umeyama provides a corrected formulation that always yields a proper rotation matrix.</p>
<h2 id="the-similarity-transformation-problem">The Similarity Transformation Problem</h2>
<p>Given two point sets ${\mathbf{x}_i}$ and ${\mathbf{y}_i}$ ($i = 1, \ldots, n$) in $m$-dimensional space, find the similarity transformation parameters (rotation $R$, translation $\mathbf{t}$, and scale $c$) minimizing the mean squared error:</p>
<p>$$
e^2(R, \mathbf{t}, c) = \frac{1}{n} \sum_{i=1}^{n} \lVert \mathbf{y}_i - (cR\mathbf{x}_i + \mathbf{t}) \rVert^2
$$</p>
<p>This generalizes the <a href="/notes/biology/computational-biology/kabsch-algorithm/">Kabsch problem</a> (rotation only) and the <a href="/notes/biology/computational-biology/horn-absolute-orientation/">absolute orientation problem</a> (rotation + translation + scale) to arbitrary dimensions $m$.</p>
<h2 id="the-core-lemma-corrected-svd-rotation">The Core Lemma: Corrected SVD Rotation</h2>
<p>The key contribution is a lemma for finding the rotation $R$ minimizing $\lVert A - RB \rVert^2$. Given the SVD of $AB^T = UDV^T$ (with $d_1 \geq d_2 \geq \cdots \geq d_m \geq 0$), define the correction matrix:</p>
<p>$$
S = \begin{cases} I &amp; \text{if } \det(AB^T) \geq 0 \\ \operatorname{diag}(1, 1, \ldots, 1, -1) &amp; \text{if } \det(AB^T) &lt; 0 \end{cases}
$$</p>
<p>The minimum value is:</p>
<p>$$
\min_{R} \lVert A - RB \rVert^2 = \lVert A \rVert^2 + \lVert B \rVert^2 - 2\operatorname{tr}(DS)
$$</p>
<p>When $\operatorname{rank}(AB^T) \geq m - 1$, the optimal rotation is uniquely determined as:</p>
<p>$$
R = USV^T
$$</p>
<p>The critical insight is that when $\det(AB^T) = 0$ (i.e., $\operatorname{rank}(AB^T) = m - 1$), the matrix $S$ must instead be chosen based on $\det(U)\det(V)$:</p>
<p>$$
S = \begin{cases} I &amp; \text{if } \det(U)\det(V) = 1 \\ \operatorname{diag}(1, 1, \ldots, 1, -1) &amp; \text{if } \det(U)\det(V) = -1 \end{cases}
$$</p>
<p>This handles the degenerate case where the sign of $\det(AB^T)$ is unreliable.</p>
<h2 id="complete-similarity-transformation-solution">Complete Similarity Transformation Solution</h2>
<p>Umeyama derives the full solution using centered coordinates and the covariance matrix $\Sigma_{xy} = \frac{1}{n} \sum_i (\mathbf{y}_i - \boldsymbol{\mu}_y)(\mathbf{x}_i - \boldsymbol{\mu}_x)^T$.</p>
<p>Given the SVD $\Sigma_{xy} = UDV^T$:</p>
<p><strong>Rotation</strong>:</p>
<p>$$
R = USV^T
$$</p>
<p><strong>Scale</strong>:</p>
<p>$$
c = \frac{1}{\sigma_x^2} \operatorname{tr}(DS)
$$</p>
<p><strong>Translation</strong>:</p>
<p>$$
\mathbf{t} = \boldsymbol{\mu}_y - cR\boldsymbol{\mu}_x
$$</p>
<p><strong>Minimum error</strong>:</p>
<p>$$
\varepsilon^2 = \sigma_y^2 - \frac{\operatorname{tr}(DS)^2}{\sigma_x^2}
$$</p>
<p>where $\sigma_x^2$ and $\sigma_y^2$ are the variances of the respective point sets around their centroids.</p>
<h2 id="why-prior-methods-fail">Why Prior Methods Fail</h2>
<p>The methods of Arun et al. and Horn et al. use $R = UV^T$ directly from the SVD. This works when $\det(UV^T) = 1$ (proper rotation). When $\det(UV^T) = -1$, these methods either produce a reflection or apply an ad hoc correction (flipping the sign of the last column of $U$). Umeyama shows that the correct fix depends on $\det(\Sigma_{xy})$:</p>
<ul>
<li>If $\det(\Sigma_{xy}) \geq 0$: set $S = I$, so $R = UV^T$</li>
<li>If $\det(\Sigma_{xy}) &lt; 0$: set $S = \operatorname{diag}(1, \ldots, 1, -1)$, flipping the last singular value&rsquo;s contribution</li>
</ul>
<p>This distinction matters because corrupted data can make $\det(UV^T) = -1$ even when the true transformation is a proper rotation. Simply flipping a column of $U$ does not always yield the correct least-squares solution.</p>
<h2 id="generality">Generality</h2>
<p>The formulation works for any dimension $m$, covering both 2D and 3D registration problems. The proof uses Lagrange multipliers with explicit enforcement of both orthogonality ($R^T R = I$) and the proper rotation constraint ($\det(R) = 1$), which prior methods enforced only partially.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Umeyama, S. (1991). Least-squares estimation of transformation parameters between two point patterns. <em>IEEE Transactions on Pattern Analysis and Machine Intelligence</em>, 13(4), 376-380. <a href="https://doi.org/10.1109/34.88573">https://doi.org/10.1109/34.88573</a></p>
<p><strong>Publication</strong>: IEEE TPAMI, 1991</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="/posts/kabsch-algorithm/">Kabsch Algorithm: NumPy, PyTorch, TensorFlow, and JAX</a> (tutorial with implementations including the Kabsch-Umeyama scaling extension)</li>
<li><a href="/projects/kabsch-horn-cookbook/">Kabsch-Horn Cookbook</a> (a differentiable, gradient-safe implementation of Kabsch, Horn, and Umeyama alignment across NumPy, PyTorch, JAX, TensorFlow, and MLX)</li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{umeyama1991least,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Least-squares estimation of transformation parameters between two point patterns}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Umeyama, Shinji}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{IEEE Transactions on Pattern Analysis and Machine Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{13}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{376--380}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{1991}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{IEEE}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1109/34.88573}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Horn et al.: Absolute Orientation Using Orthonormal Matrices</title><link>https://hunterheidenreich.com/notes/biology/computational-biology/horn-orthonormal-matrices/</link><pubDate>Mon, 16 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/biology/computational-biology/horn-orthonormal-matrices/</guid><description>Horn, Hilden, and Negahdaripour (1988) solve absolute orientation using matrix square roots, providing an orthonormal matrix alternative to quaternions.</description><content:encoded><![CDATA[<h2 id="a-matrix-based-companion-to-the-quaternion-method">A Matrix-Based Companion to the Quaternion Method</h2>
<p>This <strong>Method</strong> paper presents a closed-form solution to the absolute orientation problem using $3 \times 3$ orthonormal matrices directly, complementing <a href="/notes/biology/computational-biology/horn-absolute-orientation/">Horn&rsquo;s earlier quaternion-based solution</a> (1987). The authors note that while quaternions are more elegant, orthonormal matrices are more widely used in photogrammetry, graphics, and robotics. The solution relies on the polar decomposition of the cross-covariance matrix via its matrix square root.</p>
<p>The paper also compares two approaches: (1) directly finding the best-fit orthonormal matrix (the main result), and (2) finding an unconstrained best-fit linear transformation and then projecting it onto the nearest orthonormal matrix. These give different results, and only the first approach has the desired symmetry property.</p>
<h2 id="the-rotation-via-polar-decomposition">The Rotation via Polar Decomposition</h2>
<p>As in the quaternion paper, the problem reduces to finding the orthonormal matrix $R$ maximizing $\operatorname{Tr}(R^T M)$, where $M = \sum_{i=1}^{n} \mathbf{r}&rsquo;_{r,i} (\mathbf{r}&rsquo;_{l,i})^T$ is the cross-covariance matrix of the centered point sets.</p>
<p>The key insight is the polar decomposition: any matrix $M$ can be written as:</p>
<p>$$
M = U S
$$</p>
<p>where $U$ is orthonormal and $S = (M^T M)^{1/2}$ is positive semidefinite. When $M$ is nonsingular:</p>
<p>$$
U = M (M^T M)^{-1/2}
$$</p>
<p>The matrix square root $(M^T M)^{1/2}$ is computed via eigendecomposition. If $M^T M$ has eigenvalues $\lambda_1, \lambda_2, \lambda_3$ and eigenvectors $\hat{\mathbf{u}}_1, \hat{\mathbf{u}}_2, \hat{\mathbf{u}}_3$:</p>
<p>$$
(M^T M)^{1/2} = \sqrt{\lambda_1} , \hat{\mathbf{u}}_1 \hat{\mathbf{u}}_1^T + \sqrt{\lambda_2} , \hat{\mathbf{u}}_2 \hat{\mathbf{u}}_2^T + \sqrt{\lambda_3} , \hat{\mathbf{u}}_3 \hat{\mathbf{u}}_3^T
$$</p>
<p>The sign of $\det(U)$ equals the sign of $\det(M)$, so $U$ is a proper rotation when $\det(M) &gt; 0$ and a reflection when $\det(M) &lt; 0$.</p>
<h2 id="handling-the-coplanar-case">Handling the Coplanar Case</h2>
<p>When one set of measurements is coplanar, $M$ is singular ($\operatorname{rank}(M) = 2$) and one eigenvalue of $M^T M$ is zero. The matrix square root still exists (positive semidefinite rather than positive definite), but $S$ is no longer invertible.</p>
<p>In this case, $U$ is determined only for two of its three columns. The third column (corresponding to the zero eigenvalue) is fixed by the orthonormality constraint, with a sign ambiguity resolved by requiring $\det(U) = +1$ (proper rotation).</p>
<h2 id="the-nearest-orthonormal-matrix-alternative-approach">The Nearest Orthonormal Matrix (Alternative Approach)</h2>
<p>The paper also derives a closed-form solution for finding the orthonormal matrix nearest to an arbitrary matrix $A$ (minimizing $\lVert A - R \rVert^2$). This uses the same polar decomposition machinery: if $A = U_A S_A$, then $U_A$ is the nearest orthonormal matrix.</p>
<p>This approach (find unconstrained best-fit transform, then project to nearest orthonormal matrix) was used by some earlier methods. Horn et al. show it gives a different result from the direct least-squares solution and lacks the symmetry property: the inverse transformation from right-to-left is generally not the exact inverse of the left-to-right solution.</p>
<h2 id="relationship-to-other-methods">Relationship to Other Methods</h2>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Rotation representation</th>
          <th>Core computation</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="/notes/biology/computational-biology/kabsch-algorithm/">Kabsch (1976)</a></td>
          <td>Orthogonal matrix</td>
          <td>Eigendecomposition of $\tilde{R}R$ ($3 \times 3$)</td>
      </tr>
      <tr>
          <td><a href="/notes/biology/computational-biology/horn-absolute-orientation/">Horn (1987)</a></td>
          <td>Unit quaternion</td>
          <td>Eigenvector of $N$ ($4 \times 4$)</td>
      </tr>
      <tr>
          <td>Horn et al. (1988)</td>
          <td>Orthonormal matrix</td>
          <td>Square root of $M^T M$ ($3 \times 3$)</td>
      </tr>
      <tr>
          <td><a href="/notes/biology/computational-biology/arun-svd-point-fitting/">Arun et al. (1987)</a></td>
          <td>Orthonormal matrix</td>
          <td>SVD of $H$ ($3 \times 3$)</td>
      </tr>
  </tbody>
</table>
<p>The polar decomposition approach (this paper) and the SVD approach (<a href="/notes/biology/computational-biology/arun-svd-point-fitting/">Arun et al.</a>) are closely related: the SVD $M = U \Lambda V^T$ gives the polar decomposition as $M = (UV^T)(V \Lambda V^T)$ where $UV^T$ is the orthonormal factor and $V \Lambda V^T$ is the positive semidefinite factor. Both methods can produce reflections under noisy data, which <a href="/notes/biology/computational-biology/umeyama-similarity-transformation/">Umeyama (1991)</a> later addressed.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Horn, B. K. P., Hilden, H. M., &amp; Negahdaripour, S. (1988). Closed-form solution of absolute orientation using orthonormal matrices. <em>Journal of the Optical Society of America A</em>, 5(7), 1127-1135. <a href="https://doi.org/10.1364/josaa.5.001127">https://doi.org/10.1364/josaa.5.001127</a></p>
<p><strong>Publication</strong>: Journal of the Optical Society of America A, 1988</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="/posts/kabsch-algorithm/">Kabsch Algorithm: NumPy, PyTorch, TensorFlow, and JAX</a> (tutorial with differentiable implementations)</li>
<li><a href="/projects/kabsch-horn-cookbook/">Kabsch-Horn Cookbook</a> (a differentiable, gradient-safe implementation of Kabsch, Horn, and Umeyama alignment across NumPy, PyTorch, JAX, TensorFlow, and MLX)</li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{horn1988closed,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Closed-form solution of absolute orientation using orthonormal matrices}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Horn, Berthold K. P. and Hilden, Hugh M. and Negahdaripour, Shahriar}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of the Optical Society of America A}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{7}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1127--1135}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{1988}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Optica Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1364/josaa.5.001127}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Arun et al.: SVD-Based Least-Squares Fitting of 3D Points</title><link>https://hunterheidenreich.com/notes/biology/computational-biology/arun-svd-point-fitting/</link><pubDate>Mon, 16 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/biology/computational-biology/arun-svd-point-fitting/</guid><description>Arun, Huang, and Blostein (1987) introduce an SVD-based algorithm for least-squares rotation and translation between two 3D point sets.</description><content:encoded><![CDATA[<h2 id="svd-for-3d-point-set-registration">SVD for 3D Point Set Registration</h2>
<p>This <strong>Method</strong> paper presents a concise algorithm for finding the least-squares rotation and translation between two 3D point sets using the singular value decomposition (SVD) of a $3 \times 3$ cross-covariance matrix. The approach is closely related to the earlier <a href="/notes/biology/computational-biology/kabsch-algorithm/">Kabsch algorithm</a> (1976), which used eigendecomposition, and was developed independently of <a href="/notes/biology/computational-biology/horn-absolute-orientation/">Horn&rsquo;s quaternion method</a> (1987). The paper also identifies a reflection degeneracy that <a href="/notes/biology/computational-biology/umeyama-similarity-transformation/">Umeyama</a> later provided a complete fix for.</p>
<h2 id="problem-formulation">Problem Formulation</h2>
<p>Given two 3D point sets ${p_i}$ and ${p&rsquo;_i}$ ($i = 1, \ldots, N$) related by:</p>
<p>$$
p&rsquo;_i = R p_i + T + N_i
$$</p>
<p>where $R$ is a rotation matrix, $T$ is a translation vector, and $N_i$ is noise, find $\hat{R}$ and $\hat{T}$ minimizing:</p>
<p>$$
\Sigma^2 = \sum_{i=1}^{N} \lVert p&rsquo;_i - (R p_i + T) \rVert^2
$$</p>
<h2 id="decoupling-translation-and-rotation">Decoupling Translation and Rotation</h2>
<p>The translation is eliminated by centering both point sets at their centroids $p$ and $p&rsquo;$. Defining centered coordinates $q_i = p_i - p$ and $q&rsquo;_i = p&rsquo;_i - p&rsquo;$, the problem reduces to:</p>
<p>$$
\Sigma^2 = \sum_{i=1}^{N} \lVert q&rsquo;_i - R q_i \rVert^2
$$</p>
<p>Once $\hat{R}$ is found, the translation follows as $\hat{T} = p&rsquo; - \hat{R} p$.</p>
<h2 id="the-svd-algorithm">The SVD Algorithm</h2>
<p>The algorithm proceeds in five steps:</p>
<ol>
<li>Center both point sets by subtracting centroids</li>
<li>Compute the $3 \times 3$ cross-covariance matrix: $H = \sum_{i=1}^{N} q_i q&rsquo;^t_i$</li>
<li>Compute the SVD: $H = U \Lambda V^t$</li>
<li>Form the candidate rotation: $X = V U^t$</li>
<li>Check $\det(X)$: if $+1$, then $\hat{R} = X$; if $-1$, the result is a reflection</li>
</ol>
<p>The key insight is that minimizing $\Sigma^2$ is equivalent to maximizing $\operatorname{Trace}(RH)$. Using a lemma based on the Cauchy-Schwarz inequality, Arun et al. show that $X = VU^t$ maximizes this trace over all orthonormal matrices.</p>
<h2 id="the-reflection-problem">The Reflection Problem</h2>
<p>When $\det(VU^t) = -1$, the SVD produces a reflection rather than a proper rotation. Arun et al. analyze three cases:</p>
<p><strong>Noiseless, non-coplanar points</strong>: The SVD always gives a proper rotation ($\det = +1$). No issue arises.</p>
<p><strong>Coplanar points</strong> (including $N = 3$): One singular value of $H$ is zero. Both a rotation and a reflection achieve $\Sigma^2 = 0$. The fix is to flip the sign of the column of $V$ corresponding to the zero singular value:</p>
<p>$$
V&rsquo; = [v_1, v_2, -v_3], \quad X&rsquo; = V&rsquo; U^t
$$</p>
<p><strong>Noisy, non-coplanar points with $\det = -1$</strong>: The paper acknowledges this case cannot be handled by the algorithm. The reflection genuinely minimizes $\Sigma^2$ over all orthonormal matrices, meaning no rotation achieves a lower error. The authors suggest this only occurs with very large noise and recommend RANSAC-like approaches.</p>
<p>This last case is precisely what <a href="/notes/biology/computational-biology/umeyama-similarity-transformation/">Umeyama (1991)</a> later resolved with a corrected formulation using a sign matrix $S$ conditioned on $\det(\Sigma_{xy})$.</p>
<h2 id="computational-comparison">Computational Comparison</h2>
<p>The paper includes VAX 11/780 benchmarks comparing three methods:</p>
<table>
  <thead>
      <tr>
          <th>Points</th>
          <th>SVD (ms)</th>
          <th>Quaternion (ms)</th>
          <th>Iterative (ms)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>3</td>
          <td>54.6</td>
          <td>26.6</td>
          <td>126.8</td>
      </tr>
      <tr>
          <td>11</td>
          <td>37.0</td>
          <td>41.0</td>
          <td>105.2</td>
      </tr>
      <tr>
          <td>30</td>
          <td>44.2</td>
          <td>48.3</td>
          <td>111.0</td>
      </tr>
  </tbody>
</table>
<p>The SVD and quaternion methods have comparable speed, both significantly faster than the iterative approach. SVD becomes faster than quaternion for larger point sets since its core computation operates on a $3 \times 3$ matrix regardless of $N$.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Arun, K. S., Huang, T. S., &amp; Blostein, S. D. (1987). Least-Squares Fitting of Two 3-D Point Sets. <em>IEEE Transactions on Pattern Analysis and Machine Intelligence</em>, PAMI-9(5), 698-700. <a href="https://doi.org/10.1109/TPAMI.1987.4767965">https://doi.org/10.1109/TPAMI.1987.4767965</a></p>
<p><strong>Publication</strong>: IEEE TPAMI, 1987</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="/posts/kabsch-algorithm/">Kabsch Algorithm: NumPy, PyTorch, TensorFlow, and JAX</a> (tutorial with differentiable implementations)</li>
<li><a href="/projects/kabsch-horn-cookbook/">Kabsch-Horn Cookbook</a> (a differentiable, gradient-safe implementation of Kabsch, Horn, and Umeyama alignment across NumPy, PyTorch, JAX, TensorFlow, and MLX)</li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{arun1987least,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Least-Squares Fitting of Two 3-D Point Sets}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Arun, K. S. and Huang, T. S. and Blostein, S. D.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{IEEE Transactions on Pattern Analysis and Machine Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{PAMI-9}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{698--700}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{1987}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{IEEE}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1109/TPAMI.1987.4767965}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Kabsch Algorithm: Optimal Rotation for Point Set Alignment</title><link>https://hunterheidenreich.com/notes/biology/computational-biology/kabsch-algorithm/</link><pubDate>Sun, 15 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/biology/computational-biology/kabsch-algorithm/</guid><description>Kabsch (1976) derives a closed-form solution for the optimal rotation aligning two weighted vector sets by minimizing squared deviations.</description><content:encoded><![CDATA[<h2 id="a-closed-form-solution-for-optimal-rotation">A Closed-Form Solution for Optimal Rotation</h2>
<p>This short communication presents a <strong>Method</strong> paper: a direct, analytical solution to a constrained optimization problem. Given two sets of vectors, Kabsch derives the orthogonal matrix (rotation) that best superimposes one set onto the other by minimizing a weighted sum of squared deviations. Prior approaches either solved an unconstrained problem and factorized the result (Diamond, 1976) or used iterative methods (McLachlan, 1972). Kabsch shows that a direct, non-iterative solution exists despite the non-linear nature of the orthogonality constraint.</p>
<h2 id="the-superposition-problem">The Superposition Problem</h2>
<p>The core problem arises frequently in crystallography and structural biology: given two sets of corresponding points (e.g., atomic coordinates from a known structure and experimentally measured coordinates), find the rigid rotation that best aligns them. Translations can be removed by centering both point sets at the origin, leaving only the rotational component.</p>
<p>Formally, given vector sets $\mathbf{x}_n$ and $\mathbf{y}_n$ ($n = 1, 2, \ldots, N$) with weights $w_n$, find the orthogonal matrix $\mathsf{U}$ minimizing:</p>
<p>$$
E = \frac{1}{2} \sum_{n} w_n (\mathsf{U} \mathbf{x}_n - \mathbf{y}_n)^2
$$</p>
<p>subject to orthogonality: $\tilde{\mathsf{U}} \mathsf{U} = \mathsf{I}$.</p>
<h2 id="derivation-via-lagrange-multipliers">Derivation via Lagrange Multipliers</h2>
<p>Kabsch introduces a symmetric matrix $\mathsf{L}$ of Lagrange multipliers to enforce orthogonality, forming the Lagrangian:</p>
<p>$$
G = E + \frac{1}{2} \sum_{i,j} l_{ij} \left( \sum_{k} u_{ki} u_{kj} - \delta_{ij} \right)
$$</p>
<p>Setting $\partial G / \partial u_{ij} = 0$ and defining two key matrices:</p>
<p>$$
r_{ij} = \sum_{n} w_n , y_{ni} , x_{nj} \qquad s_{ij} = \sum_{n} w_n , x_{ni} , x_{nj}
$$</p>
<p>where $\mathsf{R} = (r_{ij})$ is the weighted cross-covariance matrix and $\mathsf{S} = (s_{ij})$ is the weighted auto-covariance matrix, the stationarity condition becomes:</p>
<p>$$
\mathsf{U} \cdot (\mathsf{S} + \mathsf{L}) = \mathsf{R}
$$</p>
<h2 id="eigendecomposition-solution">Eigendecomposition Solution</h2>
<p>The key insight is that multiplying both sides by their transposes eliminates the unknown $\mathsf{U}$:</p>
<p>$$
(\mathsf{S} + \mathsf{L})(\mathsf{S} + \mathsf{L}) = \tilde{\mathsf{R}} \mathsf{R}
$$</p>
<p>Since $\tilde{\mathsf{R}} \mathsf{R}$ is symmetric positive definite, it has positive eigenvalues $\mu_k$ and eigenvectors $\mathbf{a}_k$. The matrix $\mathsf{S} + \mathsf{L}$ shares the same eigenvectors with eigenvalues $\sqrt{\mu_k}$.</p>
<p>From the eigenvectors $\mathbf{a}_k$, a second set of unit vectors $\mathbf{b}_k$ is defined:</p>
<p>$$
\mathbf{b}_k = \frac{1}{\sqrt{\mu_k}} \mathsf{R} , \mathbf{a}_k
$$</p>
<p>The optimal rotation matrix is then constructed directly:</p>
<p>$$
u_{ij} = \sum_{k} b_{ki} , a_{kj}
$$</p>
<h2 id="handling-degeneracies-and-generalizations">Handling Degeneracies and Generalizations</h2>
<p>Kabsch addresses two extensions:</p>
<ol>
<li>
<p><strong>Planar point sets</strong>: When all vectors lie in a plane, one eigenvalue of $\tilde{\mathsf{R}} \mathsf{R}$ is zero. The missing eigenvectors are recovered via cross products: $\mathbf{a}_3 = \mathbf{a}_1 \times \mathbf{a}_2$ and $\mathbf{b}_3 = \mathbf{b}_1 \times \mathbf{b}_2$.</p>
</li>
<li>
<p><strong>General metric constraints</strong>: The orthogonality constraint $\tilde{\mathsf{U}} \mathsf{U} = \mathsf{I}$ can be replaced by $\tilde{\mathsf{U}} \mathsf{U} = \mathsf{M}$ for any symmetric positive definite $\mathsf{M}$. By finding any specific solution $\mathsf{B}$ and transforming the input vectors as $\mathbf{x}&rsquo;_n = \mathsf{B} \mathbf{x}_n$, the problem reduces back to the standard orthogonal case.</p>
</li>
</ol>
<p>The method generalizes naturally to vector spaces of arbitrary dimension.</p>
<h2 id="legacy-and-impact">Legacy and Impact</h2>
<p>This two-page communication became one of the most cited papers in structural biology. The &ldquo;Kabsch algorithm&rdquo; (or &ldquo;Kabsch rotation&rdquo;) is the standard method for computing the root-mean-square deviation (RMSD) between two molecular structures after optimal superposition. It underpins structure comparison tools across crystallography, NMR spectroscopy, cryo-EM, and computational chemistry.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Kabsch, W. (1976). A solution for the best rotation to relate two sets of vectors. <em>Acta Crystallographica Section A</em>, 32(5), 922-923. <a href="https://doi.org/10.1107/s0567739476001873">https://doi.org/10.1107/s0567739476001873</a></p>
<p><strong>Publication</strong>: Acta Crystallographica Section A, 1976</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="/posts/kabsch-algorithm/">Kabsch Algorithm: NumPy, PyTorch, TensorFlow, and JAX</a> (tutorial with differentiable implementations)</li>
<li><a href="/projects/kabsch-horn-cookbook/">Kabsch-Horn Cookbook</a> (a differentiable, gradient-safe implementation of Kabsch, Horn, and Umeyama alignment across NumPy, PyTorch, JAX, TensorFlow, and MLX)</li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{kabsch1976solution,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{A solution for the best rotation to relate two sets of vectors}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Kabsch, Wolfgang}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Acta Crystallographica Section A: Crystal Physics, Diffraction, Theoretical and General Crystallography}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{32}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{922--923}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{1976}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{International Union of Crystallography}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1107/s0567739476001873}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Horn's Method: Absolute Orientation via Unit Quaternions</title><link>https://hunterheidenreich.com/notes/biology/computational-biology/horn-absolute-orientation/</link><pubDate>Sun, 15 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/biology/computational-biology/horn-absolute-orientation/</guid><description>Horn (1987) presents a closed-form quaternion solution for absolute orientation, finding optimal rotation, translation, and scale between two point sets.</description><content:encoded><![CDATA[<h2 id="a-quaternion-approach-to-point-set-registration">A Quaternion Approach to Point Set Registration</h2>
<p>This <strong>Method</strong> paper presents a closed-form solution to the absolute orientation problem: given corresponding points measured in two different coordinate systems, find the optimal rotation, translation, and scale that maps one set onto the other. While the <a href="/notes/biology/computational-biology/kabsch-algorithm/">Kabsch algorithm</a> (1976) solved the rotation subproblem via eigendecomposition of $\tilde{\mathsf{R}}\mathsf{R}$, Horn&rsquo;s approach uses unit quaternions to represent rotation, reducing the problem to finding the eigenvector of a $4 \times 4$ symmetric matrix associated with its largest eigenvalue.</p>
<h2 id="the-absolute-orientation-problem">The Absolute Orientation Problem</h2>
<p>Given $n$ point pairs ${\mathbf{r}_{l,i}}$ and ${\mathbf{r}_{r,i}}$ measured in &ldquo;left&rdquo; and &ldquo;right&rdquo; coordinate systems, find the transformation:</p>
<p>$$
\mathbf{r}_r = s , R(\mathbf{r}_l) + \mathbf{r}_0
$$</p>
<p>where $s$ is a scale factor, $R$ is a rotation, and $\mathbf{r}_0$ is a translation, minimizing the sum of squared residual errors:</p>
<p>$$
\sum_{i=1}^{n} \lVert \mathbf{r}_{r,i} - s , R(\mathbf{r}_{l,i}) - \mathbf{r}_0 \rVert^2
$$</p>
<p>Prior methods either used iterative numerical procedures or selectively discarded constraints (e.g., Thompson&rsquo;s and Schut&rsquo;s three-point methods). Horn derives a direct solution that uses all available information from all points simultaneously.</p>
<h2 id="decoupling-translation-scale-and-rotation">Decoupling Translation, Scale, and Rotation</h2>
<p>Horn shows that the three components of the transformation can be solved sequentially.</p>
<p><strong>Translation</strong>: After centering both point sets at their centroids ($\bar{\mathbf{r}}_l$ and $\bar{\mathbf{r}}_r$), the optimal translation is:</p>
<p>$$
\mathbf{r}_0 = \bar{\mathbf{r}}_r - s , R(\bar{\mathbf{r}}_l)
$$</p>
<p><strong>Scale</strong>: Horn derives three formulations (asymmetric left, asymmetric right, and symmetric). The symmetric version, which ensures the inverse transformation yields the reciprocal scale, is:</p>
<p>$$
s = \left( \frac{\sum_{i=1}^{n} \lVert \mathbf{r}&rsquo;_{r,i} \rVert^2}{\sum_{i=1}^{n} \lVert \mathbf{r}&rsquo;_{l,i} \rVert^2} \right)^{1/2}
$$</p>
<p>the ratio of root-mean-square deviations from the respective centroids.</p>
<p><strong>Rotation</strong>: After removing translation and scale, the remaining problem is to find the rotation $R$ that maximizes:</p>
<p>$$
\sum_{i=1}^{n} \mathbf{r}&rsquo;_{r,i} \cdot R(\mathbf{r}&rsquo;_{l,i})
$$</p>
<h2 id="the-quaternion-eigenvector-solution">The Quaternion Eigenvector Solution</h2>
<p>Horn represents rotation using unit quaternions $\dot{q} = q_0 + i q_x + j q_y + k q_z$ with $\lVert \dot{q} \rVert = 1$. A rotation acts on a vector (represented as a purely imaginary quaternion $\dot{r}$) via the composite product:</p>
<p>$$
\dot{r}&rsquo; = \dot{q} , \dot{r} , \dot{q}^*
$$</p>
<p>Using the $4 \times 4$ matrix representations of quaternion products, the objective function becomes a quadratic form:</p>
<p>$$
\dot{q}^T N \dot{q}
$$</p>
<p>where $N$ is a real symmetric $4 \times 4$ matrix whose elements are combinations of the sums of products $S_{xx}, S_{xy}, \ldots, S_{zz}$ from the $3 \times 3$ cross-covariance matrix $M = \sum_i \mathbf{r}&rsquo;_{l,i} \mathbf{r}&rsquo;^T_{r,i}$:</p>
<p>$$
N = \begin{bmatrix} (S_{xx} + S_{yy} + S_{zz}) &amp; S_{yz} - S_{zy} &amp; S_{zx} - S_{xz} &amp; S_{xy} - S_{yx} \\ S_{yz} - S_{zy} &amp; (S_{xx} - S_{yy} - S_{zz}) &amp; S_{xy} + S_{yx} &amp; S_{zx} + S_{xz} \\ S_{zx} - S_{xz} &amp; S_{xy} + S_{yx} &amp; (-S_{xx} + S_{yy} - S_{zz}) &amp; S_{yz} + S_{zy} \\ S_{xy} - S_{yx} &amp; S_{zx} + S_{xz} &amp; S_{yz} + S_{zy} &amp; (-S_{xx} - S_{yy} + S_{zz}) \end{bmatrix}
$$</p>
<p>The trace of $N$ is always zero. The unit quaternion maximizing $\dot{q}^T N \dot{q}$ is the eigenvector corresponding to the most positive eigenvalue of $N$.</p>
<h2 id="the-characteristic-polynomial">The Characteristic Polynomial</h2>
<p>The eigenvalues satisfy a quartic $\lambda^4 + c_3 \lambda^3 + c_2 \lambda^2 + c_1 \lambda + c_0 = 0$ where:</p>
<ul>
<li>$c_3 = 0$ (trace of $N$ is zero, so the four roots sum to zero)</li>
<li>$c_2 = -2 \operatorname{Tr}(M^T M)$ (always negative, guaranteeing both positive and negative roots)</li>
<li>$c_1 = -8 \det(M)$</li>
<li>$c_0 = \det(N)$</li>
</ul>
<p>When points are coplanar (including the common case of exactly three points), $\det(M) = 0$, so $c_1 = 0$ and the quartic reduces to a biquadratic solvable in closed form.</p>
<h2 id="coplanar-points-and-the-three-point-case">Coplanar Points and the Three-Point Case</h2>
<p>For coplanar measurements, the quartic simplifies to $\lambda^4 + c_2 \lambda^2 + c_0 = 0$, yielding:</p>
<p>$$
\lambda_m = \left[ \frac{1}{2} \left( (c_2^2 - 4c_0)^{1/2} - c_2 \right) \right]^{1/2}
$$</p>
<p>Horn also provides a geometric interpretation for the coplanar case: first rotate one plane into the other (about their line of intersection), then solve a 2D least-squares rotation within the shared plane.</p>
<h2 id="comparison-with-the-kabsch-algorithm">Comparison with the Kabsch Algorithm</h2>
<p>Both methods solve the same underlying optimization problem but approach it differently:</p>
<table>
  <thead>
      <tr>
          <th>Aspect</th>
          <th>Kabsch (1976)</th>
          <th>Horn (1987)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Rotation representation</td>
          <td>Orthogonal matrix</td>
          <td>Unit quaternion</td>
      </tr>
      <tr>
          <td>Core computation</td>
          <td>SVD or eigendecomposition of $\tilde{R}R$ ($3 \times 3$)</td>
          <td>Eigenvector of $N$ ($4 \times 4$)</td>
      </tr>
      <tr>
          <td>Scale estimation</td>
          <td>Not addressed</td>
          <td>Three formulations (including symmetric)</td>
      </tr>
      <tr>
          <td>Constraint enforcement</td>
          <td>Lagrange multipliers</td>
          <td>Unit quaternion norm</td>
      </tr>
      <tr>
          <td>Symmetry guarantee</td>
          <td>Not addressed</td>
          <td>Proven for symmetric scale</td>
      </tr>
      <tr>
          <td>Degenerate cases</td>
          <td>Cross-product fallback</td>
          <td>Biquadratic closed form</td>
      </tr>
  </tbody>
</table>
<p>Horn emphasizes a symmetry property: the inverse transformation should yield exactly the inverse parameters. This holds automatically for the quaternion rotation but requires a specific (symmetric) choice of scale formula.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Horn, B. K. P. (1987). Closed-form solution of absolute orientation using unit quaternions. <em>Journal of the Optical Society of America A</em>, 4(4), 629-642. <a href="https://doi.org/10.1364/JOSAA.4.000629">https://doi.org/10.1364/JOSAA.4.000629</a></p>
<p><strong>Publication</strong>: Journal of the Optical Society of America A, 1987</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="/posts/kabsch-algorithm/">Kabsch Algorithm: NumPy, PyTorch, TensorFlow, and JAX</a> (tutorial with differentiable implementations of the related SVD-based method)</li>
<li><a href="/projects/kabsch-horn-cookbook/">Kabsch-Horn Cookbook</a> (a differentiable, gradient-safe implementation of Kabsch, Horn, and Umeyama alignment across NumPy, PyTorch, JAX, TensorFlow, and MLX)</li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{horn1987closed,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Closed-form solution of absolute orientation using unit quaternions}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Horn, Berthold K. P.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of the Optical Society of America A}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{629--642}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{1987}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Optica Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1364/josaa.4.000629}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Consistency Models: Fast One-Step Diffusion Generation</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/consistency-models/</link><pubDate>Sun, 15 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/consistency-models/</guid><description>Consistency models enable one-step generation by learning to map any point on a diffusion ODE trajectory to its origin, achieving FID 3.55 on CIFAR-10.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper. It proposes consistency models, a new class of generative models designed for fast one-step (or few-step) generation. The models can be trained either by distilling pretrained diffusion models (consistency distillation) or as standalone generative models from scratch (consistency training). The paper provides theoretical analysis of both training modes and achieves FID 3.55 on CIFAR-10 for single-step non-adversarial generation (state of the art at the time of publication).</p>
<h2 id="the-slow-sampling-problem-in-diffusion">The Slow Sampling Problem in Diffusion</h2>
<p>Diffusion models produce high-quality samples but require iterating through many denoising steps (often tens to hundreds), making generation slow compared to GANs or VAEs. Previous approaches to speed up sampling include faster ODE/SDE solvers (DDIM, DPM-Solver) and progressive distillation. These either still require multiple steps or depend on a complex multi-stage distillation pipeline. The goal is a model that can generate high-quality samples in a single forward pass while optionally allowing more steps for better quality.</p>
<h2 id="core-innovation-the-self-consistency-property">Core Innovation: The Self-Consistency Property</h2>
<p>The key idea builds on the Probability Flow (PF) ODE from the score-based SDE framework. The PF ODE describes a deterministic trajectory that converts noise into data, governed by the learned score function. For the VE-SDE parameterization used by EDM (Karras et al., 2022), this takes the form:</p>
<p>$$\frac{d\mathbf{x}_t}{dt} = -t , s_\phi(\mathbf{x}_t, t)$$</p>
<p>where $s_\phi$ is a pretrained score model, a <strong>consistency function</strong> $f(\mathbf{x}_t, t)$ maps any point on an ODE trajectory to the trajectory&rsquo;s origin $\mathbf{x}_\epsilon$. The defining property is self-consistency:</p>
<p>$$f(\mathbf{x}_t, t) = f(\mathbf{x}_{t&rsquo;}, t&rsquo;) \quad \text{for all } t, t&rsquo; \in [\epsilon, T]$$</p>
<p>for any points $\mathbf{x}_t$ and $\mathbf{x}_{t&rsquo;}$ on the same PF ODE trajectory.</p>
<p><strong>Parameterization.</strong> The model enforces the boundary condition $f(\mathbf{x}_\epsilon, \epsilon) = \mathbf{x}_\epsilon$ using skip connections:</p>
<p>$$f_\theta(\mathbf{x}, t) = c_{\text{skip}}(t) , \mathbf{x} + c_{\text{out}}(t) , F_\theta(\mathbf{x}, t)$$</p>
<p>where $c_{\text{skip}}(\epsilon) = 1$ and $c_{\text{out}}(\epsilon) = 0$, ensuring the boundary condition is satisfied by construction.</p>
<p><strong>Consistency Distillation (CD).</strong> Given a pretrained diffusion model, CD trains a consistency model by enforcing self-consistency between adjacent timesteps:</p>
<p>$$\mathcal{L}_{\text{CD}}^N(\theta, \theta^-; \phi) = \mathbb{E}\left[\lambda(t_n) , d!\left(f_\theta(\mathbf{x}_{t_{n+1}}, t_{n+1}), , f_{\theta^-}(\hat{\mathbf{x}}_{t_n}^\phi, t_n)\right)\right]$$</p>
<p>where $\hat{\mathbf{x}}_{t_n}^\phi$ is obtained by running one step of the ODE solver using the pretrained score model, $\theta^-$ is an exponential moving average (EMA) of $\theta$, and $d(\cdot, \cdot)$ is a distance metric. The use of a target network $\theta^-$ (updated via EMA) parallels techniques from deep Q-learning and momentum contrastive learning.</p>
<p><strong>Consistency Training (CT).</strong> CT eliminates the need for a pretrained diffusion model. It replaces the ODE solver step with a score estimate derived from the denoising score matching identity:</p>
<p>$$\nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t) = \mathbb{E}\left[\frac{\mathbf{x} - \mathbf{x}_t}{t^2} ,\middle|, \mathbf{x}_t\right]$$</p>
<p>Because this identity lets us estimate the score from noisy data alone (without a pretrained model), we can compute the ODE update directly from training samples. This allows training directly on data pairs $(\mathbf{x}, \mathbf{x} + t\mathbf{z})$ where $\mathbf{z} \sim \mathcal{N}(0, I)$.</p>
<p><strong>Theoretical guarantee.</strong> If CD achieves zero loss, the consistency model error is bounded by $O((\Delta t)^p)$ where $\Delta t$ is the maximum timestep gap and $p$ is the order of the ODE solver.</p>
<h2 id="experiments-and-benchmarks">Experiments and Benchmarks</h2>
<p><strong>Datasets:</strong> CIFAR-10 (32x32), ImageNet 64x64, LSUN Bedroom 256x256, LSUN Cat 256x256.</p>
<p><strong>Architecture:</strong> All models use the NCSN++/EDM architecture. CD distills from pretrained EDM models.</p>
<p><strong>Key results for consistency distillation (CD):</strong></p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Steps</th>
          <th>FID</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CIFAR-10</td>
          <td>1</td>
          <td>3.55</td>
      </tr>
      <tr>
          <td>CIFAR-10</td>
          <td>2</td>
          <td>2.93</td>
      </tr>
      <tr>
          <td>ImageNet 64x64</td>
          <td>1</td>
          <td>6.20</td>
      </tr>
      <tr>
          <td>ImageNet 64x64</td>
          <td>2</td>
          <td>4.70</td>
      </tr>
      <tr>
          <td>LSUN Bedroom 256</td>
          <td>1</td>
          <td>7.80</td>
      </tr>
      <tr>
          <td>LSUN Bedroom 256</td>
          <td>2</td>
          <td>5.22</td>
      </tr>
      <tr>
          <td>LSUN Cat 256</td>
          <td>1</td>
          <td>11.0</td>
      </tr>
      <tr>
          <td>LSUN Cat 256</td>
          <td>2</td>
          <td>8.84</td>
      </tr>
  </tbody>
</table>
<p>CD outperforms progressive distillation (PD) across all datasets and sampling steps, with the exception of single-step generation on Bedroom 256x256 where CD with $\ell_2$ slightly underperforms PD with $\ell_2$.</p>
<p><strong>Key results for consistency training (CT):</strong></p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Steps</th>
          <th>FID</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CIFAR-10</td>
          <td>1</td>
          <td>8.70</td>
      </tr>
      <tr>
          <td>CIFAR-10</td>
          <td>2</td>
          <td>5.83</td>
      </tr>
      <tr>
          <td>ImageNet 64x64</td>
          <td>1</td>
          <td>13.0</td>
      </tr>
      <tr>
          <td>ImageNet 64x64</td>
          <td>2</td>
          <td>11.1</td>
      </tr>
      <tr>
          <td>LSUN Bedroom 256</td>
          <td>1</td>
          <td>16.0</td>
      </tr>
      <tr>
          <td>LSUN Cat 256</td>
          <td>1</td>
          <td>20.7</td>
      </tr>
  </tbody>
</table>
<p>CT outperforms existing single-step non-adversarial models (VAEs, normalizing flows), e.g., improving over DC-VAE&rsquo;s FID of 17.90 on CIFAR-10. Samples from CT share structural similarity with EDM samples from the same initial noise, suggesting CT does not suffer from mode collapse.</p>
<p><strong>Zero-shot editing:</strong> Consistency models support colorization, super-resolution, inpainting, stroke-guided generation, interpolation, and denoising at test time without task-specific training, by modifying the multi-step sampling algorithm.</p>
<h2 id="findings-and-limitations">Findings and Limitations</h2>
<ul>
<li>Consistency distillation achieves state-of-the-art FID for one-step generation (3.55 on CIFAR-10, 6.20 on ImageNet 64x64).</li>
<li>Multi-step sampling provides a smooth quality-compute tradeoff: more steps yield better FID.</li>
<li>CT produces competitive results without any pretrained diffusion model, making consistency models a standalone generative model family.</li>
<li>The LPIPS distance metric $d(\cdot, \cdot)$ generally outperforms $\ell_1$ and $\ell_2$ for training consistency models.</li>
<li>At higher resolutions (LSUN 256x256), the gap between CD/CT and full EDM sampling widens.</li>
<li>CT currently underperforms CD, suggesting room for improvement in the standalone training paradigm.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Primary benchmark</td>
          <td>CIFAR-10</td>
          <td>32x32, 50K train</td>
          <td>FID on 50K samples</td>
      </tr>
      <tr>
          <td>Scaling benchmark</td>
          <td>ImageNet 64x64</td>
          <td>64x64, 1.28M</td>
          <td>Unconditional generation</td>
      </tr>
      <tr>
          <td>High-res benchmark</td>
          <td>LSUN Bedroom, Cat</td>
          <td>256x256</td>
          <td>Unconditional generation</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>ODE solver for CD</strong>: Euler and Heun (2nd order) solvers on the empirical PF ODE</li>
<li><strong>EMA for target network</strong>: Decay rate $\mu$ scheduled as a function of training step</li>
<li><strong>Schedule functions</strong>: $N$ (number of discretization steps) and $\mu$ (EMA rate) increase over training following specific schedules (see Appendix C of the paper)</li>
<li><strong>Distance metric</strong>: LPIPS performs best; $\ell_2$ and $\ell_1$ also evaluated</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: NCSN++/EDM architecture from Karras et al. (2022)</li>
<li><strong>CD teacher</strong>: Pretrained EDM models</li>
<li><strong>Parameterization</strong>: Skip-connection formulation with $c_{\text{skip}}(t)$ and $c_{\text{out}}(t)$ from EDM</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Dataset</th>
          <th>CD 1-step</th>
          <th>CT 1-step</th>
          <th>EDM (full)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FID</td>
          <td>CIFAR-10</td>
          <td>3.55</td>
          <td>8.70</td>
          <td>2.04</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>ImageNet 64</td>
          <td>6.20</td>
          <td>13.0</td>
          <td>2.44</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>LSUN Bedroom</td>
          <td>7.80</td>
          <td>16.0</td>
          <td>3.57</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>LSUN Cat</td>
          <td>11.0</td>
          <td>20.7</td>
          <td>6.69</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Training details follow EDM conventions</li>
<li>CD and CT use the same batch sizes and learning rate schedules as EDM training</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/openai/consistency_models">openai/consistency_models</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation with pretrained checkpoints</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Song, Y., Dhariwal, P., Chen, M., &amp; Sutskever, I. (2023). Consistency Models. <em>ICML 2023</em>. <a href="https://arxiv.org/abs/2303.01469">https://arxiv.org/abs/2303.01469</a></p>
<p><strong>Publication</strong>: ICML 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{song2023consistency,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>     = <span style="color:#e6db74">{Consistency Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>    = <span style="color:#e6db74">{Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>    = <span style="color:#e6db74">{202}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>      = <span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>       = <span style="color:#e6db74">{https://arxiv.org/abs/2303.01469}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/openai/consistency_models">GitHub Repository</a></li>
<li><a href="/notes/machine-learning/generative-models/score-based-generative-modeling-sde/">Score-Based Generative Modeling with SDEs</a></li>
</ul>
]]></content:encoded></item><item><title>Score Matching and Denoising Autoencoders: A Connection</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-matching-denoising-autoencoders/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-matching-denoising-autoencoders/</guid><description>Theoretical paper proving the equivalence between training Denoising Autoencoders and performing Score Matching on a Parzen density estimator.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Theory Paper</strong>.</p>
<p>Its primary contribution is a formal mathematical derivation connecting two previously distinct techniques: Score Matching (SM) and Denoising Autoencoders (DAE). It provides the &ldquo;why&rdquo; behind the empirical success of DAEs by grounding them in the probabilistic framework of energy-based models. It relies on proofs and equivalence relations (e.g., $J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}}$).</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The paper bridges a gap between two successful but disconnected approaches in unsupervised learning:</p>
<ol>
<li><strong>Denoising Autoencoders (DAE):</strong> Empirically successful for pre-training deep networks. They previously lacked a clear probabilistic interpretation.</li>
<li><strong>Score Matching (SM):</strong> A theoretically sound method for estimating unnormalized density models that avoids the partition function problem but requires computing expensive second derivatives.</li>
</ol>
<p>By connecting them, the authors aim to define a proper probabilistic model for DAEs (allowing sampling/ranking) and find a simpler way to apply score matching that avoids second derivatives.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>Denoising Score Matching (DSM)</strong> framework and the proof of its equivalence to DAEs. Key contributions include:</p>
<ul>
<li><strong>Equivalence Proof:</strong> Showing that training a DAE with Gaussian noise is equivalent to matching the score of a model against a non-parametric Parzen density estimator of the data.</li>
<li><strong>Denoising Score Matching ($J_{DSM}$):</strong> A new objective that learns a score function by trying to denoise corrupted samples. This avoids the explicit second derivatives required by standard Implicit Score Matching ($J_{ISM}$).</li>
<li><strong>Explicit Energy Function:</strong> Deriving the specific energy function $E(x;\theta)$ that corresponds to the standard sigmoid DAE architecture.</li>
<li><strong>Justification for Tied Weights:</strong> Providing a theoretical justification for tying encoder and decoder weights, which arises naturally from differentiating the energy function.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The validation in this theoretical paper is purely mathematical and focuses on formal proofs:</p>
<ul>
<li><strong>Derivation of Equivalence:</strong> The paper formally proves the chain of equivalences:
$$J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$$
where $q_{\sigma}$ is the Parzen density estimate.</li>
<li><strong>Appendix Proof:</strong> A detailed proof is provided to show that Explicit Score Matching ($J_{ESM}$) on the Parzen density is equivalent to the proposed Denoising Score Matching ($J_{DSM}$) objective.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Theoretical Unification:</strong> DAE training is formally equivalent to Score Matching on a smoothed data distribution ($q_{\sigma}$).</li>
<li><strong>New Training Objective:</strong> The $J_{DSM}$ objective offers a computationally efficient way to perform score matching (no Hessian required) by using a denoising objective.</li>
<li><strong>Probabilistic Interpretation:</strong> DAEs can now be understood as Energy-Based Models (EBMs), allowing for operations like sampling (via Hybrid Monte Carlo) and likelihood ranking, which were previously ill-defined for standard autoencoders.</li>
<li><strong>Regularization Insight:</strong> The smoothing kernel width $\sigma$ in the Parzen estimator corresponds to the noise level in the DAE. This suggests that DAEs are learning a regularized version of the score, which may explain their robustness.</li>
<li><strong>Connection to Regularized Score Matching:</strong> The paper notes that Kingma and LeCun (2010) independently proposed a regularized score matching criterion $J_{ISMreg}$ derived by approximating $J_{ISMq_{\sigma}}$. The four $q_{\sigma}$-based objectives in this work (including the DAE objective) can be seen as approximation-free forms of regularized score matching, with the additional advantage that $J_{DSMq_{\sigma}}$ does not require second derivatives.</li>
</ul>
<hr>
<h2 id="key-concepts-explained">Key Concepts Explained</h2>
<h3 id="1-score-and-score-matching">1. &ldquo;Score&rdquo; and &ldquo;Score Matching&rdquo;</h3>
<p><strong>What does &ldquo;score&rdquo; actually mean?</strong></p>
<p>In this paper (and probabilistic modeling generally), the <strong>score</strong> is the gradient of the log-density with respect to the <em>data vector</em> $x$.</p>
<ul>
<li><strong>Definition:</strong> $\psi(x) = \nabla_x \log p(x)$.</li>
<li><strong>Intuition:</strong> It is a vector field pointing in the direction of highest probability increase. Crucially, calculating the score avoids the intractable partition function $Z$, because $\nabla_x \log p(x) = \nabla_x \log \tilde{p}(x) - \nabla_x \log Z = \nabla_x \log \tilde{p}(x)$. The constant $Z$ vanishes upon differentiation.</li>
</ul>
<p><strong>What is Score Matching?</strong></p>
<p>Score Matching is a training objective for unnormalized models. It minimizes the squared Euclidean distance between the model&rsquo;s score $\psi(x;\theta)$ and the data&rsquo;s true score $\nabla_x \log q(x)$.</p>
<h3 id="2-the-parzen-density-estimator">2. The Parzen Density Estimator</h3>
<p><strong>What is it?</strong></p>
<p>It is a non-parametric method for estimating a probability density function from finite data. It places a smooth kernel (here, a Gaussian) centered at every data point in the training set $D_n$.</p>
<ul>
<li><strong>Formula:</strong> $q_{\sigma}(\tilde{x}) = \frac{1}{n} \sum_{t=1}^n \mathcal{N}(\tilde{x}; x^{(t)}, \sigma^2 I)$.</li>
</ul>
<p><strong>Why smooth the data?</strong></p>
<ol>
<li>
<p><strong>To define the score:</strong> The empirical data distribution is a set of Dirac deltas (spikes). The gradient (score) of a Dirac delta is undefined. Smoothing creates a differentiable surface, allowing a valid target score $\nabla_{\tilde{x}} \log q_{\sigma}(\tilde{x})$ to be computed.</p>
</li>
<li>
<p><strong>To model corruption:</strong> The Parzen estimator with Gaussian kernels mathematically models the process of taking a clean data point $x$ and adding Gaussian noise - the exact procedure used in Denoising Autoencoders.</p>
</li>
</ol>
<h3 id="3-why-avoiding-second-derivatives-matters">3. Why avoiding second derivatives matters</h3>
<p>Standard <strong>Implicit Score Matching (ISM)</strong> eliminates the need for the unknown data score, but introduces a new cost: it requires computing the trace of the Hessian (the sum of second partial derivatives) of the log-density.</p>
<ul>
<li><strong>The Cost:</strong> For high-dimensional data (like images) and deep networks, computing second derivatives of the log-density is computationally expensive.</li>
<li>This paper shows that <strong>Denoising Score Matching (DSM)</strong> allows you to bypass Hessian computation entirely. By using the Parzen target, the objective simplifies to matching a first-order vector, making it scalable to deep neural networks.</li>
</ul>
<h3 id="4-the-equivalence-chain---why-each-step">4. The equivalence chain - why each step?</h3>
<p>The chain $J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$ connects the concepts.</p>
<ul>
<li>
<p><strong>$J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}}$ (Implicit $\to$ Explicit):</strong>
<strong>Why:</strong> Integration by parts. This is Hyvärinen&rsquo;s original proof (2005): integration by parts moves the derivative from $\psi$ onto the data density $q$, producing a term involving $q$&rsquo;s gradient (the score). The boundary term vanishes because $q_{\sigma}$ decays to zero at infinity (Hyvärinen&rsquo;s 2005 regularity condition for Implicit Score Matching). The result allows replacing the unknown data score with a computable term involving only the model&rsquo;s score and its Jacobian.</p>
</li>
<li>
<p><strong>$J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}}$ (Explicit $\to$ Denoising):</strong>
<strong>Why:</strong> The explicit score of the Parzen density is known. When $x$ is perturbed to $\tilde{x}$ by Gaussian noise $\epsilon \sim \mathcal{N}(0, \sigma^2 I)$, the gradient of the log-density pointing back to the mean is exactly $\frac{1}{\sigma^2}(x - \tilde{x})$. Minimizing the error against the true score becomes minimizing the error against this restoration vector.</p>
</li>
<li>
<p><strong>$J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$ (Denoising $\to$ Autoencoder):</strong>
<strong>Why:</strong> Algebraic substitution. If you define the model&rsquo;s score $\psi(\tilde{x};\theta)$ to be proportional to the reconstruction error ($\propto x^r - \tilde{x}$), the score matching loss $J_{DSM}$ becomes proportional to the standard autoencoder squared loss $|x^r - x|^2$.</p>
</li>
</ul>
<h3 id="5-energy-based-models-ebms-connection">5. Energy-Based Models (EBMs) connection</h3>
<p><strong>What is an EBM?</strong></p>
<p>An EBM defines a probability distribution via an energy function $E(x;\theta)$, where $p(x;\theta) \propto e^{-E(x;\theta)}$.</p>
<p><strong>Why standard autoencoders lack probabilistic interpretation:</strong></p>
<p>A standard autoencoder acts as a deterministic map $x \to x^r$, providing a reconstruction error. It lacks a normalization constant or a defined density function to support sampling or probability queries.</p>
<p><strong>What does this enable?</strong></p>
<p>By proving the equivalence, the DAE is formally defined as an EBM. This enables:</p>
<ol>
<li><strong>Sampling:</strong> Using MCMC methods (like Hybrid Monte Carlo) to generate new data from the DAE.</li>
<li><strong>Ranking:</strong> Calculating the energy of inputs to determine which are more &ldquo;likely&rdquo; or &ldquo;normal&rdquo; (useful for anomaly detection).</li>
</ol>
<h3 id="6-the-specific-energy-function-form">6. The specific energy function form</h3>
<p>The function is:</p>
<p>$$E(x; W, b, c) = - \frac{1}{\sigma^2} \left( \langle c, x \rangle - \frac{1}{2}|x|^2 + \sum_{j=1}^{d_h} \text{softplus}(\langle W_j, x \rangle + b_j) \right)$$</p>
<p><strong>Why does it have that specific form?</strong></p>
<p>It was derived via integration to ensure its derivative matches the DAE architecture. The authors worked backward from the DAE&rsquo;s reconstruction function (sigmoid + linear) to find the scalar field that generates it.</p>
<p><strong>Where does the quadratic term come from?</strong></p>
<p>The score (negative energy gradient) needs to look like $\psi(x) \propto c - x + W^T\text{sigmoid}(Wx + b)$.</p>
<ul>
<li>The term $-x$ in the score arises because $\nabla_x(-\frac{1}{2}|x|^2) = -x$. Including $-\frac{1}{2}|x|^2$ inside the energy&rsquo;s numerator produces this linear term after differentiation.</li>
</ul>
<p><strong>How does differentiating it recover the DAE reconstruction?</strong></p>
<ul>
<li>$\nabla_x \sum_j \text{softplus}(\langle W_j, x \rangle + b_j) = W^T \sigma(Wx + b)$ (The encoder part).</li>
<li>$\nabla_x \langle c, x \rangle = c$ (The bias).</li>
<li>$\nabla_x (-\frac{1}{2}|x|^2) = -x$ (The input subtraction).</li>
<li>Result: $-\nabla_x E \propto c + W^T h - x = x^r - x$.</li>
</ul>
<h3 id="7-tied-weights-justification">7. &ldquo;Tied weights&rdquo; justification</h3>
<p><strong>What does it mean for weights to be &ldquo;tied&rdquo;?</strong></p>
<p>The decoder matrix is the transpose of the encoder matrix ($W^T$).</p>
<p><strong>Why is this theoretically justified?</strong></p>
<p>Because the reconstruction function is interpreted as the <strong>gradient</strong> of an energy function. A vector field can only be the gradient of a scalar field if its Jacobian is symmetric.</p>
<ul>
<li>In the DAE energy derivative, the encoder contributes $W^T \sigma(Wx + b)$. If the decoder used a separate matrix $U$, the resulting vector field would not be a valid gradient of any scalar energy function (unless $U = W^T$).</li>
<li>Therefore, for a DAE to correspond to a valid probabilistic Energy-Based Model, the weights <em>must</em> be tied.</li>
</ul>
<p><strong>The necessity of tied weights:</strong></p>
<p>Within this parametrization, tied weights are a mathematical necessity: a separate decoder matrix $U \neq W^T$ would make the reconstruction function an invalid gradient of any scalar energy, breaking the EBM correspondence.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>Since this is a theoretical paper, the &ldquo;reproducibility&rdquo; lies in the mathematical formulations derived.</p>
<h3 id="data">Data</h3>
<ul>
<li><strong>Input Data ($D_n$):</strong> The theory assumes a set of training examples $D_n = {x^{(1)}, &hellip;, x^{(n)}}$ drawn from an unknown true pdf $q(x)$.</li>
<li><strong>Parzen Density Estimate ($q_{\sigma}$):</strong> The theoretical targets are derived from a kernel-smoothed empirical distribution:
$$q_{\sigma}(\tilde{x}) = \frac{1}{n} \sum_{t=1}^n q_{\sigma}(\tilde{x}|x^{(t)})$$
where the kernel is an isotropic Gaussian of variance $\sigma^2$.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Denoising Score Matching (DSM) Objective</strong></p>
<p>The paper proposes this objective as a tractable alternative to standard score matching. It minimizes the distance between the model score and the gradient of the log-noise density:</p>
<p>$$J_{DSMq_{\sigma}}(\theta) = \mathbb{E}_{q_{\sigma}(x,\tilde{x})} \left[ \frac{1}{2} \left| \psi(\tilde{x};\theta) - \frac{\partial \log q_{\sigma}(\tilde{x}|x)}{\partial \tilde{x}} \right|^2 \right]$$</p>
<p>For Gaussian noise, the target score is simply $\frac{1}{\sigma^2}(x - \tilde{x})$.</p>
<p><strong>2. Equivalence Chain</strong></p>
<p>The central result connects four objectives:</p>
<p>$$J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$$</p>
<p>This implies optimizing the DAE reconstruction error is minimizing a score matching objective.</p>
<h3 id="models">Models</h3>
<p><strong>1. The Denoising Autoencoder (DAE)</strong></p>
<ul>
<li><strong>Corruption:</strong> Additive isotropic Gaussian noise $\tilde{x} = x + \epsilon, \epsilon \sim \mathcal{N}(0, \sigma^2 I)$.</li>
<li><strong>Encoder:</strong> $h = \text{sigmoid}(W\tilde{x} + b)$.</li>
<li><strong>Decoder:</strong> $x^r = W^T h + c$ (Tied weights $W$).</li>
<li><strong>Loss:</strong> Squared reconstruction error $|x^r - x|^2$. (The equivalence with DSM introduces a $\frac{1}{2\sigma^4}$ scaling factor.)</li>
</ul>
<p><strong>2. The Corresponding Energy Function</strong></p>
<p>To make the DAE equivalent to Score Matching, the underlying Energy-Based Model $p(x;\theta) \propto e^{-E(x;\theta)}$ must have the following energy function:</p>
<p>$$E(x; W, b, c) = - \frac{1}{\sigma^2} \left( \langle c, x \rangle - \frac{1}{2}|x|^2 + \sum_{j=1}^{d_h} \text{softplus}(\langle W_j, x \rangle + b_j) \right)$$</p>
<p>Note the scaling by $1/\sigma^2$ and the quadratic term $|x|^2$.</p>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metric:</strong> Theoretical Equivalence ($\sim$).</li>
<li><strong>Condition:</strong> The equivalence holds provided $\sigma &gt; 0$ and the density $q_{\sigma}$ is differentiable and vanishes at infinity (Hyvärinen&rsquo;s 2005 regularity condition for Implicit Score Matching).</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Vincent, P. (2011). A Connection Between Score Matching and Denoising Autoencoders. <em>Neural Computation</em>, 23(7), 1661-1674. <a href="https://doi.org/10.1162/NECO_a_00142">https://doi.org/10.1162/NECO_a_00142</a></p>
<p><strong>Publication</strong>: Neural Computation 2011</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{vincentConnectionScoreMatching2011,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{A {{Connection Between Score Matching}} and {{Denoising Autoencoders}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Vincent, Pascal}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2011</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jul,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Neural Computation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{23}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{7}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1661--1674}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1162/NECO_a_00142}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://www.iro.umontreal.ca/~vincentp/Publications/smdae_techreport.pdf">Official PDF</a></li>
</ul>
]]></content:encoded></item><item><title>Rectified Flow: Learning to Generate and Transfer Data</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/rectified-flow/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/rectified-flow/</guid><description>A unified ODE-based framework for generative modeling and domain transfer that learns straight paths for fast 1-step generation.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, with a significant <strong>Theory</strong> component.</p>
<ul>
<li><strong>Method</strong>: It proposes &ldquo;Rectified Flow,&rdquo; a novel generative framework that learns ordinary differential equations (ODEs) to transport distributions via straight paths. It introduces the &ldquo;Reflow&rdquo; algorithm to iteratively straighten these paths.</li>
<li><strong>Theory</strong>: It provides rigorous proofs connecting the method to Optimal Transport, showing that the rectification process yields a coupling with non-increasing convex transport costs and that recursive reflow reduces the curvature of trajectories.</li>
</ul>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The work addresses two main challenges in unsupervised learning: generative modeling (generating data from noise) and domain transfer (mapping between two observed distributions).</p>
<ul>
<li><strong>Inefficiency of ODE/SDE Models</strong>: Continuous-time models (like Score-based Generative Models and DDPMs) require simulating diffusions over many steps, resulting in high computational costs during inference.</li>
<li><strong>Complexity of GANs</strong>: GANs provide fast (one-step) generation alongside challenges with training instability and mode collapse.</li>
<li><strong>Disconnection</strong>: Generative modeling and domain transfer are often treated as separate tasks requiring different techniques.</li>
</ul>
<p>The authors aim to unify these tasks into a single &ldquo;transport mapping&rdquo; problem while bridging the gap between high-quality continuous models and fast one-step models.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>Rectified Flow</strong> framework and the <strong>Reflow</strong> procedure.</p>
<ul>
<li><strong>Straight-Line ODEs</strong>: Rectified Flow learns an ODE drift $v$ to follow the straight line connecting data pairs $(X_0, X_1)$, providing an alternative to diffusion models that rely on stochastic paths or specific forward processes. This is achieved via a simple least-squares optimization problem.</li>
<li><strong>Reflow (Iterative Straightening)</strong>: The authors introduce a recursive training procedure where a new flow is trained on the data pairs $(Z_0, Z_1)$ generated by the previous flow. Theoretical analysis shows this reduces the &ldquo;transport cost&rdquo; and straightens the trajectories, allowing for accurate 1-step simulation (effectively converting the ODE into a one-step model).</li>
<li><strong>Unified Framework</strong>: The method uses the exact same algorithm for generation ($\pi_0$ is Gaussian) and domain transfer ($\pi_0$ is a source dataset), removing the need for adversarial losses or cycle-consistency constraints.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the method across image generation, translation, and domain adaptation tasks.</p>
<ul>
<li><strong>Unconditioned Image Generation</strong>:
<ul>
<li><strong>Dataset</strong>: CIFAR-10 ($32\times32$).</li>
<li><strong>Baselines</strong>: Compared against GANs (StyleGAN2, TDPM), Diffusion/SDE Models (VP SDE, sub-VP SDE, VE SDE), ODE methods (VP ODE, sub-VP ODE, VE ODE), and distilled methods (DDIM Distillation).</li>
<li><strong>High-Res</strong>: Validated on LSUN Bedroom/Church, CelebA-HQ, and AFHQ ($256\times256$).</li>
</ul>
</li>
<li><strong>Image-to-Image Translation</strong>:
<ul>
<li><strong>Datasets</strong>: AFHQ (Cat $\leftrightarrow$ Dog/Wild), MetFace $\leftrightarrow$ CelebA-HQ.</li>
<li><strong>Setup</strong>: Transferring styles while preserving semantic identity (using a classifier-based feature mapping metric).</li>
</ul>
</li>
<li><strong>Domain Adaptation</strong>:
<ul>
<li><strong>Datasets</strong>: DomainNet, Office-Home.</li>
<li><strong>Metric</strong>: Classification accuracy on the transferred testing data.</li>
</ul>
</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Superior 1-Step Generation</strong>: On CIFAR-10 with a single Euler step (as of ICLR 2023), the distilled 2-Rectified Flow achieved an FID of <strong>4.85</strong>, beating the best one-step U-Net model TDPM (FID 8.91, a truncated diffusion model using a GAN). The distilled 3-Rectified Flow reached a Recall of <strong>0.51</strong>, beating the GAN baseline StyleGAN2+ADA (Recall 0.49).</li>
<li><strong>Straightening Effect</strong>: The &ldquo;Reflow&rdquo; procedure was empirically shown to reduce the &ldquo;straightness&rdquo; error and transport costs, validating the theoretical claims. &ldquo;Straightness&rdquo; is measured as $S(Z) = \mathbb{E}[\int_0^1 |\dot{Z}_t - (Z_1 - Z_0)|^2, dt]$ (zero means perfectly straight); &ldquo;transport cost&rdquo; is $\mathbb{E}[c(Z_1 - Z_0)]$ for a convex cost $c$, and Reflow reduces this for all convex costs.</li>
<li><strong>High-Quality Transfer</strong>: The model successfully performed image translation (e.g., Cat to Wild Animal) without paired data or cycle-consistency losses.</li>
<li><strong>Strong Full-Simulation Results</strong>: With RK45 adaptive ODE solving, 1-Rectified Flow achieves FID 2.58 and Recall 0.57 on CIFAR-10 (Table 1a), the best among ODE methods and comparable to fully simulated SDEs (VP SDE: FID 2.55).</li>
<li><strong>Fast Simulation</strong>: The method allows for extremely coarse time discretization (e.g., $N=1$) without significant quality loss after reflow, effectively solving the slow inference speed of standard ODE models.</li>
<li><strong>Domain Adaptation</strong>: On Office-Home, Rectified Flow achieves 69.2% accuracy, outperforming Deep CORAL (68.7%) and other baselines. On DomainNet, it achieves 41.4%, comparable to Deep CORAL (41.5%) and MLDG (41.2%).</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The paper utilizes several standard computer vision benchmarks.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size/Resolution</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Generation</td>
          <td><strong>CIFAR-10</strong></td>
          <td>32x32</td>
          <td>Standard split</td>
      </tr>
      <tr>
          <td>Generation</td>
          <td><strong>LSUN</strong> (Bedroom, Church)</td>
          <td>256x256</td>
          <td>High-res evaluation</td>
      </tr>
      <tr>
          <td>Generation</td>
          <td><strong>CelebA-HQ</strong></td>
          <td>256x256</td>
          <td>High-res evaluation</td>
      </tr>
      <tr>
          <td>Gen/Transfer</td>
          <td><strong>AFHQ</strong> (Cat, Dog, Wild)</td>
          <td>512x512</td>
          <td>256x256 for generation, 512x512 for transfer</td>
      </tr>
      <tr>
          <td>Transfer</td>
          <td><strong>MetFace</strong></td>
          <td>1024x1024</td>
          <td>Resized to 512x512 for experiments</td>
      </tr>
      <tr>
          <td>Adaptation</td>
          <td><strong>DomainNet</strong></td>
          <td>Mixed</td>
          <td>345 categories, 6 domains</td>
      </tr>
      <tr>
          <td>Adaptation</td>
          <td><strong>Office-Home</strong></td>
          <td>Mixed</td>
          <td>65 categories, 4 domains</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>
<p><strong>Objective Function</strong>:
The drift $v(Z_t, t)$ is trained by minimizing a least-squares regression objective:
$$\min_{v} \int_{0}^{1} \mathbb{E}[|(X_1 - X_0) - v(X_t, t)|^2] dt$$
where $X_t = tX_1 + (1-t)X_0$ is the linear interpolation.</p>
</li>
<li>
<p><strong>Reflow Procedure</strong>:
Iteratively updates the flow. Let $Z^k$ be the $k$-th rectified flow.</p>
<ol>
<li>Generate 4 million data pairs $(Z_0^k, Z_1^k)$ by simulating the current flow.</li>
<li>Fine-tune the $i$-rectified flow model for 300,000 steps on these pairs to obtain the $(i+1)$-rectified flow.</li>
</ol>
</li>
<li>
<p><strong>Distillation</strong>:
For 1-step distillation ($k=1$), the L2 loss is replaced with LPIPS perceptual similarity, which empirically yields better image quality. For multi-step distillation, training samples $t$ from ${0, 1/k, \ldots, (k-1)/k}$ rather than the full $[0, 1]$ interval.</p>
</li>
<li>
<p><strong>ODE Solver</strong>:</p>
<ul>
<li>Training: Analytical linear interpolation.</li>
<li>Inference: Euler method (constant step size $1/N$) or RK45 (adaptive).</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>
<p><strong>Architecture</strong>:</p>
<ul>
<li>Uses the <strong>DDPM++ U-Net</strong> architecture (from Song et al., 2020) across experiments. Implementation is modified from the open-source code of Song et al.</li>
</ul>
</li>
<li>
<p><strong>Optimization</strong>:</p>
<ul>
<li><strong>Optimizer</strong>: Adam (CIFAR-10) or AdamW (Transfer/Adaptation).</li>
<li><strong>Hyperparameters</strong>:
<ul>
<li>LR: $2 \times 10^{-4}$ (CIFAR), Grid search for transfer.</li>
<li>EMA: 0.999999 (CIFAR), 0.9999 (Transfer).</li>
<li>Batch Size: 4 (Transfer), 16 (Domain Adaptation).</li>
<li>Dropout: 0.15 (CIFAR), 0.1 (Transfer).</li>
</ul>
</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value (CIFAR-10, N=1)</th>
          <th>Baseline (Best 1-step)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>FID</strong></td>
          <td><strong>4.85</strong> (2-Rectified + Distill)</td>
          <td>8.91 (TDPM)</td>
          <td>Lower is better</td>
      </tr>
      <tr>
          <td><strong>Recall</strong></td>
          <td><strong>0.51</strong> (3-Rectified + Distill)</td>
          <td>0.49 (StyleGAN2+ADA)</td>
          <td>Higher is better</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify GPU models or training times. The DDPM++ U-Net architecture used in the experiments typically requires multi-GPU setups for training on high-resolution datasets.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/gnobitab/RectifiedFlow">RectifiedFlow (GitHub)</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Official PyTorch implementation with CIFAR-10 and high-res training code, plus pre-trained checkpoints</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Liu, X., Gong, C., &amp; Liu, Q. (2023). Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow. <em>International Conference on Learning Representations (ICLR)</em>. <a href="https://openreview.net/forum?id=XVjTT1nw5z">https://openreview.net/forum?id=XVjTT1nw5z</a></p>
<p><strong>Publication</strong>: ICLR 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{liuFlowStraightFast2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Flow {{Straight}} and {{Fast}}: {{Learning}} to {{Generate}} and {{Transfer Data}} with {{Rectified Flow}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Liu, Xingchao and Gong, Chengyue and Liu, Qiang}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2023</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://openreview.net/forum?id=XVjTT1nw5z}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/gnobitab/RectifiedFlow">Official Code Repository</a></li>
<li><a href="https://openreview.net/forum?id=XVjTT1nw5z">OpenReview Page</a></li>
</ul>
]]></content:encoded></item><item><title>Neural ODEs: Continuous-Depth Deep Learning Models</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/neural-odes/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/neural-odes/</guid><description>Introduces ODE-Nets, a continuous-depth neural network model parameterized by ODEs, enabling constant memory backpropagation and adaptive computation.</description><content:encoded><![CDATA[<blockquote>
<p><strong>Key Prerequisites</strong>: Before diving in, note that for the ODE solver to guarantee a unique solution, the neural network $f(h(t), t, \theta)$ parameterizing the dynamics must be <a href="https://en.wikipedia.org/wiki/Lipschitz_continuity">Lipschitz continuous</a>. This ensures the <a href="https://en.wikipedia.org/wiki/Picard%E2%80%93Lindel%C3%B6f_theorem">Picard-Lindelöf theorem</a> holds, preventing trajectories from crossing and guaranteeing a well-defined backward pass.</p></blockquote>
<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, with a strong secondary <strong>Theory</strong> component.</p>
<ul>
<li><strong>Method</strong>: It proposes a novel family of deep neural network models where the derivative of the hidden state is parameterized by a neural network. It provides specific algorithms (Algorithm 1) for training these models scalably.</li>
<li><strong>Theory</strong>: It derives the adjoint sensitivity method for backpropagating through black-box ODE solvers and proves the &ldquo;Instantaneous Change of Variables&rdquo; theorem (Theorem 1) for continuous normalizing flows.</li>
</ul>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The authors aim to address limitations in discrete deep learning architectures:</p>
<ul>
<li><strong>Discrete vs. Continuous</strong>: Existing models like Residual Networks build transformations by composing discrete steps, which can be seen as an Euler discretization of a continuous transformation. The authors investigate the limit as step sizes go to zero.</li>
<li><strong>Memory Efficiency</strong>: Backpropagating through deep discrete networks requires storing intermediate activations, leading to linear memory cost in terms of depth, which is a major bottleneck.</li>
<li><strong>Irregular Data</strong>: Recurrent Neural Networks (RNNs) struggle with data arriving at arbitrary times, typically requiring discretization into fixed bins.</li>
<li><strong>Normalizing Flow Costs</strong>: Standard normalizing flows have a bottleneck in computing the determinant of the Jacobian, which is computationally expensive ($O(D^3)$).</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core contribution is the <strong>Neural ODE</strong> formulation:
$$\frac{dh(t)}{dt} = f(h(t), t, \theta)$$
where the output is computed using a black-box differential equation solver.</p>
<p>Key technical innovations include:</p>
<ol>
<li><strong>Adjoint Sensitivity Method for Backprop</strong>: The authors treat the solver as a black box and compute gradients by solving a second, augmented ODE backwards in time. This allows for <strong>constant memory cost</strong> regardless of depth.</li>
<li><strong>Adaptive Computation</strong>: The model uses modern ODE solvers that adapt evaluation steps based on error tolerance, allowing the model to trade precision for speed explicitly.</li>
<li><strong>Continuous Normalizing Flows (CNF)</strong>: By moving to continuous time, the change of variables formula simplifies from a log-determinant (cubic cost) to a trace operation (linear cost), enabling scalable generative modeling.</li>
<li><strong>Latent ODEs</strong>: A generative time-series model that represents time-series as latent trajectories determined by a local initial state and global shared dynamics, handling irregular sampling naturally.</li>
</ol>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the method across three distinct domains:</p>
<ol>
<li><strong>Supervised Learning (MNIST)</strong>:
<ul>
<li>Compared <strong>ODE-Net</strong> against a standard <strong>ResNet</strong> and a Runge-Kutta network (<strong>RK-Net</strong>).</li>
<li>Measured test error, parameter count, and memory usage.</li>
<li>Analyzed the trade-off between numerical precision (tolerance) and speed (NFE).</li>
</ul>
</li>
<li><strong>Continuous Normalizing Flows (Generative)</strong>:
<ul>
<li>Compared CNF against standard Normalizing Flows (NF) on density matching and maximum likelihood estimation tasks using toy 2D datasets (Two Circles, Two Moons, and other target distributions).</li>
<li>Evaluated training loss (KL divergence) and maximum likelihood estimation.</li>
</ul>
</li>
<li><strong>Time-Series Modeling (Latent ODE)</strong>:
<ul>
<li>Tested on a dataset of bi-directional spirals with irregular timestamps and Gaussian noise.</li>
<li>Compared Latent ODEs against an RNN baseline on predictive RMSE. A second RNN variant with time-difference concatenation was also trained.</li>
</ul>
</li>
</ol>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Efficiency</strong>: ODE-Nets achieved roughly equivalent accuracy to ResNets on MNIST (0.42% vs 0.41% error) but with <strong>constant memory cost</strong> ($O(1)$) compared to ResNet&rsquo;s linear cost ($O(L)$).</li>
<li><strong>Adaptive Depth</strong>: The number of function evaluations (NFE) in ODE-Nets increases with training epoch, suggesting the model adapts its complexity as it learns. The backward pass NFE is roughly half the forward pass NFE, indicating that the adjoint method is also more computationally efficient than direct backpropagation through the integrator.</li>
<li><strong>Generative Performance</strong>: Continuous Normalizing Flows (CNF) achieved lower KL divergence loss than standard Normalizing Flows (NF), trained with only 10,000 iterations (Adam) compared to 500,000 iterations (RMSprop) for NF. Note that the two models used different optimizers, so the comparison is not fully controlled. CNF can also expand capacity by increasing width ($M$) without architectural constraints.</li>
<li><strong>Irregular Time-Series</strong>: Latent ODEs significantly outperformed RNNs across all observation counts on irregular spiral data. The advantage is most pronounced with sparse observations (0.1642 vs 0.3937 RMSE at 30 obs), and the model learns interpretable latent trajectories that switch direction smoothly.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>MNIST</strong>: Standard handwritten digit dataset used for supervised learning benchmarks.</li>
<li><strong>Toy 2D Densities</strong>: &ldquo;Two Circles&rdquo; and &ldquo;Two Moons&rdquo; distributions used for visualizing normalizing flows.</li>
<li><strong>Bi-directional Spirals</strong>: A generated dataset of 1,000 2D spirals (half clockwise, half counter-clockwise). Each spiral is sampled at 100 equally-spaced timesteps with added Gaussian noise. For training, each spiral is then subsampled without replacement to $n \in {30, 50, 100}$ irregularly-spaced observations, simulating realistic missing data.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Adjoint Sensitivity Method (Backpropagation)</strong></p>
<p>To optimize the parameters of the ODE-Net, the authors use the adjoint sensitivity method to compute gradients. Standard backpropagation would require storing the activations at every step of the ODE solver, incurring a high memory cost that scales linearly with the number of steps.</p>
<p>Instead, this method treats the ODE solver as a &ldquo;black box&rdquo; and computes gradients by solving a second, <strong>augmented ODE</strong> backwards in time from the final state $t_1$ to the initial state $t_0$.</p>
<p>The augmented state contains three components that are solved simultaneously:</p>
<ol>
<li><strong>The State</strong>: The original hidden state $z(t)$, which is reconstructed backwards.</li>
<li><strong>The Adjoint</strong>: The sensitivity of the loss with respect to the state, $a(t) = \partial L / \partial z(t)$.</li>
<li><strong>The Gradient</strong>: The accumulating gradients with respect to parameters, $\partial L / \partial \theta$.</li>
</ol>
<p>The dynamics of this augmented system are defined as:
$$\frac{d}{dt}\begin{bmatrix} z(t) \ a(t) \ \partial L/\partial \theta \end{bmatrix} = \begin{bmatrix} f(z(t), t, \theta) \ -a(t)^T \frac{\partial f}{\partial z} \ -a(t)^T \frac{\partial f}{\partial \theta} \end{bmatrix}$$</p>
<p>Using this approach, the vector-Jacobian products (e.g., $a(t)^T \frac{\partial f}{\partial z}$) are evaluated efficiently using automatic differentiation.</p>
<blockquote>
<p><strong>Why:</strong> Reconstructing $z(t)$ backwards avoids storing the forward pass, enabling <strong>constant memory cost</strong> ($O(1)$) regardless of depth.</p>
<p><strong>Origin:</strong> Adapted from Pontryagin&rsquo;s maximum principle (1962) for optimal control.</p></blockquote>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> torchdiffeq <span style="color:#f92672">import</span> odeint_adjoint
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">ODEFunc</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, dim):
</span></span><span style="display:flex;"><span>        super(ODEFunc, self)<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>net <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(dim, <span style="color:#ae81ff">50</span>),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(<span style="color:#ae81ff">50</span>, dim),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, t, y):
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Defines dy/dt = f(y, t)</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>net(y)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Usage with adjoint method for O(1) memory backprop</span>
</span></span><span style="display:flex;"><span>func <span style="color:#f92672">=</span> ODEFunc(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">2</span>)
</span></span><span style="display:flex;"><span>y0 <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>tensor([[<span style="color:#ae81ff">1.</span>, <span style="color:#ae81ff">0.</span>]]) <span style="color:#75715e"># Initial state</span>
</span></span><span style="display:flex;"><span>t <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>linspace(<span style="color:#ae81ff">0.</span>, <span style="color:#ae81ff">1.</span>, <span style="color:#ae81ff">10</span>) <span style="color:#75715e"># Time points to solve for</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># &#39;odeint_adjoint&#39; automatically handles the augmented state backward pass</span>
</span></span><span style="display:flex;"><span>out <span style="color:#f92672">=</span> odeint_adjoint(func, y0, t, method<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;dopri5&#39;</span>)
</span></span></code></pre></div><p><strong>2. Instantaneous Change of Variables (CNF)</strong></p>
<p>For generative modeling, the authors introduce <strong>Continuous Normalizing Flows (CNF)</strong>. In discrete normalizing flows, the probability density of a transformed variable is calculated using the change of variables theorem, which requires computing the log-determinant of the Jacobian: $\log p(z_1) = \log p(z_0) - \log |\det \frac{\partial z_1}{\partial z_0}|$. This operation is computationally expensive ($O(D^3)$) and often restricts model architectures to ensure the Jacobian is easy to compute (e.g., triangular).</p>
<p>Moving to continuous time simplifies this requirement. The paper proves that if the transformation is defined by an ODE, the change in log-probability follows a differential equation determined by the <strong>trace</strong> of the Jacobian:
$$\frac{\partial \log p(z(t))}{\partial t} = -\text{tr}\left( \frac{\partial f}{\partial z(t)} \right)$$</p>
<p>The total change in log-density is obtained by integrating this value over time.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_trace</span>(y, f):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes trace of Jacobian df/dy.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    For high dimensions, use Hutchinson&#39;s trace estimator (approximate).
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    tr <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> range(y<span style="color:#f92672">.</span>size(<span style="color:#ae81ff">1</span>)):
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Gradients of f&#39;s i-th component w.r.t y&#39;s i-th component</span>
</span></span><span style="display:flex;"><span>        tr <span style="color:#f92672">+=</span> torch<span style="color:#f92672">.</span>autograd<span style="color:#f92672">.</span>grad(f[:, i]<span style="color:#f92672">.</span>sum(), y, create_graph<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)[<span style="color:#ae81ff">0</span>][:, i]
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> tr
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># In the ODE function:</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># d(log_p)/dt = -trace(df/dy)</span>
</span></span></code></pre></div><blockquote>
<p><strong>Why:</strong> The trace operator has <strong>linear cost</strong> ($O(D)$), whereas the determinant has cubic cost ($O(D^3)$). This allows for unrestricted, &ldquo;wide&rdquo; architectures that are automatically bijective.</p>
<p><strong>Origin:</strong> This is the &ldquo;Instantaneous Change of Variables&rdquo; theorem (Theorem 1), derived in Appendix A of the paper.</p></blockquote>
<h3 id="models">Models</h3>
<p><strong>ODE-Net (MNIST Classification)</strong>:</p>
<ul>
<li><strong>Input</strong>: Downsamples input twice.</li>
<li><strong>Core</strong>: 6 standard residual blocks replaced by a single <strong>ODESolve</strong> module.</li>
<li><strong>Output</strong>: Global average pooling + Fully connected layer.</li>
<li><strong>Solver</strong>: Implicit Adams method.</li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">ODEBlock</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, odefunc):
</span></span><span style="display:flex;"><span>        super(ODEBlock, self)<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>odefunc <span style="color:#f92672">=</span> odefunc
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>integration_time <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>tensor([<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>])<span style="color:#f92672">.</span>float()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, x):
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>integration_time <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>integration_time<span style="color:#f92672">.</span>type_as(x)
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Returns [x(t0), x(t1)]; we only want final state x(t1)</span>
</span></span><span style="display:flex;"><span>        out <span style="color:#f92672">=</span> odeint_adjoint(self<span style="color:#f92672">.</span>odefunc, x, self<span style="color:#f92672">.</span>integration_time)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> out[<span style="color:#ae81ff">1</span>]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># ResNet-like architecture with ODE block</span>
</span></span><span style="display:flex;"><span>model <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>Conv2d(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">64</span>, <span style="color:#ae81ff">3</span>, <span style="color:#ae81ff">1</span>),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>ReLU(inplace<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>),
</span></span><span style="display:flex;"><span>    ODEBlock(ODEFunc(<span style="color:#ae81ff">64</span>)), <span style="color:#75715e"># Continuous-depth layer replacement</span>
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>BatchNorm2d(<span style="color:#ae81ff">64</span>),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>AdaptiveAvgPool2d((<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>)),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>Flatten(),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>Linear(<span style="color:#ae81ff">64</span>, <span style="color:#ae81ff">10</span>)
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><p><strong>Latent ODE (Time-Series)</strong>:</p>
<ul>
<li><strong>Encoder</strong>: RNN with 25 hidden units processing data backwards to produce $q(z_0|x)$. It runs backwards so the final RNN state summarizes the entire sequence at $t_0$, parameterizing the initial latent state $z_0$ for the forward-running ODE.</li>
<li><strong>Latent Space</strong>: 4-dimensional latent state $z_0$.</li>
<li><strong>Dynamics ($f$)</strong>: Neural network with one hidden layer of 20 units.</li>
<li><strong>Decoder</strong>: Neural network with one hidden layer of 20 units computing $p(x_{t_i}|z_{t_i})$.</li>
<li><strong>Likelihood</strong>: Gaussian log-likelihood for the spiral reconstruction task. The paper also describes an optional Poisson process likelihood $\lambda(z(t))$ for event-time data (e.g., medical records), but this is not used in the spiral experiment.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Experiment</th>
          <th>Metric</th>
          <th>Baseline (ResNet/RNN)</th>
          <th>ODE Model</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MNIST</td>
          <td>Test Error</td>
          <td>0.41%</td>
          <td>0.42%</td>
      </tr>
      <tr>
          <td>MNIST</td>
          <td>Parameters</td>
          <td>0.60 M</td>
          <td>0.22 M</td>
      </tr>
      <tr>
          <td>MNIST</td>
          <td>Memory</td>
          <td>$O(L)$</td>
          <td>$O(1)$</td>
      </tr>
      <tr>
          <td>Spirals (30 obs)</td>
          <td>RMSE</td>
          <td>0.3937</td>
          <td><strong>0.1642</strong></td>
      </tr>
      <tr>
          <td>Spirals (50 obs)</td>
          <td>RMSE</td>
          <td>0.3202</td>
          <td><strong>0.1502</strong></td>
      </tr>
      <tr>
          <td>Spirals (100 obs)</td>
          <td>RMSE</td>
          <td>0.1813</td>
          <td><strong>0.1346</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Implementation</strong>: Hidden state dynamics evaluated on GPU using <strong>TensorFlow</strong>.</li>
<li><strong>Solvers</strong>: Fortran ODE solvers (LSODE, VODE) from <code>scipy.integrate</code> were used for the actual integration.</li>
<li><strong>Note</strong>: While the original paper used TensorFlow/Scipy, the authors later released <code>torchdiffeq</code> (PyTorch), which has become the standard implementation for this architecture. The code samples above reflect this modern standard.</li>
<li><strong>Interface</strong>: Python&rsquo;s <code>autograd</code> framework bridged the TensorFlow dynamics and Scipy solvers.</li>
</ul>
<h3 id="limitations">Limitations</h3>
<p>The paper identifies several practical limitations of Neural ODEs:</p>
<ul>
<li><strong>Minibatching</strong>: Batching requires concatenating states of each batch element into a combined ODE of dimension $D \times K$. Controlling error on all batch elements together can require more evaluations than solving each system individually, though in practice this overhead was not substantial.</li>
<li><strong>Tolerance tuning</strong>: Users must choose error tolerances for both the forward and reverse passes. The paper used 1.5e-8 for sequence modeling, 1e-3 for classification, and 1e-5 for density estimation.</li>
<li><strong>Backward trajectory reconstruction</strong>: Running the dynamics backwards to reconstruct the forward state trajectory can introduce extra numerical error if the reconstructed trajectory diverges from the original. Checkpointing (storing intermediate states) can address this, though the authors did not find it necessary in practice.</li>
<li><strong>Uniqueness requirements</strong>: The neural network $f$ must be Lipschitz continuous (e.g., using tanh or ReLU activations with finite weights) to guarantee a unique solution via Picard&rsquo;s existence theorem.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/rtqichen/torchdiffeq">torchdiffeq</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official PyTorch implementation with GPU-based ODE solvers</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chen, R. T. Q., Rubanova, Y., Bettencourt, J., &amp; Duvenaud, D. (2018). Neural ordinary differential equations. <em>Proceedings of the 32nd International Conference on Neural Information Processing Systems</em>, 6572-6583.</p>
<p><strong>Publication</strong>: NeurIPS 2018</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{chen2018neural,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Neural ordinary differential equations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Chen, Ricky T. Q. and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 32nd International Conference on Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{6572--6583}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/rtqichen/torchdiffeq">Official PyTorch Implementation</a></li>
</ul>
]]></content:encoded></item><item><title>Flow Matching for Generative Modeling: Scalable CNFs</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/</guid><description>A simulation-free framework for training Continuous Normalizing Flows using Conditional Flow Matching and Optimal Transport paths.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, as it introduces &ldquo;Flow Matching&rdquo; (FM), a novel simulation-free paradigm for training Continuous Normalizing Flows (CNFs) at scale. It is supported by a strong <strong>Theory</strong> basis, providing formal theorems that allow the intractable marginal vector field regression to be solved via a tractable conditional objective. It also touches on <strong>Systematization</strong> by showing that existing diffusion paths are specific instances of the proposed Gaussian probability path framework.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The paper aims to overcome the scaling limitations of Continuous Normalizing Flows (CNFs).</p>
<ul>
<li><strong>Problem</strong>: Standard Maximum Likelihood training for CNFs requires expensive numerical ODE simulations during training, which scales poorly. Existing simulation-free methods often involve intractable integrals or result in biased gradients.</li>
<li><strong>Gap</strong>: Diffusion models scale well, yet they are restricted to specific, curved probability paths (e.g., VP, VE) that can result in slow sampling and long training times.</li>
<li><strong>Goal</strong>: To develop an efficient, simulation-free training method for CNFs that supports arbitrary probability paths, specifically allowing for straighter, more efficient trajectories like those from Optimal Transport.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is <strong>Flow Matching (FM)</strong> and specifically the <strong>Conditional Flow Matching (CFM)</strong> objective.</p>
<ul>
<li><strong>Direct Vector Field Regression</strong>: The model regresses a target vector field $u_t$ that generates a desired probability path $p_t$.</li>
<li><strong>Conditional Flow Matching (CFM)</strong>: The authors prove that regressing the vector field of <em>conditional</em> paths (e.g., $p_t(x|x_1)$ given a single data point) yields the same gradients as regressing the intractable marginal vector field. This bypasses the need to know the marginal score or vector field.</li>
<li><strong>Optimal Transport Paths</strong>: The framework enables the use of <strong>Optimal Transport (OT)</strong> displacement interpolation for probability paths. OT paths are straight lines with constant speed, leading to faster training and easier sampling.</li>
</ul>
<p><strong>Concurrent work note</strong>: Rectified Flow (Liu et al., 2023) and Stochastic Interpolants (Albergo &amp; Vanden-Eijnden, 2023) were published concurrently at ICLR 2023 with structurally similar contributions under different names. All three independently propose simulation-free training of continuous flows via direct vector field regression; the differences lie in the specific interpolation schemes, theoretical framing, and experimental focus.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<ul>
<li><strong>Domains</strong>: 2D Checkerboard data, CIFAR-10, and ImageNet at resolutions $32 \times 32$, $64 \times 64$, and $128 \times 128$.</li>
<li><strong>Task</strong>: Unconditional generative modeling (density estimation and sample quality) and conditional super-resolution ($64 \times 64 \to 256 \times 256$).</li>
<li><strong>Baselines</strong>: Compared against Diffusion-based methods on the same architecture (U-Net): DDPM, Score Matching (SM), and ScoreFlow.</li>
<li><strong>Ablations</strong>: Specifically compared <strong>FM with Diffusion paths</strong> vs. <strong>FM with Optimal Transport (OT) paths</strong> to isolate the benefit of the training objective vs. the path choice.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Outperforms diffusion baselines</strong>: FM-OT consistently outperforms all diffusion-based methods (DDPM, Score Matching, ScoreFlow) in both Likelihood (NLL) and Sample Quality (FID) across CIFAR-10 and ImageNet, using the same U-Net architecture and training budget. Selected rows from Table 1 (NLL in bits per dimension, BPD; lower is better for all three metrics; &ldquo;FM w/ OT&rdquo; and &ldquo;FM w/ Diffusion&rdquo; refer to FM trained with OT paths and Diffusion paths respectively):</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Method</th>
          <th>NLL (BPD) ↓</th>
          <th>FID ↓</th>
          <th>NFE ↓</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CIFAR-10</td>
          <td>DDPM</td>
          <td>3.12</td>
          <td>7.48</td>
          <td>274</td>
      </tr>
      <tr>
          <td>CIFAR-10</td>
          <td>FM w/ OT</td>
          <td><strong>2.99</strong></td>
          <td><strong>6.35</strong></td>
          <td><strong>142</strong></td>
      </tr>
      <tr>
          <td>ImageNet 64×64</td>
          <td>ScoreFlow</td>
          <td>3.36</td>
          <td>24.95</td>
          <td>601</td>
      </tr>
      <tr>
          <td>ImageNet 64×64</td>
          <td>FM w/ OT</td>
          <td><strong>3.31</strong></td>
          <td><strong>14.45</strong></td>
          <td><strong>138</strong></td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Training stability</strong>: FM with diffusion paths (FM w/ Diffusion) is itself a more stable alternative to diffusion training than DDPM and Score Matching, as shown by training curves in the paper (Figure 5), even before switching to OT paths. The OT path then provides further gains.</li>
<li><strong>Sampling speed</strong>: The straight trajectories of OT paths allow accurate sampling with significantly fewer function evaluations (NFE) compared to diffusion paths.</li>
<li><strong>Generality</strong>: Diffusion is a specific instance of Gaussian probability paths within FM. OT paths are a better-optimized alternative available within the same framework.</li>
<li><strong>Downstream adoption</strong>: Flow matching has been adopted beyond image generation. <a href="/notes/biology/computational-biology/dynamicflow/">DynamicFlow</a> uses it as the generative backbone for simultaneously generating ligand molecules and transforming protein pockets, extending flow matching to structure-based drug design.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Datasets</strong>: CIFAR-10, ImageNet ($32 \times 32$, $64 \times 64$, $128 \times 128$).</li>
<li><strong>Preprocessing</strong>:
<ul>
<li>Images are center-cropped and resized.</li>
<li>For $32 \times 32$ and $64 \times 64$, the preprocessing follows Chrabaszcz et al. (2017).</li>
<li>Data is transformed via $\varphi(y) = 2^7(y+1)$ mapping $[-1, 1]$ pixel values to $[0, 256]$ for BPD computation.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Conditional Flow Matching (CFM) Objective</strong></p>
<p>The practical training objective used is the CFM loss, which bypasses intractable marginalization:</p>
<p>$$\mathcal{L}_{CFM}(\theta) = \mathbb{E}_{t, q(x_1), p(x_0)} | v_t(\psi_t(x_0)) - u_t(\psi_t(x_0) | x_1) |^2$$</p>
<p>Where $t \sim \mathcal{U}[0,1]$, $x_1 \sim q(x_1)$ (data), and $x_0 \sim p(x_0)$ (noise).</p>
<p><strong>2. Optimal Transport (OT) Probability Path</strong></p>
<p>The authors recommend the OT path for efficiency.</p>
<ul>
<li><strong>Mean/Std Schedule</strong>: $\mu_t(x) = t x_1$ and $\sigma_t(x) = 1 - (1 - \sigma_{min})t$.</li>
<li><strong>Conditional Flow Map</strong>: $\psi_t(x) = (1 - (1 - \sigma_{min})t)x + t x_1$.</li>
<li><strong>Target Vector Field</strong>: The closed-form regression target for OT is:
$$u_t(x|x_1) = \frac{x_1 - (1 - \sigma_{min})x}{1 - (1 - \sigma_{min})t}$$</li>
</ul>
<p><strong>3. Sampling</strong></p>
<p>Sampling is performed by solving the ODE $\frac{d}{dt}\phi_t(x) = v_t(\phi_t(x))$ from $t=0$ to $t=1$ using the learned vector field $v_t$.</p>
<ul>
<li><strong>Solver</strong>: <code>dopri5</code> (adaptive) is used for robust evaluation. Fixed-step solvers (Euler, Midpoint) are used for low-NFE efficiency tests.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: U-Net architecture from Dhariwal &amp; Nichol (2021) is used for all image experiments.</li>
<li><strong>Toy Data</strong>: 5-layer MLP with 512 neurons.</li>
<li><strong>Hyperparameters</strong>:
<ul>
<li>Optimizer: Adam ($\beta_1=0.9, \beta_2=0.999$, weight decay=0.0).</li>
<li>Learning Rate: Polynomial decay or constant (see Table 3 in paper).</li>
<li>$\sigma_{min}$: Set to a small value (e.g., $1e-5$).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>:
<ul>
<li><strong>NLL (BPD)</strong>: Computed using the continuous change of variables formula, estimated via the Hutchinson trace estimator to bypass $O(d^3)$ divergence computation.</li>
<li><strong>FID</strong>: Frechet Inception Distance for sample quality.</li>
<li><strong>NFE</strong>: Number of Function Evaluations required by the solver.</li>
</ul>
</li>
<li><strong>Likelihood Computation</strong>: Requires solving an augmented ODE to track the log-density change:
$$\frac{d}{dt} \begin{bmatrix} \phi_t(x) \ f(t) \end{bmatrix} = \begin{bmatrix} v_t(\phi_t(x)) \ -\text{div}(v_t(\phi_t(x))) \end{bmatrix}$$</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>CIFAR-10</strong>: 2 GPUs.</li>
<li><strong>ImageNet-32</strong>: 4 GPUs.</li>
<li><strong>ImageNet-64</strong>: 16 GPUs.</li>
<li><strong>ImageNet-128</strong>: 32 GPUs.</li>
<li><strong>Precision</strong>: Full 32-bit for CIFAR/IM-32; 16-bit mixed precision for IM-64/128.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/facebookresearch/flow_matching">flow_matching (PyTorch library)</a></td>
          <td>Code</td>
          <td>CC BY-NC 4.0</td>
          <td>Later official library from Meta; not the original experiment code</td>
      </tr>
  </tbody>
</table>
<p>The paper does not release the original training code or model weights used in the experiments. The <code>facebookresearch/flow_matching</code> library was released later as a general-purpose PyTorch implementation of flow matching algorithms. Standard benchmark datasets (CIFAR-10, ImageNet) are publicly available.</p>
<hr>
<h2 id="theoretical-notes-why-cfm-works">Theoretical Notes: Why CFM Works</h2>
<p>The paper relies on three key theorems to make training tractable.</p>
<p><strong>Theorem 1 (Marginal Generation)</strong>:</p>
<p>Marginalizing conditional vector fields $u_t(x|x_1)$ yields the correct marginal vector field $u_t(x)$ that generates the marginal probability path $p_t(x)$.</p>
<p>$$u_t(x) = \int u_t(x|x_1) \frac{p_t(x|x_1)q(x_1)}{p_t(x)} dx_1$$</p>
<blockquote>
<p><strong>Understanding the Proof:</strong></p>
<p>To understand why this theorem holds, we have to look at the <strong>Continuity Equation</strong>, which is the fundamental partial differential equation (PDE) that links a probability density path $p_t$ to a vector field $u_t$.</p>
<p>A vector field $u_t$ is said to &ldquo;generate&rdquo; a probability path $p_t$ if and only if they satisfy the continuity equation:</p>
<p>$$\frac{\partial p_t(x)}{\partial t} + \nabla \cdot (p_t(x) u_t(x)) = 0$$</p>
<p>The proof of Theorem 1 relies on substituting the definitions of the marginal path and vector field into this equation to see if they balance out.</p>
<p><strong>Step-by-Step Proof:</strong></p>
<ol>
<li>
<p><strong>Start with the time derivative of the marginal path</strong>: We begin by differentiating the marginal probability path $p_t(x)$ with respect to time. By definition, the marginal path is the integral of the conditional paths over the data distribution:
$$\frac{\partial p_t(x)}{\partial t} = \frac{\partial}{\partial t} \int p_t(x|x_1) q(x_1) dx_1$$</p>
</li>
<li>
<p><strong>Swap derivative and integral</strong>: Assuming standard regularity conditions (Leibniz Rule), we can move the time derivative inside the integral:
$$\frac{\partial p_t(x)}{\partial t} = \int \frac{\partial p_t(x|x_1)}{\partial t} q(x_1) dx_1$$</p>
</li>
<li>
<p><strong>Apply the Conditional Continuity Equation</strong>: This is the critical step. We know that the conditional vector field $u_t(x|x_1)$ generates the conditional path $p_t(x|x_1)$. Therefore, for every single sample $x_1$, the pair satisfies the continuity equation:
$$\frac{\partial p_t(x|x_1)}{\partial t} = -\nabla \cdot (p_t(x|x_1) u_t(x|x_1))$$</p>
<p>Substituting this into our integral gives:
$$\frac{\partial p_t(x)}{\partial t} = -\int \nabla \cdot (p_t(x|x_1) u_t(x|x_1)) q(x_1) dx_1$$</p>
</li>
<li>
<p><strong>Pull the Divergence out</strong>: Since the divergence operator ($\nabla \cdot$) acts on $x$ and the integral is over $x_1$, we can pull the divergence operator outside the integral (by linearity):
$$\frac{\partial p_t(x)}{\partial t} = -\nabla \cdot \left( \int p_t(x|x_1) u_t(x|x_1) q(x_1) dx_1 \right)$$</p>
</li>
<li>
<p><strong>Match with the Marginal Vector Field Definition</strong>: Now, look at the term inside the parentheses. The paper defines the marginal vector field $u_t(x)$ specifically to make this term simpler. Rearranging the definition of $u_t(x)$ provided in the theorem:
$$p_t(x) u_t(x) = \int p_t(x|x_1) u_t(x|x_1) q(x_1) dx_1$$</p>
<p>Substitute $p_t(x) u_t(x)$ back into our equation from Step 4:
$$\frac{\partial p_t(x)}{\partial t} = -\nabla \cdot (p_t(x) u_t(x))$$</p>
</li>
</ol>
<p><strong>Conclusion</strong>: We have just shown that $\frac{\partial p_t(x)}{\partial t} + \nabla \cdot (p_t(x) u_t(x)) = 0$. This is exactly the continuity equation. Because the marginal path and the aggregated marginal vector field satisfy this equation, the vector field is proven to generate the path.</p></blockquote>
<p><strong>Theorem 2 (Gradient Equivalence)</strong>:</p>
<p>The intractable Flow Matching objective $\mathcal{L}_{FM}$ (which requires $u_t(x)$) has the <strong>same gradients</strong> as the tractable Conditional Flow Matching objective $\mathcal{L}_{CFM}$.</p>
<p>$$\nabla_\theta \mathcal{L}_{FM}(\theta) = \nabla_\theta \mathcal{L}_{CFM}(\theta)$$</p>
<p>This allows the model to learn the marginal vector field by only seeing conditional sample paths.</p>
<blockquote>
<p><strong>Understanding the Proof:</strong></p>
<p>The reason Theorem 2 holds is that the &ldquo;Conditional Flow Matching&rdquo; (CFM) objective is essentially an unbiased estimator of the &ldquo;Flow Matching&rdquo; (FM) objective (up to a constant). When we average over all the conditional data points $x_1$, the &ldquo;cross-term&rdquo; in the loss function aligns perfectly with the marginal vector field.</p>
<p><strong>1. Expand the Loss Functions</strong></p>
<p>First, let&rsquo;s look at the squared error in both objectives. Recall that $v_t$ is our neural network (parameterized by $\theta$), $u_t$ is the intractable marginal target, and $u_t(x|x_1)$ is the tractable conditional target.</p>
<p>Expanding the squared norms:</p>
<ul>
<li>
<p><strong>FM Objective</strong>:
$$\mathcal{L}_{FM}(\theta) = \mathbb{E}_{t, p_t(x)} \left[ |v_t(x)|^2 - 2v_t(x) \cdot u_t(x) + |u_t(x)|^2 \right]$$</p>
</li>
<li>
<p><strong>CFM Objective</strong>:
$$\mathcal{L}_{CFM}(\theta) = \mathbb{E}_{t, q(x_1), p_t(x|x_1)} \left[ |v_t(x)|^2 - 2v_t(x) \cdot u_t(x|x_1) + |u_t(x|x_1)|^2 \right]$$</p>
</li>
</ul>
<p><strong>Key Insight</strong>: When we take the gradient $\nabla_\theta$, the last term in both equations disappears because the targets ($u_t$) are independent of the network weights $\theta$. We only need to show that the expectations of the first two terms match.</p>
<p><strong>2. Matching the First Term ($|v_t(x)|^2$)</strong></p>
<p>This part is straightforward. The expectation of $|v_t(x)|^2$ is the same in both cases because of how the marginal density $p_t(x)$ is defined.</p>
<ul>
<li><strong>FM</strong>: averages over $p_t(x)$.</li>
<li><strong>CFM</strong>: averages over $p_t(x|x_1)q(x_1)$.</li>
</ul>
<p>Since $p_t(x) = \int p_t(x|x_1) q(x_1) dx_1$ (by definition), averaging over the joint distribution is mathematically identical to averaging over the marginal $p_t(x)$.</p>
<p><strong>3. Matching the Cross Term (The &ldquo;Trick&rdquo;)</strong></p>
<p>This is the critical part of the proof. We need to show that the interaction between the network and the marginal field equals the interaction between the network and the conditional field.</p>
<p><strong>The Goal</strong>: Show $\mathbb{E}_{t, p_t(x)} [v_t(x) \cdot u_t(x)] = \mathbb{E}_{t, q(x_1), p_t(x|x_1)} [v_t(x) \cdot u_t(x|x_1)]$.</p>
<p><strong>The Proof</strong>:</p>
<ol>
<li>
<p>Start with the <strong>FM cross-term</strong> (marginal):
$$\mathbb{E}_{t, p_t(x)} [v_t(x) \cdot u_t(x)]$$</p>
</li>
<li>
<p>Substitute the definition of the marginal vector field $u_t(x)$ derived in <strong>Theorem 1</strong>:
$$u_t(x) = \int u_t(x|x_1) \frac{p_t(x|x_1) q(x_1)}{p_t(x)} dx_1$$</p>
</li>
<li>
<p>Plug this into the integral. The $p_t(x)$ terms cancel:
$$\mathbb{E}_{t, p_t(x)} [v_t(x) \cdot u_t(x)] = \int_t \int_x p_t(x) v_t(x) \cdot \left[ \int_{x_1} u_t(x|x_1) \frac{p_t(x|x_1) q(x_1)}{p_t(x)} dx_1 \right] dx$$</p>
</li>
<li>
<p>This simplifies to:
$$= \int_t \int_x \int_{x_1} v_t(x) \cdot u_t(x|x_1) p_t(x|x_1) q(x_1) dx_1 dx dt$$</p>
</li>
<li>
<p>This is exactly the definition of the expectation in the <strong>CFM objective</strong>:
$$= \mathbb{E}_{t, q(x_1), p_t(x|x_1)} [v_t(x) \cdot u_t(x|x_1)]$$</p>
</li>
</ol>
<p><strong>Conclusion</strong>: Because the expectations of all terms involving $\theta$ are identical, the gradients must be identical.</p>
<p>Intuitively, this works like <strong>Denoising Score Matching</strong> or <strong>Stochastic Gradient Descent</strong>: even though each individual conditional vector field $u_t(x|x_1)$ points to a specific data point $x_1$ (which may differ from the true marginal direction), the <em>average</em> of all these pulls equals the true marginal vector field $u_t(x)$.</p></blockquote>
<p><strong>Theorem 3 (Gaussian Conditional VFs)</strong>:</p>
<p>For any Gaussian probability path $p_t(x|x_1) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I)$, the unique vector field generating it is available in closed form:</p>
<p>$$u_t(x|x_1) = \frac{\sigma&rsquo;_t(x_1)}{\sigma_t(x_1)}(x - \mu_t(x_1)) + \mu&rsquo;_t(x_1)$$</p>
<p>This theorem allows explicitly defining targets for both Diffusion (curved) and Optimal Transport (straight) paths.</p>
<blockquote>
<p><strong>Understanding the Proof:</strong></p>
<p>The derivation of Theorem 3 comes from the direct relationship between a flow map $\psi_t$ and its generating vector field. Because we chose a specific, simple path (Gaussian), we can invert the flow map to find the vector field in closed form.</p>
<p><strong>1. Define the Flow Map $\psi_t$</strong></p>
<p>We start by defining the conditional probability path as a Gaussian:</p>
<p>$$p_t(x|x_1) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I)$$</p>
<p>The simplest way to &ldquo;push&rdquo; a standard normal distribution (noise) $p_0 = \mathcal{N}(0, I)$ to this Gaussian is using an affine transformation (scaling and shifting). We define the flow map $\psi_t$ as:</p>
<p>$$\psi_t(x_0) = \sigma_t(x_1) x_0 + \mu_t(x_1)$$</p>
<p>This map takes a noise sample $x_0$ and transforms it into a sample $x$ at time $t$.</p>
<p><strong>2. The Definition of a Generating Vector Field</strong></p>
<p>By definition, a vector field $u_t$ generates a flow $\psi_t$ if the vector field describes the instantaneous velocity of the flow at any point. Mathematically:</p>
<p>$$u_t(\psi_t(x_0)) = \frac{d}{dt}\psi_t(x_0)$$</p>
<p>Let $x = \psi_t(x_0)$ be the position of the particle at time $t$. We want to find $u_t(x)$.</p>
<p><strong>3. Invert the Flow Map</strong></p>
<p>To find $u_t(x)$, we must express the equation in terms of $x$ rather than $x_0$. Since our flow map is a simple affine transformation (multiply and add), it is easily invertible (assuming $\sigma_t(x_1) \neq 0$):</p>
<p>$$x_0 = \frac{x - \mu_t(x_1)}{\sigma_t(x_1)}$$</p>
<p>We will call this inverse map $\psi_t^{-1}(x)$.</p>
<p><strong>4. Differentiate the Flow Map</strong></p>
<p>Now we calculate the left side of our definition equation (velocity): $\frac{d}{dt}\psi_t(x_0)$.</p>
<p>Taking the time derivative of $\psi_t(x_0) = \sigma_t(x_1) x_0 + \mu_t(x_1)$:</p>
<p>$$\frac{d}{dt}\psi_t(x_0) = \sigma&rsquo;_t(x_1) x_0 + \mu&rsquo;_t(x_1)$$</p>
<p>(Note: $\sigma&rsquo;_t$ and $\mu&rsquo;_t$ denote time derivatives).</p>
<p><strong>5. Substitute and Solve</strong></p>
<p>Now we combine everything. We know $u_t(\psi_t(x_0)) = \frac{d}{dt}\psi_t(x_0)$.</p>
<p>Substitute the result from Step 4 into this equation:</p>
<p>$$u_t(\psi_t(x_0)) = \sigma&rsquo;_t(x_1) x_0 + \mu&rsquo;_t(x_1)$$</p>
<p>This expresses the vector field in terms of the initial point $x_0$. We must express it in terms of the current point $x$. So, we plug in the inverse formula for $x_0$ derived in Step 3:</p>
<p>$$u_t(x|x_1) = \sigma&rsquo;_t(x_1) \frac{x - \mu_t(x_1)}{\sigma_t(x_1)} + \mu&rsquo;_t(x_1)$$</p>
<p>Rearranging terms gives the final closed form:</p>
<p>$$u_t(x|x_1) = \frac{\sigma&rsquo;_t(x_1)}{\sigma_t(x_1)}(x - \mu_t(x_1)) + \mu&rsquo;_t(x_1)$$</p>
<p><strong>Why is this useful?</strong></p>
<p>This formula means that as long as you can define a mean schedule $\mu_t(x_1)$ and a standard deviation schedule $\sigma_t(x_1)$ (which is easy to do for both Diffusion and Optimal Transport), you immediately get the exact vector field target $u_t(x|x_1)$ needed to train your neural network, bypassing complex ODE solving or score matching approximations.</p></blockquote>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Lipman, Y., Chen, R. T. Q., Ben-Hamu, H., Nickel, M., &amp; Le, M. (2023). Flow Matching for Generative Modeling. <em>International Conference on Learning Representations (ICLR)</em>.</p>
<p><strong>Publication</strong>: ICLR 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{lipmanFlowMatchingGenerative2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Flow Matching for Generative Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Lipman, Yaron and Chen, Ricky T. Q. and Ben-Hamu, Heli and Nickel, Maximilian and Le, Matt}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2210.02747">ArXiv</a></li>
</ul>
]]></content:encoded></item><item><title>Building Normalizing Flows with Stochastic Interpolants</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/stochastic-interpolants/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/stochastic-interpolants/</guid><description>A continuous-time normalizing flow using stochastic interpolants and quadratic loss to bypass costly ODE backpropagation.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, with significant <strong>Theory</strong> contributions.</p>
<p>The authors propose a specific algorithm (&ldquo;InterFlow&rdquo;) for constructing generative models based on continuous-time normalizing flows. The work is characterized by the derivation of a new training objective (a simple quadratic loss) that bypasses the computational bottlenecks of previous methods. It includes prominent baseline comparisons against continuous flow methods (FFJORD, OT-Flow) and diffusion models. The theoretical component establishes the validity of the interpolant density satisfying the continuity equation (a conservation law governing how probability mass flows) and bounds the Wasserstein-2 distance (a measure of transport cost between distributions, penalizing squared displacement) of the transport.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The primary motivation is to overcome the computational inefficiency of training Continuous Normalizing Flows (CNFs) using Maximum Likelihood Estimation (MLE). Standard CNF training requires backpropagating through numerical ODE solvers, which is costly and limits scalability.</p>
<p>Additionally, while score-based diffusion models (SDEs) have achieved high sample quality, they theoretically require infinite time integration and rely on specific noise schedules. The authors aim to establish a method that works strictly with Probability Flow ODEs on finite time intervals, retaining the flexibility to connect arbitrary densities without the complexity of SDEs or the cost of standard ODE adjoint methods.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>Stochastic Interpolant</strong> framework:</p>
<ul>
<li><strong>Explicit Interpolant Construction</strong>: The method defines a time-dependent interpolant $x_t = I_t(x_0, x_1)$ (e.g., trigonometric interpolation) that connects samples from the base density $\rho_0$ and target $\rho_1$.</li>
<li><strong>Simulation-Free Training</strong>: The velocity field $v_t(x)$ of the probability flow is learned by minimizing a simple quadratic objective: $G(\hat{v}) = \mathbb{E}[|\hat{v}_t(x_t)|^2 - 2\partial_t x_t \cdot \hat{v}_t(x_t)]$. Because $\partial_t I_t$ is known analytically from the interpolant definition, the expectation can be estimated by sampling $(x_0, x_1, t)$ directly. This avoids ODE integration during training (ODE integration is still required at inference).</li>
<li><strong>Decoupling Path and Optimization</strong>: The choice of path (interpolant) is separated from the optimization of the velocity field. MLE methods couple the path and objective.</li>
<li><strong>Connection to Score-Based Models</strong>: The authors show that for Gaussian base densities and trigonometric interpolants, the learned velocity field is explicitly related to the score function $\nabla \log \rho_t$, providing a theoretical bridge between CNFs and diffusion models.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors performed validation across synthetic, tabular, and image domains:</p>
<ul>
<li><strong>2D Density Estimation</strong>: Benchmarked on &ldquo;Checkerboard&rdquo;, &ldquo;8 Gaussians&rdquo;, and anisotropic curved densities to visualize mode coverage and transport smoothness.</li>
<li><strong>High-Dimensional Tabular Data</strong>: Evaluated on standard benchmarks (POWER, GAS, HEPMASS, MINIBOONE, BSDS300) comparing Negative Log Likelihood (NLL) against FFJORD, OT-Flow, and others.</li>
<li><strong>Image Generation</strong>: Trained models on CIFAR-10 ($32 \times 32$), ImageNet ($32 \times 32$), and Oxford Flowers ($128 \times 128$) to test scalability.</li>
<li><strong>Ablations</strong>: Investigated optimizing the interpolant path itself (e.g., learning Fourier coefficients for the path) to approach optimal transport and minimize path length.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Performance</strong>: The method matches or supersedes conventional ODE flows (like FFJORD) in terms of NLL while being significantly cheaper to train.</li>
<li><strong>Efficiency</strong>: The training cost per epoch is constant (simulation-free), whereas MLE-based ODE methods see growing costs as the dynamics become more complex.</li>
<li><strong>Scalability</strong>: The method successfully scales to $128 \times 128$ resolution on a single GPU, a resolution that prior ab-initio ODE flows had not demonstrated.</li>
<li><strong>Flexibility</strong>: The framework can connect <em>any</em> two arbitrary densities (e.g., connecting two different complex 2D distributions) without needing one to be Gaussian.</li>
<li><strong>Optimal Transport</strong>: For a fixed interpolant, minimizing $G(\hat{v})$ over the velocity field recovers the velocity for that specific path. Additionally optimizing over the interpolant family yields a solution to the Benamou-Brenier optimal transport problem.</li>
<li><strong>Limitations</strong>: The authors acknowledge that image FID scores trail dedicated diffusion models, noting that InterFlow was not optimized with standard training tricks such as exponential moving averages, truncation, or learning rate warm-ups. The framework&rsquo;s sample quality could likely improve with these additions.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Tabular Datasets</strong>: POWER (6D), GAS (8D), HEPMASS (21D), MINIBOONE (43D), BSDS300 (63D).
<ul>
<li>Training points range from ~30k (MINIBOONE) to ~1.6M (POWER).</li>
</ul>
</li>
<li><strong>Image Datasets</strong>:
<ul>
<li>CIFAR-10 ($32 \times 32$, 50k training points).</li>
<li>ImageNet ($32 \times 32$, ~1.28M training points).</li>
<li>Oxford Flowers ($128 \times 128$, ~315k training points).</li>
</ul>
</li>
<li><strong>Time Sampling</strong>: Time $t$ is sampled from a Beta distribution during training (reweighting) to focus learning near the target.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Interpolant</strong>: The primary interpolant used is trigonometric: $I_t(x_0, x_1) = \cos(\frac{\pi t}{2})x_0 + \sin(\frac{\pi t}{2})x_1$.
<ul>
<li>Alternative linear interpolant: $I_t = a_t x_0 + b_t x_1$.</li>
</ul>
</li>
<li><strong>Loss Function</strong>:
$$G(\hat{v}) = \mathbb{E}_{t, x_0, x_1}[|\hat{v}_t(x_t)|^2 - 2\partial_t I_t(x_0, x_1) \cdot \hat{v}_t(x_t)]$$
<ul>
<li>The expectation is amenable to empirical estimation using batches of $x_0, x_1, t$.</li>
</ul>
</li>
<li><strong>Sampling</strong>: Numerical integration using Dormand-Prince (Runge-Kutta 4/5).</li>
<li><strong>Optimization</strong>: SGD/Adam variants used for optimization.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Tabular Architectures</strong>:
<ul>
<li>Feed-forward networks with 4-5 hidden layers.</li>
<li>Hidden widths: 512 (POWER, GAS, HEPMASS, MINIBOONE) or 1024 (BSDS300).</li>
<li>Activation: ReLU (general) or ELU (BSDS300).</li>
</ul>
</li>
<li><strong>Image Architectures</strong>:
<ul>
<li>U-Net based on the DDPM implementation.</li>
<li>Dimensions: 256 hidden dimension.</li>
<li>Sinusoidal time embeddings used.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>: Negative Log Likelihood (NLL) in nats (tabular) or bits per dim (images), Frechet Inception Distance (FID) for images.</li>
<li><strong>Baselines</strong>: FFJORD, Glow, Real NVP, OT-Flow, ScoreFlow, DDPM.</li>
</ul>
<p><strong>Tabular NLL</strong> (nats, lower is better; Table 2 Left):</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>POWER</th>
          <th>GAS</th>
          <th>HEPMASS</th>
          <th>MINIBOONE</th>
          <th>BSDS300</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MADE</td>
          <td>3.08</td>
          <td>-3.56</td>
          <td>20.98</td>
          <td>15.59</td>
          <td>-148.85</td>
      </tr>
      <tr>
          <td>Real NVP</td>
          <td>-0.17</td>
          <td>-8.33</td>
          <td>18.71</td>
          <td>13.55</td>
          <td>-153.28</td>
      </tr>
      <tr>
          <td>Glow</td>
          <td>-0.17</td>
          <td>-8.15</td>
          <td>18.92</td>
          <td>11.35</td>
          <td>-155.07</td>
      </tr>
      <tr>
          <td>CPF</td>
          <td>-0.52</td>
          <td>-10.36</td>
          <td>16.93</td>
          <td>10.58</td>
          <td>-154.99</td>
      </tr>
      <tr>
          <td>NSP</td>
          <td>-0.64</td>
          <td>-13.09</td>
          <td>14.75</td>
          <td>9.67</td>
          <td>-157.54</td>
      </tr>
      <tr>
          <td>FFJORD</td>
          <td>-0.46</td>
          <td>-8.59</td>
          <td>14.92</td>
          <td>10.43</td>
          <td>-157.40</td>
      </tr>
      <tr>
          <td>OT-Flow</td>
          <td>-0.30</td>
          <td>-9.20</td>
          <td>17.32</td>
          <td>10.55</td>
          <td>-154.20</td>
      </tr>
      <tr>
          <td><strong>Ours</strong></td>
          <td><strong>-0.57</strong></td>
          <td><strong>-12.35</strong></td>
          <td><strong>14.85</strong></td>
          <td><strong>10.42</strong></td>
          <td><strong>-156.22</strong></td>
      </tr>
  </tbody>
</table>
<p><strong>Image Generation NLL and FID</strong> (Table 2 Right; NLL in bits per dim, lower is better):</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>CIFAR-10 NLL</th>
          <th>CIFAR-10 FID</th>
          <th>ImageNet-32 NLL</th>
          <th>ImageNet-32 FID</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FFJORD</td>
          <td>3.40</td>
          <td>-</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Glow</td>
          <td>3.35</td>
          <td>-</td>
          <td>4.09</td>
          <td>-</td>
      </tr>
      <tr>
          <td>DDPM</td>
          <td>≤3.75</td>
          <td>3.17</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>DDPM++ (Song et al., 2021)</td>
          <td>≤3.37</td>
          <td>2.90</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>ScoreSDE (Song et al., 2021)</td>
          <td>2.99</td>
          <td>2.92</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>VDM</td>
          <td>≤2.65</td>
          <td>7.41</td>
          <td>≤3.72</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Soft Truncation</td>
          <td>2.88</td>
          <td>3.45</td>
          <td>3.85</td>
          <td>8.42</td>
      </tr>
      <tr>
          <td>ScoreFlow</td>
          <td>2.81</td>
          <td>5.40</td>
          <td>3.76</td>
          <td>10.18</td>
      </tr>
      <tr>
          <td><strong>Ours</strong></td>
          <td><strong>2.99</strong></td>
          <td><strong>10.27</strong></td>
          <td><strong>3.48</strong></td>
          <td><strong>8.49</strong></td>
      </tr>
  </tbody>
</table>
<p>Note: DDPM++ is from Song et al. (2021), the same work as ScoreSDE (it is the architecture optimized for VP/sub-VP SDEs). InterFlow matches ScoreSDE on CIFAR-10 NLL (2.99 bits per dim) while being simulation-free. FID is weaker than dedicated image models (10.27 vs 2.92 for ScoreSDE), reflecting the paper&rsquo;s primary focus on tractable likelihood rather than sample quality.</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: All models were trained on a single NVIDIA A100 GPU.</li>
<li><strong>Training Time</strong>:
<ul>
<li>Tabular: $10^5$ steps.</li>
<li>Images: $1.5 \times 10^5$ to $6 \times 10^5$ steps.</li>
<li>Speedup: Demonstrated ~400x speedup compared to FFJORD on MiniBooNE dataset.</li>
</ul>
</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>lucidrains/denoising-diffusion-pytorch (link defunct)</td>
          <td>Code</td>
          <td>MIT</td>
          <td>Base U-Net architecture used for image experiments; original GitHub account no longer available</td>
      </tr>
  </tbody>
</table>
<p>No official code release accompanies this paper. All tabular datasets (POWER, GAS, HEPMASS, MINIBOONE, BSDS300) are publicly available from prior work. CIFAR-10 and ImageNet are standard public benchmarks. Oxford Flowers 102 is also publicly available. Hyperparameters and architectures are fully specified in Tables 3 and 4 of the paper.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Albergo, M. S., &amp; Vanden-Eijnden, E. (2023). Building Normalizing Flows with Stochastic Interpolants. <em>The Eleventh International Conference on Learning Representations</em>.</p>
<p><strong>Publication</strong>: ICLR 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{albergoBuildingNormalizingFlows2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Building {{Normalizing Flows}} with {{Stochastic Interpolants}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{The {{Eleventh International Conference}} on {{Learning Representations}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Albergo, Michael Samuel and {Vanden-Eijnden}, Eric}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2023</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://openreview.net/forum?id=li7qeBbCR1t}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=li7qeBbCR1t">OpenReview</a></li>
<li><a href="https://arxiv.org/abs/2209.15571">arXiv</a></li>
</ul>
]]></content:encoded></item><item><title>A Convexity Principle for Interacting Gases (McCann 1997)</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/convexity-principle-interacting-gases/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/convexity-principle-interacting-gases/</guid><description>Introduces displacement interpolation to prove ground state uniqueness via optimal transport, with mathematical tools later used in generative modeling.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <a href="/notes/interdisciplinary/research-methods/ai-physical-sciences-paper-taxonomy/">Theory</a> paper. It relies entirely on formal mathematical derivation to establish existence and uniqueness properties for energy functionals. It introduces a new mathematical structure (displacement interpolation) to analyze the geometry of probability measures.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The paper addresses the uniqueness of stationary configurations (ground states) for a gas model where particles interact via attractive forces while resisting compression.</p>
<p>The total energy functional $E(\rho)$ includes an interaction term $G(\rho)$ that lacks convexity under standard linear interpolation ($(1-t)\rho + t\rho&rsquo;$), making it difficult to prove that a unique minimizer exists. Standard convexity tools and rearrangement inequalities are also insufficient for cases without specific symmetries (like spherical symmetry) or when convexity of the potential fails.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the introduction of <strong>Displacement Interpolation</strong>.</p>
<ul>
<li><strong>New Interpolant</strong>: The paper defines an interpolant $\rho_t$ by moving mass along the gradient of a convex potential $\psi$ (transport map).</li>
<li><strong>Displacement Convexity</strong>: It proves that the internal energy $U(\rho)$ and potential energy $G(\rho)$ become convex functions of $t$ along this displacement path. This is a property specific to displacement interpolation.</li>
<li><strong>Generalization</strong>: This framework generalizes the classical <strong>Brunn-Minkowski inequality</strong> from sets to measures.</li>
</ul>
<h3 id="theoretical-framework">Theoretical Framework</h3>
<h4 id="mathematical-setup">Mathematical Setup</h4>
<p><strong>Probability Measures</strong></p>
<p>The gas state is represented by absolutely continuous probability measures $\rho \in \mathcal{P}_{ac}(\mathbb{R}^d)$ with finite second moments.</p>
<p><strong>Energy Functional</strong></p>
<p>The gas model is defined by the total energy functional $E(\rho)$:
$$E(\rho) := \underbrace{\int_{\mathbb{R}^d} A(\rho(x))dx}_{\text{Internal Energy } U(\rho)} + \underbrace{\frac{1}{2} \iint d\rho(x)V(x-y)d\rho(y)}_{\text{Potential Energy } G(\rho)}$$</p>
<h4 id="key-construction-displacement-interpolation">Key Construction: Displacement Interpolation</h4>
<p>The core theoretical tool is the construction of the interpolant $\rho_t$ between two probability measures $\rho$ and $\rho&rsquo;$:</p>
<ol>
<li><strong>Transport Map</strong>: By Brenier&rsquo;s theorem, there exists a convex function $\psi$ such that $\nabla\psi_{\#}\rho = \rho&rsquo;$ (push-forward).</li>
<li><strong>Interpolation</strong>: The interpolant at time $t \in [0,1]$ is defined as the push-forward of $\rho$ under the linear interpolation of the identity and the transport map:
$$\rho_t := [(1-t)\text{id} + t\nabla\psi]_{\#}\rho$$</li>
</ol>
<p>This is the &ldquo;displacement interpolation&rdquo; where mass moves along straight lines from initial to final positions.</p>
<h4 id="assumptions-for-uniqueness">Assumptions for Uniqueness</h4>
<p>The main existence and uniqueness theorem (Theorem 3.1) requires one condition on the interaction potential, two conditions on the equation of state, and one regularity condition:</p>
<ol>
<li><strong>Interaction</strong>: $V(x)$ is strictly convex.</li>
<li><strong>(P1) Equation of State</strong>: $P(\rho) / \rho^{(d-1)/d}$ is non-decreasing. This is equivalent to convexity of $U$ under mass-preserving dilations, and is satisfied by polytropic gases $P(\rho) = \rho^q$ with $q &gt; 1$.</li>
<li><strong>(P2) Growth Condition</strong>: $P(\rho) \cdot \rho^{-2}$ is not integrable at $\infty$. This ensures the energy minimizer has no singular part with respect to Lebesgue measure.</li>
<li><strong>Regularity</strong>: $\rho \in \mathcal{P}_{ac}(\mathbb{R}^d)$ (absolutely continuous probability measures).</li>
</ol>
<h4 id="main-results">Main Results</h4>
<p><strong>Theorem 2.2</strong> (Displacement Convexity of Internal Energy): Under condition (A1) (that $\lambda^d A(\lambda^{-d})$ is convex non-increasing on $(0, \infty)$ with $A(0) = 0$, ensuring internal energy decreases as the gas dilates), the internal energy $U(\rho)$ is convex along displacement interpolation paths. Strict convexity follows unless $\nabla^2\psi(x) = I$ holds $\rho$-a.e., i.e., $\rho&rsquo;$ is a translate of $\rho$.</p>
<p><strong>Theorem 3.1</strong> (Existence and Uniqueness of Ground State): For any equation of state satisfying (P1) and (P2) with a strictly convex interaction potential $V$, the total energy $E(\rho)$ attains a unique minimizer up to translation. The minimizer can be taken to be even, meaning $\rho_g(x) = \rho_g(-x)$.</p>
<p><strong>Theorem 3.3</strong> (Uniqueness for Spherically Symmetric Potentials): When the strict convexity of $V(x)$ is relaxed to spherical symmetry (with $V$ not constant), uniqueness up to translation still holds provided (P1) holds strictly. This extends the main result to cases like Coulomb-type interactions.</p>
<p><strong>Lemma 3.2</strong>: A decomposition lemma for convex functions. Let $\psi$ and $\phi$ be convex on $\mathbb{R}^d$, and let $\Omega \subset \mathbb{R}^d$ be an open convex set on which both are finite. Suppose $\phi$ is differentiable on $\Omega$ with a locally Lipschitz gradient. If their Aleksandrov second derivatives agree almost everywhere on $\Omega$, then $\psi - \phi$ is convex on $\Omega$. This underpins the proof of Theorem 3.3.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The validation consists entirely of rigorous mathematical proofs:</p>
<ul>
<li><strong>Convexity Proofs</strong>: Deriving inequalities to show $E(\rho_t) \le (1-t)E(\rho) + tE(\rho&rsquo;)$.</li>
<li><strong>Existence/Uniqueness</strong>: Using the new convexity principle to prove that the energy minimizer is unique up to translation.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Uniqueness of Ground State</strong>: For equations of state satisfying specific monotonicity conditions (e.g., polytropic gases), the energy minimizing state is unique up to translation.</li>
<li><strong>Brunn-Minkowski Extension</strong>: The internal energy convexity implies the Brunn-Minkowski inequality as a special case ($A(\rho) = -\rho^{(d-1)/d}$).</li>
<li><strong>Norm Concavity</strong>: The functional $|\rho_t|_q^{-p/d}$ is shown to be concave along the interpolation path for conjugate $p, q$ with $q \geq (d-1)/d$.</li>
</ul>
<h3 id="relevance-to-machine-learning">Relevance to Machine Learning</h3>
<p>This 1997 paper establishes the mathematical foundations of displacement convexity in optimal transport theory, which underpins several modern generative modeling techniques. The displacement interpolation framework introduced here is used in:</p>
<ul>
<li><strong>Flow Matching</strong>: Uses optimal transport probability paths (straight-line interpolations with constant speed) to generate samples. See the <a href="../flow-matching-for-generative-modeling/">Flow Matching note</a> for details on how OT paths differ from diffusion paths.</li>
<li><strong>Wasserstein GANs</strong>: Use the Wasserstein distance (optimal transport metric) for training stability.</li>
<li><strong>Continuous Normalizing Flows</strong>: Use OT-inspired transport maps for probability density transformation.</li>
</ul>
<p>McCann&rsquo;s convexity principle proves that energy functionals become convex along displacement paths, a mathematical structure that underpins the geometry used in flow matching and optimal transport-based generative modeling.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: McCann, R. J. (1997). A Convexity Principle for Interacting Gases. <em>Advances in Mathematics</em>, 128(1), 153-179. <a href="https://doi.org/10.1006/aima.1997.1634">https://doi.org/10.1006/aima.1997.1634</a></p>
<p><strong>Publication</strong>: Advances in Mathematics 1997</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{mccannConvexityPrincipleInteracting1997,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{A {{Convexity Principle}} for {{Interacting Gases}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{McCann, Robert J.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">1997</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jun,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Advances in Mathematics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{128}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{153--179}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{00018708}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1006/aima.1997.1634}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-12-21}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Second-Order Langevin Equation for Field Simulations</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/classical-methods/second-order-langevin-1987/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/classical-methods/second-order-langevin-1987/</guid><description>Hyperbolic Algorithm adds second-order derivatives to Langevin dynamics, reducing systematic errors to O(ε²) for lattice field simulations.</description><content:encoded><![CDATA[<h2 id="contribution-and-paper-type">Contribution and Paper Type</h2>
<p>This is a <strong>Methodological Paper</strong> ($\Psi_{\text{Method}}$). It proposes a novel stochastic algorithm, the Hyperbolic Algorithm (HA), and validates its superior efficiency against the existing Langevin Algorithm (LA) through formal error analysis and numerical simulation. It contains significant theoretical derivation (Liouville dynamics) that serves primarily to justify the algorithmic performance claims.</p>
<h2 id="motivation-and-gaps-in-prior-work">Motivation and Gaps in Prior Work</h2>
<p>The standard Langevin Algorithm (LA) for numerical simulation of Euclidean field theories suffers from efficiency bottlenecks. The simplest Euler-discretization of the LA introduces systematic errors of $O(\epsilon)$ (where $\epsilon$ is the step size). To maintain accuracy, $\epsilon$ must be kept small, which increases the sweep-sweep correlation time (autocorrelation time), making simulations computationally expensive.</p>
<h2 id="core-novelty-second-order-dynamics">Core Novelty: Second-Order Dynamics</h2>
<p>The core contribution is the introduction of a <strong>second-order derivative in fictitious time</strong> to the stochastic equation. This converts the parabolic Langevin equation into a hyperbolic equation:</p>
<p>$$
\begin{aligned}
\frac{\partial^{2}\phi}{\partial t^{2}}+\gamma\frac{\partial\phi}{\partial t}=-\frac{\partial S}{\partial\phi}+\eta
\end{aligned}
$$</p>
<h3 id="equation-comparison">Equation Comparison</h3>
<p>The key difference from the standard (first-order) Langevin equation:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Equation Type</th>
          <th style="text-align: left">Formula</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Hyperbolic (Second Order)</strong></td>
          <td style="text-align: left">$$\frac{\partial^{2}\phi}{\partial t^{2}}+\gamma\frac{\partial\phi}{\partial t}=-\frac{\partial S}{\partial\phi}+\eta$$</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Langevin (First Order)</strong></td>
          <td style="text-align: left">$$\frac{\partial\phi}{\partial t}=-\frac{\partial S}{\partial\phi}+\eta$$</td>
      </tr>
  </tbody>
</table>
<p>The standard Langevin equation corresponds to the overdamped limit where the acceleration term is absent. Physically, the Hyperbolic equation can be viewed as microcanonical equations of motion with an added friction term.</p>
<h3 id="key-innovations">Key Innovations</h3>
<ul>
<li><strong>Higher Order Accuracy</strong>: The simplest discretization of this equation leads to systematic errors of only $O(\epsilon^2)$ compared to $O(\epsilon)$ for LA.</li>
<li><strong>Tunable Damping</strong>: The addition of the damping parameter $\gamma$ allows tuning to minimize autocorrelation tails.</li>
<li><strong>Uniform Evolution</strong>: The method evolves structures of different wavelengths more uniformly than LA due to the specific dissipation structure.</li>
</ul>
<h2 id="methodology-and-experiments">Methodology and Experiments</h2>
<p>The author validated the method using the <strong>XY Model</strong> on 2D lattices.</p>
<ul>
<li><strong>System</strong>: Euclidean action $S = -\sum_{x,\mu} \cos(\theta_{x+\mu} - \theta_x)$.</li>
<li><strong>Setup</strong>:
<ul>
<li>Lattice sizes: $15^2$ (helical boundary conditions) and $30^2$.</li>
<li>$\beta$ range: 0.9 to 1.2 (crossing the critical point $\approx 1.0$).</li>
<li>Run length: &gt;100,000 updates in equilibrium.</li>
</ul>
</li>
<li><strong>Metrics</strong>:
<ul>
<li><strong>Autocorrelation time ($\tau$)</strong>: Defined as the number of updates for the time-correlation function to drop to 10% of its initial value.</li>
<li><strong>Systematic Error</strong>: Measured via deviation of average action from Monte Carlo values.</li>
</ul>
</li>
</ul>
<h2 id="results-and-conclusions">Results and Conclusions</h2>
<ul>
<li><strong>Efficiency</strong>: The Hyperbolic Algorithm (HA) is far more efficient. For equal systematic errors, sweep-sweep correlation times are significantly lower than LA.</li>
<li><strong>Error Scaling</strong>: Numerical results confirmed that HA step size $\epsilon_H = 0.1$ yields systematic errors comparable to LA step size $\epsilon_L \approx 0.008$ ($O(\epsilon^2)$ vs $O(\epsilon)$ scaling).</li>
<li><strong>Speedup</strong>: In the disordered phase, HA is roughly $\epsilon_H / \epsilon_L$ times faster (approximately a factor of 12.5 for $\epsilon_H = 0.1$, $\epsilon_L = 0.008$). In the ordered phase, efficiency gains increase with distance scale, reaching factors of 20 or more for long-range correlations.</li>
<li><strong>Optimal Damping</strong>: For the XY model, the optimal damping parameter was found to be $\gamma \approx 0.4$.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. The Hyperbolic Algorithm (HA)</strong></p>
<p>The discretized update equations for scalar fields are:</p>
<p>$$
\begin{aligned}
\pi_{t+\epsilon} - \pi_{t} &amp;= -\epsilon\gamma\pi_{t} - \epsilon\frac{\partial S}{\partial\phi_{t}} + \sqrt{2\epsilon\gamma/\beta}\xi_{t} \\
\phi_{t+\epsilon} - \phi_{t} &amp;= \epsilon\pi_{t+\epsilon}
\end{aligned}
$$</p>
<ul>
<li><strong>Variables</strong>: $\phi$ is the field, $\pi$ is the conjugate momentum ($\dot{\phi}$).</li>
<li><strong>Parameters</strong>: $\epsilon$ (step size), $\gamma$ (damping constant).</li>
<li><strong>Noise</strong>: $\xi$ is Gaussian noise with $\langle\xi_x \xi_y\rangle = \delta_{x,y}$.</li>
<li><strong>Storage</strong>: Requires storing both $\phi$ and $\pi$ vectors.</li>
</ul>
<p><strong>2. Non-Abelian Generalization</strong></p>
<p>For Lie group elements $U$ with generators $T^a$:</p>
<p>$$
\begin{aligned}
\pi_{t+\epsilon}^a - \pi_{t}^a &amp;= -\epsilon\gamma\pi_{t}^a - \epsilon\delta^a S[U_t] + \sqrt{2\epsilon\gamma/\beta}\xi_{t}^a \\
U_{t+\epsilon} &amp;= e^{i\epsilon\pi_{t+\epsilon}^a T^a} U_t
\end{aligned}
$$</p>
<h3 id="theoretical-proof-of-oepsilon2-accuracy">Theoretical Proof of $O(\epsilon^2)$ Accuracy</h3>
<p>The derivation relies on the generalized Liouville equation for the probability distribution $P[\phi, \pi; t]$.</p>
<ol>
<li><strong>Transition Probability</strong>: The transition $W$ for one iteration is defined.</li>
<li><strong>Effective Liouville Operator</strong>: The evolution is written as $P(t+\epsilon) = \exp(\epsilon L_{\text{eff}}) P(t)$.</li>
<li><strong>Baker-Hausdorff Expansion</strong>: Using normal ordering of operators, the equilibrium distribution $P_{\text{eq}}$ is derived through $O(\epsilon^2)$:</li>
</ol>
<p>$$
\begin{aligned}
P_{\text{eq}} &amp;= \exp\left\lbrace-\frac{1}{2}\beta_{1}\sum_{x}\pi_{x}^{2} - \beta S[\phi] + \frac{1}{2}\epsilon\beta\sum_{x}\pi_{x}S_{x} + \epsilon^{2}G + O(\epsilon^3)\right\rbrace
\end{aligned}
$$</p>
<p>where $\beta_1 = \beta\left(1 - \frac{1}{2}\epsilon\gamma\right)$.</p>
<ol start="4">
<li><strong>Effective Action</strong>: Integrating out $\pi$ yields the effective action for $\phi$:</li>
</ol>
<p>$$
\begin{aligned}
S_{\text{eff}}[\phi] &amp;= S[\phi] - \frac{1}{8}\epsilon^2 \sum_x S_x^2 + \dots
\end{aligned}
$$</p>
<p>The absence of $O(\epsilon)$ terms proves the higher-order accuracy.</p>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Model</strong>: XY Model (2D)</li>
<li><strong>Hamiltonian</strong>: $H = \frac{1}{2}\sum \pi^2 + S[\phi]$ where $S = -\sum \cos(\Delta \theta)$.</li>
<li><strong>Observables</strong>:
<ul>
<li>$\Gamma_n = \cos(\theta_{m+n} - \theta_m)$ (averaged over lattice $m$).</li>
</ul>
</li>
<li><strong>Comparisons</strong>:
<ul>
<li><strong>LA Step</strong>: $\epsilon_L \approx 0.005 - 0.02$.</li>
<li><strong>HA Step</strong>: $\epsilon_H \approx 0.1 - 0.2$.</li>
<li><strong>Equivalence</strong>: $\epsilon_H = 0.1$ matches error of $\epsilon_L \approx 0.008$.</li>
</ul>
</li>
</ul>
<hr>
<h2 id="terminology-note">Terminology Note</h2>
<p>The naming conventions in this paper differ from those commonly used in molecular dynamics (MD). The following table provides a cross-field mapping:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Concept</th>
          <th style="text-align: left"><strong>Field Theory (This Paper)</strong></th>
          <th style="text-align: left"><strong>Molecular Dynamics</strong></th>
          <th style="text-align: left"><strong>Mathematics</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Equation 1</strong></td>
          <td style="text-align: left">&ldquo;Langevin Equation&rdquo;</td>
          <td style="text-align: left">Brownian Dynamics (BD)</td>
          <td style="text-align: left">Overdamped Langevin</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Equation 2</strong></td>
          <td style="text-align: left">&ldquo;Hyperbolic Equation&rdquo;</td>
          <td style="text-align: left">Langevin Dynamics (LD)</td>
          <td style="text-align: left">Underdamped Langevin</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Integrator 1</strong></td>
          <td style="text-align: left">Euler Discretization</td>
          <td style="text-align: left">Euler Integrator</td>
          <td style="text-align: left">Euler-Maruyama</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Integrator 2</strong></td>
          <td style="text-align: left">Hyperbolic Algorithm (HA)</td>
          <td style="text-align: left">Velocity Verlet / Leapfrog</td>
          <td style="text-align: left">Quasi-Symplectic Splitting</td>
      </tr>
  </tbody>
</table>
<p><strong>Key insight</strong>: The paper&rsquo;s &ldquo;Hyperbolic Algorithm&rdquo; is mathematically equivalent to Langevin Dynamics with a Leapfrog/Verlet integrator, commonly used in MD. The baseline &ldquo;Langevin Algorithm&rdquo; corresponds to Brownian Dynamics. The term &ldquo;Langevin equation&rdquo; is overloaded: field theorists often use it for overdamped dynamics (no inertia), while chemists assume it includes momentum ($F=ma$).</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Horowitz, A. M. (1987). The Second Order Langevin Equation and Numerical Simulations. <em>Nuclear Physics B</em>, 280, 510-522. <a href="https://doi.org/10.1016/0550-3213(87)90159-3">https://doi.org/10.1016/0550-3213(87)90159-3</a></p>
<p><strong>Publication</strong>: Nuclear Physics B 1987</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{horowitzSecondOrderLangevin1987,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{The Second Order {{Langevin}} Equation and Numerical Simulations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Horowitz, Alan M.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">1987</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jan,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Nuclear Physics B}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{280}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{510--522}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{05503213}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1016/0550-3213(87)90159-3}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Mixture Density Networks: Modeling Multimodal Distributions</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/mixture-density-networks/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/mixture-density-networks/</guid><description>A 1994 technical report introducing Mixture Density Networks (MDNs) to model arbitrary conditional probability distributions using neural networks.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper.</p>
<p>It identifies a specific failure mode in existing neural network methodologies (least-squares regression on multi-valued inverse problems) and proposes a novel architecture (combining MLPs with Mixture Models) to solve it. It derives the mathematical framework for training this architecture via standard back-propagation and validates it against the established baseline.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>Standard neural networks trained with sum-of-squares (MSE) or cross-entropy error functions approximate the <strong>conditional average</strong> of the target data, $\langle t|x \rangle$.</p>
<p>While optimal for single-valued functions or classification, this produces completely erroneous results for <strong>inverse problems</strong> where the mapping is multi-valued (one input has multiple valid outputs). For example, in robot inverse kinematics, &ldquo;elbow-up&rdquo; and &ldquo;elbow-down&rdquo; configurations can achieve the same hand position. An MSE-trained network will average these two valid angles, resulting in an invalid configuration (the paper shows this produces end-effector positions at the outer boundary of the accessible region, corresponding to $\theta_2 = \pi$).</p>















<figure class="post-figure center ">
    <img src="/img/notes/single-gaussian-mse-prediction.webp"
         alt="Single Gaussian MSE prediction averaging multimodal distribution"
         title="Single Gaussian MSE prediction averaging multimodal distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MSE-trained networks predict the mean, which averages across modes and produces invalid outputs for inverse problems.</figcaption>
    
</figure>

<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The introduction of the <strong>Mixture Density Network (MDN)</strong>.</p>
<p>The neural network predicts the <strong>parameters</strong> (mixing coefficients, means, and variances) of a kernel mixture distribution (typically Gaussian).</p>















<figure class="post-figure center ">
    <img src="/img/notes/gaussian-mixture-mdn-prediction.webp"
         alt="Gaussian mixture model prediction capturing multimodal distribution"
         title="Gaussian mixture model prediction capturing multimodal distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MDNs predict mixture parameters to capture the full conditional probability density, representing all modes.</figcaption>
    
</figure>

<p>Key technical contributions include:</p>
<ol>
<li><strong>Architecture</strong>: Mapping network outputs to mixture parameters using specific activation functions to satisfy constraints (Softmax for priors $\alpha$, Exponential for variances $\sigma$).</li>
<li><strong>Training</strong>: Deriving the error function as the negative log-likelihood of the mixture model.</li>
<li><strong>Optimization</strong>: Deriving the exact derivatives (gradients) of the error with respect to network outputs, allowing the mixture model parameters to be learned via standard back-propagation.</li>
</ol>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>Bishop validated the method on two tasks, comparing an MDN against a standard MLP trained with least-squares:</p>
<ol>
<li><strong>Toy Inverse Problem</strong>: A sinusoidal mapping $x = t + 0.3\sin(2\pi t) + \epsilon$. The forward problem ($t \to x$) is single-valued, but the inverse ($x \to t$) is multi-valued.</li>
<li><strong>Robot Kinematics</strong>: A 2-link robot arm simulation. The task is to map end-effector Cartesian coordinates $(x_1, x_2)$ back to joint angles $(\theta_1, \theta_2)$.</li>
</ol>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Toy Problem</strong>: The standard least-squares network failed completely, drawing a smooth curve through the average of the multiple branches, which did not correspond to valid data. The MDN correctly modeled the tri-modal density and discontinuous jumps in the most probable solution.</li>
<li><strong>Robot Kinematics</strong>: The MDN reduced the RMS positioning error by an order of magnitude compared to the standard network (0.0053 vs 0.0578).</li>
<li><strong>Generality</strong>: The paper concludes that MDNs provide a complete description of the conditional probability density, allowing users to calculate any statistic (mean, mode, variance) needed for the application.</li>
</ul>
<h2 id="extracting-predictions">Extracting Predictions</h2>
<p>Once trained, the MDN outputs a full conditional density $p(t|x)$, from which several useful statistics can be derived:</p>
<ul>
<li><strong>Conditional mean</strong>: $\langle t|x \rangle = \sum_i \alpha_i(x) \mu_i(x)$, equivalent to the standard least-squares network output.</li>
<li><strong>Conditional variance</strong>: $s^2(x) = \sum_i \alpha_i(x) { \sigma_i(x)^2 + | \mu_i(x) - \sum_j \alpha_j(x) \mu_j(x) |^2 }$, which is input-dependent (more general than the constant-variance least-squares assumption).</li>
<li><strong>Most probable branch</strong>: Select the kernel $i$ with the largest mixing coefficient $\alpha_i(x)$, then use its center $\mu_i$ as the prediction. This yields a discontinuous but accurate mapping for multi-valued problems.</li>
</ul>
<h2 id="limitations">Limitations</h2>
<ul>
<li><strong>Model order selection</strong>: The number of mixture components $m$ must be chosen in advance. The paper acknowledges this as an open problem and suggests cross-validation or Bayesian model comparison as potential approaches.</li>
<li><strong>Computational overhead</strong>: The number of network outputs grows as $(c + 2) \times m$, where $c$ is the target dimensionality. For high-dimensional targets or many kernels, this can become significant.</li>
<li><strong>Isotropic kernels</strong>: The paper uses a single variance parameter $\sigma_i$ per kernel (shared across target dimensions), which assumes isotropic covariance. The paper notes this can be generalized to full covariance matrices at the cost of additional parameters.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>1. Toy Inverse Problem</strong></p>
<ul>
<li><strong>Function</strong>: $x = t + 0.3\sin(2\pi t) + \epsilon$</li>
<li><strong>Noise</strong>: $\epsilon \sim U(-0.1, 0.1)$</li>
<li><strong>Sampling</strong>: 1,000 points generated by sampling $t$ at equal intervals in range $(0, 1)$.</li>
<li><strong>Task</strong>: Inverse mapping (predict $t$ given $x$).</li>
</ul>
<p><strong>2. Robot Kinematics</strong></p>
<ul>
<li><strong>System</strong>: 2-link arm with lengths $L_1=0.8, L_2=0.2$.</li>
<li><strong>Forward Kinematics</strong>:
<ul>
<li>$x_1 = L_1 \cos(\theta_1) - L_2 \cos(\theta_1 + \theta_2)$</li>
<li>$x_2 = L_1 \sin(\theta_1) - L_2 \sin(\theta_1 + \theta_2)$</li>
</ul>
</li>
<li><strong>Constraints</strong>: $\theta_1 \in (0.3, 1.2)$, $\theta_2 \in (\pi/2, 3\pi/2)$.</li>
<li><strong>Dataset</strong>: 1,000 training points, 1,000 test points.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Mixture Model Definition</strong></p>
<p>The conditional density is defined as:</p>
<p>$$p(t|x) = \sum_{i=1}^{m} \alpha_i(x) \phi_i(t|x)$$</p>
<p>Where kernels $\phi_i$ are Gaussians with centers $\mu_i(x)$ and variances $\sigma_i(x)$.</p>
<p><strong>Network Output Mappings</strong></p>
<p>If the network produces raw outputs $z$, they are mapped to parameters as follows to satisfy probability constraints:</p>
<ul>
<li><strong>Mixing Coefficients ($\alpha$)</strong>: Softmax. $\alpha_i = \frac{\exp(z_i^\alpha)}{\sum_j \exp(z_j^\alpha)}$</li>
<li><strong>Variances ($\sigma$)</strong>: Exponential. $\sigma_i = \exp(z_i^\sigma)$</li>
<li><strong>Means ($\mu$)</strong>: Linear/Identity. $\mu_{ik} = z_{ik}^\mu$</li>
</ul>
<p><strong>Loss Function</strong></p>
<p>Negative Log Likelihood:</p>
<p>$$E^q = - \ln \left{ \sum_{i=1}^{m} \alpha_i(x^q) \phi_i(t^q|x^q) \right}$$</p>
<h3 id="models">Models</h3>
<p><strong>1. Toy Problem Configuration</strong></p>
<ul>
<li><strong>Structure</strong>: MLP with 1 input ($x$), 1 hidden layer.</li>
<li><strong>Hidden Units</strong>: 20 units (tanh activation).</li>
<li><strong>Outputs</strong>: 9 units.
<ul>
<li>$m=3$ Gaussian kernels.</li>
<li>Parameters per kernel: 1 $\alpha$, 1 $\sigma$, 1 $\mu$. Total = $3 \times 3 = 9$.</li>
</ul>
</li>
<li><strong>Training</strong>: 1,000 cycles of BFGS.</li>
</ul>
<p><strong>2. Robot Kinematics Configuration (Least-Squares Baseline)</strong></p>
<ul>
<li><strong>Structure</strong>: MLP with 2 inputs ($x_1, x_2$), 2 linear outputs ($\theta_1, \theta_2$).</li>
<li><strong>Hidden Units</strong>: Best result with 20 units (tanh activation), tested with 5, 10, 15, 20, 25, 30.</li>
<li><strong>Training</strong>: 3,000 cycles of BFGS.</li>
</ul>
<p><strong>3. Robot Kinematics Configuration (MDN)</strong></p>
<ul>
<li><strong>Structure</strong>: MLP with 2 inputs ($x_1, x_2$).</li>
<li><strong>Hidden Units</strong>: 10 units (tanh activation).</li>
<li><strong>Outputs</strong>: 8 units.
<ul>
<li>$m=2$ Gaussian kernels.</li>
<li>Target dimension $c=2$ (predicting $\theta_1, \theta_2$).</li>
<li>Parameters per kernel: 1 $\alpha$ + 1 $\sigma$ (common variance) + 2 $\mu$ (means for $\theta_1, \theta_2$).</li>
<li>Total = $2 \times (1 + 1 + 2) = 8$.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metric</strong>: RMS Euclidean distance between the desired end-effector position and the achieved position (calculated by plugging predicted angles back into forward kinematics).</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Hidden Units</th>
          <th>Kernels</th>
          <th>RMS Error</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Least Squares</td>
          <td>20</td>
          <td>N/A</td>
          <td>0.0578</td>
      </tr>
      <tr>
          <td>MDN</td>
          <td>10</td>
          <td>2</td>
          <td>0.0053</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Bishop, C. M. (1994). Mixture Density Networks. <em>Neural Computing Research Group Report: NCRG/94/004</em>, Aston University.</p>
<p><strong>Publication</strong>: Neural Computing Research Group Technical Report 1994</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@techreport</span>{bishopMixtureDensityNetworks1994,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Mixture {{Density Networks}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Bishop, Christopher M.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">1994</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = feb,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{NCRG/94/004}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">institution</span> = <span style="color:#e6db74">{Aston University}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Embedded-Atom Method User Guide: Voter's 1994 Chapter</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/classical-methods/embedded-atom-method-voter-1994/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/classical-methods/embedded-atom-method-voter-1994/</guid><description>Comprehensive user guide for the Embedded-Atom Method (EAM), covering theory, potential fitting, and applications to intermetallics.</description><content:encoded><![CDATA[<h2 id="contribution-systematizing-the-embedded-atom-method">Contribution: Systematizing the Embedded-Atom Method</h2>
<p>This is a <strong>Systematization</strong> paper (specifically a handbook chapter) with a strong secondary <strong>Method</strong> projection.</p>
<p>Its primary goal is to serve as a &ldquo;users&rsquo; guide&rdquo; to the Embedded-Atom Method (EAM). The text organizes existing knowledge:</p>
<ul>
<li>It traces the physical origins of EAM from Density Functional Theory (DFT) and Effective Medium Theory.</li>
<li>It synthesizes &ldquo;closely related methods&rdquo; (Second Moment Approximation, Glue Model), showing they are mathematically equivalent or very similar to EAM.</li>
<li>It provides a pedagogical, step-by-step methodology for fitting potentials to experimental data.</li>
</ul>
<h2 id="motivation-bridging-the-gap-between-dft-and-pair-potentials">Motivation: Bridging the Gap Between DFT and Pair Potentials</h2>
<p>The primary motivation is to bridge the gap between accurate, expensive electronic structure calculations and fast, inaccurate pair potentials.</p>
<ul>
<li><strong>Computational Efficiency</strong>: First-principles methods scale as $O(N^3)$ or worse, limiting simulations to $&lt;100$ atoms (in 1994). Pair potentials scale as $O(N)$ and fail to capture essential many-body physics of metals.</li>
<li><strong>Physical Accuracy</strong>: Simple pair potentials cannot accurately model metallic defects; they predict zero Cauchy pressure ($C_{12} - C_{44} = 0$) and equate vacancy formation energy to cohesive energy, both of which are incorrect for transition metals.</li>
<li><strong>Practical Utility</strong>: There was a need for a clear guide on how to construct and apply these potentials for large-scale simulations ($10^6+$ atoms) of fracture and defects.</li>
</ul>
<h2 id="novelty-a-unified-framework-and-robust-fitting-recipe">Novelty: A Unified Framework and Robust Fitting Recipe</h2>
<p>As a review chapter, the novelty lies in the synthesis and the specific, reproducible recipe for potential construction. Central to this synthesis is the core EAM energy functional:</p>
<p>$$E_{\text{tot}} = \sum_i \left( F(\bar{\rho}_i) + \frac{1}{2} \sum_{j \neq i} \phi(r_{ij}) \right)$$</p>
<p>where the total energy $E_{\text{tot}}$ depends on embedding an atom $i$ into a local background electron density $\bar{\rho}_i = \sum_{j \neq i} \rho(r_{ij})$, plus a repulsive pair interaction $\phi(r_{ij})$.</p>
<ul>
<li><strong>Unified Framework</strong>: It explicitly maps the &ldquo;Second Moment Approximation&rdquo; (Tight Binding) and the &ldquo;Glue Model&rdquo; onto the fundamental EAM framework above, clarifying that they differ primarily in terminology or specific functional choices (e.g., square root embedding functions).</li>
<li><strong>Cross-Potential Fitting Recipe</strong>: It details a robust method for fitting alloy potentials (specifically Ni-Al-B) by using &ldquo;transformation invariance&rdquo;, scaling the density and shifting the embedding function to fit alloy properties without disturbing pure element fits.</li>
<li><strong>Specific Parameters</strong>: It publishes optimized potential parameters for Ni, Al, and B that accurately reproduce properties like the Boron interstitial preference in $\text{Ni}_3\text{Al}$.</li>
</ul>
<h2 id="validation-computational-benchmarks-and-simulations">Validation: Computational Benchmarks and Simulations</h2>
<p>The &ldquo;experiments&rdquo; described are computational validations and simulations using the fitted Ni-Al-B potential:</p>
<ol>
<li>
<p><strong>Potential Fitting</strong>:</p>
<ul>
<li>Pure elements (Ni, Al) were fitted to elastic constants, vacancy formation energies, and diatomic data. The Ni fit achieved $\chi_{\text{rms}} = 0.75%$ and Al achieved $\chi_{\text{rms}} = 3.85%$.</li>
<li>Boron was fitted using hypothetical crystal structures (fcc, bcc) calculated via LMTO (Linear Muffin-Tin Orbital) since experimental data for fcc B does not exist.</li>
</ul>
</li>
<li>
<p><strong>Molecular Statics (Validation)</strong>:</p>
<ul>
<li><strong>Surface Relaxation</strong>: Demonstrated that EAM captures the oscillatory relaxation of atomic layers near a free surface, a many-body effect that pair potentials fail to capture.</li>
<li><strong>Defect Energetics</strong>: Calculated formation energies for Boron interstitials in $\text{Ni}_3\text{Al}$. Found the 6Ni-octahedral site is most stable ($-4.59$ eV relative to an isolated B atom and unperturbed crystal), followed by the 4Ni-2Al octahedral site ($-3.65$ eV) and the 3Ni-1Al tetrahedral site ($-2.99$ eV), consistent with channeling experiments.</li>
</ul>
</li>
<li>
<p><strong>Molecular Dynamics (Application)</strong>:</p>
<ul>
<li><strong>Grain Boundary (GB) Cleavage</strong>: Simulated the fracture of a (210) tilt grain boundary in $\text{Ni}_3\text{Al}$ at a strain rate of $5 \times 10^{10}$ s$^{-1}$.</li>
<li><strong>Comparison</strong>: Compared pure $\text{Ni}_3\text{Al}$ boundaries vs. those doped with Boron and substitutional Nickel.</li>
</ul>
</li>
</ol>
<h2 id="key-outcomes-eam-efficiency-and-boron-strengthening">Key Outcomes: EAM Efficiency and Boron Strengthening</h2>
<ul>
<li><strong>EAM Efficiency</strong>: Confirmed that EAM scales linearly with atom count ($N$), requiring only 2-5 times the computational work of pair potentials.</li>
<li><strong>Boron Strengthening Mechanism</strong>: The simulations suggested that Boron segregates to grain boundaries and, specifically when co-segregated with Ni, significantly increases cohesion.
<ul>
<li>The maximum stress for the enriched boundary was approximately 22 GPa, compared to approximately 19 GPa for the clean boundary.</li>
<li>The B-doped boundary required approximately 44% more work to cleave than the undoped boundary.</li>
<li>The fracture mode shifted from cleaving along the GB to failure in the bulk.</li>
</ul>
</li>
<li><strong>Grain Boundary Segregation</strong>: Molecular statics calculations found B interstitial energies at the GB as low as $-6.9$ eV, compared to $-4.59$ eV in the bulk, consistent with experimental observations of boron segregation to grain boundaries.</li>
<li><strong>Limitations</strong>: The author concludes that while EAM is excellent for metals, it lacks the angular dependence required for strongly covalent materials (like $\text{MoSi}_2$) or directional bonding.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>The chapter provides nearly all details required to implement the described potential from scratch.</p>
<h3 id="data">Data</h3>
<ul>
<li><strong>Experimental/Reference Data</strong>: Used for fitting the cost function $\chi_{\text{rms}}$.
<ul>
<li><strong>Pure Elements</strong>: Lattice constants ($a_0$), cohesive energy ($E_{\text{coh}}$), bulk modulus ($B$), elastic constants ($C_{11}, C_{12}, C_{44}$), vacancy formation energy ($E_{\text{vac}}^f$), and diatomic bond length/strength ($R_e, D_e$).</li>
<li><strong>Alloys</strong>: Heat of solution and defect energies (APB, SISF) for $\text{Ni}_3\text{Al}$.</li>
<li><strong>Hypothetical Data</strong>: LMTO first-principles data used for unobserved phases (e.g., fcc Boron, B2 NiB) to constrain the fit.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Component Functions</strong>:
<ul>
<li><strong>Pair Potential $\phi(r)$</strong>: Morse potential form:
$$\phi(r) = D_M {1 - \exp[-\alpha_M(r - R_M)]}^2 - D_M$$</li>
<li><strong>Density Function $\rho(r)$</strong>: Modified hydrogenic 4s orbital:
$$\rho(r) = r^6(e^{-\beta r} + 2^9 e^{-2\beta r})$$</li>
<li><strong>Embedding Function $F(\bar{\rho})$</strong>: Derived numerically to force the crystal energy to match the &ldquo;Universal Energy Relation&rdquo; (Rose et al.) as a function of lattice constant.</li>
</ul>
</li>
<li><strong>Fitting Strategy</strong>:
<ul>
<li><strong>Smooth Cutoff</strong>: A polynomial smoothing function ($h_{\text{smooth}}$) applied at $r_{\text{cut}}$ to ensure continuous derivatives.</li>
<li><strong>Simplex Algorithm</strong>: Used to optimize parameters ($D_M, R_M, \alpha_M, \beta, r_{\text{cut}}$).</li>
<li><strong>Alloy Invariance</strong>: Used transformations $F&rsquo;(\rho) = F(\rho) + g\rho$ and $\rho&rsquo;(r) = s\rho(r)$ to fit cross-potentials without altering pure-element properties.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Parameters</strong>: The text provides the exact optimized parameters for the Ni-Al-B potential in <strong>Table 2</strong> (Pure elements) and <strong>Table 5</strong> (Cross-potentials).
<ul>
<li>Example Ni parameters: $D_M=1.5335$ eV, $\alpha_M=1.7728$ Å$^{-1}$, $r_{\text{cut}}=4.7895$ Å.</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>1994 Context</strong>: Mentions that simulations of $10^6$ atoms were possible on the &ldquo;fastest computers available&rdquo;.</li>
<li><strong>Scaling</strong>: Explicitly notes computational work scales as $O(N)$, roughly 2-5x slower than pair potentials.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Voter, A. F. (1994). Chapter 4: The Embedded-Atom Method. In <em>Intermetallic Compounds: Vol. 1, Principles</em>, edited by J. H. Westbrook and R. L. Fleischer. John Wiley &amp; Sons Ltd.</p>
<p><strong>Publication</strong>: Intermetallic Compounds: Vol. 1, Principles (1994)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@incollection</span>{voterEmbeddedAtomMethod1994,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{The Embedded-Atom Method}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Voter, Arthur F.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Intermetallic Compounds: Vol. 1, Principles}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">editor</span> = <span style="color:#e6db74">{Westbrook, J. H. and Fleischer, R. L.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{1994}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{John Wiley &amp; Sons Ltd}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{77--90}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">chapter</span> = <span style="color:#e6db74">{4}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://www.ctcms.nist.gov/potentials/">NIST Interatomic Potentials Repository</a> (Modern repository often hosting EAM files)</li>
<li><a href="/notes/chemistry/molecular-simulation/classical-methods/embedded-atom-method/">Original EAM Paper (1984)</a></li>
<li><a href="/notes/chemistry/molecular-simulation/classical-methods/embedded-atom-method-review-1993/">EAM Review (1993)</a></li>
</ul>
]]></content:encoded></item><item><title>Importance Weighted Autoencoders: Beyond the Standard VAE</title><link>https://hunterheidenreich.com/posts/importance-weighted-autoencoders/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/importance-weighted-autoencoders/</guid><description>The key difference between multi-sample VAEs and IWAEs: how log-of-averages creates a tighter bound on log-likelihood.</description><content:encoded><![CDATA[<p>If you&rsquo;ve worked with Variational Autoencoders (VAEs), you&rsquo;ve almost certainly used the standard $\mathcal{L}_1$ objective, or ELBO. It&rsquo;s trained by taking <em>one</em> sample ($k=1$) from the recognition network to calculate the loss.</p>
<p>A natural question follows: &ldquo;What if I use more samples? Won&rsquo;t that make it better?&rdquo;</p>
<p>Using more samples improves performance when paired with the correct objective function. Averaging the loss over $k$ samples yields minimal gains. Changing the objective itself is where the real gain comes from. This post explores the difference between a &ldquo;multi-sample VAE&rdquo; and the <strong>Importance Weighted Autoencoder (IWAE)</strong>, a model that uses the <em>same architecture</em> as a VAE but is trained with a different objective that optimizes a tighter bound on the log-likelihood.</p>
<p>All ideas here are based on the fantastic paper: <a href="https://arxiv.org/abs/1509.00519">&ldquo;Importance Weighted Autoencoders&rdquo;</a> by Burda, Grosse, and Salakhutdinov.</p>
<h2 id="the-two-ways-to-use-k-samples">The Two Ways to Use $k$ Samples</h2>
<p>Let&rsquo;s say we have our encoder $q(h|x)$ and decoder $p(x,h)$. We decide to use $k=5$ samples instead of $k=1$. We have two main options for how to calculate our loss.</p>
<h3 id="option-1-the-multi-sample-vae-the-naive-way">Option 1: The &ldquo;Multi-Sample VAE&rdquo; (The Naive Way)</h3>
<p>This is the most straightforward idea. For each input $x$ in our batch:</p>
<ol>
<li>Draw 5 samples ($h_1, &hellip;, h_5$) from $q(h|x)$.</li>
<li>Calculate the standard VAE $\mathcal{L}_1$ loss for <em>each</em> sample.</li>
<li>Average these 5 losses together.</li>
</ol>
<p>This is an <strong>average of logs</strong>. As the IWAE paper shows experimentally, this approach gives you a more stable gradient, but the final performance (in terms of log-likelihood) is &ldquo;only slightly&rdquo; better. You&rsquo;re paying a 5x computational cost for a marginal gain because you&rsquo;re still optimizing the same &ldquo;loose&rdquo; $\mathcal{L}_1$ bound.</p>
<h3 id="option-2-the-importance-weighted-autoencoder-iwae-the-right-way">Option 2: The Importance Weighted Autoencoder (IWAE) (The Right Way)</h3>
<p>The IWAE takes a different approach. For each input $x$:</p>
<ol>
<li>Draw 5 samples ($h_1, &hellip;, h_5$) from $q(h|x)$.</li>
<li>Calculate an &ldquo;importance weight&rdquo; $w_i$ for each sample.</li>
<li>Average these 5 <em>weights</em> together.</li>
<li>Take the <em>logarithm</em> of that average.</li>
</ol>
<p>This is a <strong>log of an average</strong>, and the difference matters.</p>
<h2 id="the-math-average-of-logs-vs-log-of-averages">The Math: Average-of-Logs vs. Log-of-Averages</h2>
<p>Let&rsquo;s make this concrete. The standard VAE $\mathcal{L}_1$ objective is:</p>
<p>$$
\mathcal{L}_1(x) = \mathbb{E} _{h\sim q(h|x)} \left[ \log \frac{p(x,h)}{q(h|x)} \right]
$$</p>
<p>A <strong>multi-sample VAE</strong> simply gets a better estimate of this same value:</p>
<p>$$
\mathcal{L} _{\text{VAE}, k}(x) \approx  \frac{1}{k} \sum _{i=1}^{k} \log w_i \quad \text{where} \quad w_i = \frac{p(x,h_i)}{q(h_i|x)}
$$</p>
<p>The <strong>IWAE</strong> objective, $\mathcal{L}_k$, is fundamentally different:</p>
<p>$$
\mathcal{L} _k (x) = \mathbb{E} _{h_1..h_k \sim q(h|x)} \left[ \log \left( \frac{1}{k} \sum _{i=1}^{k} \frac{p(x,h_i)}{q(h_i|x)} \right) \right]
$$</p>
<p>In practice, we estimate this with a single Monte Carlo sample (of $k$ latents):</p>
<p>$$
\mathcal{L} _k (x) \approx \log \left( \frac{1}{k} \sum _{i=1}^{k} w_i \right)
$$</p>
<p>Because the logarithm is a concave function, Jensen&rsquo;s Inequality tells us that the &ldquo;log of an average&rdquo; is <em>always</em> greater than or equal to the &ldquo;average of logs.&rdquo;</p>
<p>$$
\mathcal{L}_k(x) \ge \mathcal{L}_1(x)
$$</p>
<p>This means the IWAE is optimizing a <strong>strictly tighter lower bound</strong> on the true log-likelihood of the data.</p>
<h2 id="why-does-this-log-of-average-matter">Why Does This &ldquo;Log-of-Average&rdquo; Matter?</h2>
<p>This mathematical property provides two practical benefits.</p>
<h3 id="1-better-density-estimation">1. Better Density Estimation</h3>
<p>Because $\mathcal{L}_k$ is a tighter bound on the true $p(x)$, optimizing it pushes the model to learn a much better generative distribution. The paper shows that IWAEs achieve &ldquo;significantly higher log-likelihoods&rdquo; than VAEs.</p>
<h3 id="2-richer-latent-representations">2. Richer Latent Representations</h3>
<p>This is the most interesting part. The standard VAE $\mathcal{L}_1$ objective &ldquo;harshly penalizes&rdquo; the model if its <em>one</em> sample $h$ is a poor explanation for $x$. This pressure forces the recognition network $q(h|x)$ to be &ldquo;overly simplified&rdquo; to avoid bad samples, which can lead to many latent dimensions becoming inactive (the paper reports the number of &ldquo;active units&rdquo; per model).</p>
<p>The IWAE objective is more flexible. It only needs <em>one</em> of the $k$ samples to be good. This &ldquo;increased flexibility&rdquo; allows the model to learn far more complex posterior distributions and &ldquo;richer latent space representations.&rdquo; The paper&rsquo;s experiments confirm this, showing IWAEs learn to use many more &ldquo;active units&rdquo; in their latent space.</p>
<h2 id="what-this-looks-like-in-code-pytorch">What This Looks Like in Code (PyTorch)</h2>
<p>The implementation difference makes this crystal clear.</p>
<p>First, the &ldquo;k-sample&rdquo; trick: for a batch <code>x</code> of shape <code>[B, D]</code> and <code>k=5</code> samples, we repeat <code>x</code> to get <code>x_repeated</code> of shape <code>[B*k, D]</code>. We do all our forward passes on this large tensor.</p>
<h3 id="vae-multi-sample-k--1-loss">VAE (Multi-Sample, k &gt; 1) Loss</h3>
<p>Here, we can still use the analytical KL divergence, which is a big simplification.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># x_repeated has shape [B*k, 784]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># mu, logvar have shape [B*k, latent_dim]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># recon_x has shape [B*k, 784]</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># recon_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span>recon_loss_all <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>binary_cross_entropy(recon_x, x_repeated, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># kl_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># We use the simple, analytical KL term!</span>
</span></span><span style="display:flex;"><span>kl_loss_all <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(<span style="color:#ae81ff">1</span> <span style="color:#f92672">+</span> logvar <span style="color:#f92672">-</span> mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> logvar<span style="color:#f92672">.</span>exp(), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># total_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span>total_loss_all <span style="color:#f92672">=</span> recon_loss_all <span style="color:#f92672">+</span> kl_loss_all
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- The Key Step ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Just average all B*k losses. This is the &#34;average of logs&#34;.</span>
</span></span><span style="display:flex;"><span>loss <span style="color:#f92672">=</span> total_loss_all<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="iwae-k--1-loss">IWAE (k &gt; 1) Loss</h3>
<p>Here, we must compute the exact log-probabilities of the <em>specific samples</em> we drew.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Helper function to compute log-prob of a sample from a Gaussian</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">log_prob_gaussian</span>(sample, mu, logvar):
</span></span><span style="display:flex;"><span>    const <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sample<span style="color:#f92672">.</span>shape[<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>] <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>log(<span style="color:#ae81ff">2</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>tensor(math<span style="color:#f92672">.</span>pi))
</span></span><span style="display:flex;"><span>    log_det <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(logvar, dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    log_exp <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum((sample <span style="color:#f92672">-</span> mu)<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">/</span> torch<span style="color:#f92672">.</span>exp(logvar), dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> const <span style="color:#f92672">+</span> log_det <span style="color:#f92672">+</span> log_exp
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Get the 3 log-prob components ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># x_repeated, recon_x, z_samples, mu_repeated, logvar_repeated</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># all have a first dimension of [B*k]</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 1. log p(x|h_i): Log-Reconstruction Probability</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_p_x_given_h shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_p_x_given_h <span style="color:#f92672">=</span> <span style="color:#f92672">-</span>F<span style="color:#f92672">.</span>binary_cross_entropy(recon_x, x_repeated, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 2. log p(h_i): Log-Prior Probability (under N(0, I))</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_p_h shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_p_h <span style="color:#f92672">=</span> log_prob_gaussian(z_samples, <span style="color:#ae81ff">0.0</span>, <span style="color:#ae81ff">0.0</span>) <span style="color:#75715e"># mu=0, logvar=0</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 3. log q(h_i|x): Log-Encoder Probability</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_q_h_given_x shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_q_h_given_x <span style="color:#f92672">=</span> log_prob_gaussian(z_samples, mu_repeated, logvar_repeated)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- The Key Step ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Combine to get the log-importance-weight</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_w shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_w <span style="color:#f92672">=</span> log_p_x_given_h <span style="color:#f92672">+</span> log_p_h <span style="color:#f92672">-</span> log_q_h_given_x
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Reshape to [B, k] to group samples by their original input</span>
</span></span><span style="display:flex;"><span>log_w_matrix <span style="color:#f92672">=</span> log_w<span style="color:#f92672">.</span>view(B, k) <span style="color:#75715e"># B is original batch size</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Apply the IWAE Objective (Log-Sum-Exp Trick) ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># This is the &#34;log of the average&#34;</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log( (1/k) * sum(exp(log_w)) ) = logsumexp(log_w) - log(k)</span>
</span></span><span style="display:flex;"><span>log_iwae_bound_per_x <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>logsumexp(log_w_matrix, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> math<span style="color:#f92672">.</span>log(k)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># The objective is to MAXIMIZE this bound, so the loss is its negative</span>
</span></span><span style="display:flex;"><span>loss <span style="color:#f92672">=</span> <span style="color:#f92672">-</span>log_iwae_bound_per_x<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="the-critical-implementation-detail">The Critical Implementation Detail</h3>
<p>Notice the key difference in the final step:</p>
<ul>
<li><strong>VAE</strong>: <code>loss = total_loss_all.mean()</code> average of individual losses</li>
<li><strong>IWAE</strong>: <code>loss = -torch.logsumexp(log_w_matrix, dim=1).mean()</code> log of averaged weights</li>
</ul>
<p>This seemingly small change implements the fundamental mathematical difference between optimizing an &ldquo;average of logs&rdquo; versus a &ldquo;log of averages.&rdquo;</p>
<h2 id="when-to-use-each-approach">When to Use Each Approach</h2>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>When to Use</th>
          <th>Key Benefit</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>VAE ($k=1$)</strong></td>
          <td>Your <strong>default baseline</strong>. It&rsquo;s fast, simple, and often &ldquo;good enough&rdquo; for many tasks.</td>
          <td>Speed and simplicity.</td>
      </tr>
      <tr>
          <td><strong>Multi-Sample VAE ($k&gt;1$)</strong></td>
          <td>When you want slightly more stable gradients but aren&rsquo;t ready for the full IWAE complexity.</td>
          <td>Marginal improvement with minimal code changes.</td>
      </tr>
      <tr>
          <td><strong>IWAE ($k&gt;1$)</strong></td>
          <td>When your baseline VAE is <strong>insufficient</strong>. Specifically, if you need:<br>1. The best possible log-likelihood.<br>2. To activate more latent dimensions or learn richer representations.</td>
          <td>Better performance and richer latents, at the cost of compute (scales linearly with $k$).</td>
      </tr>
  </tbody>
</table>
<h2 id="the-computational-trade-off">The Computational Trade-off</h2>
<p>Both approaches scale linearly with $k$. If you use $k=5$ samples, you&rsquo;re doing roughly 5x the computation. The question is whether you get 5x the benefit.</p>
<p>For multi-sample VAEs, the answer is usually &ldquo;no&rdquo;. You get more stable gradients but only marginal performance improvements.</p>
<p>For IWAEs, the answer is often &ldquo;yes&rdquo;. You get meaningfully better log-likelihoods and richer latent representations that can be worth the computational cost.</p>
<h2 id="conclusion">Conclusion</h2>
<p>The next time you use more samples with your VAE, switch to the IWAE objective to get the full benefit of the computational cost of $k &gt; 1$.</p>
<p>The mathematical insight is simple but powerful: Jensen&rsquo;s Inequality tells us that the &ldquo;log of an average&rdquo; is always greater than or equal to the &ldquo;average of logs.&rdquo; By optimizing this tighter bound, IWAEs achieve better density estimation and learn richer latent representations than standard VAEs.</p>
<p>The implementation requires computing exact log-probabilities to evaluate the specific samples. The result is a fundamentally more powerful model using the exact same architecture.</p>
<p><strong>Want to dive deeper?</strong> Check out the <a href="https://arxiv.org/abs/1509.00519">original IWAE paper</a> for experimental results and theoretical analysis, or explore my <a href="/posts/modern-variational-autoencoder-in-pytorch/">VAE tutorial</a> for hands-on implementation details.</p>
]]></content:encoded></item><item><title>Importance Weighted Autoencoders (IWAE) for Tighter Bounds</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/importance-weighted-autoencoders/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/importance-weighted-autoencoders/</guid><description>Summary of Burda, Grosse &amp; Salakhutdinov's ICLR 2016 paper introducing Importance Weighted Autoencoders for tighter variational bounds</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper that introduces the <strong>Importance Weighted Autoencoder (IWAE)</strong>, a generative model that shares the same architecture as the Variational Autoencoder (VAE) but uses a different, tighter objective function. The key innovation is using importance weighting to derive a strictly tighter log-likelihood lower bound than the standard VAE objective.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The standard VAE has several limitations that motivated this work:</p>
<ul>
<li><strong>Strong assumptions</strong>: VAEs typically assume the posterior distribution is simple (e.g., approximately factorial) and that its parameters can be easily approximated from observations.</li>
<li><strong>Simplified representations</strong>: The VAE objective can force models to learn overly simplified representations that underutilize the network&rsquo;s full modeling capacity.</li>
<li><strong>Harsh penalization</strong>: The VAE objective harshly penalizes approximate posterior samples that are poor explanations for the data, which can be overly restrictive.</li>
<li><strong>Inactive units</strong>: VAEs tend to learn latent spaces with effective dimensions far below their capacity, where many latent units are ignored (a phenomenon later termed <strong>posterior collapse</strong>, where the approximate posterior collapses to the prior and conveys no information). The authors wanted to investigate whether a new objective could address this issue.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>IWAE objective function</strong>, denoted as $\mathcal{L}_{k}$.</p>
<ul>
<li>
<p><strong>VAE ($\mathcal{L}_{1}$ Bound)</strong>: The standard VAE maximizes $\mathcal{L}(x)=\mathbb{E}_{q(h|x)}[\log\frac{p(x,h)}{q(h|x)}]$. This is equivalent to the new bound when $k=1$.</p>
</li>
<li>
<p><strong>IWAE ($\mathcal{L}_{k}$ Bound)</strong>: The IWAE maximizes a tighter bound that uses $k$ samples drawn from the recognition model $q(h|x)$:</p>
</li>
</ul>
<p>$$\mathcal{L}_{k}(x)=\mathbb{E}_{h_{1},&hellip;,h_{k}\sim q(h|x)}\left[\log\frac{1}{k}\sum_{i=1}^{k}\frac{p(x,h_{i})}{q(h_{i}|x)}\right]$$</p>
<ul>
<li>
<p><strong>Tighter Bound</strong>: The authors prove that this bound is always tighter than or equal to the VAE bound ($\mathcal{L}_{k+1} \geq \mathcal{L}_{k}$) and that as $k$ approaches infinity, $\mathcal{L}_{k}$ approaches the true log-likelihood $\log p(x)$.</p>
</li>
<li>
<p><strong>Increased Flexibility</strong>: Using multiple samples gives the IWAE additional flexibility to learn generative models whose posterior distributions are complex and violate the VAE&rsquo;s simplifying assumptions.</p>
</li>
</ul>
<h3 id="key-concept-averaging-inside-vs-outside-the-log">Key Concept: Averaging Inside vs. Outside the Log</h3>
<p>A crucial distinction exists between how VAE and IWAE utilize $k$ samples. Understanding this difference explains why increasing $k$ in IWAE improves the bound. In VAE, it reduces variance.</p>















<figure class="post-figure center ">
    <img src="/img/notes/variational-autoencoder-vae-vs-importance-weighted-autoencoder-iwae.webp"
         alt="Flowchart comparing VAE and IWAE computation: VAE takes the log of each weight then averages (average of logs). IWAE averages the weights first then takes the log (log of average)"
         title="Flowchart comparing VAE and IWAE computation: VAE takes the log of each weight then averages (average of logs). IWAE averages the weights first then takes the log (log of average)"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">VAE vs IWAE computation flow. The key difference is where the log operation occurs: VAE computes log(w_i) for each sample then averages. IWAE averages the weights first then applies log to the result.</figcaption>
    
</figure>

<p><strong>VAE (Average of Logs):</strong></p>
<p>For a VAE, using $k$ samples approximates:</p>
<p>$$\mathbb{E}\left[ \frac{1}{k} \sum_{i=1}^k \log w_i \right] \approx \text{ELBO}$$</p>
<p>where $w_i = p(x, h_i) / q(h_i | x)$. Increasing $k$ here only reduces the variance of the gradient estimator; the model still targets the same ELBO bound, so performance gains saturate quickly.</p>
<p><strong>IWAE (Log of Average):</strong></p>
<p>IWAE performs the averaging <em>inside</em> the logarithm:</p>
<p>$$\mathbb{E}\left[ \log \left( \frac{1}{k} \sum_{i=1}^k w_i \right) \right] = \mathcal{L}_k$$</p>
<p>By Jensen&rsquo;s Inequality ($\log(\mathbb{E}[X]) \geq \mathbb{E}[\log(X)]$ for concave functions), this bound is mathematically guaranteed to be at least as tight as the VAE bound. Each increase in $k$ defines a new, strictly tighter lower bound on the log-likelihood.</p>
<p><strong>Why This Matters for Gradients:</strong></p>
<p>In IWAE, the gradient weights are normalized importance weights $\tilde{w}_i = w_i / \sum_j w_j$. This means &ldquo;bad&rdquo; samples (those with low $w_i$) contribute very little to the gradient update since they vanish from the weighted sum. VAE uses unweighted samples, so a single sample with extremely low probability produces a massive negative log value that can dominate the loss and harshly penalize the model. IWAE&rsquo;s formulation allows the model to focus learning on the samples that explain the data well.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors compared VAE and IWAE on density estimation tasks using the MNIST and Omniglot datasets. They evaluated two main network architectures: one with a single stochastic layer and another with two stochastic layers. The models were trained with varying numbers of importance samples ($k \in {1, 5, 50}$) to observe the effect on performance and latent space utilization. The primary metrics for evaluation were the test log-likelihood (estimated using 5000 samples) and the number of &ldquo;active&rdquo; latent units, which quantifies the richness of the learned representations.</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>















<figure class="post-figure center ">
    <img src="/img/notes/iwae-vs-vae-active-latent-units-comparison.webp"
         alt="Bar chart comparing active latent units between VAE and IWAE across different k values on MNIST and Omniglot datasets"
         title="Bar chart comparing active latent units between VAE and IWAE across different k values on MNIST and Omniglot datasets"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Active latent units for VAE vs IWAE (1 stochastic layer). VAE active units remain flat. IWAE increases with k. Data from Table 1 of Burda et al. (2016).</figcaption>
    
</figure>

<ul>
<li>
<p><strong>Better Performance</strong>: IWAE achieved higher log-likelihoods than VAEs across all configurations. On MNIST with two stochastic layers and $k=50$, IWAE reached $-82.90$ nats compared to $-84.78$ for VAE. On Omniglot, the best IWAE achieved $-103.38$ nats versus $-106.30$ for VAE. IWAE performance improved consistently with increasing $k$, while VAE performance benefited only slightly from using more samples ($k&gt;1$).</p>
</li>
<li>
<p><strong>Richer Representations</strong>: In all experiments with $k&gt;1$, IWAE learned more active latent dimensions than VAE, suggesting richer latent representations.</p>
</li>
<li>
<p><strong>Objective Drives Representation</strong>: The authors found that latent dimension inactivation is driven by the objective function. They demonstrated this through an &ldquo;objective swap&rdquo; experiment:</p>
</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/notes/iwae-objective-swap-experiment.webp"
         alt="Bar charts showing the objective swap experiment results with active units and NLL changes when switching between VAE and IWAE objectives"
         title="Bar charts showing the objective swap experiment results with active units and NLL changes when switching between VAE and IWAE objectives"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Objective swap experiment on MNIST (1 stochastic layer). Switching a trained VAE to the IWAE objective improves both metrics. Switching IWAE to VAE degrades them. Data from Table 2 of Burda et al. (2016).</figcaption>
    
</figure>

<p>This experiment provides evidence that the objective function itself influences latent utilization:</p>
<ul>
<li><strong>VAE → IWAE</strong>: A converged VAE model, when fine-tuned with the IWAE objective ($k=50$), gained 3 active units (19 → 22) and improved test NLL from 86.76 to 84.88.</li>
<li><strong>IWAE → VAE</strong>: A converged IWAE model fine-tuned with the VAE objective lost 2 active units (25 → 23) and worsened test NLL from 84.78 to 86.02.</li>
</ul>
<p>These results strongly suggest that inactivation of latent dimensions is driven by the objective function rather than by optimization dynamics, initialization, or architecture. The authors note that optimization also plays a role, as the swap results do not exactly match training from scratch.</p>
<ul>
<li>
<p><strong>Comparison to Other Models</strong>: On MNIST, the best IWAE ($-82.90$ nats) outperformed deep belief networks ($-84.55$ nats) and deep autoregressive networks ($-84.13$ nats), though DRAW ($-80.97$ nats), which exploits spatial structure, achieved better results. On Omniglot, the best IWAE ($-103.38$ nats) fell slightly behind RBMs trained with persistent contrastive divergence ($-100.46$ nats).</p>
</li>
<li>
<p><strong>Conclusion</strong>: IWAEs learn richer latent representations and achieve better generative performance than VAEs with equivalent architectures and training time.</p>
</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>MNIST</strong>: $28 \times 28$ binarized handwritten digits (60,000 training / 10,000 test).</li>
<li><strong>Omniglot</strong>: $28 \times 28$ binarized handwritten characters from various alphabets (24,345 training / 8,070 test).</li>
<li><strong>Binarization</strong>: Dynamic sampling where binary values are sampled with expectations equal to the real pixel intensities (following Salakhutdinov &amp; Murray, 2008).</li>
<li><strong>Fixed Binarization</strong>: Results on a fixed binarization of MNIST (Larochelle, 2011) confirm that IWAE outperforms VAE across preprocessing methods. It exhibits notably more overfitting compared to dynamic sampling.</li>
</ul>
<h3 id="models">Models</h3>
<p>Two main network architectures were tested:</p>
<ol>
<li>One stochastic layer (50 units) with two deterministic layers (200 units each).</li>
<li>Two stochastic layers (100 and 50 units). Between x and h1 were two deterministic layers with 200 units each. Between h1 and h2 were two deterministic layers with 100 units each.</li>
</ol>
<ul>
<li><strong>Activations</strong>: <code>tanh</code> for deterministic layers; <code>exp</code> applied to variance predictions to ensure positivity.</li>
<li><strong>Distributions</strong>: Gaussian latent layers with diagonal covariance; Bernoulli observation layer.</li>
<li><strong>Initialization</strong>: Glorot &amp; Bengio (2010) heuristic.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Optimizer</strong>: Adam ($\beta_1=0.9$, $\beta_2=0.999$, $\epsilon=10^{-4}$).</li>
<li><strong>Batch Size</strong>: 20.</li>
<li><strong>Learning Rate Schedule</strong>: Annealed rate of $0.001 \cdot 10^{-i/7}$ for $3^i$ epochs (where $i=0&hellip;7$), totaling 3,280 passes over the data.</li>
<li><strong>Variance Control</strong>: A common concern with importance sampling is high variance. The authors prove that the Mean Absolute Deviation of their estimator is bounded by $2 + 2\delta$, where $\delta$ is the gap between the bound and true log-likelihood. As the bound tightens, variance remains controlled.</li>
<li><strong>Computational trick</strong>: In the basic IWAE implementation, both forward and backward passes must be done independently for each of the $k$ samples, so the cost scales linearly with $k$. However, the authors describe an optional optimization: stochastically approximate the gradient sum by sampling a single $\epsilon_i$ proportional to its normalized weight $\tilde{w}_i$, then computing only that one backward pass. This reduces the cost to $k$ forward passes and one backward pass. Since the backward pass costs roughly twice the forward pass, this yields approximately a 3x speedup for large $k$ at the cost of increased gradient variance.</li>
</ul>
<p><strong>Relationship to Reweighted Wake-Sleep (RWS):</strong> Both IWAE and Reweighted Wake-Sleep (Bornschein &amp; Bengio, 2015) use importance-weighted samples and have closely related generative model updates. The key difference is that IWAE derives a single unified lower bound $\mathcal{L}_k$ and uses the reparameterization trick to train the recognition network jointly. RWS instead uses separate wake and sleep phases for the recognition network, which are not derived from $\mathcal{L}_k$.</p>
<h3 id="evaluation">Evaluation</h3>
<ol>
<li><strong>Test Log-Likelihood</strong>: Primary measure of generative performance, estimated as the mean of $\mathcal{L}_{5000}$ (5000 samples) on the test set.</li>
<li><strong>Active Units</strong>: To quantify latent space richness, the authors measured &ldquo;active&rdquo; latent dimensions. A unit $u$ was defined as active if its activity statistic $A_{u}=\text{Cov}_{x}(\mathbb{E}_{u\sim q(u|x)}[u])$ exceeded $10^{-2}$. The $10^{-2}$ threshold is justified by a bimodal distribution of the log activity statistic, showing clear separation between active and inactive units.</li>
</ol>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Hardware</strong>: GPU-based implementation using mini-batch replication to parallelize the $k$ samples. Specific GPU type and training times are not reported.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/yburda/iwae">yburda/iwae</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Official Theano implementation for MNIST and Omniglot</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Burda, Y., Grosse, R., &amp; Salakhutdinov, R. (2016). Importance Weighted Autoencoders. <em>International Conference on Learning Representations (ICLR) 2016</em>. <a href="https://arxiv.org/abs/1509.00519">https://arxiv.org/abs/1509.00519</a></p>
<p><strong>Publication</strong>: ICLR 2016</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{burda2016importance,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Importance Weighted Autoencoders}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Yuri Burda and Roger Grosse and Ruslan Salakhutdinov}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2016}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/1509.00519}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/1509.00519">ArXiv</a></li>
</ul>
]]></content:encoded></item><item><title>Auto-Encoding Variational Bayes: VAE Paper Summary</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/autoencoding-variational-bayes/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/autoencoding-variational-bayes/</guid><description>Summary of Kingma &amp; Welling's 2013 VAE paper introducing the reparameterization trick and variational autoencoders.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper that introduces a generative mechanism (the VAE) and an optimization technique (the reparameterization trick), with formal theoretical derivation. The method, called the Auto-Encoding VB (AEVB) algorithm, leads to what we now know as the <strong>variational auto-encoder (VAE)</strong> when neural networks are used as the recognition model.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The authors address two central intractabilities in directed probabilistic models with continuous latent variables:</p>















<figure class="post-figure center ">
    <img src="/img/notes/autoencoding-variational-bayes-figure-1-model-diagram.webp"
         alt="VAE graphical model showing latent variable z, observed variable x, and parameters phi and theta"
         title="VAE graphical model showing latent variable z, observed variable x, and parameters phi and theta"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Figure 1 from the paper: The directed graphical model. Solid lines denote the generative model $p_\theta(z)p_\theta(x|z)$, dashed lines denote the variational approximation $q_\phi(z|x)$. The variational parameters $\phi$ are learned jointly with the generative parameters $\theta$.</figcaption>
    
</figure>

<ol>
<li>
<p><strong>Intractable Posteriors</strong>: In models with continuous latent variables (like those with non-linear hidden layers), the true posterior $p_{\theta}(z|x)$ cannot be calculated analytically, preventing the use of standard EM algorithms.</p>
</li>
<li>
<p><strong>Large Datasets</strong>: Sampling-based solutions like Monte Carlo EM (MCEM) require expensive sampling loops per datapoint. This makes them too slow for large datasets where batch optimization is too costly and efficient minibatch updates are required.</p>
</li>
</ol>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<h3 id="the-reparameterization-trick-sgvb-estimator">The Reparameterization Trick (SGVB Estimator)</h3>
<p>The core innovation is the <strong>Stochastic Gradient Variational Bayes (SGVB)</strong> estimator. The authors solve the high variance of standard gradient estimation by &ldquo;reparameterizing&rdquo; the random variable $\tilde{z}$.</p>
<p>They express $z$ as a deterministic function of the input $x$ and an auxiliary noise variable $\epsilon$:</p>
<p>$$\tilde{z} = g_{\phi}(\epsilon, x) \quad \text{with} \quad \epsilon \sim p(\epsilon)$$</p>















<figure class="post-figure center ">
    <img src="/img/notes/variational-autoencoder-reparameterization-trick.webp"
         alt="Comparison of standard stochastic node vs reparameterization trick showing gradient flow"
         title="Comparison of standard stochastic node vs reparameterization trick showing gradient flow"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The reparameterization trick. (A) Standard stochastic nodes block gradient flow during backpropagation. (B) By expressing $z = \mu + \sigma \odot \epsilon$ with external noise $\epsilon \sim \mathcal{N}(0,1)$, gradients can flow through the deterministic path to the parameters $\phi$.</figcaption>
    
</figure>

<ul>
<li><strong>Mechanism</strong>: For a Gaussian posterior, $z = \mu + \sigma \odot \epsilon$ where $\epsilon \sim \mathcal{N}(0, I)$.</li>
<li><strong>Impact</strong>: This makes the Monte Carlo estimate differentiable with respect to the variational parameters $\phi$, allowing the variational lower bound to be optimized via standard stochastic gradient ascent (like SGD or Adagrad).</li>
</ul>
<h3 id="the-aevb-algorithm-the-vae">The AEVB Algorithm (The VAE)</h3>
<p>The <strong>Auto-Encoding VB (AEVB)</strong> algorithm amortizes inference by learning a global recognition model (encoder) $q_{\phi}(z|x)$ jointly with the generative model (decoder) $p_{\theta}(x|z)$.</p>
<p><strong>Objective Function</strong>: Maximize the variational lower bound $\mathcal{L}(\theta, \phi; x^{(i)})$:</p>
<p>$$\mathcal{L} \simeq -D_{KL}(q_\phi(z|x^{(i)}) | p_\theta(z)) + \frac{1}{L} \sum_{l=1}^L \log p_\theta(x^{(i)}|z^{(i,l)})$$</p>
<ul>
<li><strong>First Term (Regularizer)</strong>: Forces the approximate posterior to match the prior (integrated analytically for Gaussians).</li>
<li><strong>Second Term (Reconstruction Error)</strong>: The expected negative reconstruction error (estimated via sampling).</li>
</ul>
<p>This mirrors the standard auto-encoder objective, adding a variational regularizer.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The method was benchmarked against the <strong>Wake-Sleep</strong> algorithm and <strong>Monte Carlo EM (MCEM)</strong> using the <strong>MNIST</strong> (digits) and <strong>Frey Face</strong> (continuous faces) datasets.</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li>
<p><strong>Efficiency</strong>: AEVB converged faster and reached a better lower bound than Wake-Sleep (Figure 2). It scaled efficiently to the full MNIST dataset. MCEM&rsquo;s per-datapoint sampling cost made it impractical at full dataset scale, so comparisons were limited to small subsets (Figure 3).</p>
</li>
<li>
<p><strong>Regularization</strong>: The KL-divergence term provided a regularizing effect, preventing overfitting while increasing latent dimensions ($N_z$).</p>
</li>
<li>
<p><strong>Manifold Learning</strong>: The model successfully learned smooth 2D latent manifolds (visualized in Appendix A), grouping similar digits/faces together.</p>
</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Evaluation Data</strong>: For the marginal likelihood comparison (Figure 3), the paper used MNIST with $N_{\text{train}} = 100$ and $N_{\text{train}} = 5000$ to compare data efficiency (marginal log-likelihood vs. training samples seen) across algorithms. A smaller network (100 hidden units, 3 latent variables) was used for this comparison because the marginal likelihood estimator only works reliably in low-dimensional latent spaces.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Algorithm</strong>: Stochastic gradient ascent with <strong>Adagrad</strong> (global stepsizes chosen from ${0.01, 0.02, 0.1}$).</li>
<li><strong>Regularization</strong>: The objective included a weight decay term corresponding to a prior $p(\theta)=\mathcal{N}(0,I)$.</li>
<li><strong>Minibatches</strong>: Size $M=100$ with $L=1$ sample per datapoint.</li>
<li><strong>Initialization</strong>: Parameters sampled from $\mathcal{N}(0, 0.01)$.</li>
</ul>
<h3 id="models">Models</h3>
<p>The original VAE used simple Multi-Layered Perceptrons (MLPs):</p>
<ul>
<li><strong>Symmetry</strong>: The encoder and decoder were symmetric, having an equal number of hidden units.</li>
<li><strong>Hidden Units</strong>: 500 units for MNIST, 200 for Frey Face (to prevent overfitting on the smaller dataset).</li>
<li><strong>Activations</strong>: <strong>Tanh</strong> activation functions for the hidden layers.</li>
<li><strong>Latent Space</strong>: Experimented with $N_z$ ranging from 2 to 200.</li>
<li><strong>Outputs</strong>:
<ul>
<li><em>MNIST</em>: <strong>Bernoulli</strong> MLP (sigmoid output).</li>
<li><em>Frey Face</em>: <strong>Gaussian</strong> MLP, with means constrained to $(0,1)$ via sigmoid.</li>
</ul>
</li>
<li><strong>Encoder Architecture</strong>: For the Gaussian encoder, the mean $\mu$ and log-variance $\log(\sigma^2)$ are linear outputs from the shared hidden layer (they share the hidden layer weights and have separate output weights).</li>
<li><strong>Log-Variance</strong>: The encoder predicted $\log(\sigma^2)$ for numerical stability.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>The paper distinguishes between two metrics:</p>
<ul>
<li><strong>Variational Lower Bound</strong>: Used as the training objective (what the model optimizes).</li>
<li><strong>Marginal Likelihood</strong>: Used for final evaluation (Figure 3). The true marginal likelihood $p_\theta(x)$ was estimated using an Importance Sampling estimator constructed from samples drawn via Hybrid Monte Carlo (HMC), as detailed in Appendix D. This estimator uses: $p_{\theta}(x^{(i)}) \simeq (\frac{1}{L}\sum \frac{q(z)}{p(z)p(x|z)})^{-1}$.</li>
</ul>
<p>This distinction is critical: the training metric (lower bound) differs from the evaluation metric (estimated marginal likelihood).</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Hardware</strong>: Trained on a standard Intel Xeon CPU (approx. 40 GFLOPS); no GPUs were used.</li>
<li><strong>Training Time</strong>: Approximately 20-40 minutes per million training samples.</li>
</ul>
<h3 id="key-implementation-details-from-appendices">Key Implementation Details from Appendices</h3>
<ul>
<li><strong>Appendix A</strong>: Visualizations of 2D latent manifolds learned for MNIST and Frey Face datasets.</li>
<li><strong>Appendix B</strong>: Closed-form solution for the KL divergence of two Gaussians, essential for implementing the efficient version of the estimator (Equation 10).</li>
<li><strong>Appendix C</strong>: Exact MLP equations, including the use of tanh hidden layers and specific output layers for Bernoulli vs. Gaussian data. Includes specifications for <strong>Bernoulli MLPs</strong> (binary data) and <strong>Gaussian MLPs</strong> (real-valued data).</li>
<li><strong>Appendix D</strong>: Marginal likelihood estimation protocol using Hybrid Monte Carlo (HMC) and importance sampling for evaluation (Figure 3).</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Diederik P. Kingma and Max Welling. &ldquo;Auto-Encoding Variational Bayes.&rdquo; arXiv:1312.6114 [stat.ML], 2013. <a href="https://doi.org/10.48550/arXiv.1312.6114">https://doi.org/10.48550/arXiv.1312.6114</a></p>
<p><strong>Publication</strong>: ICLR 2014 (arXiv preprint December 2013)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{kingma2022autoencodingvariationalbayes,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Auto-Encoding Variational Bayes}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Diederik P Kingma and Max Welling}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2013}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{1312.6114}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">archivePrefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">primaryClass</span>=<span style="color:#e6db74">{stat.ML}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/1312.6114}</span>,
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://en.wikipedia.org/wiki/Variational_autoencoder">Wikipedia: Variational Autoencoder</a> - General overview</li>
<li><a href="https://openreview.net/forum?id=33X9fd2-9FyZd">OpenReview</a> - Original peer review with author responses</li>
<li><a href="/posts/modern-variational-autoencoder-in-pytorch/">Modern VAE in PyTorch</a> - Implementation tutorial on this site</li>
</ul>
]]></content:encoded></item><item><title>The Müller-Brown Potential: A 2D Benchmark Surface</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/classical-methods/muller-brown-1979/</link><pubDate>Mon, 08 Sep 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/classical-methods/muller-brown-1979/</guid><description>The Müller-Brown potential is a classic 2D benchmark for testing optimization algorithms and molecular dynamics methods.</description><content:encoded><![CDATA[<h2 id="overview">Overview</h2>
<p>The Müller-Brown potential is a primary benchmark system in computational chemistry: a two-dimensional analytical surface used to evaluate optimization algorithms. Introduced by Klaus Müller and Leo D. Brown in 1979 as a test system for their constrained simplex optimization algorithm, this potential energy function captures the essential topology of chemical reaction landscapes while preserving computational efficiency.</p>
<p><strong>Origin</strong>: Müller, K., &amp; Brown, L. D. (1979). Location of saddle points and minimum energy paths by a constrained simplex optimization procedure. <em>Theoretica Chimica Acta</em>, 53, 75-93. The potential is introduced in footnote 7 (p. 79) as a two-parametric model surface for testing the constrained simplex procedures.</p>
<h2 id="mathematical-definition">Mathematical Definition</h2>
<p>The Müller-Brown potential combines four two-dimensional Gaussian functions:</p>
<p>$$V(x,y) = \sum_{k=1}^{4} A_k \exp\left[a_k(x-x_k^0)^2 + b_k(x-x_k^0)(y-y_k^0) + c_k(y-y_k^0)^2\right]$$</p>
<p>Each Gaussian contributes a different &ldquo;bump&rdquo; or &ldquo;well&rdquo; to the landscape. The parameters control amplitude ($A_k$), width, orientation, and center position.</p>
<h3 id="standard-parameters">Standard Parameters</h3>
<p>The canonical parameter values that define the Müller-Brown surface are:</p>
<table>
  <thead>
      <tr>
          <th>k</th>
          <th>$A_k$</th>
          <th>$a_k$</th>
          <th>$b_k$</th>
          <th>$c_k$</th>
          <th>$x_k^0$</th>
          <th>$y_k^0$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1</td>
          <td>-200</td>
          <td>-1</td>
          <td>0</td>
          <td>-10</td>
          <td>1</td>
          <td>0</td>
      </tr>
      <tr>
          <td>2</td>
          <td>-100</td>
          <td>-1</td>
          <td>0</td>
          <td>-10</td>
          <td>0</td>
          <td>0.5</td>
      </tr>
      <tr>
          <td>3</td>
          <td>-170</td>
          <td>-6.5</td>
          <td>11</td>
          <td>-6.5</td>
          <td>-0.5</td>
          <td>1.5</td>
      </tr>
      <tr>
          <td>4</td>
          <td>15</td>
          <td>0.7</td>
          <td>0.6</td>
          <td>0.7</td>
          <td>-1</td>
          <td>1</td>
      </tr>
  </tbody>
</table>
<p>The first three terms have negative amplitudes (creating energy wells), while the fourth has a positive amplitude (creating a barrier). The cross-term $b_k$ in the third Gaussian creates the tilted orientation that gives the surface its characteristic curved pathways.</p>
<h3 id="analytical-gradients-forces">Analytical Gradients (Forces)</h3>
<p>To optimize paths or simulate molecular dynamics across this surface, calculating the spatial derivatives (negative forces) is structurally simple. Defining $G_k(x,y)$ as the inner argument of the exponent, the partial derivatives with respect to $x$ and $y$ are:</p>
<p>$$ \frac{\partial V}{\partial x} = \sum_{k=1}^4 A_k \exp[G_k(x,y)] \cdot \left[ 2a_k(x-x_k^0) + b_k(y-y_k^0) \right] $$</p>
<p>$$ \frac{\partial V}{\partial y} = \sum_{k=1}^4 A_k \exp[G_k(x,y)] \cdot \left[ b_k(x-x_k^0) + 2c_k(y-y_k^0) \right] $$</p>
<h2 id="energy-landscape">Energy Landscape</h2>
<p>This simple formula creates a surprisingly rich topography with exactly the features needed to challenge optimization algorithms:</p>
<table>
  <thead>
      <tr>
          <th><strong>Stationary Point</strong></th>
          <th><strong>Coordinates</strong></th>
          <th><strong>Energy</strong></th>
          <th><strong>Type</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MA (Reactant)</td>
          <td>(-0.558, 1.442)</td>
          <td>-146.70</td>
          <td>Deep minimum</td>
      </tr>
      <tr>
          <td>MC (Intermediate)</td>
          <td>(-0.050, 0.467)</td>
          <td>-80.77</td>
          <td>Shallow minimum</td>
      </tr>
      <tr>
          <td>MB (Product)</td>
          <td>(0.623, 0.028)</td>
          <td>-108.17</td>
          <td>Medium minimum</td>
      </tr>
      <tr>
          <td>S1</td>
          <td>(-0.822, 0.624)</td>
          <td>-40.67</td>
          <td>First saddle point</td>
      </tr>
      <tr>
          <td>S2</td>
          <td>(0.212, 0.293)</td>
          <td>-72.25</td>
          <td>Second saddle point</td>
      </tr>
  </tbody>
</table>
<p>All values from Table 1 of Müller &amp; Brown (1979).</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-potential-surface.webp"
         alt="Müller-Brown Potential Energy Surface showing the three minima (dark blue regions) and two saddle points"
         title="Müller-Brown Potential Energy Surface showing the three minima (dark blue regions) and two saddle points"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Müller-Brown potential energy surface showing the three minima (dark blue regions) and two saddle points.</figcaption>
    
</figure>

<h3 id="key-challenge-curved-reaction-pathways">Key Challenge: Curved Reaction Pathways</h3>
<p>The path from the deep reactant minimum (MA) to the product minimum (MB) follows a curved two-step pathway:</p>
<ol>
<li><strong>MA → S1 → MC</strong>: First transition over the higher, rate-limiting barrier (S1) into an intermediate basin</li>
<li><strong>MC → S2 → MB</strong>: Second transition over a much lower barrier (S2) to the product</li>
</ol>
<p>This curved pathway breaks linear interpolation methods. Algorithms that draw a straight line from reactant to product miss both the intermediate minimum and the correct transition states, climbing over much higher energy regions instead.</p>
<h2 id="why-it-works-as-a-benchmark">Why It Works as a Benchmark</h2>
<p>The Müller-Brown potential has served as a computational chemistry benchmark for over four decades because of four key characteristics:</p>
<p><strong>Low dimensionality</strong>: As a 2D surface, it permits complete visualization of the landscape, clearly revealing why specific algorithms succeed or fail.</p>
<p><strong>Analytical form</strong>: Energy and gradient calculations cost virtually nothing, enabling exhaustive testing impossible with quantum mechanical surfaces.</p>
<p><strong>Non-trivial topology</strong>: The curved minimum energy path and shallow intermediate minimum challenge sophisticated methods while remaining manageable.</p>
<p><strong>Known ground truth</strong>: All minima and saddle points are precisely known, providing unambiguous success metrics.</p>
<h3 id="contrast-with-other-benchmarks">Contrast with Other Benchmarks</h3>
<p>The Müller-Brown potential provides distinct evaluation metrics compared to other classic potentials. The Lennard-Jones potential serves as the standard benchmark for equilibrium properties due to its single energy minimum. In parallel, Müller-Brown explicitly models reactive landscapes. Its multiple minima and connecting barriers create an evaluation environment for algorithms designed to discover transition states and reaction paths.</p>
<h2 id="historical-applications">Historical Applications</h2>
<p>The potential has evolved with the field&rsquo;s changing focus:</p>
<p><strong>1980s-1990s</strong>: Testing path-finding methods like Nudged Elastic Band (NEB), which creates discrete representations of reaction pathways and optimizes them to find minimum energy paths.</p>
<p><strong>2000s-2010s</strong>: Validating Transition Path Sampling (TPS) methods that harvest statistical ensembles of reactive trajectories.</p>
<p><strong>2020s</strong>: Benchmarking machine learning models and generative approaches that learn to sample transition paths or approximate potential energy surfaces.</p>
<h2 id="modern-applications-in-machine-learning">Modern Applications in Machine Learning</h2>
<p>The rise of machine learning has given the Müller-Brown potential renewed purpose. Modern <strong>Machine Learning Interatomic Potentials (MLIPs)</strong> aim to bridge the gap between quantum mechanical accuracy and classical force field efficiency by training flexible models on expensive quantum chemistry data.</p>
<p>The Müller-Brown potential provides an ideal benchmarking solution: an exactly known potential energy surface that can generate unlimited, noise-free training data. This enables researchers to ask fundamental questions:</p>
<ul>
<li>How well does a given architecture learn complex, curved surfaces?</li>
<li>How many training points are needed for acceptable accuracy?</li>
<li>How does the model behave when extrapolating beyond training data?</li>
<li>Can it correctly identify minima and saddle points?</li>
</ul>
<p>The potential serves as a consistent benchmark for measuring the learning capacity of AI models.</p>
<h2 id="extensions-and-variants">Extensions and Variants</h2>
<h3 id="higher-dimensional-extensions">Higher-Dimensional Extensions</h3>
<p>The canonical Müller-Brown potential can be extended beyond two dimensions to create more challenging test cases:</p>
<p><strong>Harmonic constraints</strong>: Add quadratic wells in orthogonal dimensions while preserving the complex 2D landscape:</p>
<p>$$V_{5D}(x_1, x_2, x_3, x_4, x_5) = V(x_1, x_3) + \kappa(x_2^2 + x_4^2 + x_5^2)$$</p>
<p><strong>Collective variables (CVs)</strong>: Collective variables are low-dimensional coordinates that capture the most important degrees of freedom in a high-dimensional system. By defining CVs that mix multiple dimensions, the original surface can be embedded in higher-dimensional spaces. For instance, the active 2D coordinates $x$ and $y$ can be projected as linear combinations of $N$ arbitrary degrees of freedom ($q_i$):</p>
<p>$$ x = \sum_{i=1}^N w_{x,i} q_i \quad \text{and} \quad y = \sum_{i=1}^N w_{y,i} q_i $$</p>
<p>This constructs a complex, high-dimensional problem where an algorithm must learn to isolate the relevant active subspace (the CVs) before it can effectively optimize the topology.</p>
<p>These extensions enable systematic testing of algorithm scaling with dimensionality while maintaining known ground truth in the active subspace.</p>
<h2 id="limitations">Limitations</h2>
<p>Despite its utility, the Müller-Brown potential has fundamental limitations as a proxy for physical systems:</p>
<ul>
<li><strong>Lack of Realistic Scaling</strong>: As a purely mathematical 2D/analytical model, it cannot directly simulate the complexities of high-dimensional scaling found in many-body atomic systems.</li>
<li><strong>No Entropic Effects</strong>: In real chemical systems, entropic contributions heavily influence the free-energy landscape. The Müller-Brown potential maps energy precisely but lacks the thermal/entropic complexity of solvent or macromolecular environments.</li>
<li><strong>Trivial Topology Contrasts</strong>: While non-trivial compared to single wells, its global topology remains simpler than proper ab initio potential energy surfaces, missing features like complex bifurcations, multi-state crossings, or non-adiabatic couplings.</li>
</ul>
<h2 id="implementation-considerations">Implementation Considerations</h2>
<p>Modern implementations typically focus on:</p>
<ul>
<li><strong>Vectorized calculations</strong> for batch processing</li>
<li><strong>Analytical derivatives</strong> for gradient-based methods</li>
<li><strong>JIT compilation</strong> for performance optimization</li>
<li><strong>Automatic differentiation</strong> compatibility for machine learning frameworks</li>
</ul>
<p>The analytical nature of the potential makes it ideal for testing both classical optimization methods and modern machine learning approaches.</p>
<h2 id="resources-and-visualizations">Resources and Visualizations</h2>
<ul>
<li><a href="/muller-brown-optimized">Interactive Müller-Brown Potential Energy Surface</a> - Local visualization tool</li>
<li><a href="https://www.wolframcloud.com/objects/demonstrations/TrajectoriesOnTheMullerBrownPotentialEnergySurface-source.nb">Müller-Brown Potential Visualization (Wolfram)</a> - External Wolfram demonstration</li>
<li><a href="/posts/muller-brown-in-pytorch/">Implementing the Müller-Brown Potential in PyTorch</a> - Detailed implementation guide with performance analysis</li>
</ul>
<h2 id="related-systems">Related Systems</h2>
<p>The Müller-Brown potential belongs to a family of analytical benchmark systems used in computational chemistry. Other notable examples include:</p>
<ul>
<li><strong>Lennard-Jones potential</strong>: Single-minimum benchmark for equilibrium properties</li>
<li><strong>Double-well potentials</strong>: Simple models for bistable systems</li>
<li><strong>Eckart barrier</strong>: One-dimensional tunneling benchmark</li>
<li><strong>Wolfe-Quapp potential</strong>: Higher-dimensional extension with valley-ridge inflection points</li>
</ul>
<h2 id="conclusion">Conclusion</h2>
<p>The Müller-Brown potential demonstrates how a well-designed benchmark can evolve with a field. Originating from 1970s computational constraints to test algorithms when quantum chemistry calculations were expensive, its topology causes naive linear-interpolation approaches to fail while maintaining instantaneous computational execution. Because of this, it remains a heavily analyzed benchmark system today.</p>
<p>It serves specific purposes in the machine learning era by providing a controlled environment for developing methods targeted at complex realistic molecular systems. Its evolution from a practical surrogate model to a machine learning benchmark demonstrates the continued relevance of foundational analytical test cases in computational science.</p>
]]></content:encoded></item><item><title>Müller-Brown Potential: A PyTorch ML Testbed</title><link>https://hunterheidenreich.com/projects/muller-brown-pytorch/</link><pubDate>Wed, 27 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/projects/muller-brown-pytorch/</guid><description>A PyTorch testbed for the Müller-Brown potential: BAOAB Langevin dynamics, torch.compile analytical forces, and a statistical-mechanics validation suite.</description><content:encoded><![CDATA[<h2 id="overview">Overview</h2>
<p>This project implements the classic 2D Müller-Brown potential in PyTorch as a ground-truth testbed for machine-learning-in-molecular-dynamics (ML-MD) work. The potential is a <code>torch.nn.Module</code> that computes forces two ways: a hand-derived analytical gradient (the default, compiled with <code>torch.compile</code>) and <code>torch.autograd.grad</code> (a reference the analytical path is checked against). On an Apple M1 Max, the analytical kernel runs about 4x faster than autograd (3-7x depending on batch size; 100 warm-up iterations, then the median of 5 runs of 1000), because it skips autograd&rsquo;s graph construction inside the force loop.</p>
<p>The energy is deliberately left uncompiled so that second derivatives (the Hessian via autograd) keep working, since <code>torch.compile</code> does not support double-backward; the force, the hot path, is the compiled function.</p>
<h2 id="features">Features</h2>
<ul>
<li><strong>Dual force kernels</strong>: a hand-derived analytical gradient (compiled) for fast simulation, and an autograd mode for differentiation and as the correctness reference the analytical path is tested against.</li>
<li><strong>BAOAB Langevin integrator</strong>: the BAOAB splitting scheme (Leimkuhler &amp; Matthews, 2013), which solves the friction-plus-noise step exactly and samples the canonical distribution accurately (exactly so for a harmonic oscillator).</li>
<li><strong>Device-agnostic</strong>: potential, forces, and simulation are plain PyTorch tensor operations that run on CPU or CUDA; the included benchmark measures CPU.</li>
<li><strong>Modular architecture</strong>: physics (<code>MuellerBrownPotential</code>), numerics (<code>LangevinSimulator</code>), visualization, and HDF5 I/O are separated, with a CLI orchestrating demo, single-run, batch, and plot-regeneration modes.</li>
</ul>
<h2 id="usage">Usage</h2>
<p>The package installs editable with <code>uv sync</code> and imports as a normal package:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> muller_brown <span style="color:#f92672">import</span> MuellerBrownPotential, LangevinSimulator
</span></span></code></pre></div><p>It provides a fast, differentiable Müller-Brown potential and a Langevin sampler for testing ML-MD algorithms against a known-exact surface.</p>
<h2 id="results">Results</h2>
<h3 id="architecture">Architecture</h3>
<ul>
<li><strong>Physics module</strong>: the energy surface is a <code>torch.nn.Module</code> with the potential parameters held as registered buffers, so device and dtype move with the module.</li>
<li><strong>Analytical force kernel</strong>: the analytical Jacobian is implemented directly and compiled with <code>torch.compile(dynamic=True)</code>, bypassing autograd-graph construction during long simulations.</li>
<li><strong>Vectorized execution</strong>: kernel operations are vectorized over particles, so an ensemble runs in roughly the same wall time as a single particle (the per-step cost is dominated by the fixed force call and noise draw).</li>
<li><strong>Device-agnostic</strong>: all operations move to CUDA via native tensor handling; the benchmark and tests run on CPU.</li>
</ul>
<h3 id="performance">Performance</h3>
<p>A force-throughput benchmark (analytical vs autograd) across batch sizes from 2 to roughly 50,000 particles, on an Apple M1 Max:</p>
<ul>
<li>The analytical kernel is about 4x faster than autograd (3-7x across batch sizes).</li>
<li>Per-particle force time drops below 1 microsecond at large batch sizes.</li>
<li>Throughput rises with batch size and saturates for large ensembles.</li>
</ul>
<h3 id="validation">Validation</h3>
<p>The sampler is checked against statistical mechanics, not just run:</p>
<ul>
<li><strong>Deterministic tests</strong>: the documented minima and saddles have the correct Hessian signatures; the analytical force matches <code>torch.autograd.grad</code>; energy is conserved in the frictionless (NVE) limit; <code>float32</code> matches <code>float64</code>; and HDF5 round-trips preserve the data.</li>
<li><strong>Statistical tests</strong>: the sampler reproduces equipartition, the harmonic-oscillator distributions, and the Müller-Brown Boltzmann mean energy against a grid-integrated reference; a separate convergence study confirms the integrator&rsquo;s kinetic-temperature bias vanishes as the timestep squared.</li>
</ul>
<h3 id="molecular-dynamics">Molecular Dynamics</h3>
<p>Langevin simulations on the surface show particle motion within energy basins, thermal fluctuations around the minima, and barrier-crossing transitions between wells, visualized as trajectories on the potential surface.</p>
<h2 id="simulation-videos">Simulation Videos</h2>
<p>These videos demonstrate Langevin dynamics simulations on the Müller-Brown potential surface:</p>
<p><div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/woVM90qXUQs?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<strong>A Basin Dynamics</strong>: Particle motion and thermal fluctuations around the A minimum.</p>
<p><div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/gdAHme07bGs?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<strong>B Basin Dynamics</strong>: Exploration of the deeper B minimum energy well.</p>
<p><div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/dVFe_4KZbps?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<strong>Transition Path</strong>: Particle transitioning between energy basins, demonstrating barrier crossing.</p>
<h2 id="related-work">Related Work</h2>
<p>This implementation is documented in detail in:</p>
<ul>
<li><a href="/posts/muller-brown-in-pytorch/">Implementing the Müller-Brown Potential in PyTorch</a></li>
<li><a href="/videos/muller-brown-basin-ma-simulation/">Basin A Simulation</a></li>
<li><a href="/videos/muller-brown-basin-mb-simulation/">Basin B Simulation</a></li>
<li><a href="/videos/muller-brown-transition-simulation/">Transition Path Simulation</a></li>
</ul>
]]></content:encoded></item><item><title>Implementing the Müller-Brown Potential in PyTorch</title><link>https://hunterheidenreich.com/posts/muller-brown-in-pytorch/</link><pubDate>Wed, 27 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/muller-brown-in-pytorch/</guid><description>Guide to implementing the Müller-Brown potential in PyTorch, comparing analytical vs automatic differentiation with performance analysis.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>The Müller-Brown potential reads, in hindsight, like an adversarial example for optimization algorithms.</p>
<p>Designed in 1979 to break naive path-finding methods, this deceptively simple 2D surface features deep minima, high barriers, and tricky saddle points. For nearly five decades, it has served as a ground-truth benchmark for computational chemistry.</p>
<p>Today, it finds new life as a testbed for machine learning. In the 1970s, chemists struggled to find transition states, saddle points where standard gradient descent fails catastrophically. Modern machine learning engineers face a strikingly similar challenge: escaping saddle points in high-dimensional loss landscapes. The Müller-Brown potential was the original stress test for these algorithms, and it remains a perfect, low-cost sandbox for benchmarking modern optimizers.</p>
<p>Whether you&rsquo;re training neural network potentials or benchmarking reinforcement learning agents for exploration, the Müller-Brown potential offers a fast, noise-free, and mathematically exact environment.</p>
<p>In this guide, we&rsquo;ll implement it in PyTorch, work through the engineering trade-off between <strong>analytical derivatives</strong> and <strong>Autograd</strong>, and measure the roughly <strong>4x</strong> force-evaluation speedup of the compiled analytical kernel over the Autograd reference.</p>
<h2 id="the-problem-finding-saddle-points-in-the-1970s">The Problem: Finding Saddle Points in the 1970s</h2>
<p>In the 1970s, finding energy minima was straightforward: follow the gradient downhill. But finding transition states proved much more challenging. These saddle points are maxima along the reaction coordinate but minima in all other directions, like standing at the top of a mountain pass.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-saddle.webp"
         alt="Diagram showing gradient descent getting stuck at a saddle point, where the surface curves up in one direction and down in another"
         title="Diagram showing gradient descent getting stuck at a saddle point, where the surface curves up in one direction and down in another"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The saddle point problem: standard gradient descent sees a minimum in one direction but a maximum in another. Modern optimizers like SGD and Adam face this same challenge in high-dimensional loss landscapes.</figcaption>
    
</figure>

<p>Standard first-order optimizers, whether 1970s simplex methods or modern SGD and Adam, are designed to minimize loss functions blindly. Point them at a saddle point, and they slide into the nearest valley. The gradients vanish or point in misleading directions. Specialized algorithms were needed to navigate this mixed landscape of ups and downs.</p>
<p>The computational reality made this worse. Early quantum chemistry programs like ATMOL and Gaussian made energy calculations possible, but each computation was expensive. Gradients required even more resources, and second derivatives were rarely computed.</p>
<p>This created a catch-22: sophisticated algorithms were needed to find saddle points, but researchers couldn&rsquo;t afford to test them on real molecular systems. Every calculation represented a major investment of time and computational resources.</p>
<h2 id="müller-and-browns-solution">Müller and Brown&rsquo;s Solution</h2>
<p>Müller and Brown&rsquo;s insight<sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup> was to create a simple analytical test function that captured the essential difficulties of real chemical systems without the computational cost. Their potential offered three key advantages:</p>
<ul>
<li><strong>Negligible computational cost</strong> - Evaluate millions of points instantly</li>
<li><strong>Analytical derivatives</strong> - Exact gradients and Hessians available immediately</li>
<li><strong>Realistic challenges</strong> - Multiple minima, saddle points, and curved pathways</li>
</ul>
<p>The clever part was the deliberate design to break naive approaches. Early methods often assumed linear paths between reactants and products. The Müller-Brown potential has a curved minimum energy path that punishes this assumption. Try to take shortcuts, and algorithms climb over high-energy barriers.</p>
<h2 id="the-mathematical-foundation">The Mathematical Foundation</h2>
<p>The Müller-Brown potential combines four two-dimensional Gaussian functions:</p>
<p>$$V(x,y) = \sum_{k=1}^{4} A_k \exp\left[a_k(x-x_k^0)^2 + b_k(x-x_k^0)(y-y_k^0) + c_k(y-y_k^0)^2\right]$$</p>
<p>Each Gaussian contributes a different &ldquo;bump&rdquo; or &ldquo;well&rdquo; to the landscape. The parameters control amplitude ($A_k$), width, orientation, and center position.</p>
<h3 id="the-standard-parameters">The Standard Parameters</h3>
<p>The specific parameter values that define the canonical Müller-Brown surface are:</p>
<table>
  <thead>
      <tr>
          <th>k</th>
          <th>$A_k$</th>
          <th>$a_k$</th>
          <th>$b_k$</th>
          <th>$c_k$</th>
          <th>$x_k^0$</th>
          <th>$y_k^0$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1</td>
          <td>-200</td>
          <td>-1</td>
          <td>0</td>
          <td>-10</td>
          <td>1</td>
          <td>0</td>
      </tr>
      <tr>
          <td>2</td>
          <td>-100</td>
          <td>-1</td>
          <td>0</td>
          <td>-10</td>
          <td>0</td>
          <td>0.5</td>
      </tr>
      <tr>
          <td>3</td>
          <td>-170</td>
          <td>-6.5</td>
          <td>11</td>
          <td>-6.5</td>
          <td>-0.5</td>
          <td>1.5</td>
      </tr>
      <tr>
          <td>4</td>
          <td>15</td>
          <td>0.7</td>
          <td>0.6</td>
          <td>0.7</td>
          <td>-1</td>
          <td>1</td>
      </tr>
  </tbody>
</table>
<p>Notice that the first three terms have negative amplitudes (creating energy wells), while the fourth has a positive amplitude (creating a barrier). The cross-term $b_k$ in the third Gaussian creates the tilted orientation that gives the surface its characteristic curved pathways.</p>
<p><a href="/muller-brown-optimized"><strong>View Interactive Müller-Brown Potential Energy Surface →</strong></a></p>
<h3 id="the-resulting-landscape">The Resulting Landscape</h3>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-potential-surface.webp"
         alt="Müller-Brown Potential Energy Surface showing the three minima (dark blue regions) and two saddle points"
         title="Müller-Brown Potential Energy Surface showing the three minima (dark blue regions) and two saddle points"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Müller-Brown potential energy surface showing the three minima (dark blue regions) and two saddle points.</figcaption>
    
</figure>

<p>This simple formula creates a surprisingly rich topography with exactly the features needed to challenge optimization algorithms:</p>
<table>
  <thead>
      <tr>
          <th><strong>Stationary Point</strong></th>
          <th><strong>Coordinates</strong></th>
          <th><strong>Energy</strong></th>
          <th><strong>Type</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MA (Reactant)</td>
          <td>(-0.558, 1.442)</td>
          <td>-146.70</td>
          <td>Deep minimum</td>
      </tr>
      <tr>
          <td>MC (Intermediate)</td>
          <td>(-0.050, 0.467)</td>
          <td>-80.77</td>
          <td>Shallow minimum</td>
      </tr>
      <tr>
          <td>MB (Product)</td>
          <td>(0.623, 0.028)</td>
          <td>-108.17</td>
          <td>Medium minimum</td>
      </tr>
      <tr>
          <td>S1</td>
          <td>(-0.822, 0.624)</td>
          <td>-40.66</td>
          <td>First saddle point</td>
      </tr>
      <tr>
          <td>S2</td>
          <td>(0.212, 0.293)</td>
          <td>-72.25</td>
          <td>Second saddle point</td>
      </tr>
  </tbody>
</table>
<h3 id="the-key-challenge-curved-pathways">The Key Challenge: Curved Pathways</h3>
<p>The path from the deep reactant minimum (MA) to the product minimum (MB) doesn&rsquo;t go directly over a single barrier. Instead, it follows a curved route:</p>
<ol>
<li><strong>MA → S1 → MC</strong>: First transition over the higher, rate-limiting barrier (S1) into an intermediate basin</li>
<li><strong>MC → S2 → MB</strong>: Second transition over a much lower barrier (S2) to the product</li>
</ol>
<p>This two-step pathway breaks linear interpolation methods. Algorithms that draw a straight line from reactant to product miss both the intermediate minimum and the correct transition states, climbing over much higher energy regions instead.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/naive-versus-minimum-path.webp"
         alt="Two-panel comparison showing naive linear interpolation versus minimum energy path. Left panel shows the contour map with both paths overlaid. Right panel shows the energy profile along each path, revealing the naive path hits a far higher barrier."
         title="Two-panel comparison showing naive linear interpolation versus minimum energy path. Left panel shows the contour map with both paths overlaid. Right panel shows the energy profile along each path, revealing the naive path hits a far higher barrier."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Why naive optimization fails: The left panel shows a straight-line path (red dashed) versus the true minimum energy path (green solid) on the potential surface. The right panel reveals the energetic cost. The naive path climbs a barrier roughly 53 reduced units higher that the curved path avoids entirely. This is the &lsquo;adversarial&rsquo; nature of the Müller-Brown surface.</figcaption>
    
</figure>

<p>The energy profile comparison makes the failure mode concrete. A naive optimizer following the red dashed path would encounter a barrier roughly <strong>53 reduced units higher</strong> than necessary (the naive summit sits at about +13 while the true path tops out near -41, the S1 saddle). The green minimum energy path navigates through the valleys, passing through the intermediate basin MC and crossing only the low-lying saddle points S1 and S2.</p>
<h2 id="why-it-works-as-a-benchmark">Why It Works as a Benchmark</h2>
<p>The Müller-Brown potential has served as a computational chemistry benchmark for over four decades because of four key characteristics:</p>
<p><strong>Low dimensionality</strong>: As a 2D surface, you can visualize the entire landscape and see exactly why algorithms succeed or fail.</p>
<p><strong>Analytical form</strong>: Energy and gradient calculations cost virtually nothing, enabling exhaustive testing impossible with quantum mechanical surfaces.</p>
<p><strong>Non-trivial topology</strong>: The curved minimum energy path and shallow intermediate minimum challenge sophisticated methods while remaining manageable.</p>
<p><strong>Known ground truth</strong>: All minima and saddle points are precisely known, providing unambiguous success metrics.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basins-of-attraction.webp"
         alt="Basins of attraction map showing which regions of the Müller-Brown surface lead to each minimum under gradient descent. Blue region flows to MA, green to MC, and yellow to MB."
         title="Basins of attraction map showing which regions of the Müller-Brown surface lead to each minimum under gradient descent. Blue region flows to MA, green to MC, and yellow to MB."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Basins of attraction: the &lsquo;optimization map&rsquo; of the Müller-Brown surface. Each color indicates which minimum a gradient descent optimizer will reach from that starting point. Saddle points (red x) sit precisely at the basin boundaries, the unstable equilibria that algorithms must navigate to find reaction paths.</figcaption>
    
</figure>

<p>This basin map reveals why the Müller-Brown potential is such an effective benchmark. Standard gradient descent from any point in the blue region inevitably falls into the deep MA minimum; from yellow, into MB. The saddle points S1 and S2 lie exactly on the boundaries between basins, infinitesimally perturbing an optimizer at these points sends it tumbling into different valleys. Finding these saddle points requires algorithms that identify and stabilize at these boundary regions.</p>
<p>What makes this particularly valuable is the contrast with other classic potentials. While the Lennard-Jones potential<sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup> serves as the benchmark for equilibrium properties with its single energy minimum, Müller-Brown explicitly models reactive landscapes. Its multiple minima and connecting barriers make it the testing ground for algorithms that find reaction paths, the methods that reveal how chemistry actually happens.</p>
<h3 id="applications-across-decades">Applications Across Decades</h3>
<p>The potential has evolved with the field&rsquo;s changing focus:</p>
<p><strong>1980s-1990s</strong>: Testing path-finding methods like Nudged Elastic Band (NEB)<sup id="fnref:3"><a href="#fn:3" class="footnote-ref" role="doc-noteref">3</a></sup>, which creates discrete representations of reaction pathways and optimizes them to find minimum energy paths.</p>
<p><strong>2000s-2010s</strong>: Validating Transition Path Sampling (TPS) methods<sup id="fnref:4"><a href="#fn:4" class="footnote-ref" role="doc-noteref">4</a></sup> that harvest statistical ensembles of reactive trajectories.</p>
<p><strong>2020s</strong>: Benchmarking machine learning models and generative approaches that learn to sample transition paths or approximate potential energy surfaces.</p>
<h2 id="modern-applications-in-machine-learning">Modern Applications in Machine Learning</h2>
<p>The rise of machine learning has given the Müller-Brown potential renewed purpose. Modern <strong>Machine Learning Interatomic Potentials (MLIPs)</strong><sup id="fnref:5"><a href="#fn:5" class="footnote-ref" role="doc-noteref">5</a></sup><sup id="fnref:6"><a href="#fn:6" class="footnote-ref" role="doc-noteref">6</a></sup> aim to bridge the gap between quantum mechanical accuracy and classical force field efficiency by training flexible models on expensive quantum chemistry data.</p>
<p>This creates a benchmarking challenge: with countless ML architectures available, how do you objectively compare them? The Müller-Brown potential provides an ideal solution, an exactly known potential energy surface that can generate unlimited, noise-free training data.</p>
<p>This enables researchers to ask fundamental questions:</p>
<ul>
<li>How well does a given architecture learn complex, curved surfaces?</li>
<li>How many training points are needed for acceptable accuracy?</li>
<li>How does the model behave when extrapolating beyond training data?</li>
<li>Can it correctly identify minima and saddle points?</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-ml-benchmark.webp"
         alt="Three-panel comparison showing the analytical Müller-Brown surface (left), a neural network at epoch 50 with high error (middle), and a converged neural network at epoch 1000 (right)"
         title="Three-panel comparison showing the analytical Müller-Brown surface (left), a neural network at epoch 50 with high error (middle), and a converged neural network at epoch 1000 (right)"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Visualizing the ML benchmark: A comparison of the analytical ground truth (left) versus a neural network potential at early (middle) and late (right) stages of training. Notice how the model in early training quickly grasps the deep minima regions but struggles significantly with the complex topography of the saddle points and energy barriers: the curved pathways are smoothed into simpler shapes. This illustrates why the Müller-Brown surface remains a challenging test case for modern architectures.</figcaption>
    
</figure>

<p>The potential has evolved from a simple model system into a <strong>reference benchmark</strong>: a fixed, exactly-known surface against which AI learning capacity is measured. Any prediction error is due to model limitations, not data quality.</p>
<p>Beyond static benchmarking, a PyTorch implementation enables <strong>differentiable simulation</strong>. Because the potential, forces, and integrator are all differentiable tensor operations, gradients can be backpropagated <em>through time</em> (via the trajectory) to optimize force field parameters or control policies directly. This capability connects classical molecular simulation to modern gradient-based machine learning in a single computational graph.</p>
<p>These benchmarking principles extend beyond abstract test cases. Real molecular dynamics simulations, such as those studying <a href="/posts/adatom-cu-diffusion/">adatom diffusion on metal surfaces</a>, face similar challenges in understanding energy landscapes and transition pathways. The Müller-Brown potential provides a controlled environment for developing methods that eventually tackle these complex realistic systems.</p>
<h2 id="extension-to-higher-dimensions">Extension to Higher Dimensions</h2>
<p>The canonical Müller-Brown potential can be extended beyond two dimensions to create more challenging test cases that better reflect real molecular systems. This extensibility demonstrates why it remains such an effective template for computational method development.</p>
<h3 id="why-higher-dimensions-matter">Why Higher Dimensions Matter</h3>
<p>Real molecules have dozens or hundreds of degrees of freedom. Understanding how algorithms scale with dimensionality is crucial for practical applications. Higher-dimensional extensions allow researchers to systematically test:</p>
<ul>
<li><strong>Algorithm scaling</strong> - Does performance degrade gracefully as dimensions increase?</li>
<li><strong>Model robustness</strong> - Do machine learning approaches maintain accuracy in high-dimensional spaces?</li>
<li><strong>Parallel efficiency</strong> - Can massively parallel methods exploit additional dimensions effectively?</li>
</ul>
<h3 id="extension-approaches">Extension Approaches</h3>
<p><strong>Harmonic constraints</strong>: Add quadratic wells in orthogonal dimensions while preserving the complex 2D landscape<sup id="fnref:7"><a href="#fn:7" class="footnote-ref" role="doc-noteref">7</a></sup>:</p>
<p>$$V_{5D}(x_1, x_2, x_3, x_4, x_5) = V(x_1, x_3) + \kappa(x_2^2 + x_4^2 + x_5^2)$$</p>
<p>The parameter $\kappa$ controls constraint strength: small values create nearly flat directions that test algorithmic efficiency.</p>
<p><strong>Collective variables</strong>: Define new coordinates that mix multiple dimensions<sup id="fnref:8"><a href="#fn:8" class="footnote-ref" role="doc-noteref">8</a></sup>:</p>
<p>$$\tilde{x} = \sqrt{x_1^2 + x_2^2 + \epsilon x_5^2},\quad \tilde{y} = \sqrt{x_3^2 + x_4^2}$$</p>
<p>where $\epsilon \ll 1$. The 5D potential becomes $V_{5D}(\tilde{x}, \tilde{y}) = V(\tilde{x}, \tilde{y})$, embedding the original surface in a higher-dimensional space.</p>
<h3 id="value-for-algorithm-development">Value for Algorithm Development</h3>
<p>This extensibility makes the Müller-Brown potential ideal for systematic testing:</p>
<ul>
<li><strong>Progressive complexity</strong>: Debug on 2D, then scale to higher dimensions</li>
<li><strong>Ground truth preservation</strong>: Known minima and saddle points remain in the active subspace</li>
<li><strong>Realistic challenges</strong>: Captures the &ldquo;needle in a haystack&rdquo; problem of transition state finding while maintaining analytical tractability</li>
</ul>
<p>These extensions transform a simple 2D benchmark into a scalable testbed for modern computational methods, probing specific challenges in high-dimensional optimization.</p>
<h2 id="implementation-in-pytorch">Implementation in PyTorch</h2>
<p>Now let&rsquo;s implement the Müller-Brown potential in PyTorch. A practical implementation needs to handle batch processing, support both analytical and automatic differentiation, and be optimized for performance.</p>
<h3 id="core-implementation">Core Implementation</h3>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> torch <span style="color:#f92672">import</span> Tensor
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">MuellerBrownPotential</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Müller-Brown potential with a torch.compile-accelerated force kernel.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        device: str <span style="color:#f92672">|</span> torch<span style="color:#f92672">.</span>device <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;cpu&#34;</span>,
</span></span><span style="display:flex;"><span>        dtype: torch<span style="color:#f92672">.</span>dtype <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>float64,
</span></span><span style="display:flex;"><span>        use_autograd: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>
</span></span><span style="display:flex;"><span>    ):
</span></span><span style="display:flex;"><span>        super()<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>use_autograd <span style="color:#f92672">=</span> use_autograd
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Standard Müller-Brown parameters</span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;A&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#f92672">-</span><span style="color:#ae81ff">200.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">100.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">170.0</span>, <span style="color:#ae81ff">15.0</span>],
</span></span><span style="display:flex;"><span>                             device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;a&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#f92672">-</span><span style="color:#ae81ff">1.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">1.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">6.5</span>, <span style="color:#ae81ff">0.7</span>],
</span></span><span style="display:flex;"><span>                             device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;b&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#ae81ff">0.0</span>, <span style="color:#ae81ff">0.0</span>, <span style="color:#ae81ff">11.0</span>, <span style="color:#ae81ff">0.6</span>],
</span></span><span style="display:flex;"><span>                             device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;c&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#f92672">-</span><span style="color:#ae81ff">10.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">10.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">6.5</span>, <span style="color:#ae81ff">0.7</span>],
</span></span><span style="display:flex;"><span>                             device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;x_centers&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#ae81ff">1.0</span>, <span style="color:#ae81ff">0.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">1.0</span>],
</span></span><span style="display:flex;"><span>                                    device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;y_centers&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#ae81ff">0.0</span>, <span style="color:#ae81ff">0.5</span>, <span style="color:#ae81ff">1.5</span>, <span style="color:#ae81ff">1.0</span>],
</span></span><span style="display:flex;"><span>                                    device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, coordinates: Tensor) <span style="color:#f92672">-&gt;</span> Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute potential energy.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> _calculate_potential(
</span></span><span style="display:flex;"><span>            coordinates, self<span style="color:#f92672">.</span>A, self<span style="color:#f92672">.</span>a, self<span style="color:#f92672">.</span>b, self<span style="color:#f92672">.</span>c,
</span></span><span style="display:flex;"><span>            self<span style="color:#f92672">.</span>x_centers, self<span style="color:#f92672">.</span>y_centers
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">force</span>(self, coordinates: Tensor) <span style="color:#f92672">-&gt;</span> Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute forces (negative gradient).&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>use_autograd:
</span></span><span style="display:flex;"><span>            coordinates <span style="color:#f92672">=</span> coordinates<span style="color:#f92672">.</span>requires_grad_(<span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>            potential <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>forward(coordinates)
</span></span><span style="display:flex;"><span>            grad <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>autograd<span style="color:#f92672">.</span>grad(potential<span style="color:#f92672">.</span>sum(), coordinates)[<span style="color:#ae81ff">0</span>]
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> <span style="color:#f92672">-</span>grad
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> _calculate_force(
</span></span><span style="display:flex;"><span>                coordinates, self<span style="color:#f92672">.</span>A, self<span style="color:#f92672">.</span>a, self<span style="color:#f92672">.</span>b, self<span style="color:#f92672">.</span>c,
</span></span><span style="display:flex;"><span>                self<span style="color:#f92672">.</span>x_centers, self<span style="color:#f92672">.</span>y_centers
</span></span><span style="display:flex;"><span>            )
</span></span></code></pre></div><p>The implementation uses <code>register_buffer</code> to store parameters, a subtle but important PyTorch best practice. This ensures that the potential parameters are automatically moved to the GPU along with the model when calling <code>.to(device)</code>, a common pitfall that leads to frustrating device mismatch errors. Beyond device placement, <code>register_buffer</code> also ensures these parameters are correctly handled during <strong>DistributedDataParallel (DDP)</strong> broadcasting, preventing silent failures when scaling training to multi-GPU clusters. The <code>use_autograd</code> flag switches between analytical and automatic differentiation.</p>
<h3 id="compiling-the-force-kernel">Compiling the Force Kernel</h3>
<p>Two decisions make the force path fast while keeping the rest flexible:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_calculate_potential</span>(coordinates: Tensor, A: Tensor, a: Tensor,
</span></span><span style="display:flex;"><span>                        b: Tensor, c: Tensor, x_centers: Tensor,
</span></span><span style="display:flex;"><span>                        y_centers: Tensor) <span style="color:#f92672">-&gt;</span> Tensor:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Energy. Left eager (uncompiled) so autograd second derivatives,
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    e.g. the Hessian, keep working: torch.compile does not support
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    double-backward, and energy is computed per save, not per step.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    coords <span style="color:#f92672">=</span> coordinates<span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)
</span></span><span style="display:flex;"><span>    x, y <span style="color:#f92672">=</span> coords[:, <span style="color:#ae81ff">0</span>], coords[:, <span style="color:#ae81ff">1</span>]
</span></span><span style="display:flex;"><span>    dx <span style="color:#f92672">=</span> x<span style="color:#f92672">.</span>unsqueeze(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> x_centers
</span></span><span style="display:flex;"><span>    dy <span style="color:#f92672">=</span> y<span style="color:#f92672">.</span>unsqueeze(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> y_centers
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    potential <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>sum(
</span></span><span style="display:flex;"><span>        A <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>exp(a <span style="color:#f92672">*</span> dx<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">+</span> b <span style="color:#f92672">*</span> dx <span style="color:#f92672">*</span> dy <span style="color:#f92672">+</span> c <span style="color:#f92672">*</span> dy<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>),
</span></span><span style="display:flex;"><span>        dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>    )
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> potential<span style="color:#f92672">.</span>view(coordinates<span style="color:#f92672">.</span>shape[:<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@torch.compile</span>(dynamic<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_calculate_force</span>(coordinates: Tensor, A: Tensor, a: Tensor,
</span></span><span style="display:flex;"><span>                    b: Tensor, c: Tensor, x_centers: Tensor,
</span></span><span style="display:flex;"><span>                    y_centers: Tensor) <span style="color:#f92672">-&gt;</span> Tensor:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Forces (negative gradient). Compiled: this is the hot path,
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    called once per simulation step.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    coords <span style="color:#f92672">=</span> coordinates<span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)
</span></span><span style="display:flex;"><span>    x, y <span style="color:#f92672">=</span> coords[:, <span style="color:#ae81ff">0</span>], coords[:, <span style="color:#ae81ff">1</span>]
</span></span><span style="display:flex;"><span>    dx <span style="color:#f92672">=</span> x<span style="color:#f92672">.</span>unsqueeze(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> x_centers
</span></span><span style="display:flex;"><span>    dy <span style="color:#f92672">=</span> y<span style="color:#f92672">.</span>unsqueeze(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> y_centers
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    exp_terms <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>exp(a <span style="color:#f92672">*</span> dx<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">+</span> b <span style="color:#f92672">*</span> dx <span style="color:#f92672">*</span> dy <span style="color:#f92672">+</span> c <span style="color:#f92672">*</span> dy<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>)
</span></span><span style="display:flex;"><span>    A_exp <span style="color:#f92672">=</span> A <span style="color:#f92672">*</span> exp_terms
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    grad_x <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>sum(A_exp <span style="color:#f92672">*</span> (<span style="color:#ae81ff">2</span> <span style="color:#f92672">*</span> a <span style="color:#f92672">*</span> dx <span style="color:#f92672">+</span> b <span style="color:#f92672">*</span> dy), dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    grad_y <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>sum(A_exp <span style="color:#f92672">*</span> (b <span style="color:#f92672">*</span> dx <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span> <span style="color:#f92672">*</span> c <span style="color:#f92672">*</span> dy), dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    forces <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>stack([<span style="color:#f92672">-</span>grad_x, <span style="color:#f92672">-</span>grad_y], dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    new_shape <span style="color:#f92672">=</span> list(coordinates<span style="color:#f92672">.</span>shape[:<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>]) <span style="color:#f92672">+</span> [<span style="color:#ae81ff">2</span>]
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> forces<span style="color:#f92672">.</span>view(new_shape)
</span></span></code></pre></div><p>The force kernel is decorated with <code>@torch.compile(dynamic=True)</code>, which traces it once and hands it to TorchInductor to fuse the pointwise operations (the exponentials and polynomial terms) and cut Python&rsquo;s dispatch overhead in the inner loop that runs millions of times per simulation. The <code>dynamic=True</code> flag keeps a single compiled trace valid across particle counts, so changing the batch size does not trigger a recompile. The energy is left eager on purpose: <code>torch.compile</code> does not support double-backward, so leaving <code>forward</code> uncompiled keeps autograd second derivatives (the Hessian) available, and since the energy is only evaluated when an observable is saved, it is not on the hot path anyway.</p>
<h3 id="performance-analytical-vs-automatic-differentiation">Performance: Analytical vs. Automatic Differentiation</h3>
<p>A key design decision is whether to use analytical derivatives or automatic differentiation. I benchmarked both on an Apple M1 Max (CPU, PyTorch 2.x). To get a stable measurement, each configuration runs 100 warm-up iterations, then the median wall-clock time over 5 runs of 1000 iterations, which filters out operating-system jitter.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-throughput-analysis.webp"
         alt="Throughput comparison showing the analytical force kernel outperforming autograd across batch sizes"
         title="Throughput comparison showing the analytical force kernel outperforming autograd across batch sizes"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Force-evaluation throughput, analytical vs autograd, across batch sizes. The analytical kernel is about 4x faster (3-7x depending on batch size).</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-time-per-particle.webp"
         alt="Per-particle computation time showing analytical derivatives maintain sub-microsecond performance for large systems"
         title="Per-particle computation time showing analytical derivatives maintain sub-microsecond performance for large systems"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Per-particle computation time. Analytical derivatives maintain sub-microsecond performance for large systems.</figcaption>
    
</figure>

<p>The speedup comes from bypassing the computational graph that PyTorch&rsquo;s Autograd engine builds. By deriving the analytical Jacobian, we skip that machinery entirely. Every <code>autograd.grad()</code> call must:</p>
<ol>
<li><strong>Build a tape</strong> of operations during the forward pass</li>
<li><strong>Traverse the graph</strong> backward to compute gradients</li>
<li><strong>Allocate intermediate tensors</strong> for each node</li>
</ol>
<p>For iterative workloads like molecular dynamics (millions of force evaluations per trajectory), this overhead elimination is critical. The analytical kernel computes forces directly in a single fused operation, no graph, no tape, no intermediate allocations.</p>
<p><strong>When to use each approach:</strong></p>
<ul>
<li><strong>Analytical</strong>: Best for production molecular dynamics where forces are computed millions of times. The speedup directly reduces wall-clock simulation time.</li>
<li><strong>Autograd</strong>: Better for prototyping, machine learning training loops, or when implementing new potentials where correctness verification is paramount. The convenience and guaranteed accuracy often outweigh performance costs during development.</li>
</ul>
<h3 id="molecular-dynamics-simulations">Molecular Dynamics Simulations</h3>
<p>To demonstrate the PyTorch implementation in action, I performed Langevin dynamics simulations in different energy basins. These simulations reveal how particles behave when confined to different regions of the potential energy surface.</p>
<h4 id="simulation-parameters">Simulation Parameters</h4>
<p>I ran 3600 time steps with a 0.01 time unit step size, using a friction coefficient of 1.0 and temperature of 25.0 in reduced units. The simulations started from equilibrium positions within each basin and show the characteristic thermal fluctuations around local minima.</p>
<h4 id="basin-ma-deep-reactant-minimum">Basin MA: Deep Reactant Minimum</h4>
<p>The deepest energy well (-146.70 in reduced units) shows highly constrained motion due to the steep energy barriers surrounding it.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-ma-position-distributions.webp"
         alt="Position distributions in Basin MA showing tight confinement"
         title="Position distributions in Basin MA showing tight confinement"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Position distributions in Basin MA. The particle remains tightly confined around (-0.558, 1.442) due to the deep potential well.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-ma-time-series.webp"
         alt="Time evolution of coordinates in Basin MA"
         title="Time evolution of coordinates in Basin MA"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Time series showing small-amplitude oscillations around the equilibrium position. The deep well severely restricts thermal motion.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-ma-trajectory.webp"
         alt="Trajectory overlaid on potential surface for Basin MA"
         title="Trajectory overlaid on potential surface for Basin MA"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Trajectory visualization showing the particle&rsquo;s motion confined to a small region around the minimum. High energy barriers prevent escape on this time scale.</figcaption>
    
</figure>

<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/woVM90qXUQs?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<h4 id="basin-mb-product-minimum">Basin MB: Product Minimum</h4>
<p>The product minimum (-108.17 in reduced units) shows intermediate behavior between the deep reactant well and shallow intermediate basin.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-mb-position-distributions.webp"
         alt="Position distributions in Basin MB showing moderate confinement"
         title="Position distributions in Basin MB showing moderate confinement"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Position distributions in Basin MB. The particle shows moderate thermal motion around (0.623, 0.028), with confinement between the deep MA basin and shallow MC basin.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-mb-time-series.webp"
         alt="Time evolution showing moderate amplitude fluctuations in Basin MB"
         title="Time evolution showing moderate amplitude fluctuations in Basin MB"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Time series demonstrating moderate amplitude fluctuations. The particle explores a region larger than MA but more constrained than the shallow MC basin.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-mb-trajectory.webp"
         alt="Basin MB trajectory showing balanced exploration"
         title="Basin MB trajectory showing balanced exploration"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The trajectory shows balanced thermal exploration within the product basin. The moderate well depth allows reasonable sampling while maintaining basin confinement.</figcaption>
    
</figure>

<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/gdAHme07bGs?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<h4 id="transition-example">Transition Example</h4>
<p>Running a longer simulation demonstrates transitions between basins:</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-transition-trajectory.webp"
         alt="Transition trajectory between basins"
         title="Transition trajectory between basins"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The trajectory illustrates the particle&rsquo;s movement between the different basins, highlighting the energy barriers and pathways involved.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-transition-time-series.webp"
         alt="Time evolution of positions during transition"
         title="Time evolution of positions during transition"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Time series showing the evolution of the particle&rsquo;s position during the transition between basins.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-transition-position-distributions.webp"
         alt="Transition trajectory on potential surface"
         title="Transition trajectory on potential surface"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The trajectory on the potential surface highlights the energy landscape the particle navigates during the transition.</figcaption>
    
</figure>

<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/dVFe_4KZbps?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<h3 id="integration-with-modern-workflows">Integration with Modern Workflows</h3>
<p>The PyTorch implementation integrates naturally with machine learning workflows. It can serve as a component in larger computational graphs, enable gradient-based optimization, and bridge classical molecular simulation with deep learning approaches.</p>
<p>For readers interested in broader molecular ML applications, this implementation pairs well with other molecular representation methods. The <a href="/posts/molecular-descriptor-coulomb-matrix/">Coulomb matrix approach</a> offers complementary perspectives on encoding molecular structure, while the <a href="/posts/kabsch-algorithm/#the-math">Kabsch algorithm</a> provides essential tools for structural alignment.</p>
<p>The complete implementation is available on <a href="https://github.com/hunter-heidenreich/Muller-Brown-Potential">GitHub</a><sup id="fnref:9"><a href="#fn:9" class="footnote-ref" role="doc-noteref">9</a></sup><sup id="fnref:10"><a href="#fn:10" class="footnote-ref" role="doc-noteref">10</a></sup><sup id="fnref:11"><a href="#fn:11" class="footnote-ref" role="doc-noteref">11</a></sup>, including benchmarking scripts, visualization tools, the test suite, and examples for optimization and molecular dynamics.</p>
<h2 id="conclusion">Conclusion</h2>
<p>The Müller-Brown potential exemplifies how a well-designed benchmark can evolve with a field. Born from 1970s computational constraints, it provided a simple way to test algorithms when quantum chemistry calculations were expensive. Its clever design, simple enough to compute instantly, complex enough to break naive approaches, made it invaluable for algorithm development.</p>
<p>Today, it serves new purposes in the machine learning era. This PyTorch implementation pairs a hand-derived analytical force kernel (compiled with <code>torch.compile</code>) with an autograd reference, a BAOAB Langevin sampler, and a test suite that checks the sampler against the canonical distribution. The analytical kernel&rsquo;s roughly 4x advantage matters for intensive simulations, while PyTorch&rsquo;s flexibility enables integration with neural network potentials and enhanced sampling methods.</p>
<p>The potential&rsquo;s evolution from practical necessity to pedagogical tool to machine learning benchmark demonstrates the value of foundational test cases. As computational chemistry continues evolving, reliable standards like the Müller-Brown potential become even more important for rigorous method development and comparison.</p>
<p>For the complete implementation with benchmarking scripts, the BAOAB Langevin simulator, and visualization tools, see the <a href="https://github.com/hunter-heidenreich/Muller-Brown-Potential">GitHub repository</a>. The full project with architecture details and performance results is at the <a href="/projects/muller-brown-pytorch/">Müller-Brown Potential: A PyTorch ML Testbed project page</a>.</p>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p>Müller, K., &amp; Brown, L. D. (1979). Location of saddle points and minimum energy paths by a constrained simplex optimization procedure. <em>Theoretica Chimica Acta</em>, 53, 75-93. <a href="https://link.springer.com/article/10.1007/BF00547608">https://link.springer.com/article/10.1007/BF00547608</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:2">
<p>Lennard-Jones, J. E. (1931). Cohesion. <em>Proceedings of the Physical Society</em>, 43(5), 461-482. <a href="https://doi.org/10.1088/0959-5309/43/5/301">https://doi.org/10.1088/0959-5309/43/5/301</a>&#160;<a href="#fnref:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:3">
<p>Henkelman, G., &amp; Jónsson, H. (2000). Improved tangent estimate in the nudged elastic band method for finding minimum energy paths and saddle points. <em>Journal of Chemical Physics</em>, 113(22), 9901-9904. <a href="https://doi.org/10.1063/1.1329672">https://doi.org/10.1063/1.1329672</a>&#160;<a href="#fnref:3" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:4">
<p>Dellago, C., Bolhuis, P. G., Csajka, F. S., &amp; Chandler, D. (1998). Transition path sampling and the calculation of rate constants. <em>Journal of Chemical Physics</em>, 108(5), 1964-1977. <a href="https://doi.org/10.1063/1.475562">https://doi.org/10.1063/1.475562</a>&#160;<a href="#fnref:4" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:5">
<p>Behler, J., &amp; Parrinello, M. (2007). Generalized neural-network representation of high-dimensional potential-energy surfaces. <em>Physical Review Letters</em>, 98(14), 146401. <a href="https://doi.org/10.1103/PhysRevLett.98.146401">https://doi.org/10.1103/PhysRevLett.98.146401</a>&#160;<a href="#fnref:5" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:6">
<p>Smith, J. S., Isayev, O., &amp; Roitberg, A. E. (2017). ANI-1: an extensible neural network potential with DFT accuracy at force field computational cost. <em>Chemical Science</em>, 8(4), 3192-3203. <a href="https://doi.org/10.1039/C6SC05720A">https://doi.org/10.1039/C6SC05720A</a>&#160;<a href="#fnref:6" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:7">
<p>Sipka, M., Dietschreit, J. C. B., Grajciar, L., &amp; Gómez-Bombarelli, R. (2023). Differentiable simulations for enhanced sampling of rare events. In <em>International Conference on Machine Learning</em> (pp. 31990-32007). PMLR.&#160;<a href="#fnref:7" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:8">
<p>Sun, L., Vandermause, J., Batzner, S., Xie, Y., Clark, D., Chen, W., &amp; Kozinsky, B. (2022). Multitask machine learning of collective variables for enhanced sampling of rare events. <em>Journal of Chemical Theory and Computation</em>, 18(4), 2341-2353.&#160;<a href="#fnref:8" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:9">
<p>LED-Molecular Repository - Original implementation of the Müller-Brown potential. <a href="https://github.com/cselab/LED-Molecular">https://github.com/cselab/LED-Molecular</a>&#160;<a href="#fnref:9" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:10">
<p>Vlachas, P. R., Zavadlav, J., Praprotnik, M., &amp; Koumoutsakos, P. (2022). Accelerated simulations of molecular systems through learning of effective dynamics. <em>Journal of Chemical Theory and Computation</em>, 18(1), 538-549. <a href="https://pubs.acs.org/doi/10.1021/acs.jctc.1c00809">https://pubs.acs.org/doi/10.1021/acs.jctc.1c00809</a>&#160;<a href="#fnref:10" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:11">
<p>Vlachas, P. R., Arampatzis, G., Uhler, C., &amp; Koumoutsakos, P. (2022). Multiscale simulations of complex systems by learning their effective dynamics. <em>Nature Machine Intelligence</em>. <a href="https://www.nature.com/articles/s42256-022-00464-w">https://www.nature.com/articles/s42256-022-00464-w</a>&#160;<a href="#fnref:11" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded></item><item><title>Efficient DFT Hamiltonian Prediction via Adaptive Sparsity</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/efficient-dft-hamiltonian-predicton-sphnet/</link><pubDate>Sat, 23 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/efficient-dft-hamiltonian-predicton-sphnet/</guid><description>Luo et al. introduce SPHNet, using adaptive sparsity to achieve up to 7x speedup in SE(3)-equivariant Hamiltonian prediction.</description><content:encoded><![CDATA[<h2 id="core-innovation-adaptive-sparsity-in-se3-networks">Core Innovation: Adaptive Sparsity in SE(3) Networks</h2>
<p>This is a <strong>methodological paper</strong> introducing a novel architecture and training curriculum to solve efficiency bottlenecks in Geometric Deep Learning. It directly tackles the primary computational bottleneck in modern SE(3)-equivariant graph neural networks (the tensor product operation) and proposes a generalizable solution through adaptive network sparsification.</p>
<h2 id="the-computational-bottleneck-in-dft-hamiltonian-prediction">The Computational Bottleneck in DFT Hamiltonian Prediction</h2>
<p>SE(3)-equivariant networks are accurate but unscalable for DFT Hamiltonian prediction due to two key bottlenecks:</p>
<ul>
<li><strong>Atom Scaling</strong>: Tensor Product (TP) operations grow quadratically with atoms ($N^2$).</li>
<li><strong>Basis Set Scaling</strong>: Computational complexity grows with the sixth power of the angular momentum order ($L^6$). Larger basis sets (e.g., def2-TZVP) require higher orders ($L=6$), making them prohibitively slow.</li>
</ul>
<p>Existing SE(3)-equivariant models cannot handle large molecules (40-100 atoms) with high-quality basis sets, limiting their practical applicability in computational chemistry.</p>
<h2 id="sphnet-architecture-and-the-three-phase-sparsity-scheduler">SPHNet Architecture and the Three-Phase Sparsity Scheduler</h2>
<p><strong>SPHNet</strong> introduces <strong>Adaptive Sparsity</strong> to prune redundant computations at two levels:</p>
<ol>
<li><strong>Sparse Pair Gate</strong>: Learns which atom pairs to include in message passing, adapting the interaction graph based on importance.</li>
<li><strong>Sparse TP Gate</strong>: Filters which spherical harmonic triplets $(l_1, l_2, l_3)$ are computed in tensor product operations, pruning higher-order combinations that contribute less to accuracy.</li>
<li><strong>Three-Phase Sparsity Scheduler</strong>: A training curriculum (Random → Adaptive → Fixed) that enables stable convergence to high-performing sparse subnetworks.</li>
</ol>
<p>Key insight: The Sparse Pair Gate learns to preserve long-range interactions (16-25 Angstrom) at higher rates than short-range ones. Short-range pairs are abundant and easier to learn, while rare long-range interactions require more samples for accurate representation, making them more critical to retain.</p>
<h2 id="benchmarks-and-ablation-studies">Benchmarks and Ablation Studies</h2>
<p>The authors evaluated SPHNet on three datasets (MD17, QH9, and PubChemQH) with varying molecule sizes and basis set complexities. Baselines include SchNOrb, PhiSNet, QHNet, and WANet. SchNOrb and PhiSNet results are limited to MD17, as those models are designed for trajectory datasets. WANet was not open-sourced, so only partial metrics from its paper are reported.</p>
<h3 id="evaluation-metrics">Evaluation Metrics</h3>
<ul>
<li><strong>Hamiltonian MAE ($H$)</strong>: Mean absolute error between predicted and DFT-computed Hamiltonian matrices, in Hartrees ($E_h$)</li>
<li><strong>Occupied Orbital Energy MAE ($\epsilon$)</strong>: Mean absolute error of all occupied molecular orbital energies derived from the predicted Hamiltonian</li>
<li><strong>Orbital Coefficient Similarity ($\psi$)</strong>: Cosine similarity of occupied molecular orbital coefficients between predicted and reference wavefunctions</li>
</ul>
<h3 id="ablation-studies">Ablation Studies</h3>
<p><strong>Sparse Gates</strong> (on PubChemQH):</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>$H$ [$10^{-6} E_h$] $\downarrow$</th>
          <th>Memory [GB] $\downarrow$</th>
          <th>Speedup $\uparrow$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Both gates</td>
          <td>97.31</td>
          <td>5.62</td>
          <td>7.09x</td>
      </tr>
      <tr>
          <td>Pair Gate only</td>
          <td>87.70</td>
          <td>6.98</td>
          <td>2.73x</td>
      </tr>
      <tr>
          <td>TP Gate only</td>
          <td>94.31</td>
          <td>8.04</td>
          <td>3.98x</td>
      </tr>
      <tr>
          <td>Neither gate</td>
          <td>86.35</td>
          <td>10.91</td>
          <td>1.73x</td>
      </tr>
  </tbody>
</table>
<p>The Sparse Pair Gate contributes a 78% speedup with 30% memory reduction. The Sparse TP Gate (pruning 70% of combinations) yields a 160% speedup. Both gates together achieve the highest speedup, though accuracy slightly decreases compared to no gating.</p>
<p><strong>Three-Phase Scheduler</strong>: Removing the random phase causes convergence to local optima ($112.68 \pm 10.75$ vs $97.31 \pm 0.52$). Removing the adaptive phase increases variance and lowers accuracy ($122.79 \pm 19.02$). Removing the fixed phase has minimal accuracy impact but reduces speedup from 7.09x to 5.45x due to dynamic graph overhead.</p>
<p><strong>Sparsity Rate</strong>: The critical sparsity threshold scales with system complexity: 30% for MD17 (small molecules), 40% for QH9 (medium), and 70% for PubChemQH (large). Beyond the threshold, MAE increases sharply. Computational cost decreases approximately linearly with sparsity rate.</p>
<h3 id="transferability-to-other-models">Transferability to Other Models</h3>
<p>To demonstrate the speedup is architecture-agnostic, the authors applied the Sparse Pair Gate and Sparse TP Gate to the QHNet baseline on PubChemQH:</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>$H$ [$10^{-6} E_h$] $\downarrow$</th>
          <th>Memory [GB] $\downarrow$</th>
          <th>Speedup $\uparrow$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>QHNet baseline</td>
          <td>123.74</td>
          <td>22.50</td>
          <td>1.00x</td>
      </tr>
      <tr>
          <td>+ TP Gate</td>
          <td>128.16</td>
          <td>12.68</td>
          <td>2.04x</td>
      </tr>
      <tr>
          <td>+ Pair Gate</td>
          <td>126.27</td>
          <td>10.07</td>
          <td>1.66x</td>
      </tr>
      <tr>
          <td>+ Both gates</td>
          <td>128.89</td>
          <td>8.46</td>
          <td>3.30x</td>
      </tr>
  </tbody>
</table>
<p>The gates reduced QHNet&rsquo;s memory by 62% and improved speed by 3.3x with modest accuracy trade-off, confirming the gates are portable modules applicable to other SE(3)-equivariant architectures.</p>
<h2 id="performance-results">Performance Results</h2>
<h3 id="qh9-134k-molecules-leq-20-atoms">QH9 (134k molecules, $\leq$ 20 atoms)</h3>
<p>SPHNet achieves 3.3x to 4.0x speedup over QHNet across all four QH9 splits, with improved Hamiltonian MAE and orbital energy MAE. Memory drops to 0.23 GB/sample (33% of QHNet&rsquo;s 0.70 GB). On the stable-iid split, Hamiltonian MAE improves from 76.31 to 45.48 ($10^{-6} E_h$).</p>
<h3 id="pubchemqh-50k-molecules-40-100-atoms">PubChemQH (50k molecules, 40-100 atoms)</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>$H$ [$10^{-6} E_h$] $\downarrow$</th>
          <th>$\epsilon$ [$E_h$] $\downarrow$</th>
          <th>$\psi$ [$10^{-2}$] $\uparrow$</th>
          <th>Memory [GB] $\downarrow$</th>
          <th>Speedup $\uparrow$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>QHNet</td>
          <td>123.74</td>
          <td>3.33</td>
          <td>2.32</td>
          <td>22.5</td>
          <td>1.0x</td>
      </tr>
      <tr>
          <td>WANet</td>
          <td>99.98</td>
          <td><strong>1.17</strong></td>
          <td><strong>3.13</strong></td>
          <td>15.0</td>
          <td>2.4x</td>
      </tr>
      <tr>
          <td>SPHNet</td>
          <td><strong>97.31</strong></td>
          <td>2.16</td>
          <td>2.97</td>
          <td><strong>5.62</strong></td>
          <td><strong>7.1x</strong></td>
      </tr>
  </tbody>
</table>
<p>SPHNet achieves the best Hamiltonian MAE and efficiency, though WANet outperforms on orbital energy MAE and coefficient similarity. The higher speedup on PubChemQH (vs QH9) reflects greater computational redundancy in larger systems with higher-order basis sets ($L_{max} = 6$ for def2-TZVP vs $L_{max} = 4$ for def2-SVP).</p>
<h3 id="md17-small-molecule-trajectories">MD17 (Small Molecule Trajectories)</h3>
<p>SPHNet achieves accuracy comparable to QHNet and PhiSNet on four MD17 molecules (water, ethanol, malondialdehyde, uracil; 3-12 atoms). MD17 represents a simpler task where baseline models already perform well, leaving limited room for improvement. For water (3 atoms), the number of interaction combinations is inherently small, limiting the benefit of adaptive sparsification.</p>
<h3 id="scaling-limit">Scaling Limit</h3>
<p>SPHNet can train on systems with approximately 3000 atomic orbitals on a single A6000 GPU; the QHNet baseline runs out of memory at approximately 1800 orbitals. Memory consumption scales more favorably as molecule size increases.</p>
<h3 id="key-findings">Key Findings</h3>
<ul>
<li><strong>Adaptive sparsity scales with system complexity</strong>: The method is most effective for large systems where redundancy is high. For small molecules (e.g., water with only 3 atoms), every interaction is critical, so pruning hurts accuracy and yields negligible speedup.</li>
<li><strong>Long-range pair preservation</strong>: The Sparse Pair Gate selects long-range pairs (16-25 Angstrom) at higher rates than short-range ones. Short-range pairs are numerous and easier to learn, while rare long-range interactions are harder to represent and thus more critical to retain.</li>
<li><strong>Generalizable components</strong>: The sparsification techniques are portable modules, demonstrated by successful integration into QHNet with 3.3x speedup.</li>
<li><strong>Architecture ablation</strong>: Removing one Vectorial Node Interaction block or Spherical Node Interaction block significantly hurts accuracy, confirming the importance of the progressive order-increase design. Removing one Pair Construction block has less impact, suggesting room for further speedup.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/microsoft/SPHNet">SPHNet (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation; archived by Microsoft (Dec 2025), read-only</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/datasets/EperLuo/PubChemQH">PubChemQH (Hugging Face)</a></td>
          <td>Dataset</td>
          <td>MIT</td>
          <td>50k molecules, 40-100 atoms, def2-TZVP basis</td>
      </tr>
  </tbody>
</table>
<p>No pre-trained model weights are provided. MD17 and QH9 are publicly available community datasets. Training requires 4x NVIDIA A100 (80GB) GPUs; benchmarking uses a single NVIDIA RTX A6000 (46GB).</p>
<h3 id="data">Data</h3>
<p>The experiments evaluated SPHNet on three datasets with different molecular sizes and basis set complexities. All datasets use DFT calculations as ground truth, with MD17 using the PBE exchange-correlation functional and QH9/PubChemQH using B3LYP.</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Molecules</th>
          <th>Molecule Size</th>
          <th>Basis Set</th>
          <th>$L_{max}$</th>
          <th>Functional</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MD17</td>
          <td>4 systems</td>
          <td>3-12 atoms (water, ethanol, malondialdehyde, uracil)</td>
          <td>def2-SVP</td>
          <td>4</td>
          <td>PBE</td>
      </tr>
      <tr>
          <td>QH9</td>
          <td>134k</td>
          <td>$\leq$ 20 atoms (Stable/Dynamic splits)</td>
          <td>def2-SVP</td>
          <td>4</td>
          <td>B3LYP</td>
      </tr>
      <tr>
          <td>PubChemQH</td>
          <td>50k</td>
          <td>40-100 atoms</td>
          <td>def2-TZVP</td>
          <td>6</td>
          <td>B3LYP</td>
      </tr>
  </tbody>
</table>
<p><strong>Data Availability</strong>:</p>
<ul>
<li><strong>MD17 &amp; QH9</strong>: Publicly available</li>
<li><strong>PubChemQH</strong>: Publicly available on Hugging Face (<a href="https://huggingface.co/datasets/EperLuo/PubChemQH">EperLuo/PubChemQH</a>)</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Loss Function</strong>:</p>
<p>The model learns the <strong>residual</strong> $\Delta H$:</p>
<p>$$
\begin{aligned}
\Delta H &amp;= H_{\text{ref}} - H_{\text{init}} \\
\mathcal{L} &amp;= \text{MAE}(H_{\text{ref}}, H_{\text{pred}}) + \text{MSE}(H_{\text{ref}}, H_{\text{pred}})
\end{aligned}
$$</p>
<p>where $H_{\text{init}}$ is a computationally inexpensive initial guess computed via PySCF.</p>
<p><strong>Hyperparameters</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Parameter</th>
          <th>PubChemQH</th>
          <th>QH9</th>
          <th>MD17</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Batch Size</td>
          <td>8</td>
          <td>32</td>
          <td>10 (uracil: 5)</td>
      </tr>
      <tr>
          <td>Training Steps</td>
          <td>300k</td>
          <td>260k</td>
          <td>200k</td>
      </tr>
      <tr>
          <td>Warmup Steps</td>
          <td>1k</td>
          <td>1k</td>
          <td>1k</td>
      </tr>
      <tr>
          <td>Learning Rate</td>
          <td>1e-3</td>
          <td>1e-3</td>
          <td>5e-4</td>
      </tr>
      <tr>
          <td>Sparsity Rate</td>
          <td>0.7</td>
          <td>0.4</td>
          <td>0.1-0.3</td>
      </tr>
      <tr>
          <td>TSS Epoch $t$</td>
          <td>3</td>
          <td>3</td>
          <td>3</td>
      </tr>
  </tbody>
</table>
<p><strong>Sparse Pair Gate</strong>: Adapts the interaction graph. It concatenates zero-order features and inner products of atom pairs, then passes them through a linear layer $F_p$ with sigmoid activation to learn a weight $W_p^{ij}$ for every pair. Pairs are kept only if selected by the scheduler ($U_p^{TSS}$). The overhead comes primarily from the linear layer $F_p$.</p>
<p><strong>Sparse TP Gate</strong>: Filters triplets $(l_1, l_2, l_3)$ inside the TP operation. Higher-order combinations are more likely to be pruned. Complexity: $\mathcal{O}(L^3)$.</p>
<p><strong>Three-Phase Sparsity Scheduler</strong>: Training curriculum designed to optimize the sparse gates effectively:</p>
<ul>
<li><strong>Phase 1 (Random)</strong>: Random selection ($1-k$ probability) to ensure unbiased weight updates. Complexity: $\mathcal{O}(|U|)$.</li>
<li><strong>Phase 2 (Adaptive)</strong>: Selects top $(1-k)$ percent based on learned magnitude. Complexity: $\mathcal{O}(|U|\log|U|)$.</li>
<li><strong>Phase 3 (Fixed)</strong>: Freezes the connectivity mask for maximum inference speed. No overhead.</li>
</ul>
<p><strong>Weight Initialization</strong>: Learnable sparsity weights ($W$) initialized as all-ones vector.</p>
<h3 id="models">Models</h3>
<p>The model predicts the Hamiltonian matrix $H$ from atomic numbers $Z$ and coordinates $r$.</p>
<p><strong>Inputs</strong>: Atomic numbers ($Z$) and 3D coordinates.</p>
<p><strong>Backbone Structure</strong>:</p>
<ol>
<li><strong>Vectorial Node Interaction (x4)</strong>: Uses long-short range message passing. Extracts vectorial representations ($l=1$) without high-order TPs to save cost.</li>
<li><strong>Spherical Node Interaction (x2)</strong>: Projects features to high-order spherical harmonics (up to $L_{max}$). The first block increases the maximum order from 0 to $L_{max}$ without the Sparse Pair Gate; the second block applies the <strong>Sparse Pair Gate</strong> to filter node pairs.</li>
<li><strong>Pair Construction Block (x2)</strong>: Splits into <strong>Diagonal</strong> (self-interaction) and <strong>Non-Diagonal</strong> (cross-interaction) blocks. Both use the <strong>Sparse TP Gate</strong> to prune cross-order combinations $(l_1, l_2, l_3)$. The Non-Diagonal blocks also use the <strong>Sparse Pair Gate</strong> to filter atom pairs. The two Pair Construction blocks receive representations from the two Spherical Node Interaction blocks respectively, and their outputs are summed.</li>
<li><strong>Expansion Block</strong>: Reconstructs the full Hamiltonian matrix from the sparse irreducible representations, exploiting symmetry ($H_{ji} = H_{ij}^T$) to halve computations.</li>
</ol>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training</strong>: 4x NVIDIA A100 (80GB)</li>
<li><strong>Benchmarking</strong>: Single NVIDIA RTX A6000 (46GB)</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Luo, E., Wei, X., Huang, L., Li, Y., Yang, H., Xia, Z., Wang, Z., Liu, C., Shao, B., &amp; Zhang, J. (2025). Efficient and Scalable Density Functional Theory Hamiltonian Prediction through Adaptive Sparsity. <em>Proceedings of the 42nd International Conference on Machine Learning</em>, PMLR 267:41368&ndash;41390.</p>
<p><strong>Publication</strong>: ICML 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{luo2025efficient,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Efficient and Scalable Density Functional Theory Hamiltonian Prediction through Adaptive Sparsity}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Luo, Erpai and Wei, Xinran and Huang, Lin and Li, Yunyang and Yang, Han and Xia, Zaishuo and Wang, Zun and Liu, Chang and Shao, Bin and Zhang, Jia}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 42nd International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{41368--41390}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{267}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">series</span>=<span style="color:#e6db74">{Proceedings of Machine Learning Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{PMLR}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://icml.cc/virtual/2025/poster/45656">ICML 2025 poster page</a></li>
<li><a href="https://openreview.net/forum?id=K3lykWhXON">OpenReview forum</a></li>
<li><a href="https://openreview.net/pdf?id=K3lykWhXON">PDF on OpenReview</a></li>
<li><a href="https://github.com/microsoft/SPHNet">GitHub Repository</a> <em>(Note: The official repository was archived by Microsoft in December 2025. It is available for reference but no longer actively maintained.)</em></li>
</ul>
]]></content:encoded></item><item><title>Contrastive Learning for Variational Autoencoder Priors</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/contrastive-learning-for-vae-priors/</link><pubDate>Sun, 17 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/contrastive-learning-for-vae-priors/</guid><description>Aneja et al.'s NeurIPS 2021 paper introducing Noise Contrastive Priors (NCPs) to address VAE's 'prior hole' problem with energy-based priors.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>method paper</strong> that introduces a training approach for Variational Autoencoders (VAEs) to address fundamental limitations in their generative quality through improved prior learning.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The work is motivated by a critical limitation in Variational Autoencoders known as the <strong>&ldquo;prior hole&rdquo; problem</strong>, where the prior distribution p(z) fails to match the aggregate approximate posterior q(z). This mismatch leads to areas in the latent space with high density under the prior that don&rsquo;t map to realistic data samples, resulting in poor generative quality compared to GANs and other generative models.</p>















<figure class="post-figure center ">
    <img src="/img/notes/vae-prior-hole-problem-illustrated.webp"
         alt="Visualization of the VAE prior hole problem showing a ring-shaped aggregate posterior q(z) with an empty center, while the standard Gaussian prior p(z) has highest density at the center where no data exists"
         title="Visualization of the VAE prior hole problem showing a ring-shaped aggregate posterior q(z) with an empty center, while the standard Gaussian prior p(z) has highest density at the center where no data exists"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The &lsquo;prior hole&rsquo; problem: the standard Gaussian prior (red dashed contours) assigns highest probability to the center, but the aggregate posterior (blue dots) forms a ring with no data in that region.</figcaption>
    
</figure>

<p>The figure above illustrates this mismatch. The blue dots represent where a trained encoder actually places data in the latent space (the aggregate posterior $q(z)$), which often forms complex, non-Gaussian shapes. The red dashed contours show the standard Gaussian prior $p(z) = \mathcal{N}(0, I)$, which assumes data is centered at the origin. When generating new samples, we draw from this prior, making it likely to sample from the empty &ldquo;hole&rdquo; where the decoder has never seen training data, producing unrealistic outputs.</p>
<p>A natural question arises: the prior $p(z)$ is used for <em>sampling</em> at inference time, so why does learning a better prior also improve <em>likelihood</em> (NLL)? The answer lies in the VAE objective. VAEs maximize the Evidence Lower Bound (ELBO):</p>
<p>$$ \log p(x) \geq \mathcal{L}_{\text{ELBO}}(x) = \underbrace{\mathbb{E}_{q(z|x)}[\log p(x|z)]}_{\text{Reconstruction}} - \underbrace{\text{KL}(q(z|x) \parallel p(z))}_{\text{Regularization}} $$</p>
<p>The KL divergence term penalizes the mismatch between each data point&rsquo;s approximate posterior $q(z|x)$ and the prior $p(z)$. When the prior is a simple Gaussian but the aggregate posterior forms a complex shape (as in the figure above), this KL term remains unnecessarily high for every data point.</p>
<p>By replacing the simple prior with a learned $p_{\text{NCP}}(z)$ that matches the aggregate posterior, the KL penalty decreases, tightening the ELBO and improving NLL. The learned prior thus provides a <strong>unified solution</strong>: better likelihood during training (tighter bound) and better sampling at inference (no &ldquo;holes&rdquo;).</p>
<p>The OpenReview discussion contains a significant theoretical debate regarding the paper&rsquo;s core premise. Reviewers argued that the &ldquo;prior hole&rdquo; problem is actually a failure of the posterior to match the prior, or a failure of the encoder. The authors defended their approach by noting that even with a perfect posterior, a simple Normal prior might fail because the decoder lacks capacity to map a simple distribution to complex data without dropping modes. This justifies fixing the prior by making it learned and complex.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The authors propose an <strong>energy-based model (EBM) prior</strong> that is trained using <strong>Noise Contrastive Estimation (NCE)</strong>, which they term a <strong>Noise Contrastive Prior (NCP)</strong>. The key innovations are:</p>
<ul>
<li><strong>Two-Stage Training Process</strong>: First, a standard VAE is trained with a simple base prior. Then, the VAE weights are frozen and a binary classifier learns to distinguish between samples from the aggregate posterior q(z) and the base prior p(z).</li>
<li><strong>Reweighting Strategy</strong>: The core idea is to reweight a base prior distribution p(z) with a learned reweighting factor r(z) to make the resulting prior $p_{\text{NCP}}(z)$ better match the aggregate posterior q(z).</li>
<li><strong>NCE for EBM Training</strong>: The method frames EBM training as a binary classification task to avoid computationally expensive MCMC sampling.</li>
<li><strong>Scalability to Hierarchical Models</strong>: For hierarchical VAEs with multiple latent groups, the NCP approach can be applied independently and in parallel to each group&rsquo;s conditional prior.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The method was evaluated on several standard image generation benchmarks:</p>
<ul>
<li><strong>MNIST</strong> (dynamically binarized): Likelihood evaluation on a controlled, small-latent-space task</li>
<li><strong>CIFAR-10</strong>: Standard computer vision benchmark for generative modeling</li>
<li><strong>CelebA 64x64</strong>: Applied to both standard VAE architectures and more advanced VAEs with GMM priors (RAE model)</li>
<li><strong>CelebA HQ 256x256</strong>: High-resolution face generation task</li>
</ul>
<p>The hierarchical NVAE models used 30 latent groups for CIFAR-10 and CelebA-64, 20 groups for CelebA-HQ-256, and 10 groups of $4 \times 4$ latent variables for MNIST (deliberately small to enable reliable partition function estimation). The experiments compared FID scores, likelihood metrics, and qualitative sample quality between baseline VAEs and NCP-enhanced versions, with particular focus on hierarchical VAEs (NVAE).</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<p>The proposed NCP method demonstrated improvements in generative quality across evaluated datasets, with modest gains on standard VAEs and particularly large gains on hierarchical models like NVAE:</p>
<ul>
<li><strong>CelebA-64</strong>: NCP improved FID scores from 48.12 to 41.28 for standard VAEs, and from 40.95 to 39.00 for RAE models with GMM priors.</li>
<li><strong>Hierarchical Models (NVAE)</strong>: The impact was particularly pronounced on hierarchical VAEs:
<ul>
<li><strong>CIFAR-10</strong>: FID improved from 51.71 to 24.08</li>
<li><strong>CelebA-64</strong>: FID improved from 13.48 to 5.25, making it competitive with GANs</li>
<li><strong>CelebA HQ 256x256</strong>: FID reduced from 40.26 to 24.79</li>
</ul>
</li>
<li><strong>Likelihood Performance</strong>: On MNIST, NCP-VAE achieved 78.10 nats NLL vs. baseline NVAE&rsquo;s 78.67 nats</li>
</ul>
<p>On CIFAR-10 and CelebA-HQ-256, the concurrent VAEBM method (which forms an EBM on the data space rather than the latent space) outperforms NCP-VAE. However, the authors argue the two approaches are complementary: NCP-VAE targets the latent space while VAEBM operates in data space, and combining them could yield further gains. NCP-VAE also has the advantage of applicability to discrete data (e.g., binarized MNIST) and simpler setup since it only requires training binary classifiers rather than MCMC-based training and sampling.</p>
<p>The key conclusions are that <strong>two-stage training with noise contrastive estimation</strong> provides an effective framework for learning expressive energy-based priors that addresses the prior hole problem while scaling efficiently to hierarchical models.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://drive.google.com/drive/folders/15tCGruQcSdm2G4yLkUpKvGASluSZPIBD">Code (Google Drive)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Official implementation; hosted on Google Drive (may become inaccessible)</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://openreview.net/forum?id=LcSfRundgwI">OpenReview</a></td>
          <td style="text-align: left">Other</td>
          <td style="text-align: left">N/A</td>
          <td style="text-align: left">Reviews, author responses, and supplementary material</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<h4 id="the-reweighting-mechanism">The Reweighting Mechanism</h4>
<p>The core innovation is defining the NCP prior as $p_{\text{NCP}}(z) \propto p(z)r(z)$. The reweighting factor $r(z)$ is derived from the binary classifier $D(z)$ using the <strong>likelihood ratio trick</strong>:</p>
<p>$$ r(z) \approx \frac{D(z)}{1 - D(z)} $$</p>
<p>Here, $D(z)$ is the sigmoid output of the trained discriminator, representing the probability that sample $z$ came from the aggregate posterior $q(z)$ (&ldquo;real&rdquo;). For an optimal discriminator $D^*(z)$, this ratio exactly equals $\frac{q(z)}{p(z)}$, allowing the model to approximate the density ratio without explicit density estimation.</p>















<figure class="post-figure center ">
    <img src="/img/notes/ncp-vae-reweighting-the-prior-posterior.webp"
         alt="Visualization of the NCP reweighting mechanism showing three 1D distributions: q(z) the complex bimodal aggregate posterior, p(z) the simple Gaussian prior, and r(z) the learned reweighting factor that transforms p(z) to match q(z)"
         title="Visualization of the NCP reweighting mechanism showing three 1D distributions: q(z) the complex bimodal aggregate posterior, p(z) the simple Gaussian prior, and r(z) the learned reweighting factor that transforms p(z) to match q(z)"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The reweighting mechanism: the learned factor $r(z)$ (bottom) reweights the simple Gaussian prior $p(z)$ (middle) to approximate the complex aggregate posterior $q(z)$ (top). Where $q(z)$ has high density but $p(z)$ is low, $r(z)$ compensates with high values.</figcaption>
    
</figure>

<h4 id="hierarchical-architecture-strategy">Hierarchical Architecture Strategy</h4>
<p>For hierarchical models (like NVAE), the method trains $K$ binary classifiers in parallel (one for each latent group). Crucially, to ensure efficiency, the classifiers reuse the <strong>context feature</strong> $c(z_{&lt;k})$ extracted by the frozen VAE&rsquo;s prior network. This architectural choice provides significant computational savings.</p>
<h4 id="test-time-sampling-inference">Test-Time Sampling (Inference)</h4>
<p>Since $p_{\text{NCP}}(z)$ is an energy-based model, direct sampling is impossible. The paper employs two methods to generate samples:</p>
<ul>
<li><strong>Sampling-Importance-Resampling (SIR):</strong> Used for most results. It draws $M$ samples (e.g., $M=5000$) from the base prior $p(z)$ and resamples them based on weights $w^{(m)} = r(z^{(m)})$.</li>
<li><strong>Langevin Dynamics (LD):</strong> An iterative refinement method using the gradient of the energy function $E(z) = -\log r(z) - \log p(z)$.</li>
</ul>
<h3 id="models">Models</h3>
<h4 id="decoder-architecture">Decoder Architecture</h4>
<p>For RGB datasets (CIFAR-10, CelebA), the output likelihood must be changed from <strong>Discretized Logistic</strong> (standard NVAE) to a <strong>Normal distribution</strong>. The authors note this change alone led to &ldquo;significant improvements in the base model performance.&rdquo; Using the standard NVAE decoder will result in a weaker baseline than reported.</p>
<h4 id="discriminator-architecture">Discriminator Architecture</h4>
<p>The binary classifier uses a ResNet-style architecture with <strong>Squeeze-and-Excitation (SE)</strong> blocks:</p>
<ul>
<li><strong>Activation:</strong> Swish</li>
<li><strong>Normalization:</strong> Batch Normalization</li>
<li><strong>Optimization:</strong> Adam with Cosine Annealing (learning rate: $10^{-3} \to 10^{-7}$)</li>
</ul>
<p>The SE blocks help the model focus on channel-wise feature recalibration, which is important for distinguishing subtle differences between prior and aggregate posterior in high-dimensional latent spaces.</p>
<h3 id="hardware">Hardware</h3>
<p>The main paper is vague on training time, but the OpenReview rebuttal explicitly lists hardware costs:</p>
<ul>
<li><strong>Hardware:</strong> NVIDIA Tesla V100 (32GB) GPUs</li>
<li><strong>Per-Discriminator Training:</strong> ~13 hours for 100 epochs</li>
<li><strong>Parallelization:</strong> Because latent groups are independent, all discriminators can train in parallel</li>
<li><strong>Total Cost (CelebA-64):</strong> ~8.1 GPU-days</li>
<li><strong>Comparison:</strong> The authors argue this is efficient compared to VDVAE, which requires ~560 GPU-days</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<h4 id="inference-speed-vs-quality-trade-off">Inference Speed vs. Quality Trade-off</h4>
<p>Reviewers flagged that SIR sampling can be prohibitively slow. The authors clarified the exact trade-off:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Proposal Samples ($M$)</th>
          <th style="text-align: left">Time per Image</th>
          <th style="text-align: left">FID (CelebA-64)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">5,000 (paper default)</td>
          <td style="text-align: left">~10.11 seconds</td>
          <td style="text-align: left">5.25</td>
      </tr>
      <tr>
          <td style="text-align: left">500 (practical)</td>
          <td style="text-align: left">~1.25 seconds</td>
          <td style="text-align: left">6.76</td>
      </tr>
  </tbody>
</table>
<p>The quality gain from 500 to 5,000 samples is modest (FID difference of 1.51) while inference time increases roughly 8x, suggesting $M=500$ may be a practical default.</p>
<h4 id="hyperparameters">Hyperparameters</h4>
<ul>
<li><strong>FID Calculation:</strong> 50,000 samples</li>
<li><strong>SIR Proposals:</strong> 5,000 samples (paper default) or 500 (practical)</li>
<li><strong>MNIST:</strong> Dynamically binarized version used for likelihood evaluation</li>
<li><strong>Optimizers:</strong> The study largely adopts hyperparameters from baseline papers (e.g., Lawson et al. for MNIST, Ghosh et al. for RAE)</li>
</ul>
<h4 id="debugging-benchmark-25-gaussians">Debugging Benchmark: 25-Gaussians</h4>
<p>The supplement provides a toy experiment ideal for verifying a new implementation before running on expensive image datasets:</p>
<ul>
<li><strong>Setup:</strong> Synthetic dataset of 25 2D-Gaussians arranged on a grid</li>
<li><strong>Target NLL:</strong> ~-0.954 nats (NCP) vs. ~-2.753 nats (Vanilla VAE)</li>
<li><strong>Success Criterion:</strong> Samples should avoid low-density regions between grid points. A standard VAE will generate samples in these &ldquo;prior holes,&rdquo; while a working NCP implementation should cleanly remove these artifacts.</li>
</ul>
<h4 id="implementation-warnings">Implementation Warnings</h4>
<ul>
<li><strong>SIR Failure Mode:</strong> If the learned prior $p_{\text{NCP}}$ deviates too far from the base prior, SIR sampling collapses (low effective sample size). The paper shows a strong correlation between the NCE classification loss and the effective sample size (Fig. 5(b)), indicating that SIR reliability depends on how well the base prior matches the aggregate posterior.</li>
<li><strong>Temperature Scaling:</strong> The qualitative images in the paper use reduced temperature for improved visual sharpness (Section 5.3). The FID tables do not specify a temperature, so results may or may not use $T=1.0$.</li>
</ul>
<h3 id="data">Data</h3>
<p>The method was evaluated on several standard image generation benchmarks:</p>
<ul>
<li><strong>MNIST</strong> (dynamically binarized): Likelihood evaluation on a controlled, small-latent-space task</li>
<li><strong>CIFAR-10</strong>: Standard computer vision benchmark for generative modeling (32x32 RGB images)</li>
<li><strong>CelebA 64x64</strong>: Face generation task with moderate resolution</li>
<li><strong>CelebA HQ 256x256</strong>: High-resolution face generation task</li>
</ul>
<p>All datasets use standard train/test splits from the computer vision literature.</p>
<h4 id="additional-metrics">Additional Metrics</h4>
<p>Beyond FID and NLL, the paper uses:</p>
<ul>
<li><strong>Effective Sample Size (ESS):</strong> Validates reliability of the SIR sampling procedure</li>
<li><strong>Maximum Mean Discrepancy (MMD):</strong> Measures distance between aggregate posterior and NCP prior distributions</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Aneja, J., Schwing, A. G., Kautz, J., &amp; Vahdat, A. (2021). A contrastive learning approach for training variational autoencoder priors. <em>Advances in Neural Information Processing Systems</em>, 34, 29604-29616. <a href="https://proceedings.neurips.cc/paper/2021/hash/0496604c1d80f66fbeb963c12e570a26-Abstract.html">https://proceedings.neurips.cc/paper/2021/hash/0496604c1d80f66fbeb963c12e570a26-Abstract.html</a></p>
<p><strong>Publication</strong>: NeurIPS 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{aneja2021contrastive,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{A Contrastive Learning Approach for Training Variational Autoencoder Priors}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Aneja, Jyoti and Schwing, Alexander G and Kautz, Jan and Vahdat, Arash}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{34}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{29604--29616}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=LcSfRundgwI">OpenReview Discussion</a></li>
<li><a href="https://drive.google.com/drive/folders/15tCGruQcSdm2G4yLkUpKvGASluSZPIBD">Code Repository</a> (Google Drive; link may become inaccessible over time)</li>
</ul>
]]></content:encoded></item><item><title>Vectorized Word2Vec in Pure PyTorch</title><link>https://hunterheidenreich.com/projects/modern-word2vec/</link><pubDate>Sat, 16 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/projects/modern-word2vec/</guid><description>A from-scratch PyTorch Word2Vec implementation with vectorized Hierarchical Softmax, Negative Sampling, and torch.compile support.</description><content:encoded><![CDATA[<h2 id="overview">Overview</h2>
<p>Word2Vec is often treated as a &ldquo;solved problem&rdquo; or a black box inside libraries like Gensim. This project deconstructs the algorithm to treat it as a <strong>systems engineering challenge</strong>.</p>
<p>I built a ground-up, typed, and compiled PyTorch implementation that bridges the gap between the original C code&rsquo;s efficiency and modern GPU acceleration. The core innovation lies in <strong>&ldquo;tensorizing the tree&rdquo;</strong>, converting the pointer-chasing logic of Hierarchical Softmax into dense, vectorized operations compatible with <code>torch.compile</code>.</p>
<h2 id="features">Features</h2>
<h3 id="1-vectorized-hierarchical-softmax">1. Vectorized Hierarchical Softmax</h3>
<p>Classically, Hierarchical Softmax involves traversing a binary Huffman tree. While efficient on a CPU, this approach creates divergent execution paths on GPUs.</p>
<ul>
<li><strong>The Solution:</strong> I implemented a &ldquo;pre-computed path&rdquo; strategy. The tree traversal for every vocabulary word is flattened into fixed-size tensors (<code>word_path_indices</code>, <code>word_codes_tensor</code>) padded to the maximum depth.</li>
<li><strong>The Result:</strong> The forward pass becomes a massive, masked batch dot-product against internal node embeddings, allowing the GPU to crunch the probability tree without branching logic.</li>
</ul>
<h3 id="2-infinite-streaming--sliding-windows">2. Infinite Streaming &amp; Sliding Windows</h3>
<p>To handle datasets larger than RAM (e.g., Wikipedia/CommonCrawl), I built a custom <code>IterableDataset</code> that performs a true single-pass read.</p>
<ul>
<li><strong>Efficient Windowing:</strong> It uses a <code>collections.deque</code> buffer to slide over the token stream, generating training pairs only when a new token enters the center context.</li>
<li><strong>Zipfian Subsampling:</strong> Implemented a probabilistic rejection sampling layer that downsamples frequent words (like &ldquo;the&rdquo; or &ldquo;of&rdquo;) on-the-fly, strictly adhering to the original Mikolov et al. paper&rsquo;s distribution.</li>
</ul>
<h3 id="3-modern-tooling">3. Modern Tooling</h3>
<p>This project uses a strict &ldquo;software 2.0&rdquo; stack:</p>
<ul>
<li><strong>Dependency Management</strong>: Built with <code>uv</code> for deterministic, fast environment resolution.</li>
<li><strong>Compilation</strong>: Fully compatible with <code>torch.compile</code> (PyTorch 2.0+), allowing for graph fusion of the custom loss functions.</li>
</ul>
<h2 id="usage">Usage</h2>
<p>The library installs from source (clone the repo, then <code>pip install -e .</code>) and exposes a typed Python API (<code>SkipGramModel</code>, <code>CBOWModel</code>, <code>Trainer</code>, <code>Word2VecDataset</code>) alongside <code>word2vec-train</code> and <code>word2vec-query</code> CLIs, with GPU acceleration. Trained embeddings export to <code>.npy</code> for use with Gensim or other tooling.</p>
<h2 id="results">Results</h2>
<ul>
<li><strong>Correct embeddings</strong>: the produced vectors pass qualitative semantic-similarity checks (e.g., analogical reasoning), confirming the tensorized tree produces the same geometry as sequential traversal.</li>
<li><strong>Branch-free GPU execution</strong>: the batched Huffman-tree path turns hierarchical-softmax tree traversal into dense, masked tensor operations, removing the divergent branching that slows naive implementations on GPUs.</li>
<li><strong>Runs on larger-than-RAM corpora</strong>: the streaming <code>IterableDataset</code> with Zipfian subsampling processes Wikipedia/CommonCrawl-scale text in a single pass without loading the corpus into memory.</li>
<li><strong><code>torch.compile</code>-compatible</strong>: the custom loss functions are written to fuse under <code>torch.compile</code> (PyTorch 2.0+).</li>
</ul>
<h2 id="related-work">Related Work</h2>
<p>This project connects to related NLP work on this site:</p>
<ul>
<li><a href="/posts/intro-to-word-embeddings/">An Introduction to Word Embeddings</a>: conceptual background on the representations this library produces</li>
<li><a href="/research/word-company-vicinity/">Word Company Vicinity</a>: research applying word vector semantics to company names</li>
<li><a href="/research/semantic-network-induction/">Semantic Network Induction</a>: research on inducing semantic graphs from embedding spaces</li>
</ul>
]]></content:encoded></item><item><title>Modern PyTorch VAEs: A Detailed Implementation Guide</title><link>https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/</link><pubDate>Sun, 03 Mar 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/</guid><description>Complete PyTorch VAE tutorial: Copy-paste code, ELBO derivation, KL annealing, and stable softplus parameterization.</description><content:encoded><![CDATA[<h2 id="what-is-a-variational-autoencoder">What is a Variational Autoencoder?</h2>
<p>A Variational Autoencoder (VAE) is a type of <strong>generative model</strong>, meaning its primary purpose is to learn the underlying structure of a dataset so it can generate new, similar data.</p>
<p>Whether the data is images, raw audio clips, or 2D graphs of drug-like molecules, a VAE aims to capture the essential features that define the data distribution. Once trained, it should be able to create entirely new samples that resemble the training data without simply copying specific examples.</p>
<p>Introduced by Kingma and Welling in 2013 (<a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">Auto-Encoding Variational Bayes</a>, <a href="https://arxiv.org/abs/1312.6114">Paper</a>), VAEs are used for:</p>
<ul>
<li><strong>Generation</strong>: Creating new data (images, music, text).</li>
<li><strong>Dimensionality Reduction</strong>: Compressing data into a much smaller, meaningful representation (a &ldquo;latent space&rdquo;).</li>
<li><strong>Imputation</strong>: Intelligently filling in missing data (e.g., denoising images).</li>
</ul>
<p>Importantly, they aim to provide a structured and continuous latent space, which allows for smooth interpolation between data points and meaningful manipulations of generated samples (think: optimization).</p>
<h2 id="tldr-the-complete-pytorch-implementation">TL;DR: The Complete PyTorch Implementation</h2>
<p>For those who just want the code, here is a complete, modern VAE implementation in PyTorch. It features <strong>softplus standard deviation parameterization</strong> for numerical stability and a <strong>custom training step</strong> that handles the ELBO loss correctly.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn.functional <span style="color:#66d9ef">as</span> F
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> dataclasses <span style="color:#f92672">import</span> dataclass
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEOutput</span>:
</span></span><span style="display:flex;"><span>    z: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    mu: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    x_recon: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss_recon: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss_kl: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAE</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, input_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">784</span>, hidden_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">512</span>, latent_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">16</span>):
</span></span><span style="display:flex;"><span>        super()<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>encoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(input_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh()
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>fc_mu <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Linear(hidden_dim, latent_dim)
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>fc_std <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Linear(hidden_dim, latent_dim)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>decoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(latent_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, input_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">encode</span>(self, x):
</span></span><span style="display:flex;"><span>        h <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encoder(x)
</span></span><span style="display:flex;"><span>        mu <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>fc_mu(h)
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Softplus + epsilon for stable std deviation</span>
</span></span><span style="display:flex;"><span>        std <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>softplus(self<span style="color:#f92672">.</span>fc_std(h)) <span style="color:#f92672">+</span> <span style="color:#ae81ff">1e-6</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu, std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">reparameterize</span>(self, mu, std):
</span></span><span style="display:flex;"><span>        eps <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>randn_like(std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu <span style="color:#f92672">+</span> eps <span style="color:#f92672">*</span> std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">decode</span>(self, z):
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>decoder(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, x, kl_weight<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>        mu, std <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encode(x)
</span></span><span style="display:flex;"><span>        z <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>reparameterize(mu, std)
</span></span><span style="display:flex;"><span>        x_recon <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>decode(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 1. Reconstruction Loss (Binary Cross Entropy for MNIST)</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Sum over features, mean over batch</span>
</span></span><span style="display:flex;"><span>        recon_loss <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>binary_cross_entropy_with_logits(x_recon, x, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>mean()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 2. KL Divergence</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Analytic KL for Normal distributions</span>
</span></span><span style="display:flex;"><span>        kl_loss <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(<span style="color:#ae81ff">1</span> <span style="color:#f92672">+</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> mu<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">-</span> std<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>mean()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 3. Total Loss (ELBO)</span>
</span></span><span style="display:flex;"><span>        loss <span style="color:#f92672">=</span> recon_loss <span style="color:#f92672">+</span> (kl_weight <span style="color:#f92672">*</span> kl_loss)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> VAEOutput(z, mu, std, x_recon, loss, recon_loss, kl_loss)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Training Loop Example ---</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">train_step</span>(model, batch, optimizer, kl_weight<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>    model<span style="color:#f92672">.</span>train()
</span></span><span style="display:flex;"><span>    optimizer<span style="color:#f92672">.</span>zero_grad()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Forward pass</span>
</span></span><span style="display:flex;"><span>    output <span style="color:#f92672">=</span> model(batch, kl_weight)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Backward pass</span>
</span></span><span style="display:flex;"><span>    output<span style="color:#f92672">.</span>loss<span style="color:#f92672">.</span>backward()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Gradient clipping (recommended)</span>
</span></span><span style="display:flex;"><span>    torch<span style="color:#f92672">.</span>nn<span style="color:#f92672">.</span>utils<span style="color:#f92672">.</span>clip_grad_norm_(model<span style="color:#f92672">.</span>parameters(), max_norm<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    optimizer<span style="color:#f92672">.</span>step()
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> output<span style="color:#f92672">.</span>loss<span style="color:#f92672">.</span>item()
</span></span></code></pre></div><h3 id="the-core-idea-learning-to-generate">The Core Idea: Learning to Generate</h3>
<p>The VAE is built on a key assumption: our complex, high-dimensional data (like a $28 \times 28$ pixel image, $\mathbf{x}$) is actually <em>generated</em> by some simpler, low-dimensional, unobserved variable (a &ldquo;latent&rdquo; variable, $\mathbf{z}$).</p>
<blockquote>
<p><strong>A Physical Metaphor: Water Molecules and Phase Diagrams</strong></p>
<p>Consider a glass of water. At the microscopic level, you have more than $10^{24}$ $\text{H}_2\text{O}$ molecules bouncing around in an incredibly high-dimensional space. Each molecule has position, velocity, and interactions with its neighbors, computationally intractable to track directly. Yet we can describe the <em>macroscopic behavior</em> of all these molecules using just two simple variables: <strong>temperature</strong> and <strong>pressure</strong>. These two dimensions create a &ldquo;phase diagram&rdquo; that tells us whether our water will be ice, liquid, or vapor. The temperature and pressure are &ldquo;latent variables&rdquo; that capture the essential physics governing this complex molecular dance.</p></blockquote>















<figure class="post-figure center ">
    <img src="/img/vae-tut/phase-diagram.webp"
         alt="Water phase diagram showing solid, liquid, and gas phases as functions of temperature and pressure"
         title="Water phase diagram showing solid, liquid, and gas phases as functions of temperature and pressure"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A water phase diagram: Complex molecular behavior reduced to two simple variables (temperature and pressure). This illustrates how high-dimensional systems can often be understood through low-dimensional latent representations.</figcaption>
    
</figure>

<p>A VAE makes the same assumption: complex data (like images) emerges from simpler underlying factors. A handwritten digit might be generated by latent factors like &ldquo;pen thickness,&rdquo; &ldquo;writing angle,&rdquo; &ldquo;digit style,&rdquo; and &ldquo;size,&rdquo; a much simpler description than tracking all 784 pixel values independently.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/hypothetical-mnist-factors.webp"
         alt="Hypothetical illustration of MNIST digits generated from latent factors like pen thickness, angle, style, and size"
         title="Hypothetical illustration of MNIST digits generated from latent factors like pen thickness, angle, style, and size"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A hypothetical illustration showing how MNIST digits could be generated from a few latent factors like pen thickness, writing angle, digit style, and size.</figcaption>
    
</figure>

<p>The VAE learns two functions: one that maps from complex data ($\mathbf{x}$) to these descriptive factors ($\mathbf{z}$), and another that maps from these factors back to the data. It accomplishes this with two main components, typically implemented as neural networks:</p>
<p><strong>1. The Encoder (Recognition Model)</strong></p>
<p>This network takes a complex data point $\mathbf{x}$ (an image) and determines the &ldquo;knob settings&rdquo; $\mathbf{z}$ that could explain or generate it. This allows us to <em>compress</em> or <em>understand</em> the data.</p>
<p>$$q_{\phi}(\mathbf{z} | \mathbf{x})$$</p>
<p>It&rsquo;s like examining a container of molecules and summarizing their complex arrangement into key parameters like temperature and pressure.</p>
<p>Crucially, the encoder outputs the <em>parameters</em> of a probability distribution (a simple Gaussian) that describes $\mathbf{z}$.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/encoding-diagram.webp"
         alt="Diagram mapping MNIST five to a Gaussian distribution in latent space with mean and standard deviation"
         title="Diagram mapping MNIST five to a Gaussian distribution in latent space with mean and standard deviation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Encoder maps an input image (e.g., an MNIST digit &lsquo;5&rsquo;) to a Gaussian distribution in latent space, characterized by a mean vector and a standard deviation vector.</figcaption>
    
</figure>

<p>For each input $\mathbf{x}$, the encoder network outputs:</p>
<ul>
<li>A vector of means, $\mathbf{\mu}$</li>
<li>A vector of standard deviations, $\mathbf{\sigma}$</li>
</ul>
<p>These parameters define our approximation $q_{\phi}(\mathbf{z} | \mathbf{x}) = \mathcal{N}(\mathbf{z} \mid \mathbf{\mu}, \mathbf{\sigma}^2\mathbf{I})$. We then <em>sample</em> from this distribution to get the $\mathbf{z}$ that we feed to the decoder. This probabilistic step is what forces the latent space to be continuous and structured. It forces similar inputs to map to nearby regions in latent space, enabling smooth interpolation and generation.</p>
<p><strong>2. The Decoder (Generative Model)</strong></p>
<p>This network learns the &ldquo;generative process.&rdquo; It takes a simple latent vector $\mathbf{z}$ and reconstructs the complex data $\mathbf{x}$. This allows us to <em>generate</em> new data by feeding it a random $\mathbf{z}$ and observing what image $\mathbf{x}$ it produces.</p>
<p>$$p_{\theta}(\mathbf{x} | \mathbf{z})$$</p>
<p>The decoder reverses the encoder: it takes the simple latent representation and &ldquo;paints&rdquo; the full, complex image from it. It&rsquo;s like taking temperature and pressure values and producing a detailed arrangement of water molecules consistent with those conditions. The goal is to reproduce the exact input as closely as possible.</p>
<p>After training, we have two networks that can be used for a variety of purposes:</p>
<ul>
<li><strong>Generation</strong>: If the latent space is well-structured, we can sample random $\mathbf{z}$ vectors from a simple distribution (like a standard normal) and feed them into the Decoder to generate new images. This is particularly useful for searching for data points with desired properties, like in drug discovery, where we might want to generate molecules with specific characteristics.</li>
<li><strong>Compression</strong>: The Encoder can compress complex data into a low-dimensional latent space, which can be useful for visualization or as a feature extractor for other tasks.</li>
</ul>
<h3 id="the-variational-problem">The &ldquo;Variational&rdquo; Problem</h3>
<p>Calculating the <em>true</em> distribution of latent variables $p_{\theta}(\mathbf{z}|\mathbf{x})$ (the posterior) is mathematically intractable.</p>
<p>This intractability arises from Bayes&rsquo; theorem:</p>
<p>$$p_{\theta}(\mathbf{z} | \mathbf{x}) = \frac{p_{\theta}(\mathbf{x} | \mathbf{z}) p_{\theta}(\mathbf{z})}{p_{\theta}(\mathbf{x})}$$</p>
<p>Breaking down each component:</p>
<ul>
<li>$p_{\theta}(\mathbf{x} | \mathbf{z})$ is our decoder, which is straightforward to compute given our likelihood model.</li>
<li>$p_{\theta}(\mathbf{z})$ is our prior over latent variables, typically a simple distribution like a standard normal, making it easy to compute.</li>
<li>$p_{\theta}(\mathbf{x})$ is the marginal likelihood of the data. And here lies the problem. It requires integrating over all possible latent variables that could have generated $\mathbf{x}$:
$$p_{\theta}(\mathbf{x}) = \int p_{\theta}(\mathbf{x} | \mathbf{z}) p_{\theta}(\mathbf{z}) d\mathbf{z}$$
It is the normalization factor that ensures the posterior is a valid probability distribution (i.e., sums to 1 over all $\mathbf{z}$).</li>
</ul>
<p>This integral is intractable because it involves integrating over a high-dimensional latent space with a complex likelihood function. No closed-form solution exists, and numerical integration is computationally prohibitive.</p>
<p>This is where the &ldquo;variational&rdquo; approach provides the solution. We approximate the true posterior by learning an encoder, $q_{\phi}(\mathbf{z} | \mathbf{x})$, that serves as a variational approximation to this intractable true distribution. The VAE&rsquo;s training process optimizes this approximation to be as accurate as possible, pushing this learned distribution closer to the true posterior.</p>
<h3 id="the-vae-objective-a-balancing-act">The VAE Objective: A Balancing Act</h3>
<p>To get these two networks (parameterized by $\theta$ and $\phi$) to work together, we train them jointly with a special loss function. This objective has two parts that balance two different goals:</p>
<h4 id="1-reconstruction-loss">1. Reconstruction Loss</h4>
<p>$$E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})]$$</p>
<p>This term asks: &ldquo;How well can we reconstruct our original image?&rdquo; It forces the VAE to be good at its job. The process goes:</p>
<ol>
<li>Take an input point $\mathbf{x}$.</li>
<li>Use the <strong>Encoder</strong> to get its latent representation $\mathbf{z} \sim q_{\phi}(\mathbf{z} | \mathbf{x})$.</li>
<li>Use the <strong>Decoder</strong> to generate a new image $\mathbf{x}&rsquo;$ from $\mathbf{z}$, $\mathbf{x}&rsquo; \sim p_{\theta}(\mathbf{x} | \mathbf{z})$.</li>
<li>Compare $\mathbf{x}$ and $\mathbf{x}&rsquo;$.</li>
</ol>
<p>The reconstruction loss measures the difference between the original and the reconstructed image.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/reconstruction-loss-graphic.webp"
         alt="Graphic illustrating the reconstruction loss between original and reconstructed images"
         title="Graphic illustrating the reconstruction loss between original and reconstructed images"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Reconstruction Loss measures how closely the Decoder&rsquo;s output matches the original input image.</figcaption>
    
</figure>

<ul>
<li><strong>For continuous inputs</strong> (like general images), this is often Mean Squared Error (MSE).</li>
<li><strong>For inputs in $[0, 1]$</strong> (like MNIST pixel intensities, which after <code>ToTensor()</code> are continuous values in $[0, 1]$), we use Binary Cross-Entropy (BCE). We treat each pixel as an independent Bernoulli variable whose target is its intensity. The decoder outputs the <em>logits</em> for each pixel, and the BCE-with-logits loss (e.g., <code>F.binary_cross_entropy_with_logits</code>) is the numerically stable way to compute the negative log-likelihood.</li>
<li><strong>More generally</strong>, you can output parameters of a desired output distribution. What if you wanted a mixture of Gaussians? The decoder could output the means, variances, and mixture weights, and you could compute the negative log-likelihood accordingly.</li>
</ul>
<p>This loss pushes the encoder to produce useful $\mathbf{z}$ vectors and pushes the decoder to learn how to interpret them accurately.</p>
<h4 id="2-the-kl-divergence-the-regularizer">2. The KL Divergence (The Regularizer)</h4>
<p>$$D_{KL}(q_{\phi}(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>On its own, the reconstruction loss might &ldquo;cheat.&rdquo; The encoder could learn to map every image to a different, specific point in the latent space, essentially &ldquo;memorizing&rdquo; the data. While this minimizes reconstruction error, it creates a meaningless latent space that fails at generation.</p>
<p>The KL divergence term fixes this. It&rsquo;s a regularizer that forces the latent space to be organized and smooth.</p>
<p>We force the encoder&rsquo;s output, $q_{\phi}(\mathbf{z} | \mathbf{x})$, to be close to a simple, predefined <em>prior distribution</em>, $p_{\theta}(\mathbf{z})$. This prior is almost always a standard normal distribution because it is mathematically convenient, easy to sample from, and encourages a well-behaved latent space.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/kl-loss-graphic.webp"
         alt="Graphic illustrating the KL divergence between the encoder&#39;s output distribution and the prior distribution"
         title="Graphic illustrating the KL divergence between the encoder&#39;s output distribution and the prior distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The KL Divergence measures how much the Encoder&rsquo;s output distribution diverges from the simple prior distribution.</figcaption>
    
</figure>

<p>This regularization term acts as a penalty, measuring how much the encoder&rsquo;s output distribution diverges from the simple prior. By minimizing this KL divergence, we encourage the model to:</p>
<ul>
<li><strong>Avoid overfitting</strong> by preventing the encoder from memorizing specific locations for each input</li>
<li><strong>Create meaningful clusters</strong> where similar inputs map to nearby regions in the latent space</li>
<li><strong>Maintain continuity</strong> so that points close together in latent space (like different variations of the digit &ldquo;7&rdquo;) decode into visually similar outputs</li>
</ul>
<p>This smooth, structured latent space is what enables generation: we can sample random points from our prior distribution and decode them into realistic new data.</p>
<p>Ultimately, the optimizer finds a balance between these two objectives: reconstructing the data well while keeping the latent space organized and regularized.</p>
<h3 id="the-reparameterization-trick-making-it-all-trainable">The Reparameterization Trick: Making it All Trainable</h3>
<p>We have a problem. The training process requires sampling:</p>
<ol>
<li>Encoder produces $\mathbf{\mu}$ and $\mathbf{\sigma}$.</li>
<li>We <strong>sample</strong> $\mathbf{z} \sim \mathcal{N}(\mathbf{\mu}, \mathbf{\sigma}^2\mathbf{I})$.</li>
<li>Decoder uses $\mathbf{z}$ to reconstruct $\mathbf{x}&rsquo;$.</li>
<li>We calculate the loss.</li>
</ol>
<p>The &ldquo;sampling&rdquo; step is a random, non-differentiable operation. We can&rsquo;t backpropagate the reconstruction loss from the decoder <em>through</em> this random node to update the encoder&rsquo;s weights.</p>
<p>The <strong>reparameterization trick</strong> makes the sampling process differentiable. We generate $\mathbf{z}$ deterministically by sampling a random noise vector and transforming it:</p>
<ol>
<li>Sample a random noise vector $\mathbf{\epsilon}$ from a simple, fixed distribution (e.g., the standard normal $\mathcal{N}(\mathbf{0}, \mathbf{I})$).</li>
<li>Compute $\mathbf{z}$ as: $\mathbf{z} = \mathbf{\mu} + \mathbf{\sigma} \odot \mathbf{\epsilon}$</li>
</ol>
<p>This simple change moves the randomness &ldquo;outside&rdquo; the network. The gradient can now flow deterministically from $\mathbf{z}$ back through the $\mathbf{\mu}$ and $\mathbf{\sigma}$ nodes to the encoder network. This is the key engineering insight that allows us to train the entire model end-to-end with standard backpropagation.</p>
<h3 id="where-does-this-objective-come-from-the-math">Where Does This Objective Come From? (The Math)</h3>
<p>This two-part loss function is derived directly from the goal of maximizing the marginal likelihood of the data, $\log p_{\theta}(\mathbf{x})$.</p>
<p>For a single data point $\mathbf{x}^{(i)}$, we can write:
$$\log p_{\theta}(\mathbf{x}^{(i)}) = D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z} | \mathbf{x}^{(i)})) + \mathcal{L}(\theta, \phi; \mathbf{x}^{(i)})$$</p>
<ul>
<li>The first term is the KL divergence between our encoder&rsquo;s approximation and the (intractable) true posterior. This is non-negative, and unfortunately we cannot compute it.</li>
<li>The second term, $\mathcal{L}$, is the Variational Lower Bound (also known as the Evidence Lower Bound, or ELBO). Since the KL term is $\ge 0$, we know that $\log p_{\theta}(\mathbf{x}^{(i)}) \ge \mathcal{L}$.</li>
</ul>
<p>By maximizing this lower bound $\mathcal{L}$, we push up the &ldquo;floor&rdquo; on the true likelihood of our data. This is a problem we can solve.</p>
<p>When we expand this $\mathcal{L}$ term, we get our famous two-part objective:</p>
<p>$$\mathcal{L}(\theta, \phi; \mathbf{x}^{(i)}) = E_{q_{\phi}(\mathbf{z} | \mathbf{x}^{(i)})}[\log p_{\theta}(\mathbf{x}^{(i)} | \mathbf{z})] - D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z}))$$</p>
<ul>
<li><strong>Term 1:</strong> The expected log-likelihood of reconstructing $\mathbf{x}^{(i)}$ from $\mathbf{z}$. Maximizing this is the same as minimizing the Reconstruction Loss.</li>
<li><strong>Term 2:</strong> The negative KL divergence between our encoder and the simple prior. Maximizing this is the same as minimizing the KL Divergence Loss.</li>
</ul>
<p>Thus, the VAE&rsquo;s objective balances these two critical goals: faithfully reconstructing the data while maintaining a simple, regularized latent structure that is useful for generation.</p>
<h3 id="from-elbo-to-practical-loss">From ELBO to Practical Loss</h3>
<p>Remember, our goal is to <strong>maximize</strong> the ELBO:</p>
<p>$$\mathcal{L}(\theta, \phi; \mathbf{x}) = E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})] - D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>Since deep learning libraries are built to <strong>minimize</strong> a loss function, we simply flip the sign and <strong>minimize the negative ELBO ($-\mathcal{L}$)</strong>.</p>
<p>This gives us our final, practical loss function:</p>
<p>$$\text{Loss} = -\mathcal{L} = -E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})] + D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>This is the function you actually implement. Minimizing this loss achieves both of our goals:</p>
<ol>
<li>It <strong>minimizes the Reconstruction Loss</strong> (which is the same as maximizing the log-likelihood).</li>
<li>It <strong>minimizes the KL Divergence</strong>, forcing the encoder to match the prior.</li>
</ol>
<h2 id="modern-pytorch-vae-implementation">Modern PyTorch VAE Implementation</h2>
<p>Now that we understand the VAE architecture and objective, let&rsquo;s implement a modern VAE in PyTorch. I&rsquo;ll focus primarily on the model and loss function here, though the full code is available <a href="https://github.com/hunter-heidenreich/vae">on GitHub</a>.</p>
<p>My VAE implementation uses an output <code>dataclass</code> and a VAE class extending <code>nn.Module</code>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#e6db74">&#34;&#34;&#34;Variational Autoencoder (VAE) model implementation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> dataclasses <span style="color:#f92672">import</span> dataclass
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn.functional <span style="color:#66d9ef">as</span> F
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_activation</span>(activation: str) <span style="color:#f92672">-&gt;</span> nn<span style="color:#f92672">.</span>Module:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Get activation function by name.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    activation_lower <span style="color:#f92672">=</span> activation<span style="color:#f92672">.</span>lower()
</span></span><span style="display:flex;"><span>    ACTIVATION_MAP <span style="color:#f92672">=</span> {
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;relu&#34;</span>: nn<span style="color:#f92672">.</span>ReLU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;tanh&#34;</span>: nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;sigmoid&#34;</span>: nn<span style="color:#f92672">.</span>Sigmoid(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;leaky_relu&#34;</span>: nn<span style="color:#f92672">.</span>LeakyReLU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;elu&#34;</span>: nn<span style="color:#f92672">.</span>ELU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;gelu&#34;</span>: nn<span style="color:#f92672">.</span>GELU(),
</span></span><span style="display:flex;"><span>    }
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> activation_lower <span style="color:#f92672">not</span> <span style="color:#f92672">in</span> ACTIVATION_MAP:
</span></span><span style="display:flex;"><span>        supported <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;, &#34;</span><span style="color:#f92672">.</span>join(ACTIVATION_MAP<span style="color:#f92672">.</span>keys())
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">raise</span> <span style="color:#a6e22e">ValueError</span>(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Unsupported activation &#39;</span><span style="color:#e6db74">{</span>activation<span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;. Supported: </span><span style="color:#e6db74">{</span>supported<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> ACTIVATION_MAP[activation_lower]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEConfig</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;VAE model configuration specifying architecture and behavior.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    hidden_dim: int
</span></span><span style="display:flex;"><span>    latent_dim: int
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    input_shape: tuple[int, int, int] <span style="color:#f92672">=</span> (<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">28</span>, <span style="color:#ae81ff">28</span>)  <span style="color:#75715e"># Default: MNIST</span>
</span></span><span style="display:flex;"><span>    activation: str <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;tanh&#34;</span>  <span style="color:#75715e"># Default: tanh, what was used in the original VAE paper</span>
</span></span><span style="display:flex;"><span>    use_softplus_std: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>  <span style="color:#75715e"># Whether to use softplus for std parameterization</span>
</span></span><span style="display:flex;"><span>    n_samples: int <span style="color:#f92672">=</span> <span style="color:#ae81ff">1</span>  <span style="color:#75715e"># Number of latent samples per input during training</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEOutput</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;VAE forward pass output containing all relevant tensors and optional losses.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    x_logits: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    z: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    mu: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    x_recon: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss_recon: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss_kl: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAE</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Variational Autoencoder with support for deterministic and probabilistic reconstruction.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    DEFAULT_EPS <span style="color:#f92672">=</span> <span style="color:#ae81ff">1e-8</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, config: VAEConfig) <span style="color:#f92672">-&gt;</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Initialize VAE with given configuration.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            config: VAE configuration specifying architecture and behavior
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        super()<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>config <span style="color:#f92672">=</span> config
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Build encoder: input -&gt; hidden -&gt; latent parameters (mu, sigma)</span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>encoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Flatten(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(
</span></span><span style="display:flex;"><span>                int(torch<span style="color:#f92672">.</span>prod(torch<span style="color:#f92672">.</span>tensor(config<span style="color:#f92672">.</span>input_shape))), config<span style="color:#f92672">.</span>hidden_dim
</span></span><span style="display:flex;"><span>            ),
</span></span><span style="display:flex;"><span>            get_activation(config<span style="color:#f92672">.</span>activation),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(config<span style="color:#f92672">.</span>hidden_dim, config<span style="color:#f92672">.</span>latent_dim <span style="color:#f92672">*</span> <span style="color:#ae81ff">2</span>),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Build decoder: latent -&gt; hidden -&gt; reconstructed input</span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>decoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(config<span style="color:#f92672">.</span>latent_dim, config<span style="color:#f92672">.</span>hidden_dim),
</span></span><span style="display:flex;"><span>            get_activation(config<span style="color:#f92672">.</span>activation),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(
</span></span><span style="display:flex;"><span>                config<span style="color:#f92672">.</span>hidden_dim, int(torch<span style="color:#f92672">.</span>prod(torch<span style="color:#f92672">.</span>tensor(config<span style="color:#f92672">.</span>input_shape)))
</span></span><span style="display:flex;"><span>            ),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Unflatten(<span style="color:#ae81ff">1</span>, config<span style="color:#f92672">.</span>input_shape),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">encode</span>(self, x: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Encode input to latent distribution parameters.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        encoder_output <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encoder(x)
</span></span><span style="display:flex;"><span>        mu, sigma <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>chunk(encoder_output, <span style="color:#ae81ff">2</span>, dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu, sigma
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">decode</span>(self, z: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Decode latent representation to reconstruction logits&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>decoder(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">reparameterize</span>(self, mu: torch<span style="color:#f92672">.</span>Tensor, std: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Apply reparameterization trick for differentiable sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        epsilon <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>randn_like(std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu <span style="color:#f92672">+</span> std <span style="color:#f92672">*</span> epsilon
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        x: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        compute_loss: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">True</span>,
</span></span><span style="display:flex;"><span>        reconstruct: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>,
</span></span><span style="display:flex;"><span>        eps: float <span style="color:#f92672">=</span> DEFAULT_EPS,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> VAEOutput:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Forward pass through the VAE.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            x: Input tensor of shape (batch_size, *input_shape)
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            compute_loss: Whether to compute VAE loss components
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            reconstruct: Whether to return reconstructions or distributions
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            eps: Small epsilon value for numerical stability
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Returns:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            VAEOutput containing all relevant tensors and optionally computed losses
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Prepare input for multiple sampling if needed</span>
</span></span><span style="display:flex;"><span>        x_expanded <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_expand_for_sampling(x) <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples <span style="color:#f92672">&gt;</span> <span style="color:#ae81ff">1</span> <span style="color:#66d9ef">else</span> x
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Encode and sample from latent space</span>
</span></span><span style="display:flex;"><span>        mu, sigma <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encode(x)
</span></span><span style="display:flex;"><span>        std <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_sigma_to_std(sigma, eps<span style="color:#f92672">=</span>eps)
</span></span><span style="display:flex;"><span>        mu_expanded, std_expanded <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_expand_latent_params(mu, std)
</span></span><span style="display:flex;"><span>        z <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>reparameterize(mu_expanded, std_expanded)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Decode latent samples</span>
</span></span><span style="display:flex;"><span>        x_logits <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>decode(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Create output object</span>
</span></span><span style="display:flex;"><span>        output <span style="color:#f92672">=</span> VAEOutput(
</span></span><span style="display:flex;"><span>            x_logits<span style="color:#f92672">=</span>x_logits,
</span></span><span style="display:flex;"><span>            z<span style="color:#f92672">=</span>z,
</span></span><span style="display:flex;"><span>            mu<span style="color:#f92672">=</span>mu,
</span></span><span style="display:flex;"><span>            std<span style="color:#f92672">=</span>std,
</span></span><span style="display:flex;"><span>            x_recon<span style="color:#f92672">=</span>torch<span style="color:#f92672">.</span>sigmoid(x_logits) <span style="color:#66d9ef">if</span> reconstruct <span style="color:#66d9ef">else</span> <span style="color:#66d9ef">None</span>,
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Compute losses if requested</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> compute_loss:
</span></span><span style="display:flex;"><span>            loss, loss_recon, loss_kl <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_loss(
</span></span><span style="display:flex;"><span>                x_expanded, x_logits, mu, sigma, std
</span></span><span style="display:flex;"><span>            )
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss <span style="color:#f92672">=</span> loss
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss_recon <span style="color:#f92672">=</span> loss_recon
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss_kl <span style="color:#f92672">=</span> loss_kl
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> output
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># ==================== Helper Methods ====================</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_sigma_to_std</span>(
</span></span><span style="display:flex;"><span>        self, sigma: torch<span style="color:#f92672">.</span>Tensor, eps: float <span style="color:#f92672">=</span> DEFAULT_EPS
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Convert sigma parameter to standard deviation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>softplus(sigma) <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>exp(<span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sigma)  <span style="color:#75715e"># sigma represents log-variance</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_expand_for_sampling</span>(self, x: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Expand input tensor for multiple sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        shape_dims <span style="color:#f92672">=</span> [<span style="color:#ae81ff">1</span>] <span style="color:#f92672">*</span> len(self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>input_shape)
</span></span><span style="display:flex;"><span>        x_expanded <span style="color:#f92672">=</span> x<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#f92672">*</span>shape_dims)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> x_expanded<span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, <span style="color:#f92672">*</span>self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>input_shape)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_expand_latent_params</span>(
</span></span><span style="display:flex;"><span>        self, mu: torch<span style="color:#f92672">.</span>Tensor, std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Expand latent parameters for multiple sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples <span style="color:#f92672">==</span> <span style="color:#ae81ff">1</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> mu, std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        mu_expanded <span style="color:#f92672">=</span> (
</span></span><span style="display:flex;"><span>            mu<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>latent_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        std_expanded <span style="color:#f92672">=</span> (
</span></span><span style="display:flex;"><span>            std<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>latent_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu_expanded, std_expanded
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># ==================== Loss Computation ====================</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_loss</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        x: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        x_logits: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        mu: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        sigma: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        std: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute VAE loss components for deterministic reconstruction.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        loss_recon <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_reconstruction_loss(x, x_logits)
</span></span><span style="display:flex;"><span>        loss_kl <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_kl_loss(mu, sigma, std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> loss_recon <span style="color:#f92672">+</span> loss_kl, loss_recon, loss_kl
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_reconstruction_loss</span>(
</span></span><span style="display:flex;"><span>        self, x: torch<span style="color:#f92672">.</span>Tensor, x_logits: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute reconstruction loss using binary cross-entropy.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>binary_cross_entropy_with_logits(
</span></span><span style="display:flex;"><span>            x_logits, x, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;sum&#34;</span>
</span></span><span style="display:flex;"><span>        ) <span style="color:#f92672">/</span> x<span style="color:#f92672">.</span>size(<span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_kl_loss</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        mu: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        sigma: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        std: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        eps: float <span style="color:#f92672">=</span> DEFAULT_EPS,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute KL divergence between latent distribution and standard normal prior.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Analytical KL: KL(N(μ,σ²) || N(0,1)) = 0.5 * Σ(μ² + σ² - 1 - log(σ²))</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># sigma is just the raw output, need to use std directly: σ</span>
</span></span><span style="display:flex;"><span>            kl_per_sample <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(
</span></span><span style="display:flex;"><span>                mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> eps), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>            )
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># sigma represents log-variance parameterization: log(σ²)</span>
</span></span><span style="display:flex;"><span>            kl_per_sample <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> sigma<span style="color:#f92672">.</span>exp() <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> sigma, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> kl_per_sample<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="loss-scaling">Loss Scaling</h3>
<p>Both components of the VAE loss should be summed over data dimensions and averaged over the batch size. A common mistake is using the <code>reduction=&quot;mean&quot;</code> option in PyTorch loss functions, which averages over all elements in the tensor.</p>
<ul>
<li>The <strong>KL Divergence</strong> (<code>loss_kl</code>) is a penalization term. Each dimension of the latent space has the potential to add complexity and deviate from the prior. As you increase the latent dimensionality, you typically see the KL loss increase in magnitude. That&rsquo;s the cost of having a more expressive latent space.</li>
<li>The <strong>Reconstruction Loss</strong> (<code>loss_recon</code>) measures how well the model reconstructs the input data, and it should scale with input dimensionality (this can bias the model toward better reconstruction for higher-dimensional data).</li>
</ul>
<p>In the case of MNIST, if we used <code>reduction=&quot;mean&quot;</code> for BCE, it would be averaged over all $784 \times \text{batch size}$ pixels, making it tiny compared to the KL loss. The KL term would dominate, and the model would learn to ignore the input, potentially leading to posterior collapse.</p>
<p>While modern optimizers can handle a variety of scenarios and you can still learn effective models with imperfect scaling, the original VAE paper used the scaling described above, and I recommend following that convention.</p>
<h3 id="mitigating-posterior-collapse-kl-annealingwarmup">Mitigating Posterior Collapse: KL Annealing/Warmup</h3>
<p>One common issue in training VAEs, especially with powerful decoders (like RNNs or deep CNNs), is <strong>posterior collapse</strong>. This happens when the KL term dominates the loss early in training. The model quickly learns to just output the prior distribution ($q(z|x) \approx p(z)$) to drive the KL loss to zero, effectively ignoring the latent code $z$. The decoder then becomes a powerful autoregressive model that ignores the latent input.</p>
<p>To prevent this, we often use <strong>KL Annealing</strong> (or Warmup). We introduce a weight $\beta$ for the KL term that starts at 0 and slowly increases to 1 over the first $N$ steps or epochs.</p>
<p>$$ \mathcal{L} = \mathcal{L}_{recon} + \beta \cdot D_{KL} $$</p>
<p>This allows the model to focus purely on reconstruction first (using the full latent capacity), and then slowly adds the regularization pressure.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Simple Linear Annealing Scheduler</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_kl_weight</span>(step, total_steps, max_val<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>    val <span style="color:#f92672">=</span> (step <span style="color:#f92672">/</span> total_steps) <span style="color:#f92672">*</span> max_val
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> min(max(val, <span style="color:#ae81ff">0.0</span>), max_val)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># In your training loop:</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> epoch <span style="color:#f92672">in</span> range(epochs):
</span></span><span style="display:flex;"><span>    beta <span style="color:#f92672">=</span> get_kl_weight(epoch, warmup_epochs)
</span></span><span style="display:flex;"><span>    loss <span style="color:#f92672">=</span> recon_loss <span style="color:#f92672">+</span> beta <span style="color:#f92672">*</span> kl_loss
</span></span></code></pre></div><h4 id="parameterizing-standard-deviation">Parameterizing Standard Deviation</h4>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_sigma_to_std</span>(self, sigma: torch<span style="color:#f92672">.</span>Tensor, eps: float <span style="color:#f92672">=</span> DEFAULT_EPS) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Convert sigma parameter to standard deviation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>bound_std <span style="color:#f92672">is</span> <span style="color:#f92672">not</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>sigmoid(sigma) <span style="color:#f92672">*</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>bound_std <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">elif</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>softplus(sigma) <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>exp(<span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sigma)  <span style="color:#75715e"># sigma represents log-variance</span>
</span></span></code></pre></div><p>Parameterizing the mean of the latent distribution is straightforward since $\mu \in \mathbb{R}$. However, the standard deviation $\sigma$ must be strictly positive (as must the variance $\sigma^2$). This type of constrained optimization is challenging for neural networks.</p>
<p><strong>Log-Variance</strong>
One common approach is to have the network output the <strong>log-variance</strong> ($\log \sigma^2$). This is what the original VAE paper did. The idea is to allow the network to output any real number and treat that value as the log-variance, $s = \log \sigma^2$. We can then compute the standard deviation as $\sigma = \exp(0.5 s)$, which is always positive.</p>
<p>The KL divergence formula simplifies nicely with this parameterization:</p>
<p>$$
\text{KL}( \mathcal{N}(\mu, \sigma^2) || \mathcal{N}(0, 1) ) = \frac{1}{2} \sum_{i=1}^d (\mu_i^2 + \sigma_i^2 - 1 - \log \sigma_i^2)
$$</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> s<span style="color:#f92672">.</span>exp() <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> s, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span></code></pre></div><p><strong>Softplus Standard Deviation</strong>
An alternative is to have the network output $\sigma$ directly. This must be handled with care to ensure positivity. Strictly positive activations like <code>softplus</code> are required. Activations like <code>ReLU</code> can output zero, leading to numerical instability (during training) and deterministic behavior (during sampling). Additionally, adding a small epsilon value ensures numerical stability by preventing $\sigma$ from being exactly zero.</p>
<p>The KL divergence formula becomes slightly more complex:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(
</span></span><span style="display:flex;"><span>    mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> eps), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><p><strong>Bounded Standard Deviation</strong>
Another option is to bound the standard deviation to a maximum value using a <code>sigmoid</code> transformation (or similar). This replaces mapping to $(0, \infty)$ with mapping to $(0, \text{bound})$. This helps prevent extremely high variance values that might destabilize training, while limiting the expressiveness of the latent distribution. Like with <code>softplus</code>, adding a small epsilon ensures numerical stability by preventing $\sigma$ from being exactly zero or approaching it too closely.</p>
<p><strong>Gradient Behavior</strong>
All parameterizations can work well in practice and have different gradient behaviors. Think of $g(s)$ as a transformation function from the network output to the proper domain of $\sigma$ (or $\sigma^2$); in the log-variance case, $g(s) = \exp(s)$, while in the softplus case, $g(s) = \text{softplus}(s) + \epsilon$.</p>
<p>The gradient of the loss with respect to these outputs can be written using the chain rule:</p>
<p>$$
\frac{\partial \mathcal{L}}{\partial s} = \frac{\partial \mathcal{L}}{\partial \sigma} \cdot \frac{\partial g(s)}{\partial s}
$$</p>
<p>where $\frac{\partial g(s)}{\partial s}$ is the derivative of the transformation function.</p>
<p>We need to guard against two pathological cases:</p>
<ul>
<li>$\frac{\partial g(s)}{\partial s} \rightarrow 0$: This leads to vanishing gradients, making it hard for the network to learn.</li>
<li>$\frac{\partial g(s)}{\partial s} \rightarrow \infty$: This leads to exploding gradients, causing instability during training and potentially divergence.</li>
</ul>
<p>The log-variance parameterization, with its exponential transformation that is its own derivative, exhibits both issues at extreme values. If $s \rightarrow -\infty$, then $\sigma \rightarrow 0$ and the gradient vanishes. If $s \rightarrow \infty$, then $\sigma \rightarrow \infty$ and the gradient explodes. Since the interval $(0, 1)$ is mapped to $(-\infty, 0)$ in log-space, it&rsquo;s much more difficult for the network to drive $\sigma$ to small values. In practice, exploding gradients at high values have been more problematic in my experience. Gradient clipping, learning rate scheduling, and clamping the log-variance output to a maximum value can help mitigate this.</p>
<p>What about softplus? The derivative of <code>softplus</code> is the <code>sigmoid</code> function, which smoothly maps $(-\infty, \infty)$ to $(0, 1)$. Gradients are always bounded by unity, preventing explosion (barring explosion from other parts of the network). However, as $s \rightarrow -\infty$, the gradient approaches zero, leading to vanishing gradients. Adding a small epsilon helps mitigate this, ensuring that $\sigma$ never gets too close to zero. Nonetheless, learning can still slow down.</p>
<p>For bounded standard deviation, the derivative of the <code>sigmoid</code> function is also bounded, preventing exploding gradients. (The gradient of <code>sigmoid</code> is defined in terms of itself: $\text{sig}&rsquo;(x) = \text{sig}(x)(1 - \text{sig}(x))$; its maximum value is $0.25$ at $x=0$.)</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/gradient_behaviors.webp"
         alt="Graph comparing gradient behaviors of log-variance, softplus, and bounded standard deviation parameterizations"
         title="Graph comparing gradient behaviors of log-variance, softplus, and bounded standard deviation parameterizations"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Gradient behaviors of different standard deviation parameterizations: Log-Variance (exponential), Softplus, and Bounded Standard Deviation (sigmoid). Each has unique characteristics affecting training stability.</figcaption>
    
</figure>

<h2 id="experiments">Experiments</h2>
<h3 id="2d-mnist-vae-with-different-std-dev-parameterizations">2D MNIST VAE with Different Std. Dev. Parameterizations</h3>
<p>First, let&rsquo;s run an experiment that is close to what was done in the original VAE paper. We&rsquo;ll use MNIST as our dataset, a simple feedforward architecture with <code>tanh</code> activations, and the log-variance parameterization for the latent distribution.</p>
<p>Some of the differences from the original paper include:</p>
<ul>
<li>Using a hidden size of 512 (the original used 500)</li>
<li>Using the AdamW optimizer (the original used vanilla Adagrad)</li>
<li>Applying similar weight decay, doing so quite differently due to the optimizer change</li>
<li>Focusing primarily on 2D latent spaces (for now)</li>
</ul>
<p>This results in a network with 807,700 parameters.
I train each model for 150 epochs at most and highlight the best based on the reconstruction loss on the test set.
Just for fun, I sweep across different standard deviation parameterizations and learning rate warmup strategies.</p>
<table>
  <thead>
      <tr>
          <th>Std. Dev. Param</th>
          <th>Warmup Steps</th>
          <th>Test Recon. Loss</th>
          <th>Test KL Loss</th>
          <th>Test Total Loss</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Log-Variance</td>
          <td>0</td>
          <td>140.88</td>
          <td>6.96</td>
          <td>147.84</td>
      </tr>
      <tr>
          <td>Log-Variance</td>
          <td>600</td>
          <td>141.41</td>
          <td>6.63</td>
          <td>148.04</td>
      </tr>
      <tr>
          <td>Softplus</td>
          <td>0</td>
          <td>141.51</td>
          <td>6.56</td>
          <td>148.07</td>
      </tr>
      <tr>
          <td>Softplus</td>
          <td>600</td>
          <td>140.37</td>
          <td>6.67</td>
          <td>147.04</td>
      </tr>
      <tr>
          <td>Bounded Std. Dev. (10)</td>
          <td>0</td>
          <td>140.96</td>
          <td>6.82</td>
          <td>147.79</td>
      </tr>
      <tr>
          <td>Bounded Std. Dev. (10)</td>
          <td>600</td>
          <td>141.78</td>
          <td>6.68</td>
          <td>148.45</td>
      </tr>
  </tbody>
</table>
<p>From this summary table, all three parameterizations work well. The differences in final loss values are quite small. This could be due to the simplicity of the dataset and model architecture, further amplified by forcing the network to compress images into a very low-dimensional latent space (2D).</p>
<p>Since the softplus parameterization with learning rate warmup achieved the best reconstruction loss, let&rsquo;s visualize some of its training dynamics and results more closely.</p>
<h4 id="loss-dynamics">Loss Dynamics</h4>
<p>To understand the VAE&rsquo;s behavior, we must look at the ELBO and its two components: the Reconstruction Loss and the KL Divergence.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-elbo_epochs.webp"
         alt="Plot showing training and testing ELBO across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing ELBO across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Total ELBO: Training and testing ELBO across 150 epochs.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-reconstruction_loss_epochs.webp"
         alt="Plot showing training and testing reconstruction loss across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing reconstruction loss across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Reconstruction Loss: Training and testing reconstruction loss across 150 epochs.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-kl_loss_epochs.webp"
         alt="Plot showing training and testing KL divergence loss across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing KL divergence loss across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">KL Divergence: Training and testing KL divergence loss across 150 epochs.</figcaption>
    
</figure>

<p>These plots reveal a clear narrative:</p>
<ol>
<li><strong>Rapid Initial Learning:</strong> Performance skyrockets in the first ~15 epochs.</li>
<li><strong>Overfitting:</strong> The <strong>Reconstruction Loss</strong> (middle) flatlines for the test set while continuing to improve for training, a classic sign of memorization.</li>
<li><strong>The Balancing Act:</strong> The <strong>KL Divergence</strong> (bottom) initially rises (&ldquo;The Cost of Learning&rdquo;) as the model stretches the latent space to encode digits, then saturates.</li>
<li><strong>Equilibrium:</strong> The total <strong>ELBO</strong> (top) improves slowly, driven by the model finding the optimal trade-off between reconstruction and regularization. Notice that Test and Train KL tracks closely: a sign of good regularization!</li>
</ol>
<p><strong>Visualizing the VAE Trade-Off: BCE vs. KL</strong></p>
<p>While the line plots visualize progress over time, they miss the evolving <em>relationship</em> between our two competing objectives.</p>
<p>A VAE is fundamentally a multi-objective optimization problem. We want to:</p>
<ol>
<li>Minimize Reconstruction Loss (BCE)</li>
<li>Minimize KL Divergence</li>
</ol>
<p>Combining them as the ELBO is common and effective, though it can mask some of the underlying dynamics.</p>
<p>These two goals are in direct conflict. To get perfect reconstruction (BCE = 0), the encoder would need to &ldquo;memorize&rdquo; each input, mapping it to a unique, precise point in latent space. This would cause the KL divergence to skyrocket, as these specific, &ldquo;pointy&rdquo; distributions are nothing like our smooth <code>N(0, 1)</code> prior.</p>
<p>Conversely, to get perfect KL divergence (KL = 0), the encoder must <em>always</em> output <code>N(0, 1)</code>, regardless of the input. This perfectly matches the prior. Since the latent code $\mathbf{z}$ now contains zero information about the input $\mathbf{x}$, the decoder can only learn to output the &ldquo;average&rdquo; image, resulting in terrible reconstruction.</p>
<p>The training process is a search for the best compromise.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-loss_scatter_epochs.webp"
         alt="Scatter plot of Test BCE vs KL Divergence, showing the training path from epoch 0 to 150"
         title="Scatter plot of Test BCE vs KL Divergence, showing the training path from epoch 0 to 150"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The training path on the Test set, plotting Reconstruction Loss (BCE) vs. KL Divergence. The model&rsquo;s journey clearly shows the trade-off between these two objectives.</figcaption>
    
</figure>

<p>This plot shows the test set&rsquo;s BCE (y-axis) vs. KL Divergence (x-axis) at every evaluation step. The color gradient from cool (blue) to warm (red) represents the training progress from Epoch 0 to 150.</p>
<p>Here&rsquo;s how to interpret this training path:</p>
<ol>
<li>
<p><strong>The Start (Green Diamond, ~Epoch 0):</strong> The model starts at the top-left.</p>
<ul>
<li><strong>High BCE (Reconstruction):</strong> The decoder is random and hasn&rsquo;t learned to reconstruct anything. Reconstruction is terrible.</li>
<li><strong>Low KL Divergence:</strong> The <em>encoder</em> is also random. Its output distributions $q_{\phi}(\mathbf{z} | \mathbf{x})$ are a random mess. On average, this &ldquo;mess&rdquo; is coincidentally close to the &ldquo;mess&rdquo; of the prior $p_{\theta}(\mathbf{z})$, so the KL penalty is low. The model isn&rsquo;t encoding any useful information yet, so it&rsquo;s not paying a high price for it.</li>
</ul>
</li>
<li>
<p><strong>Phase 1: The Initial Plunge (Blue Path):</strong> The path moves almost <em>straight down</em>.</p>
<ul>
<li><strong>BCE Plummets:</strong> The model&rsquo;s first and easiest task is to learn to reconstruct <em>something</em>. The optimizer finds massive, easy gains by making the decoder output &ldquo;blurry digits&rdquo; to replace the initial noise.</li>
<li><strong>KL Stays Low:</strong> The model achieves this huge reconstruction win without needing to learn a very complex latent space. It&rsquo;s the &ldquo;low-hanging fruit&rdquo; of training.</li>
</ul>
</li>
<li>
<p><strong>Phase 2: The Trade-Off (The &ldquo;Elbow&rdquo;):</strong> The path stops dropping vertically and starts moving to the <em>right and down</em>.</p>
<ul>
<li><strong>&ldquo;Spending&rdquo; KL to &ldquo;Buy&rdquo; Reconstruction:</strong> This is the true VAE trade-off in action. The easy wins are gone. To make the reconstructions sharper and more accurate (lowering BCE further), the model must now learn a more complex, informative latent representation.</li>
<li>It &ldquo;stretches&rdquo; the latent distributions $q_{\phi}(\mathbf{z} | \mathbf{x})$ to encode more details about each specific digit. This &ldquo;stretching&rdquo; moves it further from the simple <code>N(0, 1)</code> prior, and the KL divergence (the &ldquo;cost&rdquo;) goes up.</li>
</ul>
</li>
<li>
<p><strong>The End Game (Red Path &amp; Star):</strong> The path settles in the bottom-right corner.</p>
<ul>
<li><strong>Finding the &ldquo;Elbow&rdquo;:</strong> The model finds an equilibrium. It has pushed the KL divergence as high as it&rsquo;s &ldquo;worth&rdquo; for the reconstruction gains it gets. Trying to get even better reconstruction (moving further down) would cost an enormous, disproportionate amount in KL divergence (moving far to the right), and the total loss would increase.</li>
<li><strong>Best Recon (Orange Star):</strong> The best reconstruction model (Epoch 118) is found right at this &ldquo;elbow,&rdquo; representing the best-found balance point on the trade-off frontier.</li>
</ul>
</li>
</ol>
<p>This single plot visualizes the entire training dynamic as a journey along the <strong>Pareto frontier</strong>: the set of optimal solutions where you can&rsquo;t improve one objective (BCE) without worsening the other (KL).</p>
<h4 id="generative-performance">Generative Performance</h4>
<p>Let&rsquo;s take a look at how well this model can decode samples.</p>
<p><strong>Reconstruction Performance</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-reconstructions.webp"
         alt="Grid of original and reconstructed MNIST images from the test set using the trained VAE model"
         title="Grid of original and reconstructed MNIST images from the test set using the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Original (top row) vs. Reconstructed (bottom row) MNIST images from the test set using the trained VAE model.</figcaption>
    
</figure>

<p>Immediately, we see a couple of key points:</p>
<ul>
<li>Reconstructions are quite blurry compared to the originals. This is expected given the low capacity of the model and the extreme compression into a 2D latent space. General structure is typically preserved, while fine details are lost.</li>
<li>The network struggles with 4s and 9s, often mixing them up or producing ambiguous shapes. This is a common failure mode in MNIST models due to the similarity of these digits.</li>
</ul>
<p><strong>Sampling from the Prior</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-samples.webp"
         alt="Grid of MNIST-like images generated by sampling from the prior distribution using the trained VAE model"
         title="Grid of MNIST-like images generated by sampling from the prior distribution using the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MNIST-like images generated by sampling from the prior distribution using the trained VAE model.</figcaption>
    
</figure>

<p>If we sample from the prior <code>N(0, 1)</code> and decode those samples, we get a variety of digit-like images. From this, we get a pretty rich representation of digits. Almost all digits appear to be featured in this random sampling. Again, we see the standard blurriness.</p>
<p><strong>Sweeping the Latent Space</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-generation_interpolation.webp"
         alt="Grid of images generated by sweeping across the 2D latent space of the trained VAE model"
         title="Grid of images generated by sweeping across the 2D latent space of the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Images generated by sweeping across the 2D latent space of the trained VAE model.</figcaption>
    
</figure>

<p>We can select two points at random (here, two zeros), embed them into our latent space and then walk across that latent space to interpolate between two data points.
Here, we see a walk that takes us from a zero that is askew to one that is more upright.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-generation_latent_sweep.webp"
         alt="2D latent sweep, varying one dimension at a time while holding the other constant"
         title="2D latent sweep, varying one dimension at a time while holding the other constant"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent sweep, varying one dimension at a time while holding the other constant.</figcaption>
    
</figure>

<p>Finally, we can also sweep each latent dimension independently to see how they affect the generated images.</p>
<ol>
<li>Sweeping <code>z_1</code> (top row), we see a 5 become an 8 and then a 9. The slant shifts from left to right as we sweep the dimension.</li>
<li>Sweeping <code>z_2</code> (bottom row), we see a 4 become a 9 and then an 8. Then it becomes a 3, a 2, some nonsense, and a 6.</li>
</ol>
<p>So clearly each latent dimension is encoding some high-level features of the digits, and we can manipulate those features by moving in latent space.</p>
<h4 id="inspecting-the-latent-space">Inspecting the Latent Space</h4>
<p>What does the actual latent space look like?</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-latent_combined.webp"
         alt="2D latent space visualization with points colored by their true digit labels"
         title="2D latent space visualization with points colored by their true digit labels"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent space visualization with points colored by their true digit labels (left) and 2D heatmap of latent space density (right).</figcaption>
    
</figure>

<p>Even without class information, the network organizes the latent space to encode digit structure effectively. It also becomes immediately apparent why 4s and 9s are so confused by the model. That region is a dense mixture of the two.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-latent_marginals.webp"
         alt="1D histograms of each latent dimension compared to the standard normal distribution"
         title="1D histograms of each latent dimension compared to the standard normal distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">1D histograms of each latent dimension compared to the standard normal distribution.</figcaption>
    
</figure>

<p>We can also look at the marginal distributions of each latent dimension to see how well they match the prior <code>N(0, 1)</code>. Here, <code>z_1</code> is closer to the prior than <code>z_2</code>. <code>z_2</code> exhibits a bimodal marginal distribution, indicating that the encoder is using this dimension to separate two distinct clusters of data.</p>
<p>We also might want to understand how the log-variance of the latent distributions behaves.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-logvar_combined.webp"
         alt="2D latent space visualization with log-variance values and 1D histograms of log-variance for each latent dimension"
         title="2D latent space visualization with log-variance values and 1D histograms of log-variance for each latent dimension"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent space visualization with log-variance values with respect to digit class (left) and 2D heatmap of log-variance magnitude (right).</figcaption>
    
</figure>

<p>For the most part, we see similar concentration. Some digits are more concentrated than others, though in general the difference is slight.</p>
<h3 id="beyond-2d-higher-dimensional-latent-spaces">Beyond 2D: Higher-Dimensional Latent Spaces</h3>
<p>What happens as we increase the latent dimensionality? We must do dimensionality reduction to visualize latent spaces, giving us an approximate sense of how the latent space is organized.</p>
<table>
  <thead>
      <tr>
          <th>Latent Dimensionality</th>
          <th>Test Recon. Loss</th>
          <th>Test KL Loss</th>
          <th>Test Total Loss</th>
          <th>KL per Dim</th>
          <th>Active Dims (KL &gt; 0.1)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>2</td>
          <td>140.37</td>
          <td>6.67</td>
          <td>147.04</td>
          <td>3.34</td>
          <td>2</td>
      </tr>
      <tr>
          <td>4</td>
          <td>114.94</td>
          <td>10.61</td>
          <td>125.56</td>
          <td>2.65</td>
          <td>4</td>
      </tr>
      <tr>
          <td>8</td>
          <td>89.46</td>
          <td>16.84</td>
          <td>106.31</td>
          <td>2.11</td>
          <td>8</td>
      </tr>
      <tr>
          <td>16</td>
          <td>76.57</td>
          <td>23.65</td>
          <td>100.21</td>
          <td>1.48</td>
          <td>16</td>
      </tr>
      <tr>
          <td>32</td>
          <td>74.65</td>
          <td>25.59</td>
          <td>100.25</td>
          <td>0.80</td>
          <td>24</td>
      </tr>
  </tbody>
</table>
<p>As we double the dimensionality, we see a dominant trend at first:</p>
<ul>
<li>The reconstruction loss goes down</li>
<li>The KL loss goes up</li>
<li>The KL loss per dimension goes down</li>
</ul>
<p>Something odd happens when we jump from 16 to 32 latent dimensions: some of our latent dimensions become degenerate and stop encoding useful information.
This could be an indication we need to choose our hyperparameters a little more cautiously. Perhaps we need a different architecture. Or maybe there is an intrinsic limit to the dimensionality needed for this dataset past which it&rsquo;s not really helpful to keep scaling the latent dimension.</p>
<h4 id="training-dynamics">Training Dynamics</h4>















<figure class="post-figure center ">
    <img src="/img/vae-tut/loss_scatter_epochs.webp"
         alt="Scatter plot of Test BCE vs KL Divergence for different latent dimensionalities, showing training paths from epoch 0 to 150"
         title="Scatter plot of Test BCE vs KL Divergence for different latent dimensionalities, showing training paths from epoch 0 to 150"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The training paths on the Test set for different latent dimensionalities, plotting Reconstruction Loss (BCE) vs. KL Divergence. Each path shows the model&rsquo;s journey, clearly illustrating the trade-off between these two objectives.</figcaption>
    
</figure>

<p>The training dynamics show the battle between reconstruction and KL divergence for different latent dimensionalities. As we increase the latent dimensionality, the oscillation in the KL divergence becomes more pronounced. Particularly chaotic is the $D=16$ case, which struggles to find a stable equilibrium. By the time we expand to $D=32$, the KL penalty seems to overpower the ability to encode information in the latent space, leading to many inactive dimensions. The drop in KL complexity has staircase-like steps without clearly gaining reconstruction ability.</p>
<h4 id="reconstruction-and-generation">Reconstruction and Generation</h4>
<p>As we increase the latent dimensionality, the reconstruction quality improves significantly.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/reconstructions.webp"
         alt="Grid of original and reconstructed MNIST images from the test set using trained VAE models with different latent dimensionalities"
         title="Grid of original and reconstructed MNIST images from the test set using trained VAE models with different latent dimensionalities"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Original (top row) vs. Reconstructed (bottom row) MNIST images from the test set using trained VAE models with different latent dimensionalities.</figcaption>
    
</figure>

<p>As we increase the dimensionality, we see the increase in quality we&rsquo;d expect given the reduction in BCE reconstruction loss. In the jump to 4D, we&rsquo;re able to better resolve the differences between 4s and 9s. Images become much sharper by the time we hit 16 dimensions. The differences between 16 and 32 dimensions, however, are marginal.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/samples.webp"
         alt="Grid of MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities"
         title="Grid of MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities.</figcaption>
    
</figure>

<p>Sampling quality also improves with latent dimensionality. Images are sharper as we increase the dimensionality. However, the space seems to get sparser as we increase to the largest dimensionalities, which makes sense given the size and nature of our dataset.</p>
<h4 id="latent-space-visualizations">Latent Space Visualizations</h4>















<figure class="post-figure center ">
    <img src="/img/vae-tut/latent_combined.webp"
         alt="2D PCA projections of higher-dimensional latent spaces colored by their true digit labels"
         title="2D PCA projections of higher-dimensional latent spaces colored by their true digit labels"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D PCA projections of higher-dimensional latent spaces colored by their true digit labels.</figcaption>
    
</figure>

<p>The challenge with visualizing higher-dimensional latent spaces is that we must reduce their dimensionality to 2D. PCA struggles to capture the variance of higher dimensionalities. The 4D and 8D plots suggest increasingly better separation of the numeric classes. However, the 16D and 32D plots only show 10-20% of the variance and give a misleading image of overlap.</p>
<h2 id="conclusion">Conclusion</h2>
<p>In this tutorial, we&rsquo;ve journeyed from the core theory of Variational Autoencoders to a practical, modern PyTorch implementation and a series of experiments on the MNIST dataset. Our findings highlight several key takeaways for practitioners:</p>
<ol>
<li>
<p><strong>The VAE is a Balancing Act:</strong> The fundamental tension between reconstruction fidelity and latent space regularization is the core of the VAE. Our visualization of the BCE vs. KL loss trade-off clearly showed training as a search for an optimal point on this Pareto frontier, where improving one objective necessarily means sacrificing the other.</p>
</li>
<li>
<p><strong>Latent Dimensionality is a Critical Hyperparameter:</strong> Increasing the latent dimension consistently improved reconstruction quality with diminishing returns. As we saw in the jump from 16 to 32 dimensions, too much capacity can lead to &ldquo;inactive&rdquo; dimensions, where the KL penalty overpowers the model&rsquo;s ability to encode useful information. This demonstrates that choosing the right latent size is crucial for both performance and efficiency.</p>
</li>
<li>
<p><strong>VAEs Learn Meaningful Unsupervised Representations:</strong> Without any labels, our VAE successfully organized the latent space, clustering similar digits and enabling smooth interpolations. This underscores the power of VAEs for unsupervised learning, dimensionality reduction, and discovering the underlying structure in complex data.</p>
</li>
<li>
<p><strong>Implementation Details Matter:</strong> While different standard deviation parameterizations yielded similar results on this simple problem, understanding their gradient behaviors is key for tackling more complex datasets where training stability can be a major challenge. Proper loss scaling is similarly crucial to prevent one term from dominating the other and leading to issues like posterior collapse.</p>
</li>
</ol>
<p>While the classic VAE produces characteristically blurry reconstructions, it remains a foundational generative model. The principles we&rsquo;ve explored here (the ELBO, the reparameterization trick, and the trade-off between reconstruction and regularization) are central to many more advanced generative models used today.</p>
<p><strong>Questions or feedback?</strong> Feel free to reach out. I&rsquo;d love to hear about your experiences with VAE experiments!</p>
]]></content:encoded></item><item><title>Kabsch Algorithm: NumPy, PyTorch, TensorFlow, and JAX</title><link>https://hunterheidenreich.com/posts/kabsch-algorithm/</link><pubDate>Tue, 03 Oct 2023 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/kabsch-algorithm/</guid><description>Learn about the Kabsch algorithm for optimal point alignment with implementations in NumPy, PyTorch, TensorFlow, and JAX for ML applications.</description><content:encoded><![CDATA[<h2 id="what-is-the-kabsch-algorithm">What is the Kabsch Algorithm?</h2>
<p>In computer vision or scientific computing, a common problem frequently arises: given two sets of points, what is the optimal rigid body transformation for their alignment? The Kabsch algorithm provides a nice solution.</p>















<figure class="post-figure center ">
    <img src="/img/scientific-computing/kabsch-alignment-before-and-after.webp"
         alt="Visualization of two point sets before and after Kabsch alignment"
         title="Visualization of two point sets before and after Kabsch alignment"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Kabsch algorithm optimally rotates and translates the blue points to align with the red points.</figcaption>
    
</figure>

<p>What are some concrete situations where this crops up?</p>
<ul>
<li><strong>Molecular Dynamics</strong>: Your points are a set of atoms (with physically relevant types), and you want to compare two molecular conformations. Are they the same structure with minor noise or rotation? Or are they different conformations, like a different folding of a protein? This is especially helpful when applying generative models to chemical structures. For example, if you are building a <a href="/notes/chemistry/molecular-simulation/ml-potentials/denoise-vae/">3D Molecular VAE</a> in PyTorch or working with <a href="/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/">Flow Matching models</a>, Kabsch alignment ensures your generative loss function remains rotationally invariant.</li>
<li><strong>Computer Vision</strong>: You have two point clouds from 3D scans of an object taken from different angles. You want to align them to reconstruct the full shape. Or perhaps you&rsquo;re generating 3D shapes from 2D images and need to compare the generated shape to a ground truth scan. Anytime a 3D system is represented as a point cloud, the Kabsch algorithm can help with alignment.</li>
</ul>
<p>Of course, existing libraries implement this algorithm. However, often I find it beneficial to implement algorithms from scratch to build intuition. Furthermore, modern machine learning applications require automatic differentiation, so we will implement the algorithm in PyTorch, TensorFlow, and JAX.</p>
<p>Below, we&rsquo;ll cover the math behind the Kabsch algorithm (and its scaling variant, the <strong>Kabsch-Umeyama</strong> algorithm) and provide complete, differentiable implementations in <strong>NumPy</strong>, <strong>PyTorch</strong>, <strong>TensorFlow</strong>, and <strong>JAX</strong>, demonstrating both single-pair and batched computations for ML applications.</p>
<h2 id="the-math">The Math</h2>















<figure class="post-figure center ">
    <img src="/img/scientific-computing/kabsch-algorithm-basic-animation.webp"
         alt="Animation showing the iterative steps of centroid alignment and rotation"
         title="Animation showing the iterative steps of centroid alignment and rotation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Visualizing the alignment process: first centering the datasets, then finding the optimal rotation.</figcaption>
    
</figure>

<p>Let&rsquo;s say we have two sets of paired points,
$P={\mathbf{p}_i} \in \mathbb{R}^{N \times D}$ and $Q={\mathbf{q}_i} \in \mathbb{R}^{N \times D}$, for $i = 1, \dots, N$
(where $D$ is the dimensionality and $N$ is the number of points).
We want to find a translation vector $\mathbf{t}$ and rotation matrix $R$ to transform $P$ to align with $Q$.</p>
<p>The optimization problem is:</p>
<p>$$
\min_{\mathbf{t}, \ R} \mathcal{L}(\mathbf{t}, R) = \frac{1}{2} \sum_{i=1}^N | \mathbf{q}_i - (R\mathbf{p}_i + \mathbf{t}) |^2
$$</p>
<p>where $\mathbf{t}^\ast \in \mathbb{R}^D$ and $R^\ast \in \mathbb{R}^{D \times D}$ are the optimal translation and rotation.</p>
<p>Often we use a weighted version with weights $w_i$ (e.g., atomic masses in molecular dynamics):</p>
<p>$$
\min_{\mathbf{t}, \ R} \mathcal{L}(\mathbf{t}, R) = \frac{1}{2} \sum_{i=1}^N w_i | \mathbf{q}_i - (R\mathbf{p}_i + \mathbf{t}) |^2
$$</p>
<h3 id="the-translation">The Translation</h3>
<p>The translation and rotation are coupled, but they separate cleanly once we work in centroid-centered coordinates. Compute the centroids (averages) of both point sets:</p>
<p>$$
\bar{\mathbf{p}} = \frac{1}{N} \sum_{i=1}^N \mathbf{p}_i \quad \text{and} \quad \bar{\mathbf{q}} = \frac{1}{N} \sum_{i=1}^N \mathbf{q}_i
$$</p>
<p>For any fixed rotation $R$, the translation that minimizes $\mathcal{L}$ is found by setting $\partial \mathcal{L} / \partial \mathbf{t} = 0$. It maps the rotated source centroid onto the target centroid:</p>
<p>$$
\mathbf{t} = \bar{\mathbf{q}} - R\bar{\mathbf{p}}
$$</p>
<p>A tempting shortcut is to write $\mathbf{t} = \bar{\mathbf{q}} - \bar{\mathbf{p}}$, but that is only correct when $R = I$. In general the translation depends on the rotation, so we compute it <em>after</em> solving for $R$. Substituting this optimal $\mathbf{t}$ back into the objective cancels the centroids and leaves a rotation-only problem in the centered coordinates $\mathbf{p}_i^\prime = \mathbf{p}_i - \bar{\mathbf{p}}$ and $\mathbf{q}_i^\prime = \mathbf{q}_i - \bar{\mathbf{q}}$:</p>
<p>$$
\mathcal{L}(R) = \frac{1}{2} \sum_{i=1}^N | \mathbf{q}_i^\prime - R\mathbf{p}_i^\prime |^2
$$</p>
<p>which is what the next section solves.</p>
<h3 id="the-rotation-matrix">The Rotation Matrix</h3>
<p>We now minimize $\mathcal{L}(R)$ over rotations, using the centered points $\mathbf{p}_i^\prime$ and $\mathbf{q}_i^\prime$ from above. Compute the cross-covariance matrix between the centered sets:</p>
<p>$$
C = P^{\prime T} Q^\prime = \sum_{i=1}^N \mathbf{p}_i^{\prime T} \mathbf{q}_i^{\prime} \in \mathbb{R}^{D \times D}
$$</p>
<p>This is a fairly lightweight operation since $D$ is typically small (e.g., 3 for 3D points), even if $N$ is large.</p>
<p>With $C$ in hand, we want to compute its Singular Value Decomposition (SVD):</p>
<p>$$
C = U \Sigma V^T
$$</p>
<p>This operation is computationally expensive. It scales cubically with $D$ (i.e., $O(D^3)$).
However, since we&rsquo;re often interested in cases where $D$ is small (e.g., 2D or 3D points), this is manageable.</p>
<p>Next, we check for improper rotations (i.e., reflections) and correct for them where necessary:</p>
<p>$$
d = \text{sign}(\det(V U^T))
$$</p>
<p>If $d = -1$, we need to flip the last column of $V$ in the final rotation matrix.</p>
<p>Let $B = \text{diag}(1, 1, d)$.
The optimal rotation matrix comes out:</p>
<p>$$
R^\ast = V B U^T
$$</p>
<h3 id="summary">Summary</h3>
<p>In a nutshell, the Kabsch algorithm boils down to:</p>
<ol>
<li>Compute centroids of $P$ and $Q$ ($\bar{\mathbf{p}}$ and $\bar{\mathbf{q}}$)</li>
<li>Center both point sets by subtracting centroids: $P^\prime$ and $Q^\prime$</li>
<li>Compute cross-covariance matrix $C = P^{\prime T} Q^\prime$</li>
<li>Compute SVD: $C = U \Sigma V^T$ (<em>expensive step</em>)</li>
<li>Compute $d = \text{sign}(\det(V U^T))$ and $B = \text{diag}(1, 1, d)$</li>
<li>Optimal rotation: $R^\ast = V B U^T$</li>
<li>Optimal translation (using the rotation from step 6): $\mathbf{t}^\ast = \bar{\mathbf{q}} - R^\ast\bar{\mathbf{p}}$</li>
</ol>
<p>The resulting root-mean-square deviation (RMSD) between aligned point sets is</p>
<p>$$
\text{RMSD} = \sqrt{\frac{1}{N} \sum_{i=1}^N | \mathbf{q}_i - (R^\ast\mathbf{p}_i + \mathbf{t}^\ast) |^2}
$$</p>















<figure class="post-figure center ">
    <img src="/img/scientific-computing/kabsch-algorithm-visualized-rmsd.webp"
         alt="Diagram illustrating Root Mean Square Deviation (RMSD) distances"
         title="Diagram illustrating Root Mean Square Deviation (RMSD) distances"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">RMSD measures the average distance between the aligned points.</figcaption>
    
</figure>

<p>which is frequently used as a measure of similarity between molecular structures or as a metric in loss functions for ML applications.</p>
<h3 id="the-kabsch-umeyama-algorithm-scaling">The Kabsch-Umeyama Algorithm (Scaling)</h3>
<p>While the standard Kabsch algorithm solves for optimal rotation and translation, the <strong>Kabsch-Umeyama algorithm</strong> extends this by also finding an optimal <strong>scaling factor</strong> $c$. This is essential when aligning structures of different scales, such as a 3D scan versus a ground truth model.</p>
<p><em>(Note: This is sometimes searched for as the &ldquo;Absch-Umeyama algorithm&rdquo; due to typos, but the correct attribution is to Shinji Umeyama based on Wolfgang Kabsch&rsquo;s work.)</em></p>
<p>The method estimates the transformation $\mathbf{q}_i \approx c R \mathbf{p}_i + \mathbf{t}$. The optimal scale is the trace of the (reflection-corrected) singular values of the cross-covariance divided by the variance of the source points about their centroid. See the <a href="/notes/biology/computational-biology/umeyama-similarity-transformation/">Umeyama paper notes</a> for the full derivation.</p>
<p><strong>A Note on SVD and Automatic Differentiation</strong></p>
<p>While modern frameworks allow us to backpropagate through the Singular Value Decomposition (SVD), it comes with a known stability issue: if the cross-covariance matrix has identical (degenerate) singular values (which can occur if the point clouds are perfectly aligned or have certain symmetries), the gradient of the SVD approaches infinity, causing <code>NaN</code> values during backpropagation. If you plan to use this algorithm as a loss function for a neural network, it is often necessary to add a tiny epsilon to the matrix before computing the SVD, or to utilize an SVD gradient patch. The <a href="/projects/kabsch-horn-cookbook/">Kabsch-Horn Cookbook</a> library provides a SafeSVD primitive that floors the singular-value-gap denominator at machine epsilon in the backward pass, producing finite gradients at degenerate inputs across PyTorch, JAX, TensorFlow, and MLX.</p>
<h2 id="implementation">Implementation</h2>
<p>Let&rsquo;s implement the algorithm in different frameworks. Note that for simplicity, the following implementations cover the <strong>unweighted</strong> Kabsch algorithm. If your application (like molecular dynamics) requires weights (e.g., atomic masses), the <a href="/projects/kabsch-horn-cookbook/">Kabsch-Horn Cookbook</a> library provides per-point weighted alignment out of the box.</p>
<h3 id="numpy">NumPy</h3>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> numpy <span style="color:#66d9ef">as</span> np
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_numpy</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>mean(P, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>mean(Q, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>dot(p<span style="color:#f92672">.</span>T, q)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    U, S, Vt <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Validate right-handed coordinate system</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> np<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>det(np<span style="color:#f92672">.</span>dot(Vt<span style="color:#f92672">.</span>T, U<span style="color:#f92672">.</span>T)) <span style="color:#f92672">&lt;</span> <span style="color:#ae81ff">0.0</span>:
</span></span><span style="display:flex;"><span>        Vt[<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, :] <span style="color:#f92672">*=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">1.0</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal rotation</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>dot(Vt<span style="color:#f92672">.</span>T, U<span style="color:#f92672">.</span>T)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> centroid_Q <span style="color:#f92672">-</span> np<span style="color:#f92672">.</span>dot(R, centroid_P)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>sqrt(np<span style="color:#f92672">.</span>sum(np<span style="color:#f92672">.</span>square(np<span style="color:#f92672">.</span>dot(p, R<span style="color:#f92672">.</span>T) <span style="color:#f92672">-</span> q)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div><p>Here&rsquo;s a quick test to verify correctness:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">test_numpy</span>():
</span></span><span style="display:flex;"><span>    np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>seed(<span style="color:#ae81ff">12345</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    P <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>randn(<span style="color:#ae81ff">100</span>, <span style="color:#ae81ff">3</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    alpha <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>rand() <span style="color:#f92672">*</span> <span style="color:#ae81ff">2</span> <span style="color:#f92672">*</span> np<span style="color:#f92672">.</span>pi
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>array([[np<span style="color:#f92672">.</span>cos(alpha), <span style="color:#f92672">-</span>np<span style="color:#f92672">.</span>sin(alpha), <span style="color:#ae81ff">0</span>],
</span></span><span style="display:flex;"><span>                    [np<span style="color:#f92672">.</span>sin(alpha), np<span style="color:#f92672">.</span>cos(alpha), <span style="color:#ae81ff">0</span>],
</span></span><span style="display:flex;"><span>                    [<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>]])
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>randn(<span style="color:#ae81ff">3</span>) <span style="color:#f92672">*</span> <span style="color:#ae81ff">10</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    Q <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>dot(P, R<span style="color:#f92672">.</span>T) <span style="color:#f92672">+</span> t
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    R_opt, t_opt, rmsd <span style="color:#f92672">=</span> kabsch_numpy(P, Q)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">&#39;RMSD: </span><span style="color:#e6db74">{}</span><span style="color:#e6db74">&#39;</span><span style="color:#f92672">.</span>format(rmsd))
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">&#39;R:</span><span style="color:#ae81ff">\n</span><span style="color:#e6db74">{}</span><span style="color:#e6db74">&#39;</span><span style="color:#f92672">.</span>format(R))
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">&#39;R_opt:</span><span style="color:#ae81ff">\n</span><span style="color:#e6db74">{}</span><span style="color:#e6db74">&#39;</span><span style="color:#f92672">.</span>format(R_opt))
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">&#39;t:</span><span style="color:#ae81ff">\n</span><span style="color:#e6db74">{}</span><span style="color:#e6db74">&#39;</span><span style="color:#f92672">.</span>format(t))
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">&#39;t_opt:</span><span style="color:#ae81ff">\n</span><span style="color:#e6db74">{}</span><span style="color:#e6db74">&#39;</span><span style="color:#f92672">.</span>format(t_opt))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    l2_t <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>norm(t <span style="color:#f92672">-</span> t_opt)
</span></span><span style="display:flex;"><span>    l2_R <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>norm(R <span style="color:#f92672">-</span> R_opt)
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">&#39;l2_t: </span><span style="color:#e6db74">{}</span><span style="color:#e6db74">&#39;</span><span style="color:#f92672">.</span>format(l2_t))
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">&#39;l2_R: </span><span style="color:#e6db74">{}</span><span style="color:#e6db74">&#39;</span><span style="color:#f92672">.</span>format(l2_R))
</span></span></code></pre></div><p>Running this test shows the algorithm correctly recovers the rotation and translation:</p>
<pre><code>RMSD: 3.2111501877699246e-15
R:
[[-0.8475392 -0.5307328  0.       ]
 [ 0.5307328 -0.8475392  0.       ]
 [ 0.         0.         1.       ]]
R_opt:
[[-8.47539198e-01 -5.30732803e-01 -2.95434260e-16]
 [ 5.30732803e-01 -8.47539198e-01  2.92859649e-16]
 [ 0.00000000e+00 -2.77555756e-16  1.00000000e+00]]
t:
[ 5.99726796  1.50078468 -3.34633977]
t_opt:
[ 5.99726796  1.50078468 -3.34633977]
l2_t: 2.7012892057857038e-15
l2_R: 8.028174304721057e-16
</code></pre>
<p>Both the rotation and the translation are recovered to within floating-point precision (the residuals <code>l2_t</code> and <code>l2_R</code> are on the order of <code>1e-15</code>).</p>
<p>For batch processing:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_numpy_batched</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A BxNx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A BxNx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>mean(P, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)  <span style="color:#75715e"># Bx1x3</span>
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>mean(Q, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)  <span style="color:#75715e"># Bx1x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P  <span style="color:#75715e"># BxNx3</span>
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q  <span style="color:#75715e"># BxNx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>matmul(p<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>), q)  <span style="color:#75715e"># Bx3x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    U, S, Vt <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)  <span style="color:#75715e"># Bx3x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Validate right-handed coordinate system</span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>det(np<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>), U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>)))
</span></span><span style="display:flex;"><span>    flip <span style="color:#f92672">=</span> d <span style="color:#f92672">&lt;</span> <span style="color:#ae81ff">0.0</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> flip<span style="color:#f92672">.</span>any():
</span></span><span style="display:flex;"><span>        Vt[flip, <span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, :] <span style="color:#f92672">*=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">1.0</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal rotation</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>), U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>))  <span style="color:#75715e"># Bx3x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> centroid_Q<span style="color:#f92672">.</span>squeeze(<span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> np<span style="color:#f92672">.</span>matmul(centroid_P, R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>))<span style="color:#f92672">.</span>squeeze(<span style="color:#ae81ff">1</span>)  <span style="color:#75715e"># Bx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>sqrt(np<span style="color:#f92672">.</span>sum(np<span style="color:#f92672">.</span>square(np<span style="color:#f92672">.</span>matmul(p, R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>)) <span style="color:#f92672">-</span> q), axis<span style="color:#f92672">=</span>(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">1</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div><h3 id="pytorch">PyTorch</h3>


<p><details >
  <summary markdown="span">📝 Important Update (February 15, 2026)</summary>
  <strong>Bug Fix Notice:</strong> The PyTorch implementation has been updated to use the &ldquo;B-matrix&rdquo; broadcasting approach. This eliminates in-place tensor modification (which breaks <code>autograd</code>) and data-dependent control flow (which breaks <code>torch.compile</code> and <code>torch.vmap</code>).
</details></p>

<p>The PyTorch implementation now uses broadcasting to ensure differentiability:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_torch</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>mean(P, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>mean(Q, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>matmul(p<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>), q)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    U, S, Vt <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 1. Calculate determinant</span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>det(torch<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>), U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>)))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 2. Build diagonal B tensor without in-place mutation</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># We use stack to preserve gradients and graph connections</span>
</span></span><span style="display:flex;"><span>    B_diag <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>stack([torch<span style="color:#f92672">.</span>tensor(<span style="color:#ae81ff">1.0</span>, device<span style="color:#f92672">=</span>d<span style="color:#f92672">.</span>device, dtype<span style="color:#f92672">=</span>d<span style="color:#f92672">.</span>dtype),
</span></span><span style="display:flex;"><span>                          torch<span style="color:#f92672">.</span>tensor(<span style="color:#ae81ff">1.0</span>, device<span style="color:#f92672">=</span>d<span style="color:#f92672">.</span>device, dtype<span style="color:#f92672">=</span>d<span style="color:#f92672">.</span>dtype),
</span></span><span style="display:flex;"><span>                          torch<span style="color:#f92672">.</span>sign(d)])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 3. Scale columns of Vt.T via broadcasting, then multiply by U^T</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Vt.T: (3, 3). B_diag: (3) -&gt; B_diag[None, :]: (1, 3)</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>) <span style="color:#f92672">*</span> B_diag[<span style="color:#66d9ef">None</span>, :], U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> centroid_Q <span style="color:#f92672">-</span> centroid_P <span style="color:#f92672">@</span> R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>sqrt(torch<span style="color:#f92672">.</span>sum(torch<span style="color:#f92672">.</span>square(torch<span style="color:#f92672">.</span>matmul(p, R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>)) <span style="color:#f92672">-</span> q)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div><p>And our batched version:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_torch_batched</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD, in a batched manner.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A BxNx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A BxNx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>mean(P, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)  <span style="color:#75715e"># Bx1x3</span>
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>mean(Q, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)  <span style="color:#75715e"># Bx1x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P  <span style="color:#75715e"># BxNx3</span>
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q  <span style="color:#75715e"># BxNx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>matmul(p<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>), q)  <span style="color:#75715e"># Bx3x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    U, S, Vt <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)  <span style="color:#75715e"># Bx3x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 1. Calculate batched determinant</span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>det(torch<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>), U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)))  <span style="color:#75715e"># B</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 2. Build batched B_diag without in-place mutation or control flow</span>
</span></span><span style="display:flex;"><span>    ones <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>ones_like(d)
</span></span><span style="display:flex;"><span>    B_diag <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>stack([ones, ones, torch<span style="color:#f92672">.</span>sign(d)], dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>) <span style="color:#75715e"># Bx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 3. Scale columns of Vt.T and multiply</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Vt.T: (B, 3, 3). B_diag: (B, 3). B_diag[:, None, :]: (B, 1, 3).</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>) <span style="color:#f92672">*</span> B_diag[:, <span style="color:#66d9ef">None</span>, :], U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> centroid_Q<span style="color:#f92672">.</span>squeeze(<span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> torch<span style="color:#f92672">.</span>matmul(centroid_P, R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>))<span style="color:#f92672">.</span>squeeze(<span style="color:#ae81ff">1</span>)  <span style="color:#75715e"># Bx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>sqrt(torch<span style="color:#f92672">.</span>sum(torch<span style="color:#f92672">.</span>square(torch<span style="color:#f92672">.</span>matmul(p, R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)) <span style="color:#f92672">-</span> q), dim<span style="color:#f92672">=</span>(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">1</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div><h3 id="tensorflow">TensorFlow</h3>
<p>The TensorFlow implementation returns <code>S</code>, <code>U</code>, and <code>V</code> directly. To handle immutability and potential compilation (e.g., via <code>@tf.function</code>), we avoid explicit conditional branching by constructing a correction matrix $B$ and broadcasting it.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> tensorflow <span style="color:#66d9ef">as</span> tf
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_tensorflow</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    P <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>convert_to_tensor(P, dtype<span style="color:#f92672">=</span>tf<span style="color:#f92672">.</span>float32)
</span></span><span style="display:flex;"><span>    Q <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>convert_to_tensor(Q, dtype<span style="color:#f92672">=</span>tf<span style="color:#f92672">.</span>float32)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>reduce_mean(P, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>reduce_mean(Q, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>matmul(tf<span style="color:#f92672">.</span>transpose(p), q)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    S, U, V <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 1. Calculate determinant</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Note: V in TF SVD is V, not V^T.</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># R = V * U^T. Det(R) = Det(V * U^T)</span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>det(tf<span style="color:#f92672">.</span>matmul(V, tf<span style="color:#f92672">.</span>transpose(U)))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 2. Build diagonal B tensor: [1.0, 1.0, sign(d)]</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Use static shape 3 if possible, or infer from D. Assuming D=3 here.</span>
</span></span><span style="display:flex;"><span>    B_diag <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>stack([<span style="color:#ae81ff">1.0</span>, <span style="color:#ae81ff">1.0</span>, tf<span style="color:#f92672">.</span>sign(d)])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 3. Scale columns of V via broadcasting (V * B_diag), then multiply by U^T</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># V is DxD, B_diag is D. V * B_diag[None, :] multiplies each column j by B_diag[j]</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>matmul(V <span style="color:#f92672">*</span> B_diag[<span style="color:#66d9ef">None</span>, :], tf<span style="color:#f92672">.</span>transpose(U))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> centroid_Q <span style="color:#f92672">-</span> tf<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>matvec(R, centroid_P)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>sqrt(tf<span style="color:#f92672">.</span>reduce_sum(tf<span style="color:#f92672">.</span>square(tf<span style="color:#f92672">.</span>matmul(p, tf<span style="color:#f92672">.</span>transpose(R)) <span style="color:#f92672">-</span> q)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div><p>and a batched version:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_tensorflow_batched</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    P <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>convert_to_tensor(P, dtype<span style="color:#f92672">=</span>tf<span style="color:#f92672">.</span>float32)
</span></span><span style="display:flex;"><span>    Q <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>convert_to_tensor(Q, dtype<span style="color:#f92672">=</span>tf<span style="color:#f92672">.</span>float32)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>reduce_mean(P, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>reduce_mean(Q, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>matmul(tf<span style="color:#f92672">.</span>transpose(p, perm<span style="color:#f92672">=</span>[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>]), q)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    S, U, V <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 1. Calculate batched determinant</span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>det(tf<span style="color:#f92672">.</span>matmul(V, tf<span style="color:#f92672">.</span>transpose(U, perm<span style="color:#f92672">=</span>[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>])))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 2. Build batched B_diag: shape (B, 3)</span>
</span></span><span style="display:flex;"><span>    ones <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>ones_like(d)
</span></span><span style="display:flex;"><span>    B_diag <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>stack([ones, ones, tf<span style="color:#f92672">.</span>sign(d)], axis<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 3. Scale columns of V (Broadcasting adds the middle dimension)</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># V: (B, 3, 3), B_diag: (B, 3) -&gt; B_diag[:, None, :]: (B, 1, 3)</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>matmul(V <span style="color:#f92672">*</span> B_diag[:, <span style="color:#66d9ef">None</span>, :], tf<span style="color:#f92672">.</span>transpose(U, perm<span style="color:#f92672">=</span>[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>]))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>squeeze(centroid_Q, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> tf<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>matvec(R, tf<span style="color:#f92672">.</span>squeeze(centroid_P, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>))  <span style="color:#75715e"># Bx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> tf<span style="color:#f92672">.</span>sqrt(tf<span style="color:#f92672">.</span>reduce_sum(tf<span style="color:#f92672">.</span>square(tf<span style="color:#f92672">.</span>matmul(p, tf<span style="color:#f92672">.</span>transpose(R, perm<span style="color:#f92672">=</span>[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>])) <span style="color:#f92672">-</span> q), axis<span style="color:#f92672">=</span>(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">1</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div><h3 id="jax">JAX</h3>
<p>The JAX implementation closely mirrors NumPy, replacing <code>np</code> with <code>jnp</code>. However, we again avoid <code>if</code> statements and in-place assignment (which JAX disallows) by using the broadcasting B-matrix approach.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> jax.numpy <span style="color:#66d9ef">as</span> jnp
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_jax</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A Nx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    P <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>array(P)
</span></span><span style="display:flex;"><span>    Q <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>array(Q)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>mean(P, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>mean(Q, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>dot(p<span style="color:#f92672">.</span>T, q)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    U, S, Vt <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 1. Calculate determinant</span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>det(jnp<span style="color:#f92672">.</span>dot(Vt<span style="color:#f92672">.</span>T, U<span style="color:#f92672">.</span>T))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 2. Build diagonal B array</span>
</span></span><span style="display:flex;"><span>    B_diag <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>array([<span style="color:#ae81ff">1.0</span>, <span style="color:#ae81ff">1.0</span>, jnp<span style="color:#f92672">.</span>sign(d)])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 3. Scale columns of Vt.T and multiply by U.T</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Vt.T is V.</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>dot(Vt<span style="color:#f92672">.</span>T <span style="color:#f92672">*</span> B_diag[<span style="color:#66d9ef">None</span>, :], U<span style="color:#f92672">.</span>T)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> centroid_Q <span style="color:#f92672">-</span> jnp<span style="color:#f92672">.</span>dot(R, centroid_P)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>sqrt(jnp<span style="color:#f92672">.</span>sum(jnp<span style="color:#f92672">.</span>square(jnp<span style="color:#f92672">.</span>dot(p, R<span style="color:#f92672">.</span>T) <span style="color:#f92672">-</span> q)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div><p>and batched:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kabsch_jax_batched</span>(P, Q):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes the optimal rotation and translation to align two sets of points (P -&gt; Q),
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    and their RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param P: A BxNx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :param Q: A BxNx3 matrix of points
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    :return: A tuple containing the optimal rotation matrix, the optimal
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">             translation vector, and the RMSD.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    P <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>array(P)
</span></span><span style="display:flex;"><span>    Q <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>array(Q)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">assert</span> P<span style="color:#f92672">.</span>shape <span style="color:#f92672">==</span> Q<span style="color:#f92672">.</span>shape, <span style="color:#e6db74">&#34;Matrix dimensions must match&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute centroids</span>
</span></span><span style="display:flex;"><span>    centroid_P <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>mean(P, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)  <span style="color:#75715e"># Bx1x3</span>
</span></span><span style="display:flex;"><span>    centroid_Q <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>mean(Q, axis<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, keepdims<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)  <span style="color:#75715e"># Bx1x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Center the points</span>
</span></span><span style="display:flex;"><span>    p <span style="color:#f92672">=</span> P <span style="color:#f92672">-</span> centroid_P  <span style="color:#75715e"># BxNx3</span>
</span></span><span style="display:flex;"><span>    q <span style="color:#f92672">=</span> Q <span style="color:#f92672">-</span> centroid_Q  <span style="color:#75715e"># BxNx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Compute the covariance matrix</span>
</span></span><span style="display:flex;"><span>    H <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>matmul(p<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>), q)  <span style="color:#75715e"># Bx3x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># SVD</span>
</span></span><span style="display:flex;"><span>    U, S, Vt <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>svd(H)  <span style="color:#75715e"># Bx3x3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 1. Calculate batched determinant</span>
</span></span><span style="display:flex;"><span>    d <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>det(jnp<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>), U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>)))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 2. Build batched B_diag</span>
</span></span><span style="display:flex;"><span>    ones <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>ones_like(d)
</span></span><span style="display:flex;"><span>    B_diag <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>stack([ones, ones, jnp<span style="color:#f92672">.</span>sign(d)], axis<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 3. Scale columns of Vt.T and multiply by U.T</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Vt.T: (B, 3, 3). B_diag: (B, 3).</span>
</span></span><span style="display:flex;"><span>    R <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>matmul(Vt<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>) <span style="color:#f92672">*</span> B_diag[:, <span style="color:#66d9ef">None</span>, :], U<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Optimal translation (depends on R, so computed after it)</span>
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> centroid_Q<span style="color:#f92672">.</span>squeeze(<span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> jnp<span style="color:#f92672">.</span>matmul(centroid_P, R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>))<span style="color:#f92672">.</span>squeeze(<span style="color:#ae81ff">1</span>)  <span style="color:#75715e"># Bx3</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># RMSD</span>
</span></span><span style="display:flex;"><span>    rmsd <span style="color:#f92672">=</span> jnp<span style="color:#f92672">.</span>sqrt(jnp<span style="color:#f92672">.</span>sum(jnp<span style="color:#f92672">.</span>square(jnp<span style="color:#f92672">.</span>matmul(p, R<span style="color:#f92672">.</span>transpose(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">1</span>)) <span style="color:#f92672">-</span> q), axis<span style="color:#f92672">=</span>(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)) <span style="color:#f92672">/</span> P<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">1</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> R, t, rmsd
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/scientific-computing/kabsch-animated-protein-conformational-alignment-analysis.webp"
         alt="Animation of a protein structure being aligned using the Kabsch algorithm"
         title="Animation of a protein structure being aligned using the Kabsch algorithm"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Real-world application: Aligning protein conformations to analyze structural changes.</figcaption>
    
</figure>

<h2 id="extensions">Extensions</h2>
<p>The Kabsch algorithm has several important extensions that go beyond the formulation dealt with here:</p>
<ul>
<li><strong>Quaternion Form</strong>: The algorithm can be reformulated using quaternions for better numerical stability, particularly useful in applications requiring high precision.</li>
<li><strong>Iterative Versions</strong>: More robust variants that handle noise better and have improved scaling properties for large point sets. This also can be advantageous for setups with limited computational resources.</li>
<li><strong>Weighted Kabsch</strong>: Extensions that incorporate point weights (e.g., atomic masses in molecular dynamics). While SciPy provides a <a href="https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.align_vectors.html#scipy.spatial.transform.Rotation.align_vectors">weighted version</a>, it lacks batch processing capabilities.</li>
<li><strong>The Umeyama Algorithm</strong>: If your point sets are rotated, translated, and scaled differently, the Umeyama algorithm is the direct extension of Kabsch. It solves the same optimization problem but introduces a scaling factor $c$, finding the optimal alignment for $Q \approx c R P + t$.</li>
</ul>
<p>Several of these extensions are implemented in the <a href="/projects/kabsch-horn-cookbook/">Kabsch-Horn Cookbook</a> library, which provides differentiable Kabsch, Horn, and Umeyama alignment across NumPy, PyTorch, JAX, TensorFlow, and MLX.</p>
<h2 id="further-reading">Further Reading</h2>
<ul>
<li><a href="https://en.wikipedia.org/wiki/Kabsch_algorithm">Wikipedia, Kabsch Algorithm</a></li>
<li><a href="https://zalo.github.io/blog/kabsch/">Zalo on Kabsch</a>: An interactive shape matching demo.</li>
</ul>
<h3 id="original-papers">Original Papers</h3>
<ul>
<li><strong>[Kabsch 1976]</strong> Kabsch, W. (1976). &ldquo;A solution for the best rotation to relate two sets of vectors.&rdquo; <em>Acta Crystallographica Section A</em>, 32(5), 922-923. <a href="https://doi.org/10.1107/S0567739476001873">DOI: 10.1107/S0567739476001873</a>
<em>The original paper: a closed-form, non-iterative optimal-rotation solution derived via Lagrange multipliers and eigendecomposition of $\tilde{R}R$ (the SVD reformulation came later; see Arun et al. 1987).</em> See also: <a href="/notes/biology/computational-biology/kabsch-algorithm/">paper notes</a>.</li>
<li><strong>[Kabsch 1978]</strong> Kabsch, W. (1978). &ldquo;A discussion of the solution for the best rotation to relate two sets of vectors.&rdquo; <em>Acta Crystallographica Section A</em>, 34(5), 827-828. <a href="https://doi.org/10.1107/S0567739478001680">DOI: 10.1107/S0567739478001680</a>
<em>The follow-up paper correcting for improper rotations (reflections).</em></li>
<li><strong>[Arun et al. 1987]</strong> Arun, K. S., Huang, T. S., &amp; Blostein, S. D. (1987). &ldquo;Least-Squares Fitting of Two 3-D Point Sets.&rdquo; <em>IEEE Transactions on Pattern Analysis and Machine Intelligence</em>, PAMI-9(5), 698-700. <a href="https://doi.org/10.1109/TPAMI.1987.4767965">DOI: 10.1109/TPAMI.1987.4767965</a>
<em>The first SVD-based formulation for 3D point set alignment.</em> See also: <a href="/notes/biology/computational-biology/arun-svd-point-fitting/">paper notes</a>.</li>
<li><strong>[Horn et al. 1988]</strong> Horn, B. K. P., Hilden, H. M., &amp; Negahdaripour, S. (1988). &ldquo;Closed-form solution of absolute orientation using orthonormal matrices.&rdquo; <em>Journal of the Optical Society of America A</em>, 5(7), 1127-1135. <a href="https://doi.org/10.1364/JOSAA.5.001127">DOI: 10.1364/JOSAA.5.001127</a>
<em>The matrix square root (polar decomposition) approach to the same problem.</em> See also: <a href="/notes/biology/computational-biology/horn-orthonormal-matrices/">paper notes</a>.</li>
<li><strong>[Horn 1987]</strong> Horn, B. K. P. (1987). &ldquo;Closed-form solution of absolute orientation using unit quaternions.&rdquo; <em>Journal of the Optical Society of America A</em>, 4(4), 629-642. <a href="https://doi.org/10.1364/JOSAA.4.000629">DOI: 10.1364/JOSAA.4.000629</a>
<em>An alternative quaternion-based closed-form solution that also handles scale.</em> See also: <a href="/notes/biology/computational-biology/horn-absolute-orientation/">paper notes</a>.</li>
<li><strong>[Umeyama 1991]</strong> Umeyama, S. (1991). &ldquo;Least-squares estimation of transformation parameters between two point patterns.&rdquo; <em>IEEE Transactions on Pattern Analysis and Machine Intelligence</em>, 13(4), 376-380. <a href="https://doi.org/10.1109/34.88573">DOI: 10.1109/34.88573</a>
<em>The extension of the algorithm to include optimal scaling in addition to rotation and translation.</em> See also: <a href="/notes/biology/computational-biology/umeyama-similarity-transformation/">paper notes</a>.</li>
</ul>
]]></content:encoded></item><item><title>IQCRNN: Certified Stability for Neural Networks</title><link>https://hunterheidenreich.com/projects/iqcrnn-pytorch/</link><pubDate>Wed, 11 May 2022 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/projects/iqcrnn-pytorch/</guid><description>PyTorch IQCRNN enforcing stability guarantees on RNNs via Integral Quadratic Constraints and semidefinite programming.</description><content:encoded><![CDATA[<p>This project is a PyTorch re-implementation of <strong>IQCRNN</strong>, a method that enforces strict stability guarantees on Recurrent Neural Networks used in control systems.</p>
<h2 id="overview">Overview</h2>
<p>Standard Reinforcement Learning agents can behave unpredictably in unseen states. This approach forces the agent&rsquo;s weights to satisfy <strong>Integral Quadratic Constraints (IQC)</strong> via a projection step. Effectively, it solves a convex optimization problem (Semidefinite Program) inside the gradient descent loop to ensure the controller never violates Lyapunov stability criteria.</p>
<p>The method bridges classic <strong>Robust Control Theory</strong> (1990s) with <strong>Deep Reinforcement Learning</strong> (2020s), providing mathematical certificates of safety for neural network controllers.</p>
<h2 id="features">Features</h2>
<ul>
<li><strong>Hybrid Optimization:</strong> Interleaved standard Gradient Descent (PyTorch) with Convex Optimization (<code>cvxpy</code> + <code>MOSEK</code>) to project weights onto the &ldquo;safe&rdquo; manifold after each training step.</li>
<li><strong>Complex Constraints:</strong> Implemented the &ldquo;Tilde&rdquo; parametrization from the original paper to convexify the non-convex stability conditions of the RNN dynamics, transforming an intractable problem into a solvable Linear Matrix Inequality (LMI).</li>
<li><strong>Safety-Critical Domains:</strong> Applied the controller across six control systems (cartpole, inverted pendulum, nonlinear pendulum, pendubot, power grid, and vehicle dynamics), including unstable plants where &ldquo;crashing&rdquo; during training is unacceptable.</li>
</ul>
<h2 id="usage">Usage</h2>
<p>The repository includes training scripts for the inverted pendulum and power grid environments, demonstrating the stability guarantees in practice.</p>
<h2 id="results">Results</h2>
<p>This project was a deep dive into the tension between <strong>Safety</strong> and <strong>Speed</strong>.</p>
<ul>
<li><strong>The Bottleneck:</strong> Solving an SDP at every few steps of training is computationally expensive (interior-point SDP solvers scale steeply, roughly $O(n^6)$ in the matrix dimension). While it provided mathematical certificates of safety, it highlighted why these methods haven&rsquo;t yet overtaken standard PPO/SAC in production: the &ldquo;safety tax&rdquo; on training time is steep.</li>
<li><strong>The Lesson:</strong> It taught me that &ldquo;theoretical guarantees&rdquo; often come with &ldquo;engineering fine print.&rdquo; If I were to redo this today, I would look into <strong>differentiable convex optimization layers</strong> (like <code>cvxpylayers</code>) to make the projection end-to-end differentiable.</li>
<li><strong>The &ldquo;Rough Edges&rdquo;:</strong> The codebase has artifacts of its research origins (e.g., the <code>reqs.txt</code> dependency dump). Reading a dense control theory paper (Gu et al., 2021) and implementing the math correctly was the primary focus.</li>
</ul>
<h2 id="citation">Citation</h2>
<p>Credit to the original authors:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{gu2021recurrentneuralnetworkcontrollers,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Recurrent Neural Network Controllers Synthesis with Stability Guarantees for Partially Observed Systems}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Fangda Gu and He Yin and Laurent El Ghaoui and Murat Arcak and Peter Seiler and Ming Jin}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2109.03861}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">archivePrefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">primaryClass</span>=<span style="color:#e6db74">{eess.SY}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/2109.03861}</span>,
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><h2 id="related-work">Related Work</h2>
<ul>
<li><a href="/research/deconstructing-recurrence-attention-gating/">Deconstructing Recurrence and Attention Gating</a>: research on recurrent network architectures, providing context for why stability guarantees on RNNs matter</li>
</ul>
]]></content:encoded></item><item><title>EigenNoise: Data-Free Word Vector Initialization</title><link>https://hunterheidenreich.com/research/eigennoise-contrastive-prior/</link><pubDate>Sun, 01 May 2022 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/research/eigennoise-contrastive-prior/</guid><description>Investigation into EigenNoise, a data-free initialization scheme for word vectors that approaches pre-trained model performance after fine-tuning.</description><content:encoded><![CDATA[<h2 id="abstract">Abstract</h2>
<p>We developed EigenNoise, a method to initialize word vectors using <strong>zero pre-training data</strong>. By deriving a co-occurrence matrix solely from the theoretical harmonic structure of language (Zipf&rsquo;s Law), this project demonstrates that we can mathematically synthesize a &ldquo;warm-start&rdquo; for NLP models. This approach challenges the reliance on massive corpora for initialization and offers a competitive alternative for low-resource environments.</p>
<h2 id="key-contributions">Key Contributions</h2>
<ul>
<li>A <strong>data-free initialization scheme</strong>: word vectors derived from a co-occurrence matrix synthesized from independent (Zipfian) frequency statistics, with no pre-training corpus.</li>
<li>Grounds the construction in the <strong>harmonic statistical structure</strong> of language, so the representation follows from first principles rather than from data.</li>
<li>Evaluates with <strong>Minimum Description Length (MDL)</strong> probing, which measures how much task-relevant information a representation encodes and how compactly, rather than raw accuracy.</li>
<li>After fine-tuning, EigenNoise <strong>approaches</strong> the performance of GloVe (trained on Gigaword) despite seeing <strong>no pre-training text</strong>.</li>
</ul>
<h2 id="technical-implementation">Technical Implementation</h2>
<p>The core insight is that &ldquo;noise&rdquo; in language follows a predictable distribution.</p>
<ol>
<li><strong>Modeling</strong>: We model the &ldquo;null hypothesis&rdquo; of text, how words would co-occur if they were statistically independent but followed Zipfian rank-frequency. This yields a theoretical co-occurrence matrix $\hat{X}$:</li>
</ol>
<p>$$\hat{X}_{ij} = \frac{2mN}{r_i r_j H_N}$$</p>
<p>Where $r_i$ is the rank of word $i$, $N$ is vocabulary size, $m$ is the context window size, and $H_N$ is the $N$-th harmonic number.</p>
<ol start="2">
<li>
<p><strong>Factorization</strong>: We then solve for the word vectors by performing an <strong>eigen-decomposition</strong> on this matrix, extracting the top $d$ components to form the representation space.</p>
</li>
<li>
<p><strong>Probing</strong>: Validated performance using MDL probing on CoNLL-2003 and TweetEval benchmarks.</p>
</li>
</ol>
<h2 id="why-this-matters">Why This Matters</h2>
<p>This research explores how much structure can emerge from frequency statistics alone, with no text exposure at all. The central finding is that EigenNoise vectors, derived purely from Zipf&rsquo;s Law, reach competitive performance with GloVe after fine-tuning. This is evidence that a significant portion of what we call &ldquo;learned linguistic knowledge&rdquo; is a consequence of word frequency distributions, not semantic exposure to real text.</p>
<p>In 2026, small pretrained models are freely available and handle most low-resource initialization needs, so the practical case for data-free initialization is narrower than it was in 2022. The theoretical contribution remains relevant: EigenNoise establishes a clean null hypothesis for what word vectors look like when only frequency information is present. For interpretability researchers trying to disentangle frequency artifacts from genuine semantic content, this baseline has value independent of the initialization use case.</p>
<p>The <strong>MDL probing</strong> methodology applied here also contributes beyond the main result. Unlike task accuracy, MDL measures how much information a representation encodes and how compactly, providing a more principled lens for evaluating representational quality. EigenNoise&rsquo;s co-occurrence prior is grounded directly in the <strong>Independent Frequencies Model (IFM)</strong> introduced in the companion <a href="/research/word-company-vicinity/">Word2Vec factorization paper</a>. Together, the two works form a coherent theoretical line: the IFM characterizes the frequency-driven baseline of embedding space, and EigenNoise operationalizes it as a practical, data-free initialization scheme.</p>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{heidenreich2022eigennoisecontrastivepriorwarmstart,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{EigenNoise: A Contrastive Prior to Warm-Start Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Hunter Scott Heidenreich and Jake Ryland Williams}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2205.04376}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archivePrefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryClass</span>=<span style="color:#e6db74">{cs.CL}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/2205.04376}</span>,
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><h2 id="related-work">Related Work</h2>
<p>For the theoretical foundation underlying EigenNoise&rsquo;s null hypothesis, including the first analytical solution to Word2Vec&rsquo;s softmax objective, see <a href="/research/word-company-vicinity/">Analytical Solution to Word2Vec Softmax &amp; Bias Probing</a>.</p>
]]></content:encoded></item><item><title>5 Axes of Multi-Arm Bandit Problems: A Practical Guide</title><link>https://hunterheidenreich.com/posts/a-roadmap-to-multi-arm-bandit-algorithms/</link><pubDate>Tue, 10 Nov 2020 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/a-roadmap-to-multi-arm-bandit-algorithms/</guid><description>Explore 5 key dimensions of multi-arm bandit problems to help practitioners better navigate the exploration-exploitation tradeoff in ML applications.</description><content:encoded><![CDATA[<h2 id="what-is-a-multi-arm-bandit-problem">What is a Multi-Arm Bandit Problem?</h2>
<p>Multi-arm bandit problems are a fundamental class of sequential decision-making problems in machine learning. They&rsquo;re less complex than a full reinforcement learning problem, but they capture a lot of the essential challenges of learning from interaction.</p>
<p>The name comes from the analogy of a gambler facing multiple slot machines (or &ldquo;one-armed bandits&rdquo;) and trying to figure out which one pays out the most over time. Do you keep playing the machine that&rsquo;s paid out well so far, or try others to see if they&rsquo;re better? This is the exploration-exploitation dilemma at the heart of bandit algorithms.</p>
<p>Bandit algorithms solve this by learning the reward distributions for each arm over time, balancing the need to explore new options with exploiting what they&rsquo;ve learned.</p>















<figure class="post-figure center ">
    <img src="/img/multi-arm-bandits/multi-arm-bandit-conceptual-graphic.webp"
         alt="Illustration of a vintage slot machine with multiple arms, representing the multi-arm bandit problem in machine learning where algorithms must balance exploration and exploitation"
         title="Illustration of a vintage slot machine with multiple arms, representing the multi-arm bandit problem in machine learning where algorithms must balance exploration and exploitation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Multi-arm bandit algorithms balance exploration and exploitation</figcaption>
    
</figure>

<p>What I&rsquo;ve found helpful is thinking about any bandit problem along five key dimensions. Asking these five questions helps quickly identify which approaches work best for a given problem:</p>
<div style="display: flex; flex-direction: column; gap: 16px; max-width: 700px; margin: 30px auto; font-family: system-ui, -apple-system, sans-serif;">
  <div style="background-color: #5A9BD5; color: white; padding: 20px 24px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.05);">
    <h3 style="margin: 0 0 8px 0; font-size: 1.25rem; font-weight: 700; color: white;">1. Action Space</h3>
    <p style="margin: 0; font-size: 1.05rem; line-height: 1.5;"><em>What does the problem action space look like?</em> <br>Consider whether your options are finite vs. infinite, or single vs. combinatorial.</p>
  </div>
  <div style="background-color: #55C3BA; color: white; padding: 20px 24px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.05);">
    <h3 style="margin: 0 0 8px 0; font-size: 1.25rem; font-weight: 700; color: white;">2. Problem Structure</h3>
    <p style="margin: 0; font-size: 1.05rem; line-height: 1.5;"><em>Is there any structure to the problem?</em> <br>Determine whether choosing certain actions provides information about the expected rewards of other actions.</p>
  </div>
  <div style="background-color: #5CC382; color: white; padding: 20px 24px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.05);">
    <h3 style="margin: 0 0 8px 0; font-size: 1.25rem; font-weight: 700; color: white;">3. External Information</h3>
    <p style="margin: 0; font-size: 1.05rem; line-height: 1.5;"><em>Is there external information that my learner has access to?</em> <br>Assess the availability of contextual information (like user data or environment state) before an action is chosen.</p>
  </div>
  <div style="background-color: #56B14E; color: white; padding: 20px 24px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.05);">
    <h3 style="margin: 0 0 8px 0; font-size: 1.25rem; font-weight: 700; color: white;">4. Reward Mechanism</h3>
    <p style="margin: 0; font-size: 1.05rem; line-height: 1.5;"><em>How are rewards generated?</em> <br>Identify if the environment's payouts are stochastic (random but consistent), non-stationary (changing over time), or adversarial.</p>
  </div>
  <div style="background-color: #81B653; color: white; padding: 20px 24px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.05);">
    <h3 style="margin: 0 0 8px 0; font-size: 1.25rem; font-weight: 700; color: white;">5. Learner Feedback</h3>
    <p style="margin: 0; font-size: 1.05rem; line-height: 1.5;"><em>What kind of feedback does my learner receive each round?</em> <br>Clarify if the feedback loop is strict bandit (only seeing the reward for the chosen arm), full, partial, or semi-bandit feedback.</p>
  </div>
</div>
<p>Jump to <strong><a href="#real-examples">Real Examples</a></strong> to see these dimensions in practice.</p>
<hr>
<h2 id="1-what-can-you-do-action-space">1. What Can You Do? (Action Space)</h2>
<p>The first question is simple: what options do you have? Action spaces are generally categorized along two dimensions: size and complexity.</p>
<p><strong>Finite vs. Infinite (Size)</strong></p>
<ul>
<li><strong>Finite (Discretized):</strong> Most people think of this scenario where the agent chooses between a small number of discrete options. Examples include selecting an arm in a standard 3-arm bandit or testing 5 different website layouts.</li>
<li><strong>Infinite (Continuous):</strong> Sometimes your action has continuous parameters. For example, selecting a bid price anywhere in the range of (0, 1), or adjusting a recommendation algorithm&rsquo;s temperature parameter. This requires different mathematical approaches.</li>
</ul>
<p><strong>Single vs. Combinatorial (Complexity)</strong></p>
<ul>
<li><strong>Single Action:</strong> Selection of exactly 1 action per round (e.g., showing a user <em>one</em> specific ad).</li>
<li><strong>Combinatorial Actions:</strong> Selection of a vector (multiple) of actions simultaneously. For example, selecting a subset of edges in a graph to form a path from node $t$ to node $s$. If you are picking 10 movies to populate a Netflix homepage at once, those selections might interact with one another.</li>
</ul>
<h2 id="2-do-your-choices-tell-you-about-other-choices-problem-structure">2. Do Your Choices Tell You About Other Choices? (Problem Structure)</h2>
<p>This is often overlooked but crucial: does trying option A teach you anything about option B?</p>
<ul>
<li><strong>Independent (Unstructured):</strong> Information gained from one action provides <em>zero insight</em> into the expected reward of other actions. For example, knowing the open rate of an email campaign tells you nothing about the click-through rate of a separate social media post.</li>
<li><strong>Correlated (Structured):</strong> Information gained from one action provides <em>valuable hints</em> about similar actions. If you test a 10% discount and see high conversions, it strongly implies a 15% discount will also perform well, helping you map the underlying demand curve.</li>
</ul>
<h2 id="3-what-extra-information-do-you-have-context">3. What Extra Information Do You Have? (Context)</h2>
<p>Real-world problems rarely happen in isolation. The question is: what additional predictive information might help you make better decisions <em>before</em> you pull the lever?</p>
<p>When you incorporate these state variables, you move from a Standard Bandit to a <strong>Contextual Bandit</strong>. This data usually falls into two buckets:</p>
<ul>
<li><strong>User Context:</strong> State information tied directly to the individual, such as age, location, or past purchase history.</li>
<li><strong>Environmental Context:</strong> State information tied to the surrounding conditions at the exact moment of decision, such as time of day, seasonality, or device type (mobile vs. desktop).</li>
</ul>
<h2 id="4-how-do-rewards-work-reward-mechanism">4. How Do Rewards Work? (Reward Mechanism)</h2>
<p>This dimension is about understanding the nature of the feedback you&rsquo;re getting. Are the rewards predictable, changing over time, or actively working against you?</p>
<blockquote>
<p><strong>Stochastic (Stable but Random)</strong>
Each action corresponds to an IID (Independent and Identically Distributed) reward. The underlying mean rewards do not shift significantly over time.
<em>Example: Clinical trials. A drug has a true underlying effectiveness rate, but individual patients respond with random variation around that mean.</em></p></blockquote>
<blockquote>
<p><strong>Non-Stationary (Changing Over Time)</strong>
Reward distributions <em>do</em> shift over time, usually following some underlying rule. This is a realistic relaxation of the stochastic model, but comes at a learning cost.
<em>Example: Stock trading or ad performance. What worked last month might not work today.</em></p></blockquote>
<blockquote>
<p><strong>Adversarial (Actively Working Against You)</strong>
An adversary selects the worst-case rewards for your options, often with full knowledge of your learner&rsquo;s policy. Randomization is key to remaining unpredictable.
<em>Example: Cybersecurity defenses. Attackers actively adapt their strategies to exploit your algorithm.</em></p></blockquote>
<h2 id="5-how-much-do-you-learn-from-each-decision-feedback">5. How Much Do You Learn From Each Decision? (Feedback)</h2>
<p>The last dimension is about information flow. How much do you learn each time you make a choice?</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Feedback Type</th>
          <th style="text-align: left">What You Learn</th>
          <th style="text-align: left">Real-World Example</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Bandit</strong></td>
          <td style="text-align: left">You only observe the reward for the specific action selected. You have no knowledge of what could have been gained from other options.</td>
          <td style="text-align: left">A literal slot machine payout.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Semi-Bandit</strong></td>
          <td style="text-align: left">Common in combinatorial settings. You see the individual rewards associated with each <em>sub-action</em> you took.</td>
          <td style="text-align: left">Learning exactly which specific edges of a graph &ldquo;dropped&rdquo; or succeeded in a path-finding algorithm.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Full</strong></td>
          <td style="text-align: left">You see all reward signals for <em>every</em> action, including the ones you didn&rsquo;t take.</td>
          <td style="text-align: left">Analyzing historical stock market data where you can see all alternative outcomes.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Partial Monitoring</strong></td>
          <td style="text-align: left">Feedback is <em>not</em> received every round. You are occasionally flying blind.</td>
          <td style="text-align: left"><em>(Note: Because the core feedback loop is broken, this is generally not considered a true bandit problem!)</em></td>
      </tr>
  </tbody>
</table>
<h2 id="real-examples">Real Examples</h2>
<p>If you are trying to figure out which bandit algorithm to use for your project, it helps to map out your problem first. Here is how these five dimensions play out in three common real-world scenarios:</p>
<h3 id="1-e-commerce-recommendations">1. E-commerce Recommendations</h3>















<figure class="post-figure center ">
    <img src="/img/multi-arm-bandits/multi-arm-bandit-for-ecommerce.webp"
         alt="Illustration of a multi-arm bandit algorithm applied to e-commerce recommendations, showing a user interacting with a carousel of product recommendations on a website"
         title="Illustration of a multi-arm bandit algorithm applied to e-commerce recommendations, showing a user interacting with a carousel of product recommendations on a website"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The algorithm must populate an entire &lsquo;Recommended for You&rsquo; carousel with multiple items simultaneously.</figcaption>
    
</figure>

<ul>
<li><strong>Action Space:</strong> Combinatorial <em>(Pick multiple products to show at once)</em></li>
<li><strong>Problem Structure:</strong> Structured <em>(Similar products perform similarly)</em></li>
<li><strong>Context:</strong> Available <em>(User history, time of day, device type)</em></li>
<li><strong>Reward Mechanism:</strong> Stochastic <em>(Relatively stable underlying user preferences)</em></li>
<li><strong>Feedback:</strong> Bandit <em>(You only see clicks on the specific items you showed)</em></li>
</ul>
<h3 id="2-online-ad-bidding">2. Online Ad Bidding</h3>















<figure class="post-figure center ">
    <img src="/img/multi-arm-bandits/multi-arm-bandit-for-ad-bidding.webp"
         alt="Illustration of a multi-arm bandit algorithm applied to online ad bidding, showing a marketer adjusting bid prices in real-time auctions for ad placements"
         title="Illustration of a multi-arm bandit algorithm applied to online ad bidding, showing a marketer adjusting bid prices in real-time auctions for ad placements"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The algorithm must learn the optimal price to bid in real-time auctions.</figcaption>
    
</figure>

<ul>
<li><strong>Action Space:</strong> Infinite / Continuous <em>(Bid any amount from 0.01 to 10.00)</em></li>
<li><strong>Problem Structure:</strong> Structured <em>(Similar bid prices yield similar win rates)</em></li>
<li><strong>Context:</strong> Available <em>(User demographics, search terms, ad relevance)</em></li>
<li><strong>Reward Mechanism:</strong> Non-stationary <em>(Market conditions and competitor budgets change constantly)</em></li>
<li><strong>Feedback:</strong> Bandit <em>(You only see the results from your winning bids)</em></li>
</ul>
<h3 id="3-content-personalization">3. Content Personalization</h3>
<p><em>A media site dynamically selects articles or videos based on trending topics and user habits to create a personalized homepage.</em></p>
<ul>
<li><strong>Action Space:</strong> Combinatorial <em>(Select a layout of multiple articles/videos)</em></li>
<li><strong>Problem Structure:</strong> Structured <em>(Content categories and tags have predictable patterns)</em></li>
<li><strong>Context:</strong> Available <em>(User profile, current geographic trends, browsing history)</em></li>
<li><strong>Reward Mechanism:</strong> Non-stationary <em>(User interests and news cycles evolve over time)</em></li>
<li><strong>Feedback:</strong> Partial / Semi-Bandit <em>(You see the performance of the individual content pieces within the selected layout)</em></li>
</ul>
<h2 id="the-algorithm-selection-cheat-sheet">The Algorithm Selection Cheat Sheet</h2>
<p>If you prefer a text breakdown, here is how those steps translate into algorithm choices:</p>
<p><strong>Step 1: What are my options? (Action Space)</strong></p>
<ul>
<li><strong>Few discrete choices?</strong> $\rightarrow$ Start with standard <strong>UCB</strong> (Upper Confidence Bound) or <strong>Thompson Sampling</strong>.</li>
<li><strong>Continuous parameters?</strong> $\rightarrow$ Look into <strong>Gaussian Process</strong> methods.</li>
</ul>
<p><strong>Step 2: Do my choices relate to each other? (Structure)</strong></p>
<ul>
<li><strong>Yes?</strong> $\rightarrow$ Use <strong>Linear Bandits</strong> or kernel methods to share information across arms.</li>
<li><strong>No?</strong> $\rightarrow$ Treat each option completely independently.</li>
</ul>
<p><strong>Step 3: What extra information do I have? (Context)</strong></p>
<ul>
<li><strong>Rich context available?</strong> $\rightarrow$ Use contextual bandits (LinUCB is the standard choice here).</li>
<li><strong>No context?</strong> $\rightarrow$ Stick with standard, context-free approaches.</li>
</ul>
<p><strong>Step 4: How do rewards behave? (Mechanism)</strong></p>
<ul>
<li><strong>Stable?</strong> $\rightarrow$ <strong>UCB</strong> and <strong>Thompson Sampling</strong> are your go-to choices.</li>
<li><strong>Changing over time?</strong> $\rightarrow$ Add forgetting factors (discounting) or use sliding-window variants of standard algorithms.</li>
<li><strong>Actively working against me?</strong> $\rightarrow$ You need adversarial approaches, most notably the <strong>EXP3</strong> algorithm.</li>
</ul>
<p><strong>Step 5: How much do I learn each time? (Feedback)</strong></p>
<ul>
<li><strong>Just my choice?</strong> $\rightarrow$ Standard bandit algorithms.</li>
<li><strong>Everything?</strong> $\rightarrow$ Online gradient descent or multiplicative weights.</li>
<li><strong>Something in between?</strong> $\rightarrow$ Look for specialized algorithms that exploit your specific combinatorial structure.</li>
</ul>
<p>This framework keeps the focus exactly where it needs to be: on what actually matters for the problem at hand.</p>
<h2 id="summary">Summary</h2>
<p>These five questions have helped me navigate bandit problems more systematically:</p>
<ol>
<li><strong>What can you do?</strong> (Action space: few vs many options, single vs multiple choices)</li>
<li><strong>Do choices relate?</strong> (Whether trying one option teaches you about others)</li>
<li><strong>What extra info do you have?</strong> (Context that might improve decisions)</li>
<li><strong>How do rewards work?</strong> (Stable, changing, or adversarial)</li>
<li><strong>How much do you learn?</strong> (Feedback from just your choice vs everything)</li>
</ol>
<p>This framework helps avoid getting overwhelmed by the academic literature and focuses attention on what matters for real problems. Structured problems let you learn faster by sharing information between options, while adversarial settings require completely different approaches.</p>
<h2 id="if-you-want-to-learn-more">If You Want to Learn More</h2>
<ul>
<li><strong><a href="https://tor-lattimore.com/downloads/book/book.pdf">Bandit Algorithms by Lattimore &amp; Szepesvári</a></strong> - The comprehensive textbook (free PDF)</li>
<li><strong><a href="https://arxiv.org/abs/1904.07272">Introduction to Multi-Armed Bandits</a></strong> - Aleksandrs Slivkins&rsquo; survey paper is more accessible</li>
</ul>
<h2 id="get-your-hands-dirty">Get Your Hands Dirty</h2>
<p>Want to see these dimensions in code? I&rsquo;ve implemented the foundational algorithms (Explore-Then-Commit and Follow-The-Leader) in Python. Check out the <strong><a href="https://github.com/hunter-heidenreich/Bandit-Algorithms">Bandit Algorithms Repository</a></strong> to run the simulations yourself.</p>
]]></content:encoded></item><item><title>A Guide to Neuroevolution: NEAT and HyperNEAT</title><link>https://hunterheidenreich.com/posts/neuroevolution-neat-and-hyperneat/</link><pubDate>Wed, 02 Jan 2019 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/neuroevolution-neat-and-hyperneat/</guid><description>Explore the evolution of neural network topologies with NEAT and how HyperNEAT scales this approach using geometric patterns and indirect encoding.</description><content:encoded><![CDATA[<h2 id="automating-neural-architecture-design">Automating Neural Architecture Design</h2>
<p>Designing neural network architectures is typically a manual, iterative process. Researchers experiment with different layer configurations, activation functions, and connection patterns, often guided by intuition and empirical results. Evolution offers an automated alternative to this design process.</p>
<p><a href="https://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf">NEAT (NeuroEvolution of Augmenting Topologies)</a>, introduced in 2002, optimizes network weights and evolves the network structure itself, starting from minimal topologies and growing complexity only when beneficial.</p>
<p>NEAT&rsquo;s core innovations solved fundamental problems that had plagued earlier attempts at topology evolution. Its solutions for genetic encoding, structural crossover, and innovation protection remain influential today, especially as neural architecture search and automated ML gain prominence.</p>
<h2 id="the-core-challenges-of-neat">The Core Challenges of NEAT</h2>
<p>Evolving neural network topologies presents several fundamental challenges that NEAT elegantly addressed. Understanding these problems helps explain why NEAT&rsquo;s solutions were so influential.</p>
<h3 id="genetic-encoding-how-to-represent-networks">Genetic Encoding: How to Represent Networks</h3>
<p>Evolutionary algorithms require a genetic representation, a way to encode individuals that enables meaningful selection, mutation, and crossover. For neural networks, this choice is critical.</p>
<p><strong>Direct encoding</strong> explicitly represents each network component. Genes directly correspond to nodes and connections. This approach is intuitive and readable, and it works well for smaller networks.</p>
<p><strong>Indirect encoding</strong> specifies construction rules or processes. These encodings are more compact and can generate highly complex structures from simple rules.</p>
<p>NEAT chose direct encoding with a simple two-part structure: separate gene lists for nodes and connections. This balances simplicity with the flexibility needed for evolutionary operations.</p>















<figure class="post-figure center ">
    <img src="/img/neat_genomes.webp"
         alt="NEAT genome encoding showing node genes and connection genes with innovation numbers"
         title="NEAT genome encoding showing node genes and connection genes with innovation numbers"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">NEAT&rsquo;s direct encoding: node genes (top) and connection genes (bottom) with historical markings</figcaption>
    
</figure>

<p>Connection genes specify the source and target nodes, weight, enabled status, and an innovation number for historical tracking. Input and output nodes are fixed; only hidden nodes evolve.</p>
<h3 id="structural-mutations-growing-complexity">Structural Mutations: Growing Complexity</h3>
<p>NEAT employs two categories of mutations to evolve both weights and structure:</p>
<p><strong>Weight mutations</strong> adjust existing connection strengths using standard perturbation methods, the familiar approach from traditional neuroevolution.</p>
<p><strong>Structural mutations</strong> add new network components:</p>
<ul>
<li><strong>Add connection</strong>: Creates a new link between existing nodes with a random initial weight</li>
<li><strong>Add node</strong>: Splits an existing connection by inserting a new node. The original connection is disabled, while two new connections replace it. One inherits the original weight, the other starts at 1.0</li>
</ul>
<p>This node-splitting approach minimizes disruption. The new node initially acts as an identity function, giving it time to prove useful before natural selection pressure intensifies.</p>
<h3 id="solving-the-competing-conventions-problem">Solving the Competing Conventions Problem</h3>
<p>Performing crossover between networks with different structures presents a fundamental challenge. Consider two networks that solve the same problem using different internal organizations. Naive crossover between them typically produces broken offspring.</p>















<figure class="post-figure center ">
    <img src="/img/competing_conventions.webp"
         alt="Two neural networks performing the same function but with different internal structures"
         title="Two neural networks performing the same function but with different internal structures"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The competing conventions problem: identical functions, different implementations</figcaption>
    
</figure>

<p>NEAT&rsquo;s solution draws inspiration from biology through <strong>historical markings</strong>. Each structural innovation (adding a node or connection) receives a unique innovation number, a timestamp of when that change first appeared in the population.</p>
<p>During crossover, genes with matching innovation numbers are aligned and combined. This biological concept of homology enables meaningful recombination between networks of different sizes and structures.</p>















<figure class="post-figure center ">
    <img src="/img/neat_crossover.webp"
         alt="Diagram showing how NEAT aligns genes during crossover using innovation numbers"
         title="Diagram showing how NEAT aligns genes during crossover using innovation numbers"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">NEAT crossover using historical markings for gene alignment</figcaption>
    
</figure>

<h3 id="protecting-innovation-through-speciation">Protecting Innovation Through Speciation</h3>
<p>New structural innovations face a harsh reality: they usually perform worse initially. Adding nodes or connections typically decreases performance before optimization can improve the new structure. Without protection, these innovations disappear before realizing their potential.</p>
<p>NEAT addresses this through <strong>speciation</strong>: dividing the population into species based on structural and weight similarity. The historical markings that enable crossover also measure compatibility between individuals.</p>
<p>Crucially, individuals only compete within their species. This gives new structural innovations time to optimize without immediately competing against established, well-tuned networks.</p>
<p><strong>Explicit fitness sharing</strong> enhances this protection: species divide their collective fitness among members, preventing any single species from dominating the population while maintaining diversity for continued exploration.</p>
<h3 id="complexification-starting-minimal">Complexification: Starting Minimal</h3>
<p>NEAT begins with the simplest possible networks (just input and output nodes connected by random weights). No hidden layers exist initially. Complexity emerges only when mutations that add structure prove beneficial.</p>
<p>This complexification approach builds efficient solutions that solve problems with minimal structure. Combined with speciation, it tends to produce highly optimized architectures.</p>
<h2 id="scaling-up-hyperneat">Scaling Up: HyperNEAT</h2>
<p>NEAT evolved networks through direct encoding, where each gene explicitly specifies nodes and connections. Scaling this approach to larger architectures requires a fundamentally different method. Evolving networks with billions of connections like the brain requires indirect encoding.</p>
<p><a href="https://doi.org/10.1162/artl.2009.15.2.15202">HyperNEAT</a> introduces <strong>indirect encoding</strong> through geometric principles. HyperNEAT evolves geometric patterns that generate connections based on spatial relationships. This enables the evolution of large networks with biological regularities: symmetry, repetition, and locality.</p>
<p>The key insight is leveraging Compositional Pattern Producing Networks (CPPNs) to map coordinates to connection weights, exploiting the geometric organization found in natural neural networks.</p>
<h3 id="biological-motivation">Biological Motivation</h3>
<p>The human brain exhibits remarkable organizational principles:</p>
<ul>
<li><strong>Scale</strong>: ~86 billion neurons with ~100 trillion connections</li>
<li><strong>Repetition</strong>: Structural patterns reused across regions</li>
<li><strong>Symmetry</strong>: Mirrored structures like bilateral visual processing</li>
<li><strong>Locality</strong>: Spatial proximity influences connectivity and function</li>
</ul>
<p>HyperNEAT aims to evolve networks that capture these biological regularities, leading to more efficient and interpretable architectures.</p>
<h3 id="compositional-pattern-producing-networks">Compositional Pattern Producing Networks</h3>
<p>CPPNs are the foundation of HyperNEAT&rsquo;s indirect encoding. Think of them as pattern generators that create complex spatial structures from simple coordinate inputs.</p>
<p>DNA exemplifies indirect encoding (roughly 20,000 protein-coding genes specify a brain with trillions of connections). This massive compression ratio suggests that simple rules can generate complex structures through developmental processes.</p>
<p>CPPNs abstract this concept, using compositions of mathematical functions to create patterns in coordinate space. The same genetic program (function composition) can be reused across different locations and scales, just like how developmental genes control pattern formation throughout an organism.</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_cppns.webp"
         alt="Various symmetric and repetitive patterns created by CPPNs"
         title="Various symmetric and repetitive patterns created by CPPNs"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Complex patterns generated by CPPNs through function composition</figcaption>
    
</figure>

<h3 id="pattern-generation-through-function-composition">Pattern Generation Through Function Composition</h3>
<p>CPPNs generate patterns by composing simple mathematical functions. Key function types include:</p>
<ul>
<li><strong>Gaussian functions</strong>: Create symmetric patterns and gradients</li>
<li><strong>Trigonometric functions</strong>: Generate periodic/repetitive structures</li>
<li><strong>Linear functions</strong>: Produce gradients and asymmetric patterns</li>
<li><strong>Sigmoid functions</strong>: Create sharp transitions and boundaries</li>
</ul>
<p>By combining these functions, CPPNs can encode complex regularities that would require many explicit rules in direct encoding.</p>
<h3 id="evolution-of-cppns">Evolution of CPPNs</h3>
<p>HyperNEAT uses NEAT to evolve the CPPN structure. This brings several advantages:</p>
<ul>
<li><strong>Complexification</strong>: CPPNs start simple and grow more complex only when beneficial</li>
<li><strong>Historical markings</strong>: Enable proper crossover between different CPPN topologies</li>
<li><strong>Speciation</strong>: Protects innovative CPPN patterns during evolution</li>
</ul>
<p>Additional activation functions beyond standard neural networks are crucial:</p>
<ul>
<li>Gaussian functions for symmetry</li>
<li>Sine/cosine for repetition</li>
<li>Specialized functions for specific geometric patterns</li>
</ul>
<h2 id="the-hyperneat-process">The HyperNEAT Process</h2>
<h3 id="substrates-geometric-organization">Substrates: Geometric Organization</h3>
<p>A <strong>substrate</strong> defines the spatial arrangement of neurons. Substrates embed neurons in geometric space (2D grids, 3D volumes, etc.), providing an alternative to layer-based connectivity rules.</p>
<p>The CPPN maps from coordinates to connection weights:</p>
<p>$$\text{CPPN}(x_1, y_1, x_2, y_2) = w$$</p>
<p>Where $(x_1, y_1)$ and $(x_2, y_2)$ are the coordinates of two neurons, and $w$ determines their connection weight.</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_cppn_basics.webp"
         alt="Diagram showing CPPN taking four coordinate inputs and outputting connection weight"
         title="Diagram showing CPPN taking four coordinate inputs and outputting connection weight"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Basic CPPN architecture mapping coordinates to connection weights</figcaption>
    
</figure>

<p>This geometric approach enables several key properties:</p>
<ul>
<li><strong>Locality</strong>: Nearby neurons tend to have similar connectivity patterns</li>
<li><strong>Symmetry</strong>: Patterns can be mirrored across spatial axes</li>
<li><strong>Repetition</strong>: Periodic functions create repeating motifs</li>
<li><strong>Scalability</strong>: The same pattern can be applied at different resolutions</li>
</ul>
<h3 id="emergent-regularities">Emergent Regularities</h3>
<p>The geometric encoding naturally produces the desired biological patterns:</p>
<p><strong>Symmetry</strong> emerges from symmetric functions. A Gaussian centered at the origin creates identical patterns when $(x_1, y_1)$ and $(x_2, y_2)$ are equidistant from the center.</p>
<p><strong>Repetition</strong> arises from periodic functions like sine and cosine. These create repeating connectivity motifs across the substrate.</p>
<p><strong>Locality</strong> results from functions that vary smoothly across space. Nearby coordinates produce similar outputs, leading to local connectivity patterns.</p>
<p><strong>Imperfect regularity</strong> occurs when these patterns are modulated by additional coordinate dependencies, creating biological-like variation within the basic structure.</p>
<h3 id="substrate-configurations">Substrate Configurations</h3>
<p>The choice of substrate geometry critically influences network behavior. Several standard configurations exist:</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_substrate_configurations.webp"
         alt="Various substrate layouts including grids, 3D arrangements, and circular patterns"
         title="Various substrate layouts including grids, 3D arrangements, and circular patterns"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Common substrate geometries for different problem types</figcaption>
    
</figure>

<p><strong>2D Grid</strong>: Simple planar arrangement, CPPN takes four coordinates $(x_1, y_1, x_2, y_2)$</p>
<p><strong>3D Volume</strong>: Extends to three dimensions, CPPN becomes six-dimensional $(x_1, y_1, z_1, x_2, y_2, z_2)$</p>
<p><strong>Sandwich</strong>: Input layer connects only to output layer, useful for sensory-motor tasks</p>
<p><strong>Circular</strong>: Radial geometry enables rotation-invariant patterns and cyclic behaviors</p>
<p>The substrate must be chosen before evolution begins, making domain knowledge important for success.</p>
<h3 id="exploiting-input-output-geometry">Exploiting Input-Output Geometry</h3>
<p>HyperNEAT exploits the spatial organization of inputs and outputs. For visual tasks, pixel coordinates provide meaningful geometric information. For control problems, sensor and actuator layouts can guide connectivity patterns.</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_inputs_outputs.webp"
         alt="Visual representation of how HyperNEAT maps spatial input arrangements to output patterns"
         title="Visual representation of how HyperNEAT maps spatial input arrangements to output patterns"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Spatial organization of inputs and outputs enables geometric exploitation</figcaption>
    
</figure>

<p>This spatial awareness allows HyperNEAT to:</p>
<ul>
<li>Develop receptive fields similar to biological vision systems</li>
<li>Create locally connected patterns for spatial processing</li>
<li>Generate symmetric motor control patterns</li>
<li>Scale across different input resolutions</li>
</ul>
<h3 id="resolution-independence">Resolution Independence</h3>
<p>A unique advantage of HyperNEAT is <strong>substrate resolution independence</strong>. Networks evolved on low-resolution substrates can be deployed on higher-resolution versions without retraining. The CPPN&rsquo;s coordinate-based mapping scales naturally across different granularities.</p>
<p>This property suggests that evolved patterns capture fundamental spatial relationships, providing a key insight for scalable neural architecture design.</p>
<h2 id="impact-and-future-directions">Impact and Future Directions</h2>
<p>NEAT and HyperNEAT demonstrated that evolution could design neural network topologies and scale them through indirect encoding. The algorithms&rsquo; key insights, exploiting geometry, generating patterns through function composition, and scaling across resolutions, continue to influence modern research.</p>
<p>Extensions like ES-HyperNEAT add even more sophisticated capabilities by evolving the substrate itself. As neural architecture search becomes increasingly important, these principles find new applications in hybrid approaches that combine evolutionary pattern generation with gradient-based optimization.</p>
<p>The emphasis on spatial organization and regularity also connects to contemporary work on geometric deep learning and equivariant networks, suggesting that evolution and hand-design converge on similar organizing principles for building structured, efficient neural architectures.</p>
]]></content:encoded></item><item><title>Cartesian Genetic Programming in Julia</title><link>https://hunterheidenreich.com/projects/cgp-julia/</link><pubDate>Sun, 18 Nov 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/projects/cgp-julia/</guid><description>A fork of Dennis Wilson's CGP.jl applying Cartesian Genetic Programming to Atari RL tasks; my work was the Atari experiments, not the core framework.</description><content:encoded><![CDATA[<p>Written in 2018, this was an exploration into <strong>Evolutionary Algorithms</strong> applied to Reinforcement Learning tasks (specifically Atari games). It is a fork of <a href="https://github.com/d9w/CGP.jl">d9w/CGP.jl</a> (Dennis Wilson, Apache 2.0); my work centered on the Atari reinforcement-learning experiments rather than the core CGP framework.</p>
<h2 id="overview">Overview</h2>
<p>Standard Cartesian Genetic Programming (CGP) relies heavily on mutation. The upstream library hybridizes CGP with <strong>NEAT (NeuroEvolution of Augmenting Topologies)</strong> concepts to protect topological innovation through speciation.</p>
<p>My goal in forking it was to evolve graph-based programs that could learn Atari control policies using gradient-free optimization.</p>
<h2 id="features">Features</h2>
<p>The upstream framework provides the CGP machinery this project builds on:</p>
<ul>
<li><strong>Graph-based Crossover:</strong> Crossover operators such as <code>subgraph_crossover</code> and <code>aligned_node_crossover</code> that handle the destructive nature of mating graph structures.</li>
<li><strong>Speciation:</strong> A NEAT-inspired compatibility-distance metric (<code>cgpneat.jl</code>) to maintain population diversity and prevent premature convergence.</li>
<li><strong>Active Gene Tracking:</strong> Differentiates between &ldquo;active&rdquo; nodes (those contributing to output) and &ldquo;junk DNA,&rdquo; focusing mutation on phenotypic changes.</li>
</ul>
<p>My own contribution was the <strong>Atari reinforcement-learning layer</strong> on top of this: experiment variants (<code>action_atari.jl</code>, <code>original_atari.jl</code>, <code>manual_atari.jl</code>, <code>play_atari.jl</code>, <code>param_sweep.jl</code>), custom fitness and scoring functions, early-stopping and completion-percentage logging, multithreading and <code>pmap</code> multiprocessing attempts (reverted to single-thread), and config tuning to match a reference paper&rsquo;s hyperparameters.</p>
<h2 id="usage">Usage</h2>
<p>The library provides a Julia API for defining CGP graphs, configuring evolutionary parameters, and running the evolutionary loop against custom environments.</p>
<h2 id="results">Results</h2>
<p>Looking back, this codebase captures a transitional moment where I was moving from scripting to library design.</p>
<ul>
<li><strong>The Ambition:</strong> Getting CGP graphs to learn Atari policies under the mixed-type regime (RGB-array inputs, scalar action outputs) was an ambitious undertaking for my software engineering skills at the time.</li>
<li><strong>The &ldquo;Legacy&rdquo; Code:</strong> The project relies on the now-deprecated Julia v0.6 and uses <code>eval(parse(...))</code> patterns for configuration (a significant performance anti-pattern in modern Julia).</li>
<li><strong>The Lesson:</strong> It taught me the difficulty of designing genetic operators that respect topological constraints, a lesson that informs my current understanding of optimization in structured spaces.</li>
</ul>
]]></content:encoded></item><item><title>Understanding GANs: From Fundamentals to Objective Functions</title><link>https://hunterheidenreich.com/posts/what-is-a-gan/</link><pubDate>Sat, 18 Aug 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/what-is-a-gan/</guid><description>A complete guide to Generative Adversarial Networks (GANs), covering intuitive explanations, mathematical foundations, and objective functions.</description><content:encoded><![CDATA[<h2 id="understanding-generative-models">Understanding Generative Models</h2>
<p>Modern generative AI is dominated by diffusion models and autoregressive transformers. The adversarial training dynamics and objective functions introduced by <a href="https://arxiv.org/abs/1406.2661">Generative Adversarial Networks</a> (GANs) still inform how the field thinks about loss-function design and training stability today. Before diving into GANs, let&rsquo;s establish what we&rsquo;re trying to accomplish with generative models.</p>
<p><strong>The core goal</strong>: Create a system that can generate new, realistic data that appears to come from the same distribution as our training data.</p>
<p>Think of having a model that can create images, text, or audio that are difficult to distinguish from human-created content. This is what generative modeling aims to achieve.</p>
<h3 id="the-mathematical-foundation">The Mathematical Foundation</h3>
<p>Generative models aim to estimate the probability distribution of real data. If we have parameters $\theta$, we want to find the optimal $\theta^*$ that maximizes the likelihood of observing our real samples:</p>
<p>$$
\theta^* = \arg\max_\theta \prod_{i=1}^{n} p_\theta(x_i)
$$</p>
<p>This is equivalent to minimizing the distance between our estimated distribution and the true data distribution. A common distance measure is the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">Kullback-Leibler Divergence</a>. Maximizing log-likelihood equals minimizing KL divergence.</p>
<h3 id="two-approaches-to-generative-modeling">Two Approaches to Generative Modeling</h3>
<h4 id="explicit-distribution-models">Explicit Distribution Models</h4>
<p>These models define an explicit probability distribution and refine it through training.</p>
<p><strong>Example</strong>: <a href="https://arxiv.org/abs/1606.05908">Variational Auto-Encoders</a> (VAEs) require:</p>
<ul>
<li>An explicitly assumed prior distribution</li>
<li>A likelihood distribution</li>
<li>A &ldquo;variational approximation&rdquo; to evaluate performance</li>
</ul>
<h4 id="implicit-distribution-models">Implicit Distribution Models</h4>
<p>These models learn to generate data by indirectly sampling from a learned distribution. GANs exemplify this implicit approach, learning distributions through adversarial competition.</p>















<figure class="post-figure center ">
    <img src="/img/gen_ai_types.webp"
         alt="Types of deep generative models showing taxonomy"
         title="Types of deep generative models showing taxonomy"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>Taxonomy of Deep Generative Models</strong>: GANs fall into the implicit density category, learning distributions through adversarial training. <em>Source: NeurIPS 2016 tutorial on Generative Adversarial Networks</em></figcaption>
    
</figure>

<h2 id="the-gan-architecture-a-game-of-deception">The GAN Architecture: A Game of Deception</h2>
<p>Generative Adversarial Networks get their name from three key components:</p>
<ul>
<li><strong>Generative</strong>: They create new data</li>
<li><strong>Adversarial</strong>: Two networks compete against each other</li>
<li><strong>Networks</strong>: Built using neural networks</li>
</ul>
<p>The core innovation is the adversarial setup: two neural networks compete against each other, driving mutual improvement.</p>















<figure class="post-figure center ">
    <img src="/img/GAN-70.webp"
         alt="Diagram showing data flow through a GAN architecture"
         title="Diagram showing data flow through a GAN architecture"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>GAN Data Flow</strong>: The generator creates fake samples from random noise, while the discriminator tries to distinguish real from fake data. This adversarial competition drives both networks to improve.</figcaption>
    
</figure>

<h3 id="the-generator-the-forger">The Generator: The Forger</h3>
<p><strong>Role</strong>: Create convincing fake data from random noise</p>
<p>The generator network $G$ learns a mapping function:
$$z \rightarrow G(z) \approx x_{\text{real}}$$</p>
<p>Where:</p>
<ul>
<li>$z$ is a random latent vector (the &ldquo;noise&rdquo;)</li>
<li>$G(z)$ is the generated sample</li>
<li>The goal is making $G(z)$ indistinguishable from real data</li>
</ul>
<p><strong>Key insight</strong>: The latent space $z$ is continuous, meaning small changes in $z$ produce smooth, meaningful changes in the generated output.</p>
<h3 id="the-discriminator-the-detective">The Discriminator: The Detective</h3>
<p><strong>Role</strong>: Distinguish between real and generated samples</p>
<p>The discriminator network $D$ outputs a probability:
$$D(x) = P(\text{x is real})$$</p>
<ul>
<li>$D(x) \approx 1$ for real samples</li>
<li>$D(x) \approx 0$ for fake samples</li>
</ul>
<p>It functions as an &ldquo;authenticity detector&rdquo; that progressively improves.</p>
<h3 id="the-adversarial-competition">The Adversarial Competition</h3>
<p>This adversarial dynamic drives the training process. The generator and discriminator have <strong>directly opposing objectives</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Generator Goal</th>
          <th>Discriminator Goal</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Fool the discriminator</td>
          <td>Correctly classify all samples</td>
      </tr>
      <tr>
          <td>Minimize $D(G(z))$</td>
          <td>Maximize $D(x_{\text{real}})$ and minimize $D(G(z))$</td>
      </tr>
      <tr>
          <td>&ldquo;Create convincing fakes&rdquo;</td>
          <td>&ldquo;Never be fooled&rdquo;</td>
      </tr>
  </tbody>
</table>
<p>This creates a dynamic where both networks continuously improve:</p>
<ul>
<li>Generator creates better fakes to fool the discriminator</li>
<li>Discriminator becomes better at detecting fakes</li>
<li>The cycle continues until equilibrium</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/GAN-SUMMARY-50.webp"
         alt="Illustration of GAN training process showing adversarial competition"
         title="Illustration of GAN training process showing adversarial competition"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>The Adversarial Training Process</strong>: Through competition, both networks improve. The generator learns to create increasingly realistic samples while the discriminator becomes more discerning.</figcaption>
    
</figure>

<h2 id="learning-through-metaphors">Learning Through Metaphors</h2>
<p>Relatable analogies often clarify complex concepts. Here are two metaphors that capture different aspects of how GANs work.</p>
<h3 id="the-art-forger-vs-critic">The Art Forger vs. Critic</h3>
<p><strong>Generator = Art Forger</strong><br>
<strong>Discriminator = Art Critic</strong></p>
<p>A criminal forger tries to create fake masterpieces, while an art critic must identify authentic works. Each interaction teaches both parties:</p>
<ul>
<li>The forger learns what makes art look authentic</li>
<li>The critic develops a keener eye for detecting fakes</li>
<li>Eventually, the forger becomes so skilled that even experts can&rsquo;t tell the difference</li>
</ul>
<p><em>This captures the adversarial nature and continuous improvement aspect of GANs.</em></p>
<h3 id="the-counterfeiter-vs-bank-teller">The Counterfeiter vs. Bank Teller</h3>
<p><strong>Generator = Counterfeiter</strong><br>
<strong>Discriminator = Bank Teller</strong></p>
<p>Day 1: Criminal brings a crayon drawing of a dollar bill. Even a new teller spots this fake.</p>
<p>Day 100: The counterfeiter has learned better techniques. The teller has developed expertise in security features.</p>
<p>Day 1000: The fake money is so convincing that detecting it requires advanced equipment.</p>
<p><em>This illustrates the progressive improvement and escalating sophistication in both networks.</em></p>
<h2 id="the-mathematical-foundation-1">The Mathematical Foundation</h2>
<p>Now let&rsquo;s examine the mathematical framework that makes GANs work. The core of GAN training is solving a <strong>minimax optimization problem</strong>.</p>
<h3 id="the-minimax-objective">The Minimax Objective</h3>
<p>$$
\min_{G} \max_{D} V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]
$$</p>
<p><strong>Breaking this down:</strong></p>
<ul>
<li>$\mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)]$: The expected log-probability for real data.
<ul>
<li><strong>Discriminator&rsquo;s Goal</strong>: Maximize this term to correctly classify real samples.</li>
</ul>
</li>
<li>$\mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]$: The expected log-probability for fake data being correctly identified as fake.
<ul>
<li><strong>Discriminator&rsquo;s Goal</strong>: Maximize this term.</li>
<li><strong>Generator&rsquo;s Goal</strong>: Minimize this term to fool the discriminator.</li>
</ul>
</li>
</ul>
<h3 id="why-minimax">Why &ldquo;Minimax&rdquo;?</h3>
<ul>
<li><strong>Discriminator ($D$)</strong>: Tries to <strong>maximize</strong> the objective → Better at distinguishing real from fake.</li>
<li><strong>Generator ($G$)</strong>: Tries to <strong>minimize</strong> the objective → Better at fooling the discriminator.</li>
</ul>
<h3 id="a-practical-challenge-vanishing-gradients">A Practical Challenge: Vanishing Gradients</h3>
<p>The minimax objective presents a practical problem early in training. When the generator is poor, the discriminator can easily distinguish real from fake samples with high confidence ($D(G(z)) \approx 0$). This causes $\log(1 - D(G(z)))$ to saturate and results in vanishing gradients for the generator, which effectively stalls learning.</p>
<p><strong>The Solution</strong>: Practitioners typically train the generator to <strong>maximize</strong> $\log(D(G(z)))$ to provide stronger gradients early in training. This non-saturating heuristic prevents the learning process from stalling.</p>
<h3 id="the-training-process">The Training Process</h3>
<p>The beauty of GANs lies in their alternating optimization:</p>
<ol>
<li><strong>Fix $G$, train $D$</strong>: Make the discriminator optimal for the current generator</li>
<li><strong>Fix $D$, train $G$</strong>: Improve the generator against the current discriminator</li>
<li><strong>Repeat</strong>: Continue until reaching Nash equilibrium</li>
</ol>
<h3 id="theoretical-goal-nash-equilibrium">Theoretical Goal: Nash Equilibrium</h3>
<p>At convergence, the discriminator outputs $D(x) = 0.5$ for all samples, meaning it can&rsquo;t distinguish between real and fake data. This indicates that $p_{\text{generator}} = p_{\text{data}}$. Our generator has learned the true data distribution.</p>
<h2 id="the-evolution-of-objective-functions">The Evolution of Objective Functions</h2>
<p>The objective function is the mathematical heart of any GAN. It defines how we measure the &ldquo;distance&rdquo; between our generated distribution and the real data distribution. This choice profoundly impacts:</p>
<ul>
<li><strong>Training stability</strong>: Some objectives lead to more stable convergence</li>
<li><strong>Sample quality</strong>: Different losses emphasize different aspects of realism</li>
<li><strong>Mode collapse</strong>: The tendency to generate limited variety</li>
<li><strong>Computational efficiency</strong>: Some objectives are faster to compute</li>
</ul>
<p>The original GAN uses Jensen-Shannon Divergence (JSD), but researchers have discovered many alternatives that address specific limitations. Let&rsquo;s explore this evolution.</p>
<h3 id="the-original-gan-jensen-shannon-divergence">The Original GAN: Jensen-Shannon Divergence</h3>
<p>The foundational GAN minimizes the Jensen-Shannon Divergence:</p>
<p>$$
\text{JSD}(P, Q) = \frac{1}{2} \text{KL}(P | M) + \frac{1}{2} \text{KL}(Q | M)
$$</p>
<p>Where $M = \frac{1}{2}(P + Q)$ is the average distribution, and $\text{KL}$ is the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">Kullback-Leibler Divergence</a>.</p>
<p><strong>Strengths</strong>: Solid theoretical foundation, introduced adversarial training<br>
<strong>Limitations</strong>: Can suffer from vanishing gradients and mode collapse</p>
<h3 id="wasserstein-gan-wgan">Wasserstein GAN (WGAN)</h3>
<p>The <a href="https://arxiv.org/abs/1701.07875">Wasserstein GAN</a> replaced Jensen-Shannon divergence with the Earth-Mover (Wasserstein) distance, which gives meaningful gradients even when the real and generated distributions do not overlap.</p>
<h4 id="understanding-earth-mover-distance">Understanding Earth-Mover Distance</h4>
<p>The Wasserstein distance, also known as Earth-Mover distance, has an intuitive interpretation:</p>
<blockquote>
<p><strong>Imagine two probability distributions as piles of dirt.</strong> The Earth-Mover distance measures the minimum cost to transform one pile into the other, where cost = mass x distance moved.</p></blockquote>
<p>Mathematically:</p>
<p>$$
W_p(\mu, \nu) = \left( \inf_{\gamma \in \Gamma(\mu, \nu)} \int_{M xM} d(x, y)^p , d\gamma(x, y) \right)^{1/p}
$$</p>
<h4 id="why-earth-mover-distance-matters">Why Earth-Mover Distance Matters</h4>
<table>
  <thead>
      <tr>
          <th>Jensen-Shannon Divergence</th>
          <th>Earth-Mover Distance</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Can be discontinuous</td>
          <td><strong>Always continuous</strong></td>
      </tr>
      <tr>
          <td>May have vanishing gradients</td>
          <td><strong>Meaningful gradients everywhere</strong></td>
      </tr>
      <tr>
          <td>Limited convergence guarantees</td>
          <td><strong>Broader convergence properties</strong></td>
      </tr>
  </tbody>
</table>
<h4 id="wgan-implementation">WGAN Implementation</h4>
<p>Since we can&rsquo;t compute Wasserstein distance directly, WGAN uses the <strong>Kantorovich-Rubinstein duality</strong>:</p>
<ol>
<li><strong>Train a critic function</strong> $f$ to approximate the Wasserstein distance</li>
<li><strong>Constrain the critic</strong> to be 1-Lipschitz (using weight clipping)</li>
<li><strong>Optimize the generator</strong> to minimize this distance</li>
</ol>















<figure class="post-figure center ">
    <img src="/img/wasserstein.webp"
         alt="WGAN training results showing stable convergence"
         title="WGAN training results showing stable convergence"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>WGAN Results</strong>: Demonstrating improved training stability and meaningful loss curves. <em>Source: Wasserstein GAN paper</em></figcaption>
    
</figure>

<h4 id="key-wgan-benefits">Key WGAN Benefits</h4>
<p><strong>Meaningful loss function</strong>: Loss correlates with sample quality<br>
<strong>Improved stability</strong>: Less prone to mode collapse<br>
<strong>Theoretical guarantees</strong>: Solid mathematical foundation<br>
<strong>Better convergence</strong>: Works even when distributions don&rsquo;t overlap</p>
<h3 id="improved-wgan-solving-the-weight-clipping-problem">Improved WGAN: Solving the Weight Clipping Problem</h3>
<p><a href="https://arxiv.org/abs/1704.00028">Improved WGAN</a> (WGAN-GP) addresses a critical flaw in the original WGAN: <strong>weight clipping</strong>.</p>
<h4 id="the-problem-with-weight-clipping">The Problem with Weight Clipping</h4>
<p>Original WGAN clips weights to maintain the 1-Lipschitz constraint:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Problematic approach</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> param <span style="color:#f92672">in</span> critic<span style="color:#f92672">.</span>parameters():
</span></span><span style="display:flex;"><span>    param<span style="color:#f92672">.</span>data<span style="color:#f92672">.</span>clamp_(<span style="color:#f92672">-</span><span style="color:#ae81ff">0.01</span>, <span style="color:#ae81ff">0.01</span>)
</span></span></code></pre></div><p><strong>Issues with clipping</strong>:</p>
<ul>
<li>Forces critic to use extremely simple functions</li>
<li>Pushes weights toward extreme values ($\pm c$)</li>
<li>Can lead to poor gradient flow</li>
<li>Capacity limitations hurt performance</li>
</ul>
<h4 id="the-gradient-penalty-solution">The Gradient Penalty Solution</h4>
<p>WGAN-GP introduces a <strong>gradient penalty term</strong> to constrain the critic:</p>
<p>$$
L = E_{\tilde{x} \sim P_g}[D(\tilde{x})] - E_{x \sim P_r}[D(x)] + \lambda E_{\hat{x}}[(||\nabla_{\hat{x}} D(\hat{x})||_2 - 1)^2]
$$</p>
<p>Where $\hat{x}$ are points sampled uniformly along straight lines between real and generated data points.</p>
<p><strong>Advantages</strong>:</p>
<ul>
<li>No capacity limitations</li>
<li>Better gradient flow</li>
<li>More stable training</li>
<li>Works across different architectures</li>
</ul>
<h3 id="lsgan-the-power-of-least-squares">LSGAN: The Power of Least Squares</h3>
<p><a href="https://arxiv.org/abs/1611.04076">Least Squares GAN</a> takes a different approach. It replaces the logarithmic loss with <strong>L2 (least squares) loss</strong>.</p>
<h4 id="motivation-beyond-binary-classification">Motivation: Beyond Binary Classification</h4>
<p>Traditional GANs use log loss, which focuses primarily on correct classification:</p>
<ul>
<li>Real sample correctly classified → minimal penalty</li>
<li>Fake sample correctly classified → minimal penalty</li>
<li>Distance from decision boundary ignored</li>
</ul>
<h4 id="l2-loss-distance-matters">L2 Loss: Distance Matters</h4>
<p>LSGAN uses L2 loss, which <strong>penalizes proportionally to distance</strong>:</p>
<p>$$
\min_D V_{LSGAN}(D) = \frac{1}{2}E_{x \sim p_{data}(x)}[(D(x) - b)^2] + \frac{1}{2}E_{z \sim p_z(z)}[(D(G(z)) - a)^2]
$$</p>
<p>$$
\min_G V_{LSGAN}(G) = \frac{1}{2}E_{z \sim p_z(z)}[(D(G(z)) - c)^2]
$$</p>
<p>Where typically: $a = 0$ (fake label), $b = c = 1$ (real label)</p>
<h4 id="benefits-of-l2-loss">Benefits of L2 Loss</h4>
<table>
  <thead>
      <tr>
          <th>Log Loss</th>
          <th>L2 Loss</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Binary focus</td>
          <td><strong>Distance-aware</strong></td>
      </tr>
      <tr>
          <td>Can saturate</td>
          <td><strong>Informative gradients</strong></td>
      </tr>
      <tr>
          <td>Sharp decision boundary</td>
          <td><strong>Smooth decision regions</strong></td>
      </tr>
  </tbody>
</table>















<figure class="post-figure center ">
    <img src="/img/lsgan-result.webp"
         alt="LSGAN generated samples showing improved quality"
         title="LSGAN generated samples showing improved quality"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>LSGAN Results</strong>: Demonstrating improved sample quality through distance-aware loss functions. <em>Source: LSGAN paper</em></figcaption>
    
</figure>

<p><strong>Key insight</strong>: LSGAN minimizes the Pearson χ² divergence, providing smoother optimization landscape than JSD.</p>
<h3 id="relaxed-wasserstein-gan-rwgan">Relaxed Wasserstein GAN (RWGAN)</h3>
<p><a href="https://arxiv.org/abs/1705.07164">Relaxed WGAN</a> bridges the gap between WGAN and WGAN-GP, proposing a <strong>general framework</strong> for designing GAN objectives.</p>
<h4 id="key-innovations">Key Innovations</h4>
<p><strong>Asymmetric weight clamping</strong>: RWGAN introduces an asymmetric approach that provides better balance.</p>
<p><strong>Relaxed Wasserstein divergences</strong>: A generalized framework that extends the Wasserstein distance, enabling systematic design of new GAN variants while maintaining theoretical guarantees.</p>
<h4 id="benefits">Benefits</h4>
<ul>
<li>Better convergence properties than standard WGAN</li>
<li>Framework for designing new loss functions and GAN architectures</li>
<li>Competitive performance with other Wasserstein-based methods</li>
</ul>
<p><strong>Key insight</strong>: RWGAN parameterized with KL divergence shows excellent performance while maintaining the theoretical foundations that make Wasserstein GANs attractive.</p>
<h3 id="statistical-distance-approaches">Statistical Distance Approaches</h3>
<p>Several GAN variants focus on minimizing specific statistical distances between distributions.</p>
<h4 id="mcgan-mean-and-covariance-matching">McGAN: Mean and Covariance Matching</h4>
<p><a href="https://arxiv.org/abs/1702.08398">McGAN</a> belongs to the Integral Probability Metric (IPM) family, using <strong>statistical moments</strong> as the distance measure.</p>
<p><strong>Approach</strong>: Match first and second-order statistics:</p>
<ul>
<li><strong>Mean matching</strong>: Align distribution centers</li>
<li><strong>Covariance matching</strong>: Align distribution shapes</li>
</ul>
<p>Moment-matching objectives like this are conceptually related to settings where aligning statistical moments matters, such as matching a generated distribution to a target physical distribution (e.g., molecular conformations). McGAN itself, however, was introduced and demonstrated as an IPM method for image generation.</p>
<p><strong>Limitation</strong>: Relies on weight clipping like original WGAN.</p>
<h4 id="gmmn-maximum-mean-discrepancy">GMMN: Maximum Mean Discrepancy</h4>
<p><a href="https://arxiv.org/abs/1502.02761">Generative Moment Matching Networks</a> eliminates the discriminator entirely, directly minimizing <strong>Maximum Mean Discrepancy (MMD)</strong>.</p>
<p><strong>MMD Intuition</strong>: Compare distributions by their means in a high-dimensional feature space:</p>
<p>$$
\text{MMD}^2(X, Y) = ||E[\phi(x)] - E[\phi(y)]||^2
$$</p>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Simple, discriminator-free training</li>
<li>Theoretical guarantees</li>
<li>Can incorporate autoencoders for better MMD estimation</li>
</ul>
<p><strong>Drawbacks</strong>:</p>
<ul>
<li>Computationally expensive</li>
<li>Often weaker empirical results</li>
</ul>
<h4 id="mmd-gan-learning-better-kernels">MMD GAN: Learning Better Kernels</h4>
<p><a href="https://arxiv.org/abs/1705.08584">MMD GAN</a> improves GMMN by <strong>learning optimal kernels</strong> adversarially to improve upon fixed Gaussian kernels.</p>
<p><strong>Innovation</strong>: Combine GAN adversarial training with MMD objective for the best of both worlds.</p>
<h3 id="different-distance-metrics">Different Distance Metrics</h3>
<h4 id="cramer-gan-addressing-sample-bias">Cramer GAN: Addressing Sample Bias</h4>
<p><a href="https://arxiv.org/abs/1705.10743">Cramer GAN</a> identifies a critical issue with WGAN: <strong>biased sample gradients</strong>.</p>
<p><strong>The Problem</strong>: WGAN&rsquo;s Wasserstein distance lacks three important properties:</p>
<ol>
<li><strong>Sum invariance</strong> (satisfied)</li>
<li><strong>Scale sensitivity</strong> (satisfied)</li>
<li><strong>Unbiased sample gradients</strong> (not satisfied)</li>
</ol>
<p><strong>The Solution</strong>: Use the <strong>Cramer distance</strong>, which satisfies all three properties:</p>
<p>$$
d_C^2(\mu, \nu) = \int ||E_{X \sim \mu}[X - x] - E_{Y \sim \nu}[Y - x]||^2 d\pi(x)
$$</p>
<p><strong>Benefit</strong>: More reliable gradients lead to better training dynamics.</p>
<h4 id="fisher-gan-chi-square-distance">Fisher GAN: Chi-Square Distance</h4>
<p><a href="https://arxiv.org/abs/1705.09675">Fisher GAN</a> uses a <strong>data-dependent constraint</strong> on the critic&rsquo;s second-order moments (variance).</p>
<p><strong>Key Innovation</strong>: The constraint naturally bounds the critic without manual techniques:</p>
<ul>
<li>No weight clipping needed</li>
<li>No gradient penalties required</li>
<li>Constraint emerges from the objective itself</li>
</ul>
<p><strong>Distance</strong>: Approximates the <strong>Chi-square distance</strong> as critic capacity increases:</p>
<p>$$
\chi^2(P, Q) = \int \frac{(P(x) - Q(x))^2}{Q(x)} dx
$$</p>
<p>The Fisher GAN essentially measures the Mahalanobis distance, which accounts for correlated variables relative to the distribution&rsquo;s centroid. This ensures the generator and critic remain bounded, and as the critic&rsquo;s capacity increases, it estimates the Chi-square distance.</p>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Efficient computation</li>
<li>Training stability</li>
<li>Unconstrained critic capacity</li>
</ul>
<h3 id="beyond-traditional-gans-alternative-approaches">Beyond Traditional GANs: Alternative Approaches</h3>
<p>The following variants explore fundamentally different architectures and training paradigms.</p>
<h4 id="ebgan-energy-based-discrimination">EBGAN: Energy-Based Discrimination</h4>
<p><a href="https://arxiv.org/abs/1609.03126">Energy-Based GAN</a> replaces the discriminator with an <strong>autoencoder</strong>.</p>
<p><strong>Key insight</strong>: Use reconstruction error as the discrimination signal:</p>
<ul>
<li>Good data → Low reconstruction error</li>
<li>Poor data → High reconstruction error</li>
</ul>
<p><strong>Architecture</strong>:</p>
<ol>
<li>Train autoencoder on real data</li>
<li>Generator creates samples</li>
<li>Poor generated samples have high reconstruction loss</li>
<li>This loss drives generator improvement</li>
</ol>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Fast and stable training</li>
<li>Robust to hyperparameter changes</li>
<li>No need to balance discriminator/generator</li>
</ul>
<h4 id="began-boundary-equilibrium">BEGAN: Boundary Equilibrium</h4>
<p><a href="https://arxiv.org/abs/1703.10717">BEGAN</a> combines EBGAN&rsquo;s autoencoder approach with WGAN-style loss functions.</p>
<p><strong>Innovation</strong>: Dynamic equilibrium parameter $k_t$ that balances:</p>
<ul>
<li>Real data reconstruction quality</li>
<li>Generated data reconstruction quality</li>
</ul>
<p><strong>Equilibrium equation</strong>:</p>
<p>$$
L_D = L(x) - k_t L(G(z))
$$</p>
<p>$$
k_{t+1} = k_t + \lambda(\gamma L(x) - L(G(z)))
$$</p>
<h4 id="magan-adaptive-margins">MAGAN: Adaptive Margins</h4>
<p><a href="https://arxiv.org/abs/1704.03817">MAGAN</a> improves EBGAN by making the margin in the hinge loss <strong>adaptive over time</strong>.</p>
<p><strong>Concept</strong>: Start with a large margin, gradually reduce it as training progresses:</p>
<ul>
<li>Early training: Focus on major differences</li>
<li>Later training: Fine-tune subtle details</li>
</ul>
<p><strong>Result</strong>: Better sample quality and training stability.</p>
<h2 id="summary-the-evolution-of-gan-objectives">Summary: The Evolution of GAN Objectives</h2>
<p>The evolution of GAN objective functions reflects the field&rsquo;s progression toward more stable and theoretically grounded training procedures. Each variant addresses specific limitations in earlier approaches.</p>
<h3 id="complete-reference-table">Complete Reference Table</h3>
<table>
  <thead>
      <tr>
          <th><strong>GAN Variant</strong></th>
          <th><strong>Key Innovation</strong></th>
          <th><strong>Main Benefit</strong></th>
          <th><strong>Limitation</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Original GAN</strong></td>
          <td>Jensen-Shannon divergence</td>
          <td>Foundation of adversarial training</td>
          <td>Vanishing gradients, mode collapse</td>
      </tr>
      <tr>
          <td><strong>WGAN</strong></td>
          <td>Earth-Mover distance</td>
          <td>Meaningful loss, better stability</td>
          <td>Weight clipping issues</td>
      </tr>
      <tr>
          <td><strong>WGAN-GP</strong></td>
          <td>Gradient penalty</td>
          <td>Solves weight clipping problems</td>
          <td>Additional hyperparameter tuning</td>
      </tr>
      <tr>
          <td><strong>LSGAN</strong></td>
          <td>Least squares loss</td>
          <td>Better gradients, less saturation</td>
          <td>May converge to non-optimal points</td>
      </tr>
      <tr>
          <td><strong>RWGAN</strong></td>
          <td>Relaxed Wasserstein framework</td>
          <td>General framework for new designs</td>
          <td>Complex theoretical setup</td>
      </tr>
      <tr>
          <td><strong>McGAN</strong></td>
          <td>Mean/covariance matching</td>
          <td>Simple statistical alignment</td>
          <td>Limited by weight clipping</td>
      </tr>
      <tr>
          <td><strong>GMMN</strong></td>
          <td>Maximum mean discrepancy</td>
          <td>No discriminator needed</td>
          <td>Computationally expensive</td>
      </tr>
      <tr>
          <td><strong>MMD GAN</strong></td>
          <td>Adversarial kernels for MMD</td>
          <td>Improved GMMN performance</td>
          <td>Still computationally heavy</td>
      </tr>
      <tr>
          <td><strong>Cramer GAN</strong></td>
          <td>Cramer distance</td>
          <td>Unbiased sample gradients</td>
          <td>Complex implementation</td>
      </tr>
      <tr>
          <td><strong>Fisher GAN</strong></td>
          <td>Chi-square distance</td>
          <td>Self-constraining critic</td>
          <td>Limited empirical validation</td>
      </tr>
      <tr>
          <td><strong>EBGAN</strong></td>
          <td>Autoencoder discriminator</td>
          <td>Fast, stable training</td>
          <td>Requires careful regularization</td>
      </tr>
      <tr>
          <td><strong>BEGAN</strong></td>
          <td>Boundary equilibrium</td>
          <td>Dynamic training balance</td>
          <td>Additional equilibrium parameter</td>
      </tr>
      <tr>
          <td><strong>MAGAN</strong></td>
          <td>Adaptive margin</td>
          <td>Progressive refinement</td>
          <td>Margin scheduling complexity</td>
      </tr>
  </tbody>
</table>
<h3 id="practical-recommendations">Practical Recommendations</h3>
<p>For practitioners, the choice depends on specific requirements and engineering tradeoffs:</p>
<ul>
<li><strong>WGAN-GP</strong>: Best balance of stability and performance for most applications. However, tuning the gradient penalty $\lambda$ can be sensitive in practice.</li>
<li><strong>LSGAN</strong>: Simpler implementation with good empirical results.</li>
<li><strong>EBGAN</strong>: Fast experimentation and prototyping.</li>
<li><strong>Original GAN</strong>: Educational purposes and understanding fundamentals.</li>
</ul>
<p><strong>Real-World Impact:</strong> In my work training VLMs on terabyte-scale multimodal data and forecasting chaotic physical systems, these foundational dynamics still matter. Most generation today runs on diffusion models or autoregressive transformers, but the loss-design and training-stability lessons that came out of GAN research carry over. The choice of objective function shapes generation quality, training stability, and compute cost.</p>
<hr>
<p><strong>Acknowledgments</strong>: This post was inspired by the excellent survey &ldquo;<a href="https://arxiv.org/abs/1711.05914">How Generative Adversarial Networks and Their Variants Work: An Overview of GAN</a>&rdquo;.</p>
]]></content:encoded></item><item><title>FFTW Compiler in Haskell</title><link>https://hunterheidenreich.com/projects/fftw-compiler-haskell/</link><pubDate>Thu, 15 Mar 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/projects/fftw-compiler-haskell/</guid><description>Reverse-engineering the genfft logic to generate optimized C kernels for Fast Fourier Transforms using Haskell metaprogramming.</description><content:encoded><![CDATA[<p>Written during my sophomore year, this project was an attempt to look inside the &ldquo;black box&rdquo; of one of the fastest Fourier transform libraries: <strong>FFTW</strong>.</p>
<h2 id="overview">Overview</h2>
<p>I sought to replicate the logic of FFTW&rsquo;s <code>genfft</code>: a metaprogram that generates straight-line, highly optimized C code. The goal was to understand how abstract algebra (group theory) could be translated into efficient machine code through symbolic manipulation.</p>
<h2 id="features">Features</h2>
<p>This was my first deep dive into <strong>functional metaprogramming</strong> and <strong>compiler theory</strong>:</p>
<ul>
<li><strong>Symbolic AST:</strong> Modeled mathematical operations as a Directed Acyclic Graph (DAG) in Haskell (<code>data Node</code>), separating the <em>definition</em> of the math from its <em>execution</em>.</li>
<li><strong>Algebraic Simplification:</strong> Implemented a symbolic optimization pass that pruned operations at compile-time (e.g., eliminating multiplications by $1$, $0$, or $-1$) before code generation.</li>
<li><strong>Monadic State Management:</strong> Used Haskell&rsquo;s <code>State</code> Monad to manage the graph construction and memoization, ensuring common subexpressions (like reusable cosine factors) were calculated only once.</li>
<li><strong>Code Generation:</strong> The system outputted unrolled, straight-line C code (e.g., <code>fftw4.c</code>), mimicking the &ldquo;codelets&rdquo; used by the actual FFTW library.</li>
</ul>
<h2 id="usage">Usage</h2>
<p>The compiler is run via the command line, taking the desired FFT size as input and outputting the optimized C code.</p>
<h2 id="results">Results</h2>
<p>Looking back, this project represents a pivotal moment where I moved from &ldquo;writing programs&rdquo; to &ldquo;writing tools that write programs.&rdquo;</p>
<ul>
<li><strong>The &ldquo;Magic&rdquo;:</strong> It demystified high-performance computing. I learned that speed often comes from unrolling recursion and managing register pressure at compile time alongside writing fast loops.</li>
<li><strong>The &ldquo;Rough Edges&rdquo;:</strong> The scheduler (coloring nodes Red/Blue for register allocation) was a heuristic approximation of the optimal Aho-Johnson-Ullman algorithm.</li>
<li><strong>Legacy:</strong> The core lesson that domain-specific compilers can outperform hand-tuned generic code remains relevant to my current work in optimizing scientific computing kernels.</li>
</ul>
]]></content:encoded></item><item><title>Term Schedule Optimizer</title><link>https://hunterheidenreich.com/projects/term-schedule-optimizer/</link><pubDate>Wed, 15 Feb 2017 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/projects/term-schedule-optimizer/</guid><description>A constraint satisfaction solver built to generate conflict-free university schedules from web-scraped course data.</description><content:encoded><![CDATA[<p>A Python-based automation tool I wrote as a freshman to solve the &ldquo;Term Master Schedule&rdquo; problem (and used throughout my undergrad from 2016 to 2020).</p>
<h2 id="overview">Overview</h2>
<p>Manually creating a university schedule involves solving a <strong>Constraint Satisfaction Problem (CSP)</strong> with multiple variables:</p>
<ul>
<li><strong>Hard Constraints:</strong> No time overlaps between classes.</li>
<li><strong>Soft Constraints:</strong> Preferences for &ldquo;no 8 AMs,&rdquo; specific lunch breaks, or maximizing free days.</li>
</ul>
<p>The naive approach (manually checking every possible combination) becomes intractable as the number of courses and sections grows.</p>
<h2 id="features">Features</h2>
<p>I built a script that:</p>
<ol>
<li><strong>Scraped Data:</strong> Parsed the Drexel WebTMS (Term Master Schedule) using <code>lxml</code> to build a localized dataset of course availability.</li>
<li><strong>Solved for X:</strong> Implemented a <strong>recursive backtracking algorithm</strong> to generate every valid schedule permutation that satisfied user-defined constraints.</li>
</ol>
<h3 id="the-algorithm">The Algorithm</h3>
<p>The core of this project is a <code>recursive_generator</code> function that implements a valid CSP solver using backtracking. It performs a recursive depth-first search that:</p>
<ol>
<li>Takes a set of variables (courses).</li>
<li>Checks constraints (time overlaps, lunch hours, max classes per day).</li>
<li>Backtracks when a branch fails.</li>
</ol>
<p>It is the same backtracking pattern used in everything from Sudoku solvers to compiler register allocation.</p>
<h2 id="usagegameplay">Usage/Gameplay</h2>
<p>The tool is run via the command line, taking a list of desired courses and outputting valid schedule combinations.</p>
<h2 id="results">Results</h2>
<p>This tool saved me (and several friends) hours of planning time each quarter. While the scraping logic was fragile (dependent on 2017 HTML structures), the core logic (a depth-first search through the state space of possible schedules) remains a fundamental algorithmic pattern.</p>
]]></content:encoded></item></channel></rss>