<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:content="http://purl.org/rss/1.0/modules/content/"><channel><title>Machine-Learning on Hunter Heidenreich | Senior AI Research Scientist</title><link>https://hunterheidenreich.com/tags/machine-learning/</link><description>Recent content in Machine-Learning on Hunter Heidenreich | Senior AI Research Scientist</description><image><title>Hunter Heidenreich | Senior AI Research Scientist</title><url>https://hunterheidenreich.com/img/avatar.webp</url><link>https://hunterheidenreich.com/img/avatar.webp</link></image><generator>Hugo -- 0.147.7</generator><language>en-US</language><copyright>2026 Hunter Heidenreich</copyright><lastBuildDate>Sat, 30 May 2026 00:00:00 +0000</lastBuildDate><atom:link href="https://hunterheidenreich.com/tags/machine-learning/index.xml" rel="self" type="application/rss+xml"/><item><title>MB-nrg: CCSD(T)-Accurate Potentials for Polyalanine</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/mb-nrg-polyalanine-ccsdt/</link><pubDate>Sun, 12 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/mb-nrg-polyalanine-ccsdt/</guid><description>MB-nrg decomposes polyalanine into n-mer building blocks fit to DLPNO-CCSD(T) references, reaching coupled-cluster accuracy for gas-phase peptide dynamics.</description><content:encoded><![CDATA[<h2 id="a-modular-mb-nrg-method-for-biomolecular-potentials">A Modular MB-nrg Method for Biomolecular Potentials</h2>
<p>This is a <strong>Method</strong> paper. Zhou and colleagues extend the MB-nrg (many-body energy) formalism to covalently bonded biomolecules and build the first coupled-cluster-accurate potential energy function (PEF) for polyalanine in the gas phase. The contribution has three parts: a generalization of the MB-nrg decomposition from whole-molecule 1-mers to functional-group &ldquo;natural building blocks,&rdquo; a DLPNO-CCSD(T)/aug-cc-pVTZ training protocol driven by parallel-bias metadynamics sampling, and a demonstration that the resulting PEF reproduces alanine dipeptide energetics and AceAla$_9$Nme secondary-structure dynamics more faithfully than the Amber ff14SB and ff19SB force fields.</p>
<h2 id="why-empirical-force-fields-fall-short-for-protein-dynamics">Why Empirical Force Fields Fall Short for Protein Dynamics</h2>
<p>Protein dynamics span femtosecond vibrations to millisecond conformational changes, and capturing them at atomic resolution is central to understanding catalysis, allostery, and ligand binding. Classical force fields such as CHARMM, OPLS, and Amber approximate the potential energy surface with pairwise-additive analytical terms. This functional form struggles with the many-body interactions that shape disordered regions of proteins, including exchange-repulsion, charge transfer, charge penetration, and cooperative hydrogen bonding. Polarizable force fields add induced dipoles but remain empirically parameterized and fail to capture short-range many-body effects from electron-density overlap.</p>
<p>Quantum-mechanical methods avoid this, but <a href="https://en.wikipedia.org/wiki/Coupled_cluster">coupled cluster theory</a> scales as $\mathcal{O}(N^7)$ in the number of electrons and even DFT remains $\mathcal{O}(N^3)$ to $\mathcal{O}(N^4)$, ruling out direct ab initio molecular dynamics for biomolecules. Fragmentation methods like molecular fractionation with conjugate caps (MFCC) mitigate the cost, but they truncate the many-body expansion at two bodies and miss long-range hydrogen bonding. <a href="/notes/chemistry/molecular-simulation/ml-potentials/dark-side-of-forces/">Machine-learned force fields (MLFFs)</a> reach near-QM accuracy at lower cost, yet they typically train on DFT data (inheriting delocalization errors and poor dispersion), struggle with interpretability, and extrapolate unreliably. Existing permutationally invariant polynomial (PIP) approaches scale factorially in the number of atoms, capping direct applicability at roughly ten to fifteen atoms per fragment.</p>
<p>MB-nrg PEFs based on the many-body expansion and PIPs have successfully modeled water, halides in water, carbon dioxide, methane, ammonia, nitrogen pentoxide, and N-methylacetamide. Extending them to covalently bonded biomolecules requires rethinking what counts as a &ldquo;body.&rdquo;</p>
<h2 id="building-polyalanine-from-functional-group-n-mers">Building Polyalanine from Functional-Group n-mers</h2>
<p>The MB-nrg formalism starts from the many-body expansion of the total energy,</p>
<p>$$
E_N(1, \dots, N) = \sum_{i=1}^{N} \varepsilon^{1\mathrm{B}}(i) + \sum_{i&lt;j}^{N} \varepsilon^{2\mathrm{B}}(i,j) + \sum_{i&lt;j&lt;k}^{N} \varepsilon^{3\mathrm{B}}(i,j,k) + \dots + \varepsilon^{N\mathrm{B}}(1, \dots, N)
$$</p>
<p>where each $n$-body contribution is defined recursively as the $n$-mer energy minus all lower-order terms. The full PEF combines physics-based and data-driven components,</p>
<p>$$
V_{\mathrm{MB\text{-}nrg}} = V_{\mathrm{ML}} + V_{\mathrm{phys}}
$$</p>
<p>with $V_{\mathrm{ML}} = V_{\mathrm{ML}}^{1\mathrm{B}} + V_{\mathrm{ML}}^{2\mathrm{B}} + V_{\mathrm{ML}}^{3\mathrm{B}}$ capturing short-range quantum-mechanical interactions, and $V_{\mathrm{phys}} = V_{\mathrm{elec}} + V_{\mathrm{disp}} + V_{\mathrm{rep}}$ supplying electrostatics, dispersion, and repulsion. Dispersion follows a Tang-Toennies damped $C_6/R^6$ form with XDM-derived coefficients; electrostatics uses a Thole-modified self-consistent polarization model inherited from MB-pol; the repulsion term is a Lennard-Jones $R^{-12}$ contribution borrowed from Amber ff14SB, activated only for non-bonded atom pairs not covered by a PIP.</p>
<p>Each data-driven $n$-body term is expressed as</p>
<p>$$
V_{\mathrm{ML}}^{n\mathrm{B}} = \sum_{\mathrm{M}_1 &lt; \dots &lt; \mathrm{M}_n}^{N} s^{n\mathrm{B}}(\mathrm{M}_1, \dots, \mathrm{M}_n), V_{\mathrm{PIP}}^{n\mathrm{B}}(\mathrm{M}_1, \dots, \mathrm{M}_n)
$$</p>
<p>where $V_{\mathrm{PIP}}^{n\mathrm{B}}$ is a permutationally invariant polynomial in Morse-like variables $\xi_{ij} = \exp(-k_{\tau(ij)} R_{ij})$ and $s^{n\mathrm{B}}$ is a switching function.</p>
<p>The key extension in this paper, building on earlier work on linear alkanes, is to treat functional groups (not whole molecules) as 1-mers. An Ace-capped, Nme-capped polyalanine chain decomposes into three distinct 1-mer types (-CH-, CH$_3$-, -CONH-), five distinct 2-mer types, and six distinct 3-mer types, for 14 unique PIPs that cover every $n$-mer appearing in any AceAla$_n$Nme chain. Cleaving covalent bonds between 1-mers would produce radicals, so the authors cap dangling valences with &ldquo;ghost&rdquo; hydrogen atoms at fixed C-H (1.14 Å) and N-H (1.09 Å) distances. Each $n$-mer energy is then referenced to its own optimized H-capped structure,</p>
<p>$$
E_n(1, \dots, n) = E_n^{\mathrm{H\text{-}capped}}(1, \dots, n) - E_n^{\mathrm{H\text{-}capped,opt}}(1, \dots, n).
$$</p>
<p>In the current implementation, only covalently bonded $n$-mers receive PIPs, the 2-body contribution from a dimer with one intervening 1-mer is folded into the corresponding 3-body term, and non-bonded 1-mers interact through the Lennard-Jones repulsion alone. Crucially, no whole-chain polyalanine data enters any stage of training: every PIP is parameterized on isolated $n$-mer configurations, and the total energy is reconstructed through the many-body expansion.</p>
<h2 id="training-on-dlpno-ccsdt-with-metadynamics-sampling">Training on DLPNO-CCSD(T) with Metadynamics Sampling</h2>
<p>Training sets are generated for each of the 14 $n$-mer types using <a href="https://en.wikipedia.org/wiki/Metadynamics">parallel-bias metadynamics (PBMetaD)</a> with partitioned families, biasing heavy-atom bonds, angles, and dihedrals across 300 K, 500 K, and 700 K in LAMMPS interfaced with PLUMED and modified OPLS/CM1A and Amber ff14SB force fields. For each $n$-mer, 200,000 candidate configurations are sampled, then reduced to roughly 10,000-20,000 training configurations (and about 1,000 test configurations) through Mini-batch K-means clustering on chemically equivalent pairwise distances. Reference energies are computed at the DLPNO-CCSD(T)/aug-cc-pVTZ level in ORCA.</p>
<p>Each PIP minimizes a weighted, ridge-regularized sum of squared errors,</p>
<p>$$
\chi^2 = \sum_{k \in \mathcal{S}} w_k \left[ V^{n\mathrm{B}}(k) - \varepsilon^{n\mathrm{B}}(k) \right]^2 + \Gamma^2 \sum_l c_l^2
$$</p>
<p>with $\Gamma = 0.0005$ throughout and low-energy bias weights</p>
<p>$$
w_k = \left( \frac{\delta E}{\varepsilon^{n\mathrm{B}}(k) - \varepsilon^{n\mathrm{B}}_{\min} + \delta E} \right)^2.
$$</p>
<p>MB-Fit handles the fit, combining simplex optimization for non-linear parameters $k_{\tau(ij)}$ with ridge regression for the linear coefficients $c_l$.</p>
<p>Table 1 in the paper reports, for each of the 14 PIPs, the polynomial degree (5 for the smaller -CH- and CH$_3$- 1-mers, 3 for the larger -CONH- 1-mer and for all 2-mers and 3-mers), the number of symmetrized monomials (ranging from 635 for the -CH- and CH$_3$- 1-mers to 2871 for the -CONH-CH-CONH- 3-mer), the training-set size, and RMSDs for the train and test splits. All training RMSDs stay below 0.4 kcal/mol and all test RMSDs below 0.5 kcal/mol, with the smallest errors for the -CH- and CH$_3$- 1-mers (0.05 kcal/mol train, 0.14 kcal/mol test) and the largest test RMSD (0.47 kcal/mol) for the -CONH-CH- 2-mer.</p>
<p>MD validations run in LAMMPS interfaced with MBX and PLUMED. For alanine dipeptide metadynamics, bias potentials on the backbone $\varphi$ and $\psi$ angles are deposited every 500 steps with a 1.0 kJ/mol height and 11.46° width over 10 ns trajectories in the NVT ensemble, using the velocity-Verlet integrator with a 0.5 fs time step. Analogous MetaD runs with Amber ff14SB and ff19SB are performed in Amber23. The longer AceAla$_9$Nme trajectories start from fully extended structures and run in a 100 Å × 100 Å × 100 Å gas-phase box.</p>
<h2 id="ccsdt-energy-landscapes-free-energy-surfaces-and-helix-dynamics">CCSD(T) Energy Landscapes, Free-Energy Surfaces, and Helix Dynamics</h2>
<p><strong>Alanine dipeptide 2D PES.</strong> Alanine dipeptide geometries are optimized on a <a href="https://en.wikipedia.org/wiki/Ramachandran_plot">Ramachandran</a> grid with 10° spacing at the RI-MP2/def2-TZVP level and then evaluated at DLPNO-CCSD(T)/aug-cc-pVTZ. Despite never seeing whole alanine dipeptide in training, MB-nrg closely matches the reference locations and relative energies of four minima ($m_1$ to $m_4$), three maxima ($M_1$ to $M_3$), and one saddle point ($X$). Amber ff14SB and ff19SB capture the minima reasonably but badly overshoot the barriers: at $M_1$, MB-nrg misses the reference by only -2.41 kcal/mol, while ff14SB and ff19SB overshoot by +7.50 and +7.83 kcal/mol. The authors also note that ff19SB incorrectly orders the secondary minima by predicting $m_3$ lower than $m_2$.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>RMSD overall (kcal/mol)</th>
          <th>RMSD $\leq 10$ kcal/mol</th>
          <th>RMSD $&gt; 10$ kcal/mol</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MB-nrg</td>
          <td>1.27</td>
          <td>1.18</td>
          <td>1.59</td>
      </tr>
      <tr>
          <td>Amber ff14SB</td>
          <td>6.33</td>
          <td>5.72</td>
          <td>8.44</td>
      </tr>
      <tr>
          <td>Amber ff19SB</td>
          <td>5.23</td>
          <td>4.79</td>
          <td>6.81</td>
      </tr>
  </tbody>
</table>
<p>The authors attribute MB-nrg&rsquo;s residual high-energy error to terminal methyl groups approaching the backbone in conformations where non-bonded 1-mer interactions are modeled by the simple LJ repulsion rather than an explicit PIP.</p>
<p><strong>Harmonic vibrations.</strong> Normal modes for the $m_1$ and $m_4$ alanine dipeptide conformers, computed by diagonalizing the Hessian, match RI-MP2/def2-TZVP references with mean deviations of 17.41 cm$^{-1}$ and 21.07 cm$^{-1}$ across all 60 modes. The authors acknowledge that some of this discrepancy reflects differences in theoretical levels (MB-nrg is trained to CCSD(T)/aug-cc-pVTZ, while the reference normal modes are computed at RI-MP2/def2-TZVP).</p>
<p><strong>Free-energy surfaces.</strong> Well-tempered metadynamics at 300 K produces 2D free-energy surfaces over $(\varphi, \psi)$. MB-nrg yields a smoother FES whose extrema line up with the DLPNO-CCSD(T) reference PES. Amber ff14SB and ff19SB remain reasonable near the low-energy $m_1$ and $m_2$ minima but systematically overestimate the barriers near $M_1$, $M_2$, and $M_3$, which the authors argue artificially confines the dipeptide and suppresses conformational transitions.</p>
<p><strong>Secondary structure in AceAla$_9$Nme.</strong> In 600 ps NVT MD starting from a fully extended structure, the <a href="https://en.wikipedia.org/wiki/STRIDE_(algorithm)">STRIDE algorithm</a> tracks residue-level secondary structures. Amber ff14SB and ff19SB collapse into $\alpha$-helices at roughly 40 ps and 80 ps, respectively, with ff19SB remaining especially rigid. MB-nrg takes about 100 ps before helices begin to form and then exhibits continuous oscillations between $3_{10}$- and $\alpha$-helical conformations. Ramachandran plots over the nine alanine residues show MB-nrg exploring the &ldquo;bridge&rdquo; region ($\varphi &lt; 0°$, $-20° \leq \psi \leq 20°$) associated with $3_{10}$-helices and sampling the left-handed $\alpha_L$ region that Amber rarely visits. The authors tie this flexibility to experimental observations of alanine-rich peptides in the gas phase and to similar predictions from GEMS and MACE-OFF.</p>
<h2 id="transferability-without-whole-chain-training-data">Transferability Without Whole-Chain Training Data</h2>
<p>The paper demonstrates that a modular, bottom-up PEF built from functional-group $n$-mers can reach CCSD(T) accuracy for polyalanine in the gas phase without ever training on whole-chain data. Truncating explicit data-driven terms at the 3-body level appears to balance cost and fidelity, with long-range effects handled by many-body polarization in $V_{\mathrm{elec}}$ and by Amber-derived repulsion between distant 1-mers. The 2D PES, harmonic frequencies, free-energy surface, and secondary-structure dynamics each validate a different facet of the model.</p>
<p>The authors are explicit about limitations. The current PEF applies only to gas-phase polyalanine; solvent effects and other amino acids remain open. The Lennard-Jones repulsion for non-bonded 1-mers is a placeholder for eventual 2-body PIPs that should capture short-range interactions during folding. Long-range hydrogen bonding in compact secondary structures (π-helices, $3_{10}$-helices, $\alpha$-helices) may produce non-negligible higher-order many-body contributions that the current 3-body truncation omits. The 2-body contribution from a dimer with one intervening monomer is currently folded into the 3-body term because of steric conflicts between capping hydrogens, and a systematic fix is flagged for future work. The authors position this paper as the first in a series (the &ldquo;I.&rdquo; in the title refers to &ldquo;Polyalanine in the Gas Phase&rdquo;) that will extend MB-nrg to broader biomolecular systems under physiological conditions. The follow-up, <a href="/notes/chemistry/molecular-simulation/ml-potentials/mb-nrg-polyalanine-water/">MB-nrg in Solution: Polyalanine in Water with CCSD(T) PEFs</a>, adds explicit 1-mer/water 2-body PIPs and benchmarks alanine dipeptide solvation.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>Per $n$-mer pools from PBMetaD in LAMMPS/PLUMED</td>
          <td>200,000 configurations each, reduced to ~10-20k via Mini-batch K-means</td>
          <td>OPLS/CM1A and Amber ff14SB sampled at 300 K, 500 K, 700 K</td>
      </tr>
      <tr>
          <td>Training labels</td>
          <td>DLPNO-CCSD(T)/aug-cc-pVTZ in ORCA</td>
          <td>14 unique $n$-mer types</td>
          <td>Domain-based local pair natural orbital approximation to canonical CCSD(T)</td>
      </tr>
      <tr>
          <td>Test</td>
          <td>Held-out $n$-mer configurations</td>
          <td>~1,000 per $n$-mer</td>
          <td>Same clustering protocol</td>
      </tr>
      <tr>
          <td>Alanine dipeptide benchmark</td>
          <td>Ramachandran grid at 10° spacing, RI-MP2/def2-TZVP geometries</td>
          <td>1,296 grid points (approximate)</td>
          <td>Single-point energies at DLPNO-CCSD(T)/aug-cc-pVTZ, ff14SB, ff19SB, MB-nrg</td>
      </tr>
      <tr>
          <td>AceAla$_9$Nme dynamics</td>
          <td>600 ps NVT MD from fully extended start</td>
          <td>Single trajectory per model</td>
          <td>STRIDE for secondary-structure assignment</td>
      </tr>
  </tbody>
</table>
<p>Per the Data Availability statement, &ldquo;any data generated and analyzed in this study are available from the authors upon request.&rdquo; No public release is announced in the text.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Many-body expansion of the energy with 1-, 2-, and 3-body data-driven terms.</li>
<li>Permutationally invariant polynomials in Morse-exponential variables $\xi_{ij} = \exp(-k_{\tau(ij)} R_{ij})$, symmetrized over chemically equivalent atoms.</li>
<li>&ldquo;Ghost&rdquo; H-capping at cleaved covalent bonds, with fixed C-H (1.14 Å) and N-H (1.09 Å) bond lengths and per-$n$-mer optimized-structure referencing.</li>
<li>Non-linear parameters fit by simplex minimization, linear coefficients by ridge regression with $\Gamma = 0.0005$.</li>
<li>Low-energy weighting in the loss through $w_k = (\delta E / (\varepsilon^{n\mathrm{B}}(k) - \varepsilon^{n\mathrm{B}}_{\min} + \delta E))^2$.</li>
<li>Tang-Toennies damped dispersion with XDM-derived $C_6$ and damping parameters, Thole-modified many-body polarization, and LJ repulsion borrowed from Amber ff14SB.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>14 PIPs total covering three 1-mer types, five 2-mer types, and six 3-mer types. Polynomial degree is 5 for the -CH- and CH$_3$- 1-mers, and 3 for the -CONH- 1-mer together with all 2-mers and 3-mers. Term counts range from 635 (-CH-, CH$_3$-) to 2871 (-CONH-CH-CONH-).</li>
<li>MB-nrg PEF implemented in the MBX code and exercised through LAMMPS and PLUMED.</li>
<li>Training set sizes per $n$-mer range from roughly 12,000 to 47,000 configurations (the -CONH- 1-mer dataset is the largest at 47,438).</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>MB-nrg</th>
          <th>Amber ff14SB</th>
          <th>Amber ff19SB</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>$n$-mer training RMSD</td>
          <td>$\leq 0.35$ kcal/mol</td>
          <td>n/a</td>
          <td>n/a</td>
      </tr>
      <tr>
          <td>$n$-mer test RMSD</td>
          <td>$\leq 0.47$ kcal/mol</td>
          <td>n/a</td>
          <td>n/a</td>
      </tr>
      <tr>
          <td>Alanine dipeptide 2D PES RMSD (overall)</td>
          <td>1.27 kcal/mol</td>
          <td>6.33 kcal/mol</td>
          <td>5.23 kcal/mol</td>
      </tr>
      <tr>
          <td>Same, $\leq 10$ kcal/mol region</td>
          <td>1.18 kcal/mol</td>
          <td>5.72 kcal/mol</td>
          <td>4.79 kcal/mol</td>
      </tr>
      <tr>
          <td>Same, $&gt; 10$ kcal/mol region</td>
          <td>1.59 kcal/mol</td>
          <td>8.44 kcal/mol</td>
          <td>6.81 kcal/mol</td>
      </tr>
      <tr>
          <td>Alanine dipeptide $m_1$ normal-mode mean deviation vs RI-MP2/def2-TZVP</td>
          <td>17.41 cm$^{-1}$</td>
          <td>n/a</td>
          <td>n/a</td>
      </tr>
      <tr>
          <td>Alanine dipeptide $m_4$ normal-mode mean deviation vs RI-MP2/def2-TZVP</td>
          <td>21.07 cm$^{-1}$</td>
          <td>n/a</td>
          <td>n/a</td>
      </tr>
      <tr>
          <td>AceAla$_9$Nme helix-formation onset (from extended start)</td>
          <td>~100 ps ($\alpha$/$3_{10}$ mix)</td>
          <td>~40 ps ($\alpha$)</td>
          <td>~80 ps ($\alpha$)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Computational resources came from the Air Force Office of Scientific Research (FA9550-20-1-0351), NSF award 2311260, the DoD High Performance Computing Modernization Program, the San Diego Supercomputer Center via ACCESS allocation CHE240114, and NERSC (contract DE-AC02-05CH11231, award BES-ERCAP0030920). Specific wall-clock and node-hour figures are not reported in the main text.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhou, R., Bull-Vulpe, E. F., Pan, Y., &amp; Paesani, F. (2025). Data-Driven Many-Body Simulations of Biomolecules with CCSD(T) Accuracy: I. Polyalanine in the Gas Phase. <em>ChemRxiv</em>. <a href="https://doi.org/10.26434/chemrxiv-2025-b05k5">https://doi.org/10.26434/chemrxiv-2025-b05k5</a></p>
<p><strong>Publication</strong>: ChemRxiv preprint, 25 March 2025.</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/paesanilab/MBX">MBX software (Paesani group)</a></li>
<li><a href="https://github.com/paesanilab/MB-Fit">MB-Fit (training pipeline)</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{zhou2025data,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Data-Driven Many-Body Simulations of Biomolecules with CCSD(T) Accuracy: I. Polyalanine in the Gas Phase}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Zhou, Ruihan and Bull-Vulpe, Ethan F. and Pan, Yuanhui and Paesani, Francesco}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.26434/chemrxiv-2025-b05k5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">howpublished</span>=<span style="color:#e6db74">{\url{https://doi.org/10.26434/chemrxiv-2025-b05k5}}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MB-nrg in Solution: Polyalanine in Water with CCSD(T) PEFs</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/mb-nrg-polyalanine-water/</link><pubDate>Sun, 12 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/mb-nrg-polyalanine-water/</guid><description>Zhou and Paesani extend MB-nrg to peptide-water interactions, training 1-mer-water 2-body PIPs on DLPNO-CCSD(T) and benchmarking alanine dipeptide solvation.</description><content:encoded><![CDATA[<h2 id="extending-mb-nrg-from-gas-phase-polyalanine-to-aqueous-solution">Extending MB-nrg from Gas-Phase Polyalanine to Aqueous Solution</h2>
<p>This is a <strong>Method</strong> paper, the second installment in Zhou and Paesani&rsquo;s MB-nrg-for-biomolecules series. Paper I (covered in <a href="/notes/chemistry/molecular-simulation/ml-potentials/mb-nrg-polyalanine-ccsdt/">MB-nrg: CCSD(T)-Accurate Potentials for Polyalanine</a>) decomposed gas-phase polyalanine into functional-group $n$-mers and fit permutationally invariant polynomials (PIPs) to DLPNO-CCSD(T)/aug-cc-pVTZ reference data. This sequel adds the missing piece: explicit, machine-learned 2-body interactions between every polyalanine functional-group 1-mer and a water molecule, trained on the same <a href="https://en.wikipedia.org/wiki/Coupled_cluster">coupled-cluster</a> reference. The resulting PEF couples the gas-phase intramolecular MB-nrg term, the MB-pol water model, and a new MB-nrg ala-water cross term within a single modular many-body decomposition.</p>
<h2 id="why-empirical-force-fields-struggle-with-hydrated-peptides">Why Empirical Force Fields Struggle with Hydrated Peptides</h2>
<p>Biomolecular function in water emerges from a coupling of intramolecular flexibility with solvent-mediated interactions, including hydrogen-bond networks, cooperative polarization, dispersion, and short-range exchange-repulsion. Empirical force fields such as AMBER, CHARMM, and OPLS approximate the multidimensional PES with pairwise-additive analytical terms whose parameters are tuned to experimental observables or low-level quantum data. The authors note that this functional form leads to systematic errors in predicted conformational ensembles for short peptides and <a href="https://en.wikipedia.org/wiki/Intrinsically_disordered_proteins">intrinsically disordered proteins (IDPs)</a>, with reported overpopulation of polyproline II (pPII) basins and antiparallel $\beta$ regions for alanine residues, plus underrepresentation of the transitional $\beta$ basin compared to experiment.</p>
<p>Polarizable force fields recover dielectric and hydration trends through induced dipoles, but still lean on empirical functional forms and miss short-range quantum effects (charge transfer, charge penetration, exchange-repulsion) that arise from electron-density overlap. <a href="/notes/chemistry/molecular-simulation/ml-potentials/dark-side-of-forces/">Machine-learned force fields</a> like MACE-OFF, GEMS, and FeNNix-Bio1 have improved bio-organic accuracy, but they still depend critically on the diversity and quality of training data, struggle to decompose energies into physically interpretable components, and most rely on DFT references that inherit delocalization errors and incomplete long-range correlation. Local descriptors common to MLFFs also limit treatment of long-range electrostatics and many-body correlations, both essential for biomolecular solvation.</p>
<p>The MB-nrg formalism, originally developed for water and small molecules and recently extended to alkanes and gas-phase polyalanine, offers an alternative: a rigorous many-body expansion (MBE) of the energy combined with both data-driven $n$-body PIPs and physics-based long-range terms. Paper II asks whether this modular gas-phase scaffold can be cleanly extended to aqueous environments by adding only short-range peptide-water 2-body PIPs.</p>
<h2 id="a-modular-mb-nrg-pef-for-polyalanine-in-water">A Modular MB-nrg PEF for Polyalanine in Water</h2>
<p>The MBE writes the total energy of a system of $N$ 1-mers as</p>
<p>$$
E_N(1, \dots, N) = \sum_{i=1}^{N} \varepsilon^{1\mathrm{B}}(i) + \sum_{i&lt;j}^{N} \varepsilon^{2\mathrm{B}}(i,j) + \sum_{i&lt;j&lt;k}^{N} \varepsilon^{3\mathrm{B}}(i,j,k) + \dots + \varepsilon^{N\mathrm{B}}(1, \dots, N)
$$</p>
<p>with each $n$-body term defined recursively as the $n$-mer energy minus all lower-order contributions. The MBE converges quickly for insulating molecular systems with large electronic band gaps (such as water and peptides), so explicit PIP corrections are typically truncated at $n \leq 4$, with higher-order effects absorbed into many-body polarization.</p>
<p>For polyalanine in water, the total potential is partitioned into three modular blocks:</p>
<p>$$
V_{\mathrm{MB\text{-}nrg}}^{\mathrm{tot}} = V_{\mathrm{MB\text{-}nrg}}^{\mathrm{ala}} + V_{\mathrm{MB\text{-}pol}}^{\mathrm{wat}} + V_{\mathrm{MB\text{-}nrg}}^{\mathrm{ala\text{-}wat}}
$$</p>
<p>where $V_{\mathrm{MB\text{-}nrg}}^{\mathrm{ala}}$ is the gas-phase intramolecular polyalanine PEF from Paper I, $V_{\mathrm{MB\text{-}pol}}^{\mathrm{wat}}$ is the MB-pol water model, and $V_{\mathrm{MB\text{-}nrg}}^{\mathrm{ala\text{-}wat}}$ is the new peptide-water cross term. The cross term itself follows the MB-nrg recipe of splitting machine-learned and physics-based contributions:</p>
<p>$$
V_{\mathrm{MB\text{-}nrg}}^{\mathrm{ala\text{-}wat}} = V_{\mathrm{ML}} + V_{\mathrm{phys}}
$$</p>
<p>with $V_{\mathrm{ML}} = V_{\mathrm{ML}}^{2\mathrm{B}}$ (only 2-body PIPs in this implementation) and $V_{\mathrm{phys}} = V_{\mathrm{elec}} + V_{\mathrm{disp}}$. The 2-body machine-learned term sums switched PIPs over every (1-mer, water) dimer:</p>
<p>$$
V_{\mathrm{ML}}^{2\mathrm{B}} = \sum_{i=1}^{N} s^{2\mathrm{B}}(\mathrm{M}_i, \mathrm{WAT}), V_{\mathrm{PIP}}^{2\mathrm{B}}(\mathrm{M}_i, \mathrm{WAT})
$$</p>
<p>where $\mathrm{M}_i$ is the $i$-th polyalanine functional-group 1-mer (-CH-, CH$_3$-, or -CONH-), WAT is a water molecule, and $s^{2\mathrm{B}}$ is a cosine switching function</p>
<p>$$
s^{2\mathrm{B}}(x) = \begin{cases} 1 &amp; x &lt; 0 \\ (1 + \cos(x))/2 &amp; 0 \leq x &lt; 1 \\ 0 &amp; 1 \leq x \end{cases}, \quad x = \frac{R - R_{\mathrm{in}}}{R_{\mathrm{out}} - R_{\mathrm{in}}}
$$</p>
<p>that smoothly attenuates the short-range PIP beyond a defined distance to preserve energy conservation in MD. The physics-based block uses a Thole-modified self-consistent polarization model (inherited from MB-pol) for $V_{\mathrm{elec}}$ and a Tang-Toennies damped dispersion sum</p>
<p>$$
V_{\mathrm{disp}} = -\sum_{\substack{\alpha \in 1\text{-mers} \\ \beta \in \mathrm{water}}} f(\mathrm{b}_{\alpha\beta} R_{\alpha\beta}), \frac{C_{6, \alpha\beta}}{R_{\alpha\beta}^{6}}
$$</p>
<p>with $C_{6, \alpha\beta}$ coefficients and atomic polarizabilities derived from the exchange-hole dipole moment (XDM) method, and atomic charges fit to reproduce the permanent multipole moments of each $n$-mer&rsquo;s optimized structure.</p>
<p>The authors stress that explicit 3-body and higher peptide-water PIPs are deliberately omitted in this first implementation; their effects are absorbed into the classical polarization term. They flag that strongly hydrogen-bonded or cooperative configurations may benefit from adding higher-body corrections in future work, following the precedent of MB-pol(2023) for water.</p>
<h2 id="training-set-generation-and-dlpno-ccsdt-reference-data">Training Set Generation and DLPNO-CCSD(T) Reference Data</h2>
<p>Training pools for the three 1-mer-water dimers (CH$_3$-H$_2$O, -CH&ndash;H$_2$O, -CONH&ndash;H$_2$O) extend the <a href="https://en.wikipedia.org/wiki/Metadynamics">parallel-bias metadynamics with partitioned families (PBMetaD+PFs)</a> protocol from Paper I. Covalent boundaries are capped with &ldquo;ghost&rdquo; hydrogens at fixed C-H (1.14 Å) and N-H (1.09 Å) distances to preserve closed-shell character; each 2-body energy is referenced to the corresponding optimized capped 1-mer-water geometry to remove constant offsets.</p>
<p>PBMetaD simulations are run in LAMMPS interfaced with PLUMED, using Amber ff14SB for the alanine 1-mers and TIP4P/2005f for water. Collective variables span all heavy-atom bonds, angles, and dihedrals in each dimer. To target distinct interaction regimes, three separate biased runs apply upper and lower walls on the 1-mer/water center-of-mass distance: 0-4 Å (short-range repulsion), 4-7 Å (mid-range attraction), and 7-10 Å (long-range orientation-dependent interactions). Each dimer yields about 600,000 configurations, reduced to roughly 40,000 training and 2,000 test configurations per type by K-means clustering.</p>
<p>Reference 2-body energies are computed at the DLPNO-CCSD(T)/aug-cc-pVTZ level in ORCA, using the aug-cc-pVTZ/C auxiliary basis, the RIJCOSX approximation, TightSCF, TightPNO, and the PModel pair-selection option. The counterpoise method corrects every 2-body energy for <a href="https://en.wikipedia.org/wiki/Basis_set_superposition_error">basis set superposition error</a>.</p>
<p>Each PIP minimizes a weighted, ridge-regularized least-squares objective:</p>
<p>$$
\chi^2 = \sum_{k \in \mathcal{S}} w_k \left[ V^{2\mathrm{B}}(k) - \varepsilon^{2\mathrm{B}}(k) \right]^2 + \Gamma^2 \sum_l c_l^2
$$</p>
<p>with $\Gamma = 0.0005$ throughout. Training weights bias the fit toward low-energy configurations,</p>
<p>$$
w_k = \left( \frac{\delta E}{\varepsilon^{2\mathrm{B}}(k) - \varepsilon_{\mathrm{min}}^{2\mathrm{B}} + \delta E} \right)^2
$$</p>
<p>with $\delta E = 40$ kcal/mol for all 1-mer-water pairs. MB-Fit handles the optimization, combining simplex minimization for non-linear parameters (Morse decay constants) with ridge regression for the linear coefficients.</p>
<p>Table 1 reports the PIP specifications. All three PIPs use polynomial degree 3 with a complete, unscreened basis. The -CH- and CH$_3$- dimers each require 710 symmetrized terms; the chemically richer -CONH- dimer requires 1,267 terms to capture its dipolar character and directional hydrogen bonding. Training-set sizes range from 41,781 to 43,174 configurations.</p>
<table>
  <thead>
      <tr>
          <th>1-mer type</th>
          <th>PIP degree</th>
          <th>PIP terms</th>
          <th>Training configs</th>
          <th>Train RMSD (kcal/mol)</th>
          <th>Test RMSD (kcal/mol)</th>
          <th>Train MAE</th>
          <th>Test MAE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>-CH-</td>
          <td>3</td>
          <td>710</td>
          <td>43,174</td>
          <td>0.07</td>
          <td>0.08</td>
          <td>0.06</td>
          <td>0.06</td>
      </tr>
      <tr>
          <td>CH$_3$-</td>
          <td>3</td>
          <td>710</td>
          <td>43,172</td>
          <td>0.08</td>
          <td>0.08</td>
          <td>0.05</td>
          <td>0.05</td>
      </tr>
      <tr>
          <td>-CONH-</td>
          <td>3</td>
          <td>1,267</td>
          <td>41,781</td>
          <td>0.18</td>
          <td>0.20</td>
          <td>0.13</td>
          <td>0.16</td>
      </tr>
  </tbody>
</table>
<p>All RMSDs sit below 0.20 kcal/mol on both train and test splits, well within sub-chemical accuracy.</p>
<h2 id="validation-dimer-scans-free-energy-surfaces-and-hydration">Validation: Dimer Scans, Free-Energy Surfaces, and Hydration</h2>
<p>The authors stage four validation studies of increasing complexity, each touching a distinct facet of the new PEF.</p>
<p><strong>Alanine dipeptide-water dimer scans.</strong> One-dimensional scans probe the interaction energy along four hydrogen-bonding coordinates of an alanine dipeptide-water dimer: O$_1$-H$_w$, H$_1$-O$_w$, O$_2$-H$_w$, and H$_2$-O$_w$, where subscripts 1 and 2 mark the acetyl and N-methyl termini. The dipeptide is constrained to four representative <a href="https://en.wikipedia.org/wiki/Ramachandran_plot">Ramachandran</a> conformations: C5 ($\varphi = -150°$, $\psi = 150°$), pPII ($\varphi = -80°$, $\psi = 150°$), C7$_{\mathrm{eq}}$ ($\varphi = -80°$, $\psi = 70°$), and right-handed $\alpha$-helix $\alpha_R$ ($\varphi = -80°$, $\psi = -30°$). MB-nrg closely tracks the DLPNO-CCSD(T)/aug-cc-pVTZ reference curves across all 16 (4 conformation $\times$ 4 site) scans, despite never seeing the full dipeptide-water surface during training. Amber ff14SB/TIP3P and ff19SB/OPC underestimate hydrogen-bond depths and miss curvature near equilibrium, with the ff14SB/TIP3P combination yielding slightly better overall agreement than ff19SB/OPC even though TIP3P is the less accurate water model.</p>
<p>Two specific failure modes of the empirical force fields stand out. In the pPII conformation, both ff14SB and ff19SB predict significantly deeper interaction wells than the reference, overstabilizing several hydrogen bonds. In the H$_2$-O$_w$ scan of the $\alpha_R$ conformation, both empirical FFs exhibit a spurious 2.5-4.0 Å energy barrier that the authors trace to the simple Lennard-Jones repulsion between the acetyl carbonyl oxygen and water; MB-nrg and DLPNO-CCSD(T) instead show a smoothly decaying profile. The one MB-nrg deviation noted is the C5 H$_1$-O$_w$ scan in the 1.5-2.5 Å range, where MB-nrg predicts a slightly more attractive interaction than the reference. Here the H$_1$-O$_2$ distance is 2.3 Å and water acts simultaneously as acceptor at H$_1$ and donor to O$_2$, a cooperative pattern the authors expect would require explicit 2-mer-water or 3-mer-water terms to fully reproduce.</p>
<p><strong>Free-energy surface in explicit MB-pol water.</strong> Four-walker well-tempered metadynamics (WT-MetaD) simulations explore the conformational landscape of alanine dipeptide as a function of $(\varphi, \psi)$, biasing the central alanine residue&rsquo;s backbone dihedrals every 500 steps with 1.0 kJ/mol Gaussians of 11.46° width. The free-energy section reports 2.5 ns per replica across four parallel walkers (10 ns aggregate, matching the Figure 6 caption); the methods section states 8 ns total, an internal inconsistency in the paper. The MB-nrg FES recovers all major low-energy conformers identified by NMR and prior MP2/DFT studies: a global minimum at $\alpha_R$, additional local minima in C5, $\beta_2$, and $\alpha_L$, and a metastable pPII basin. The C7$_{\mathrm{eq}}$ minimum that dominates the gas-phase Ramachandran surface in Paper I is significantly destabilized in solution, consistent with experiment.</p>
<p>Quantitatively, MB-nrg predicts $\alpha_R$ and $\beta_2$ as isoenergetic global minima, with C5 about 3 kcal/mol higher in free energy. Prior DFT-with-implicit-solvation studies (Mironov et al., Yang and Honig) report C5, $\alpha_R$, and $\beta_2$ as nearly isoenergetic, and the authors note that the discrepancy may reflect the explicit MB-pol water treatment, residual DFT errors in the reference, or both. They flag a planned systematic benchmarking of MB-nrg PEFs for diverse polypeptides against both DFT and DLPNO-CCSD(T) data in future work. The Amber FESs over-stabilize pPII relative to C5/$\alpha_R$, contradicting experimental and DFT benchmarks; ff19SB/OPC also exhibits a spurious C7$_{\mathrm{eq}}$ minimum that is absent from MB-nrg.</p>
<p><strong>Hydration radial distribution functions.</strong> Site-site RDFs at 300 K for the same hydrogen-bond contacts (O$_1$-H$_w$, O$_2$-H$_w$, H$_1$-O$_w$, H$_2$-O$_w$) are computed from NVT MD trajectories. All three models reproduce well-defined first-shell peaks near 2.0 Å. For the O-H$_w$ pairs, MB-nrg shows a broader, slightly right-shifted second-shell peak, indicating less rigid water structure beyond the first shell. The amide-hydrogen RDFs are nearly identical between ff14SB/TIP3P and ff19SB/OPC, while MB-nrg reveals subtle first-shell shifts (shorter H$_1$-O$_w$, longer H$_2$-O$_w$) and weaker, less-defined second-shell features near 3.7-3.8 Å that are absent from the empirical force fields and consistent with prior ab initio MD on alanine dipeptide.</p>
<h2 id="a-modular-path-to-chemically-accurate-biomolecular-simulations">A Modular Path to Chemically Accurate Biomolecular Simulations</h2>
<p>Across the four benchmarks, the same picture emerges: a modular, bottom-up MB-nrg PEF built from functional-group $n$-mers and trained only on isolated 1-mer-water dimers can reach DLPNO-CCSD(T) accuracy for both energetic and structural observables of alanine dipeptide in explicit water. The decomposition into a gas-phase intramolecular term, an MB-pol water model, and an MB-nrg cross term keeps each piece interpretable and individually replaceable; the gas-phase polyalanine PEF from Paper I drops in unchanged, and the new ala-water PIPs were fit without ever seeing the full alanine dipeptide-water PES.</p>
<p>The authors are explicit about limitations:</p>
<ul>
<li>The cross term currently includes only 2-body PIPs (one 1-mer with one water). Higher-body peptide-water terms ($n &gt; 2$) are folded into the classical polarization, which the authors expect will be inadequate for strongly cooperative configurations such as the C5 H$_1$-O$_w$ scan where one water bridges H$_1$ and O$_2$.</li>
<li>Quantitative differences between the MB-nrg FES and prior implicit-solvation DFT studies (relative depths of $\alpha_R$, $\beta_2$, and C5) remain to be reconciled through systematic benchmarking against higher-level reference data.</li>
<li>Only polyalanine is considered. The framework is designed to generalize to other amino acids and side-chain-water interactions, but sequence- and side-chain-specific PIPs are still to be fit.</li>
<li>No public release of the parameterized PEF or training data is announced; the data availability statement says &ldquo;available from the authors upon request.&rdquo;</li>
</ul>
<p>The paper positions MB-nrg as a transferable, interpretable strategy for chemically accurate biomolecular simulations in solution, with future work aimed at heteropolypeptides and explicit higher-order many-body cross terms.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training pools</td>
          <td>PBMetaD+PFs in LAMMPS/PLUMED</td>
          <td>~600,000 configs per dimer, reduced to ~40,000</td>
          <td>ff14SB for alanine 1-mers, TIP4P/2005f for water; 300 K, 500 K, 700 K</td>
      </tr>
      <tr>
          <td>Distance regimes</td>
          <td>Walls on 1-mer/water COM distance</td>
          <td>0-4, 4-7, and 7-10 Å</td>
          <td>Short-range repulsion, mid-range attraction, long-range orientation</td>
      </tr>
      <tr>
          <td>Training labels</td>
          <td>DLPNO-CCSD(T)/aug-cc-pVTZ in ORCA</td>
          <td>3 unique 1-mer-water dimer types</td>
          <td>RIJCOSX, TightSCF, TightPNO, PModel; counterpoise BSSE correction</td>
      </tr>
      <tr>
          <td>Test sets</td>
          <td>Held-out clustered configs</td>
          <td>~2,000 per dimer</td>
          <td>Same K-means clustering protocol</td>
      </tr>
      <tr>
          <td>Alanine dipeptide-water scans</td>
          <td>1D scans along 4 H-bond coordinates in 4 conformations</td>
          <td>16 scans total</td>
          <td>C5, pPII, C7$_{\mathrm{eq}}$, and $\alpha_R$ conformations</td>
      </tr>
      <tr>
          <td>Alanine dipeptide FES</td>
          <td>WT-MetaD on $\varphi$, $\psi$ in MB-pol water</td>
          <td>4 walkers, 2.5 ns each (10 ns total per the results section and Figure 6 caption; methods section states 8 ns)</td>
          <td>1.0 kJ/mol height, 11.46° width, deposition every 500 steps</td>
      </tr>
      <tr>
          <td>Hydration RDFs</td>
          <td>NVT MD at 300 K</td>
          <td>Single trajectory per model</td>
          <td>Same H-bond sites as the dimer scans</td>
      </tr>
  </tbody>
</table>
<p>Per the data availability statement, &ldquo;any data generated and analyzed in this study, including the MB-nrg PEF, are available from the authors upon request.&rdquo; The MBX engine is publicly available on <a href="https://github.com/paesanilab/MBX">GitHub</a> under a UC Regents custom license that grants free use for educational, research, and non-profit purposes but restricts commercial use. No public release of the new ala-water PIPs is announced in the text.</p>
<h4 id="artifacts-table">Artifacts table</h4>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/paesanilab/MBX">MBX</a></td>
          <td>Code</td>
          <td>UC Regents custom (academic/non-profit only; no SPDX-recognized OSS license)</td>
          <td>C++ many-body potential engine; runs the MB-nrg PEF via LAMMPS and PLUMED</td>
      </tr>
      <tr>
          <td><a href="https://github.com/paesanilab/MB-Fit">MB-Fit</a></td>
          <td>Code</td>
          <td>Check repo</td>
          <td>Training pipeline for PIP fitting; used to fit the new 1-mer-water PIPs</td>
      </tr>
      <tr>
          <td>MB-nrg ala-water PIPs (this paper)</td>
          <td>Model</td>
          <td>Not released</td>
          <td>&ldquo;Available from the authors upon request&rdquo; per the data availability statement</td>
      </tr>
      <tr>
          <td>DLPNO-CCSD(T) training/test sets</td>
          <td>Dataset</td>
          <td>Not released</td>
          <td>Same statement; ~600,000 raw configs per dimer reduced to ~40,000 train + ~2,000 test</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Many-body expansion of the energy partitioned into three modular blocks: $V_{\mathrm{MB\text{-}nrg}}^{\mathrm{ala}} + V_{\mathrm{MB\text{-}pol}}^{\mathrm{wat}} + V_{\mathrm{MB\text{-}nrg}}^{\mathrm{ala\text{-}wat}}$.</li>
<li>Cross term split into $V_{\mathrm{ML}}^{2\mathrm{B}}$ (PIPs over every 1-mer-water dimer) and $V_{\mathrm{phys}} = V_{\mathrm{elec}} + V_{\mathrm{disp}}$.</li>
<li>Permutationally invariant polynomials in Morse-exponential variables $\xi_{ij} = \exp(-k_{\tau(ij)} R_{ij})$, symmetrized over chemically equivalent atoms; same construction as the NMA-water PIPs.</li>
<li>Cosine switching function $s^{2\mathrm{B}}$ smoothly attenuates short-range PIPs between user-defined inner and outer cutoffs.</li>
<li>Dispersion: Tang-Toennies damped $C_6/R^6$ with XDM-derived coefficients and damping parameters.</li>
<li>Electrostatics: modified Thole model with self-consistent induced dipoles for many-body polarization; per-atom charges fit to reproduce permanent multipole moments of each $n$-mer&rsquo;s optimized structure.</li>
<li>Ghost-H capping at cleaved covalent boundaries with fixed C-H (1.14 Å) and N-H (1.09 Å) distances; per-dimer optimized-structure referencing.</li>
<li>Training with simplex minimization for non-linear parameters and ridge regression for linear coefficients via MB-Fit, with low-energy weighting and $\Gamma = 0.0005$, $\delta E = 40$ kcal/mol.</li>
<li>WT-MetaD with four parallel walkers for the alanine dipeptide FES.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Three new 1-mer-water 2-body PIPs covering -CH-/H$_2$O, CH$_3$-/H$_2$O, and -CONH-/H$_2$O dimers.</li>
<li>All three PIPs use polynomial degree 3 with a complete, unscreened basis (no term screening).</li>
<li>Term counts: 710 for -CH-/H$_2$O and CH$_3$-/H$_2$O, 1,267 for -CONH-/H$_2$O.</li>
<li>Combined with the gas-phase polyalanine MB-nrg PEF from Paper I and the MB-pol water model, exercised through MBX, LAMMPS, and PLUMED.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>MB-nrg</th>
          <th>Amber ff14SB/TIP3P</th>
          <th>Amber ff19SB/OPC</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>-CH-/H$_2$O 2-body train/test RMSD</td>
          <td>0.07 / 0.08 kcal/mol</td>
          <td>n/a</td>
          <td>n/a</td>
      </tr>
      <tr>
          <td>CH$_3$-/H$_2$O 2-body train/test RMSD</td>
          <td>0.08 / 0.08 kcal/mol</td>
          <td>n/a</td>
          <td>n/a</td>
      </tr>
      <tr>
          <td>-CONH-/H$_2$O 2-body train/test RMSD</td>
          <td>0.18 / 0.20 kcal/mol</td>
          <td>n/a</td>
          <td>n/a</td>
      </tr>
      <tr>
          <td>Alanine dipeptide-water 1D scans (qualitative)</td>
          <td>Tracks DLPNO-CCSD(T) curves across 16 scans</td>
          <td>Underestimates H-bond depths; spurious $\alpha_R$ H$_2$-O$_w$ barrier</td>
          <td>Same shape as ff14SB/TIP3P</td>
      </tr>
      <tr>
          <td>Alanine dipeptide FES global minima</td>
          <td>Isoenergetic $\alpha_R$ and $\beta_2$; C5 ~3 kcal/mol higher</td>
          <td>Over-stabilizes pPII</td>
          <td>Over-stabilizes pPII; spurious C7$_{\mathrm{eq}}$ minimum</td>
      </tr>
      <tr>
          <td>O-H$_w$ second shell</td>
          <td>Broader, right-shifted; finer detail consistent with prior AIMD</td>
          <td>Sharper, less detail</td>
          <td>Sharper, less detail</td>
      </tr>
      <tr>
          <td>H-O$_w$ second shell</td>
          <td>Weak features near 3.7-3.8 Å</td>
          <td>Absent</td>
          <td>Absent</td>
      </tr>
  </tbody>
</table>
<p>Quantitative RMSD or KL-divergence values for the FES and RDF benchmarks are not reported in the main text.</p>
<h3 id="hardware">Hardware</h3>
<p>The authors acknowledge support from the Air Force Office of Scientific Research (FA9550-20-1-0351, theoretical development) and NSF (award 2311260, MBX implementation). Computational resources came from the DoD High Performance Computing Modernization Program, the San Diego Supercomputer Center via ACCESS allocation CHE240114, and NERSC (contract DE-AC02-05CH11231, award BES-ERCAP0030920). Specific wall-clock and node-hour figures are not reported in the main text.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhou, R., &amp; Paesani, F. (2025). Toward Chemical Accuracy in Biomolecular Simulations through Data-Driven Many-Body Potentials: II. Polyalanine in Water. <em>ChemRxiv</em>. <a href="https://doi.org/10.26434/chemrxiv-2025-j6cwv-v2">https://doi.org/10.26434/chemrxiv-2025-j6cwv-v2</a></p>
<p><strong>Publication</strong>: ChemRxiv preprint (version 2), 10 October 2025.</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/paesanilab/MBX">MBX software (Paesani group)</a></li>
<li><a href="https://github.com/paesanilab/MB-Fit">MB-Fit (training pipeline)</a></li>
<li>Companion paper: <a href="/notes/chemistry/molecular-simulation/ml-potentials/mb-nrg-polyalanine-ccsdt/">MB-nrg: CCSD(T)-Accurate Potentials for Polyalanine</a> (Paper I)</li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{zhou2025toward,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Toward Chemical Accuracy in Biomolecular Simulations through Data-Driven Many-Body Potentials: II. Polyalanine in Water}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Zhou, Ruihan and Paesani, Francesco}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{ChemRxiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.26434/chemrxiv-2025-j6cwv-v2}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>LSTNet: Long- and Short-Term Time Series Network</title><link>https://hunterheidenreich.com/notes/time-series/lstnet-multivariate-time-series/</link><pubDate>Sat, 11 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/time-series/lstnet-multivariate-time-series/</guid><description>LSTNet combines CNNs, recurrent-skip connections, and autoregressive models to capture both short-term and long-term patterns in multivariate time series.</description><content:encoded><![CDATA[<h2 id="a-deep-learning-framework-for-multivariate-forecasting">A Deep Learning Framework for Multivariate Forecasting</h2>
<p>This is a <strong>Method</strong> paper that introduces the Long- and Short-term Time-series Network (LSTNet), a deep learning architecture specifically designed for multivariate time series forecasting. LSTNet combines convolutional neural networks (CNNs), recurrent neural networks (RNNs) with a novel skip-connection structure, and a traditional autoregressive (AR) component into a unified framework. The architecture targets the challenge of simultaneously capturing both short-term local dependencies and long-term periodic patterns in temporal data.</p>
<h2 id="why-short-term-and-long-term-patterns-need-separate-treatment">Why Short-Term and Long-Term Patterns Need Separate Treatment</h2>
<p>Real-world multivariate time series often exhibit a mixture of repeating patterns at different time scales. Highway traffic, for example, shows daily peaks (morning vs. evening commutes) alongside weekly patterns (weekday vs. weekend behavior). Solar energy output varies with cloud movements on short time scales and with seasonal daylight changes on longer ones. Electricity consumption follows similar daily and weekly cycles.</p>
<p>Traditional autoregressive methods (<a href="https://en.wikipedia.org/wiki/Vector_autoregression">VAR</a>, <a href="https://en.wikipedia.org/wiki/Autoregressive_integrated_moving_average">ARIMA</a>) and <a href="https://en.wikipedia.org/wiki/Gaussian_process">Gaussian Process</a> models struggle to distinguish and jointly model these two kinds of recurring patterns. Standard RNNs, including LSTM and <a href="https://en.wikipedia.org/wiki/Gated_recurrent_unit">GRU</a> variants, theoretically handle long-range dependencies but in practice suffer from <a href="https://en.wikipedia.org/wiki/Vanishing_gradient_problem">gradient vanishing</a> when the period length is large (e.g., 24 hours at hourly resolution, or 168 time steps for weekly patterns). The authors also identify a scale sensitivity problem: neural network models can fail when the magnitude of the input signal changes in non-periodic ways, such as sudden shifts in electricity consumption due to holidays or weather events.</p>
<h2 id="combining-cnns-recurrent-skip-connections-and-autoregression">Combining CNNs, Recurrent-Skip Connections, and Autoregression</h2>
<p>The LSTNet architecture consists of four main components that work together.</p>
<h3 id="convolutional-component">Convolutional Component</h3>
<p>The first layer applies 1D convolution without pooling across the multivariate input. Each filter has width $\omega$ (in the time dimension) and height $n$ (spanning all variables), producing feature maps that capture short-term local dependency patterns among variables:</p>
<p>$$h_k = \text{RELU}(W_k * X + b_k)$$</p>
<p>where $*$ denotes convolution and the input is zero-padded so each output vector has length $T$. The output is a $d_c \times T$ matrix where $d_c$ is the number of filters.</p>
<h3 id="recurrent-component">Recurrent Component</h3>
<p>The CNN output feeds into a GRU-based recurrent layer that uses RELU (rather than the standard tanh) as the hidden update activation:</p>
<p>$$\begin{aligned}
r_t &amp;= \sigma(x_t W_{xr} + h_{t-1} W_{hr} + b_r) \\
u_t &amp;= \sigma(x_t W_{xu} + h_{t-1} W_{hu} + b_u) \\
c_t &amp;= \text{RELU}(x_t W_{xc} + r_t \odot (h_{t-1} W_{hc}) + b_c) \\
h_t &amp;= (1 - u_t) \odot h_{t-1} + u_t \odot c_t
\end{aligned}$$</p>
<h3 id="recurrent-skip-component">Recurrent-Skip Component</h3>
<p>The key architectural innovation is a recurrent structure with temporal skip connections. Instead of connecting to the immediately preceding hidden state $h_{t-1}$, skip links connect to the hidden state from $p$ steps ago ($h_{t-p}$), where $p$ corresponds to the period length of the data (e.g., $p = 24$ for hourly data with daily periodicity):</p>
<p>$$\begin{aligned}
r_t &amp;= \sigma(x_t W_{xr} + h_{t-p} W_{hr} + b_r) \\
u_t &amp;= \sigma(x_t W_{xu} + h_{t-p} W_{hu} + b_u) \\
c_t &amp;= \text{RELU}(x_t W_{xc} + r_t \odot (h_{t-p} W_{hc}) + b_c) \\
h_t &amp;= (1 - u_t) \odot h_{t-p} + u_t \odot c_t
\end{aligned}$$</p>
<p>This design shortens the effective path length for learning periodic dependencies, making optimization easier. A dense layer combines outputs from both recurrent components:</p>
<p>$$h_t^D = W^R h_t^R + \sum_{i=0}^{p-1} W_i^S h_{t-i}^S + b$$</p>
<h3 id="temporal-attention-alternative">Temporal Attention Alternative</h3>
<p>For datasets without clear periodicity, LSTNet offers an attention-based variant (LSTNet-Attn) as an alternative to the recurrent-skip component. The attention mechanism learns to weight hidden representations across the input window adaptively. The attention weights $\alpha_t \in \mathbb{R}^q$ at time $t$ are computed as:</p>
<p>$$\alpha_t = \text{AttnScore}(H_t^R, h_{t-1}^R)$$</p>
<p>where $H_t^R = [h_{t-q}^R, \dots, h_{t-1}^R]$ stacks the RNN hidden representations column-wise and AttnScore is a similarity function (dot product, cosine, or a parameterized MLP). The weighted context vector and final output are:</p>
<p>$$\begin{aligned}
c_t &amp;= H_t \alpha_t \\
h_t^D &amp;= W[c_t;; h_{t-1}^R] + b
\end{aligned}$$</p>
<h3 id="autoregressive-component">Autoregressive Component</h3>
<p>To address the scale insensitivity of neural networks, LSTNet adds a classical autoregressive model in parallel:</p>
<p>$$h_{t,i}^L = \sum_{k=0}^{q^{ar}-1} W_k^{ar} y_{t-k,i} + b^{ar}$$</p>
<p>The final prediction integrates both the neural network and AR outputs:</p>
<p>$$\hat{Y}_t = h_t^D + h_t^L$$</p>
<p>This decomposition separates the prediction into a linear part (handling local scale changes) and a non-linear part (capturing recurring patterns).</p>
<h3 id="objective-function">Objective Function</h3>
<p>LSTNet supports two loss functions, selected via validation performance. The default is the squared (L2) loss:</p>
<p>$$\underset{\Theta}{\text{minimize}} \sum_{t \in \Omega_{\text{Train}}} \left| Y_t - \hat{Y}_{t-h} \right|_F^2$$</p>
<p>Motivated by the strong performance of Linear SVR baselines, LSTNet also supports the absolute (L1) loss, which is more robust to anomalies in real time series data:</p>
<p>$$\underset{\Theta}{\text{minimize}} \sum_{t \in \Omega_{\text{Train}}} \sum_{i=0}^{n-1} \left| Y_{t,i} - \hat{Y}_{t-h,i} \right|$$</p>
<p>where $\Theta$ is the full parameter set, $\Omega_{\text{Train}}$ is the set of training time stamps, $|\cdot|_F$ is the Frobenius norm, and $h$ is the forecast horizon.</p>
<h2 id="evaluation-on-four-benchmark-datasets">Evaluation on Four Benchmark Datasets</h2>
<h3 id="datasets">Datasets</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Length</th>
          <th>Variables</th>
          <th>Sample Rate</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Traffic</td>
          <td>17,544</td>
          <td>862</td>
          <td>1 hour</td>
      </tr>
      <tr>
          <td>Solar-Energy</td>
          <td>52,560</td>
          <td>137</td>
          <td>10 minutes</td>
      </tr>
      <tr>
          <td>Electricity</td>
          <td>26,304</td>
          <td>321</td>
          <td>1 hour</td>
      </tr>
      <tr>
          <td>Exchange-Rate</td>
          <td>7,588</td>
          <td>8</td>
          <td>1 day</td>
      </tr>
  </tbody>
</table>
<p>All datasets are split 60/20/20 (train/validation/test) in chronological order. Traffic, Solar-Energy, and Electricity exhibit clear periodic patterns (daily and weekly), while Exchange-Rate shows only short-term local continuity.</p>
<h3 id="baselines">Baselines</h3>
<p>The authors compare against seven methods: AR (univariate autoregression), LRidge (VAR with L2 regularization), LSVR (VAR with SVR objective), TRMF (temporal regularized matrix factorization), GP (Gaussian Process), VAR-MLP (hybrid MLP-autoregressive), and RNN-GRU (standard GRU).</p>
<h3 id="metrics">Metrics</h3>
<p>Two evaluation metrics are used:</p>
<ul>
<li><strong>Root Relative Squared Error (RSE)</strong> (lower is better): A scaled RMSE that normalizes by the standard deviation of the test data, making comparison across datasets readable regardless of data scale:</li>
</ul>
<p>$$\text{RSE} = \frac{\sqrt{\sum_{(i,t) \in \Omega_{\text{Test}}} (Y_{it} - \hat{Y}_{it})^2}}{\sqrt{\sum_{(i,t) \in \Omega_{\text{Test}}} (Y_{it} - \text{mean}(Y))^2}}$$</p>
<ul>
<li><strong>Empirical Correlation Coefficient (CORR)</strong> (higher is better): The average Pearson correlation between predicted and true time series across all $n$ variables:</li>
</ul>
<p>$$\text{CORR} = \frac{1}{n} \sum_{i=1}^{n} \frac{\sum_t (Y_{it} - \text{mean}(Y_i))(\hat{Y}_{it} - \text{mean}(\hat{Y}_i))}{\sqrt{\sum_t (Y_{it} - \text{mean}(Y_i))^2 \sum_t (\hat{Y}_{it} - \text{mean}(\hat{Y}_i))^2}}$$</p>
<h3 id="main-results">Main Results</h3>
<p>The models are evaluated at horizons $h \in {3, 6, 12, 24}$, corresponding to 3-24 hours for Traffic and Electricity, 30-240 minutes for Solar-Energy, and 3-24 days for Exchange-Rate.</p>
<p>LSTNet-Skip achieved the best result in 17 out of 32 (dataset, metric, horizon) combinations, and LSTNet-Attn won 7 more. No other method won more than 3. At horizon 24, the best LSTNet variant improved over RNN-GRU by 9.2% RSE on Solar-Energy (LSTNet-Attn), 11.7% on Traffic (LSTNet-Skip), and 22.2% on Electricity (LSTNet-Skip). On the Exchange-Rate dataset, which lacks periodic patterns, LSTNet performed comparably to AR and LRidge, as expected.</p>
<h3 id="ablation-study">Ablation Study</h3>
<p>Removing each component individually revealed:</p>
<ul>
<li><strong>Without AR</strong>: The largest performance drops across most datasets, confirming the AR component&rsquo;s role in handling scale changes. Visualization showed that LSTNet-Skip successfully tracks sudden magnitude shifts in electricity consumption around the 1000th hour, while the model without AR fails.</li>
<li><strong>Without Skip/CNN</strong>: Significant drops on datasets with periodic patterns, though less consistent than removing AR.</li>
<li><strong>Full LSTNet</strong>: The most robust configuration across all datasets and horizons.</li>
</ul>
<p>A simulation experiment with synthetic autoregressive data confirmed that standard RNN-GRU fails to track non-periodic scale changes, while LSTNet with its AR component adapts properly.</p>
<h2 id="robust-performance-through-architectural-complementarity">Robust Performance Through Architectural Complementarity</h2>
<p>LSTNet&rsquo;s main strength is the complementarity of its components. The CNN captures short-term local patterns, the recurrent-skip layer captures long-term periodic dependencies, and the AR component provides robustness to scale changes. On datasets with strong periodicity (Traffic, Solar-Energy, Electricity), the skip connections provide large gains. On datasets without periodicity (Exchange-Rate), the AR component prevents degradation below competitive baselines.</p>
<p>The primary limitation is that the skip length $p$ in the recurrent-skip component must be manually specified or tuned. For datasets with known periodicity (e.g., hourly data with daily cycles), $p$ is straightforward to set. For datasets without clear periodicity, $p$ must be tuned as a hyperparameter, and the attention-based variant (LSTNet-Attn) offers an alternative that avoids this requirement. Future work directions include automatic period detection and incorporating variable-level attribute information into the convolutional layer.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Evaluation</td>
          <td>Traffic</td>
          <td>17,544 x 862</td>
          <td>California DoT highway occupancy, hourly, 2015-2016</td>
      </tr>
      <tr>
          <td>Training/Evaluation</td>
          <td>Solar-Energy</td>
          <td>52,560 x 137</td>
          <td>Solar power from 137 PV plants in Alabama, 10-min intervals, 2006</td>
      </tr>
      <tr>
          <td>Training/Evaluation</td>
          <td>Electricity</td>
          <td>26,304 x 321</td>
          <td>kWh consumption for 321 clients, hourly, 2012-2014</td>
      </tr>
      <tr>
          <td>Training/Evaluation</td>
          <td>Exchange-Rate</td>
          <td>7,588 x 8</td>
          <td>Daily exchange rates for 8 countries, 1990-2016</td>
      </tr>
  </tbody>
</table>
<p>All datasets are publicly available via the <a href="https://github.com/laiguokun/LSTNet">GitHub repository</a>.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Optimizer: Adam</li>
<li>Dropout: 0.1 or 0.2 after each layer except input and output</li>
<li>Window size $q$: grid search over ${2^0, 2^1, \ldots, 2^9}$</li>
<li>Skip length $p$: set to 24 for Traffic/Electricity; tuned from $2^1$ to $2^6$ for Solar-Energy/Exchange-Rate</li>
<li>Objective: L2 loss (Eq. 7) or L1 loss (Eq. 9), selected via validation</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Hidden dimensions (Recurrent/CNN): ${50, 100, 200}$</li>
<li>Hidden dimensions (Recurrent-skip): ${20, 50, 100}$</li>
<li>AR regularization: ${0.1, 1, 10}$</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best LSTNet RSE</th>
          <th>Baseline (RNN-GRU)</th>
          <th>Improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Solar-Energy (h=24)</td>
          <td>0.4403 (Attn)</td>
          <td>0.4852</td>
          <td>9.2%</td>
      </tr>
      <tr>
          <td>Traffic (h=24)</td>
          <td>0.4973 (Skip)</td>
          <td>0.5633</td>
          <td>11.7%</td>
      </tr>
      <tr>
          <td>Electricity (h=24)</td>
          <td>0.1007 (Skip)</td>
          <td>0.1295</td>
          <td>22.2%</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/laiguokun/LSTNet">LSTNet (laiguokun/LSTNet)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official PyTorch implementation (Python 2.7, PyTorch 0.3.0)</td>
      </tr>
      <tr>
          <td><a href="https://github.com/laiguokun/multivariate-time-series-data">Multivariate Time Series Data (laiguokun/multivariate-time-series-data)</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>Preprocessed benchmark datasets (Traffic, Solar-Energy, Electricity, Exchange-Rate)</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status</strong>: Highly Reproducible. Code and all four benchmark datasets are publicly available. Hyperparameter search ranges are fully specified.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Lai, G., Chang, W.-C., Yang, Y., &amp; Liu, H. (2018). Modeling Long- and Short-Term Temporal Patterns with Deep Neural Networks. <em>The 41st International ACM SIGIR Conference on Research &amp; Development in Information Retrieval (SIGIR &lsquo;18)</em>, 95-104. <a href="https://doi.org/10.1145/3209978.3210006">https://doi.org/10.1145/3209978.3210006</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{lai2018modeling,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Modeling Long- and Short-Term Temporal Patterns with Deep Neural Networks}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Lai, Guokun and Chang, Wei-Cheng and Yang, Yiming and Liu, Hanxiao}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{The 41st International ACM SIGIR Conference on Research \&amp; Development in Information Retrieval}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{95--104}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1145/3209978.3210006}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Atom-Density Representations for Machine Learning</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/atom-density-representations-ml/</link><pubDate>Sat, 11 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/atom-density-representations-ml/</guid><description>A unified bra-ket framework connecting SOAP, Behler-Parrinello, and other atom-density representations for ML on molecules and materials.</description><content:encoded><![CDATA[<h2 id="a-unified-theory-of-atom-density-representations">A Unified Theory of Atom-Density Representations</h2>
<p>This is a <strong>Theory</strong> paper that provides a formal, basis-independent framework for constructing structural representations of atomic systems for machine learning. Rather than proposing a new representation, Willatt, Musil, and Ceriotti show that many popular approaches (SOAP power spectra, Behler-Parrinello symmetry functions, $n$-body kernels, and tensorial SOAP) are special cases of a single abstract construction based on smoothed atom densities and <a href="https://en.wikipedia.org/wiki/Haar_measure">Haar integration</a> over symmetry groups.</p>
<h2 id="the-challenge-of-representing-atomic-structures">The Challenge of Representing Atomic Structures</h2>
<p>Machine learning models for predicting molecular and materials properties require input representations that are (1) complete enough to distinguish structurally distinct configurations and (2) invariant to physical symmetries (translations, rotations, and permutations of identical atoms). This has led to a large and growing set of competing approaches: <a href="/posts/molecular-descriptor-coulomb-matrix/">Coulomb matrices</a>, symmetry functions, <a href="https://en.wikipedia.org/wiki/Radial_distribution_function">radial distribution functions</a>, wavelets, invariant polynomials, and many more.</p>
<p>The proliferation of representations makes it difficult to compare them on equal footing or to identify which design choices are fundamental and which are incidental. Internal-coordinate approaches (e.g., Coulomb matrices) are automatically translation- and rotation-invariant but require additional symmetrization over permutations, which can introduce derivative discontinuities when done via sorting. Density-based approaches such as radial distribution functions and SOAP avoid these discontinuities by working with smooth density fields, but their theoretical connections to one another have not been made explicit.</p>
<h2 id="dirac-notation-for-atomic-environments">Dirac Notation for Atomic Environments</h2>
<p>The core innovation is to describe an atomic configuration $\mathcal{A}$ as a ket $|\mathcal{A}\rangle$ in a Hilbert space, formed by placing smooth functions $g(\mathbf{r})$ (typically Gaussians) on each atom and decorating them with orthonormal element kets $|\alpha\rangle$:</p>
<p>$$
\langle \mathbf{r} | \mathcal{A} \rangle = \sum_{i} g(\mathbf{r} - \mathbf{r}_{i}) | \alpha_{i} \rangle
$$</p>
<p>This ket is basis-independent, which is the reason for adopting the Dirac notation. The same abstract object can be projected onto position space, reciprocal space, or a basis of radial functions and spherical harmonics, yielding different concrete representations that all encode the same structural information.</p>
<h3 id="symmetrization-via-haar-integration">Symmetrization via Haar Integration</h3>
<p>To impose translational invariance, the ket is averaged over the translation group. Averaging the raw density $|\mathcal{A}\rangle$ directly (first order, $\nu = 1$) discards all geometric information and retains only atom counts per element. The solution is to first take tensor products and then average:</p>
<p>$$
\left| \mathcal{A}^{(\nu)} \right\rangle_{\hat{t}} = \int \mathrm{d}\hat{t} \underbrace{\hat{t}|\mathcal{A}\rangle \otimes \hat{t}|\mathcal{A}\rangle \cdots \hat{t}|\mathcal{A}\rangle}_{\nu}
$$</p>
<p>For $\nu = 2$, this yields a translationally-invariant ket that encodes pairwise distance information between atoms, and naturally decomposes into atom-centered contributions:</p>
<p>$$
\left| \mathcal{A}^{(2)} \right\rangle_{\hat{t}} = \sum_{j} |\alpha_{j}\rangle |\mathcal{X}_{j}\rangle
$$</p>
<p>where $|\mathcal{X}_{j}\rangle$ is the environment ket centered on atom $j$, defined with a smooth cutoff function $f_{c}(r_{ij})$ that restricts each environment to a spherical neighborhood (justified by the nearsightedness principle of electronic matter). This decomposition is what justifies the widely used additive kernel between structures (a sum of kernels between environments).</p>
<h3 id="rotational-invariance-and-body-order-correlations">Rotational Invariance and Body-Order Correlations</h3>
<p>The same Haar integration procedure over the $SO(3)$ rotation group produces rotationally invariant representations:</p>
<p>$$
\left| \mathcal{X}^{(\nu)} \right\rangle_{\hat{R}} = \int \mathrm{d}\hat{R} \prod_{\aleph}^{\nu} \otimes \hat{R}\hat{U}_{\aleph}|\mathcal{X}_{j}^{\aleph}\rangle
$$</p>
<p>The order $\nu$ of the tensor product before symmetrization determines the body-order of correlations captured: $\nu$ corresponds to $(\nu + 1)$-body correlations. The $\nu = 1$ invariant ket retains only radial (distance) information (two-body). The $\nu = 2$ ket encodes three-body correlations (two distances and an angle), and is argued to be sufficient for unique reconstruction of a configuration (up to inversion symmetry), based on extensive numerical experiments. Using nonlinear kernels (tensor products of the symmetrized ket, parameterized by $\zeta$) allows the model to incorporate higher body-order correlations beyond those explicitly in the feature vector.</p>
<h2 id="recovering-soap-symmetry-functions-and-tensorial-extensions">Recovering SOAP, Symmetry Functions, and Tensorial Extensions</h2>
<p>By projecting the abstract invariant kets onto specific basis sets, the authors recover several well-known frameworks as special cases.</p>
<h3 id="behler-parrinello-symmetry-functions">Behler-Parrinello Symmetry Functions</h3>
<p>In the $\delta$-function limit of the atomic density, the $\nu = 1$ and $\nu = 2$ invariant kets in real space directly correspond to the 2-body and 3-body correlation functions. Behler-Parrinello symmetry functions are projections of these correlation functions onto suitable test functions $G$:</p>
<p>$$
\langle \alpha \beta G_{2} | \mathcal{X}_{j} \rangle = \langle \alpha | \alpha_{j} \rangle \int \mathrm{d}r, G_{2}(r), r \left\langle \beta r | \mathcal{X}_{j}^{(1)} \right\rangle_{\hat{R},, h \to \delta}
$$</p>
<p>where the $h \to \delta$ subscript indicates the Dirac delta limit of the atomic density.</p>
<h3 id="soap-power-spectrum">SOAP Power Spectrum</h3>
<p>Expanding the environmental ket in a basis of radial functions $R_{n}(r)$ and <a href="https://en.wikipedia.org/wiki/Spherical_harmonics">spherical harmonics</a> $Y_{m}^{l}(\hat{\mathbf{r}})$, the $\nu = 2$ invariant ket is the SOAP power spectrum:</p>
<p>$$
\left\langle \alpha n \alpha&rsquo; n&rsquo; l \left| \mathcal{X}_{j}^{(2)} \right\rangle_{\hat{R}} \propto \frac{1}{\sqrt{2l+1}} \sum_{m} \langle \alpha n l m | \mathcal{X}_{j} \rangle^{\star} \langle \alpha&rsquo; n&rsquo; l m | \mathcal{X}_{j} \rangle \right.
$$</p>
<p>This identity shows that the SOAP kernel, which can be expressed as a scalar product between truncated power spectrum vectors, is a natural consequence of the inner product between invariant kets. The $\nu = 3$ case yields the <a href="https://en.wikipedia.org/wiki/Bispectrum">bispectrum</a>, used as a four-body feature vector in both SOAP and Spectral Neighbor Analysis Potentials (SNAP), where its high resolution enables accurate interatomic potentials through linear regression:</p>
<p>$$
\langle \alpha_{1} n_{1} l_{1}, \alpha_{2} n_{2} l_{2}, \alpha n l | \mathcal{X}_{j}^{(3)} \rangle_{\hat{R}} \propto \frac{1}{\sqrt{2l+1}} \sum_{m, m_{1}, m_{2}} \langle \mathcal{X}_{j} | \alpha n l m \rangle \langle \alpha_{1} n_{1} l_{1} m_{1} | \mathcal{X}_{j} \rangle \langle \alpha_{2} n_{2} l_{2} m_{2} | \mathcal{X}_{j} \rangle \langle l_{1} m_{1} l_{2} m_{2} | l m \rangle
$$</p>
<p>where $\langle l_{1} m_{1} l_{2} m_{2} | l m \rangle$ is a <a href="https://en.wikipedia.org/wiki/Clebsch%E2%80%93Gordan_coefficients">Clebsch-Gordan coefficient</a>.</p>
<h3 id="tensorial-soap-lambda-soap">Tensorial SOAP ($\lambda$-SOAP)</h3>
<p>The tensorial extension of SOAP incorporates an angular momentum ket $|\lambda \mu\rangle$ into the tensor product before symmetrization:</p>
<p>$$
\left| \mathcal{X}^{(\nu)} \lambda \mu \right\rangle_{\hat{R}} = \int \mathrm{d}\hat{R}, \hat{R}|\lambda \mu\rangle \prod_{\aleph=1}^{\nu} \otimes \hat{R}|\mathcal{X}_{j}\rangle
$$</p>
<p>This construction is rotationally invariant in the full product space but covariant in the subspace of atomic environments, enabling models for tensorial properties (e.g., polarizability tensors, chemical shielding).</p>
<h3 id="distributions-vs-sorted-vectors">Distributions vs. Sorted Vectors</h3>
<p>The paper also connects density-based and sorted-vector approaches. Given a set of structural descriptors ${a_{i}}$, the sorted vector is equivalent to the inverse cumulative distribution function of the histogram of values. The Euclidean distance between sorted vectors is the $\mathcal{L}^{2}$ norm of the difference between the inverse CDFs, and the $\mathcal{L}^{1}$ norm corresponds to the <a href="https://en.wikipedia.org/wiki/Earth_mover%27s_distance">earth mover&rsquo;s distance</a>. This highlights that different symmetrization strategies encode essentially the same structural information.</p>
<h2 id="generalized-operators-for-tuning-representations">Generalized Operators for Tuning Representations</h2>
<p>The framework becomes especially powerful through the introduction of a linear Hermitian operator $\hat{U}$ that transforms the density ket before symmetrization. This operator must commute with rotations:</p>
<p>$$
\langle \alpha n l m | \hat{U} | \alpha&rsquo; n&rsquo; l&rsquo; m&rsquo; \rangle = \delta_{ll&rsquo;} \delta_{mm&rsquo;} \langle \alpha n l | \hat{U} | \alpha&rsquo; n&rsquo; l&rsquo; \rangle
$$</p>
<p>Several practical modifications to standard representations can be understood as choices of $\hat{U}$:</p>
<h3 id="dimensionality-reduction">Dimensionality Reduction</h3>
<p>A low-rank expansion of $\hat{U}$ via PCA on the spherical-harmonic covariance matrix of environments identifies linearly independent components, enabling compression of the feature vector. For a given $l$, the covariance matrix between spherical expansion coefficients is:</p>
<p>$$
C_{\alpha n \alpha&rsquo; n&rsquo;}^{(l)} = \frac{1}{N} \sum_{j} \sum_{m} \langle \alpha n l m | \mathcal{X}_{j} \rangle^{\star} \langle \mathcal{X}_{j} | \alpha&rsquo; n&rsquo; l m \rangle = \frac{\sqrt{2l+1}}{N} \sum_{j} \left\langle \alpha n, \alpha&rsquo; n&rsquo; l \left| \mathcal{X}_{j}^{(2)} \right\rangle_{\hat{R}} \right.
$$</p>
<p>The eigenvectors of $\mathbf{C}^{(l)}$ provide the mixing coefficients for a compressed representation, retaining only components with significant eigenvalues.</p>
<h3 id="radial-scaling">Radial Scaling</h3>
<p>In systems with relatively uniform atom density, the overlap kernel is dominated by the region farthest from the center. A radial scaling operator $u(r)$ (diagonal in position space) downweights distant contributions:</p>
<p>$$
\langle \alpha \mathbf{r} | \hat{U} | \mathcal{X}_{j} \rangle = u(r), \psi_{\mathcal{X}_{j}}^{\alpha}(\mathbf{r})
$$</p>
<p>This recovers the multi-scale kernels that are known to improve predictions in practice, and connects to the two-body features of Faber et al.</p>
<h3 id="alchemical-kernels">Alchemical Kernels</h3>
<p>An operator that acts only in chemical-element space introduces correlations between different elements. The &ldquo;alchemical&rdquo; projection:</p>
<p>$$
\langle J \mathbf{r} | \mathcal{X}_{j} \rangle = \sum_{\alpha} u_{J\alpha}, \psi_{\mathcal{X}_{j}}^{\alpha}(\mathbf{r})
$$</p>
<p>reduces the dimensionality from $O(n_{\mathrm{sp}}^{2})$ to $O(d_{J}^{2})$ and has been shown to produce a low-dimensional representation of elemental space that shares similarities with periodic-table groupings.</p>
<h3 id="non-factorizable-operators">Non-Factorizable Operators</h3>
<p>For more complex modifications (e.g., distance- and angle-dependent scaling of three-body correlations), the operator must act on the full product space rather than factoring into independent components. The authors show that the three-body scaling function of Faber et al. corresponds to a diagonal non-factorizable operator in the real-space representation.</p>
<h2 id="implications-and-future-directions">Implications and Future Directions</h2>
<p>The main conclusions are:</p>
<ol>
<li>
<p><strong>Unification</strong>: SOAP, Behler-Parrinello symmetry functions, $\lambda$-SOAP, bispectrum descriptors, and sorted-vector approaches all emerge from the same abstract construction, differing only in the choice of basis set, body order $\nu$, and kernel power $\zeta$.</p>
</li>
<li>
<p><strong>Systematic improvability</strong>: The $\hat{U}$ operator framework provides a principled way to tune representations, from simple radial scaling to full alchemical and non-factorizable couplings, with clear connections to existing heuristic modifications.</p>
</li>
<li>
<p><strong>Completeness hierarchy</strong>: The body-order parameter $\nu$ and kernel power $\zeta$ together control the trade-off between completeness and computational cost. The $\nu = 2$ (three-body) representation appears to be sufficient for unique structural identification, while higher orders can be recovered through nonlinear kernels.</p>
</li>
</ol>
<p><strong>Limitations</strong>: The paper is primarily theoretical and does not include extensive numerical benchmarks comparing the different instantiations of the framework. Optimization of the $\hat{U}$ operator (especially in its general form) carries a risk of overfitting that the authors acknowledge but do not resolve. The connection to neural network-based representations (message-passing networks, equivariant architectures) is not explored.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>This is a theoretical paper that does not introduce new benchmark results. No training or test datasets are used. The ethanol molecule in Figure 2 serves as a visualization example for three-body correlation functions.</p>
<h3 id="algorithms">Algorithms</h3>
<p>The paper defines abstract constructions (Haar integration, tensor products, operator transformations) and shows how they reduce to concrete algorithms:</p>
<ul>
<li><strong>SOAP power spectrum</strong>: Expand atom density in radial functions and spherical harmonics, compute $\nu = 2$ invariant ket (Eq. 33).</li>
<li><strong>Alchemical projection</strong>: Contract element channels via learned or PCA-derived mixing coefficients (Eq. 53-55).</li>
<li><strong>Dimensionality reduction</strong>: PCA on the covariance matrix $C_{\alpha n \alpha&rsquo; n&rsquo;}^{(l)}$ of spherical expansion coefficients (Eq. 45).</li>
</ul>
<h3 id="models">Models</h3>
<p>No trained models are presented. The framework applies to kernel ridge regression / Gaussian process regression models that use these representations as inputs.</p>
<h3 id="evaluation">Evaluation</h3>
<p>No quantitative benchmarks are reported. The contribution is the theoretical framework itself, connecting and generalizing existing representations.</p>
<h3 id="hardware">Hardware</h3>
<p>Not applicable (theoretical work).</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Willatt, M. J., Musil, F., &amp; Ceriotti, M. (2019). Atom-density representations for machine learning. <em>The Journal of Chemical Physics</em>, 150(15), 154110. <a href="https://doi.org/10.1063/1.5090481">https://doi.org/10.1063/1.5090481</a></p>
<p><strong>Publication</strong>: The Journal of Chemical Physics, 2019</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{willatt2019atom,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Atom-density representations for machine learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Willatt, Michael J. and Musil, F{\&#39;e}lix and Ceriotti, Michele}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{The Journal of Chemical Physics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{150}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{154110}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2019}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1063/1.5090481}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{1807.00408}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span>=<span style="color:#e6db74">{physics.chem-ph}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>T5: Exploring Transfer Learning Limits</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/</guid><description>Raffel et al. systematically study transfer learning for NLP with a text-to-text framework, ablating architectures, objectives, data, and multi-task mixing.</description><content:encoded><![CDATA[<h2 id="a-systematic-study-of-nlp-transfer-learning">A systematic study of NLP transfer learning</h2>
<p>This is a <strong>systematization paper</strong> that provides a comprehensive empirical survey of transfer learning techniques for NLP. Rather than proposing a single new method, T5 introduces a unified text-to-text framework and uses it as a testbed to systematically compare pre-training objectives, architectures, unlabeled data sources, transfer approaches, and multi-task mixing strategies. The scale of the ablation study (covering dozens of configurations) and the release of C4, pre-trained models, and code make it both a reference guide and a resource.</p>
<h2 id="unifying-nlp-tasks-as-text-to-text">Unifying NLP tasks as text-to-text</h2>
<p>The core design decision is to cast every NLP task as a text-to-text problem: both the input and output are text strings, with a task-specific prefix. Classification, regression, summarization, translation, and question answering all use the same model, loss function (cross-entropy on output tokens), and decoding procedure. This simplicity enables fair comparison across tasks and training strategies.</p>
<p>The model architecture is a standard encoder-decoder Transformer. The paper finds that this form outperforms decoder-only (language model) and encoder-only (BERT-style) variants in the text-to-text setting, while having similar computational cost to decoder-only models despite twice the parameters (the encoder processes the input only once, then the decoder attends to it).</p>
<h2 id="multi-task-mixing-strategies-and-findings">Multi-task mixing: strategies and findings</h2>
<p>The most thesis-relevant contribution is the systematic ablation of multi-task mixing strategies (Section 3.5.2). When training on multiple tasks simultaneously (which in the text-to-text framework simply means mixing data from different sources), the central question is how to set the proportion of data from each task.</p>
<h3 id="three-mixing-strategies">Three mixing strategies</h3>
<p><strong>Examples-proportional mixing.</strong> Sample in proportion to each dataset&rsquo;s size, with an artificial cap $K$ on the maximum dataset size. Without the cap, the unsupervised pre-training data (orders of magnitude larger) would dominate all batches. The mixing rate for task $m$ is:</p>
<p>$$
r_{m} = \frac{\min(e_{m}, K)}{\sum_{n} \min(e_{n}, K)}
$$</p>
<p>where $e_{m}$ is the number of examples in task $m$&rsquo;s dataset.</p>
<p><strong>Temperature-scaled mixing.</strong> Raise each mixing rate $r_{m}$ to the power $1/T$ and renormalize. At $T=1$ this equals examples-proportional mixing; as $T$ increases, proportions approach equal mixing. Uses a large cap $K = 2^{21}$.</p>
<p><strong>Equal mixing.</strong> Sample uniformly from all tasks. Included as a negative reference: the model overfits on low-resource tasks and underfits on high-resource tasks.</p>
<h3 id="results">Results</h3>
<table>
  <thead>
      <tr>
          <th>Mixing strategy</th>
          <th>GLUE</th>
          <th>CNN/DM</th>
          <th>SQuAD</th>
          <th>SuperGLUE</th>
          <th>EnDe</th>
          <th>EnFr</th>
          <th>EnRo</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Baseline (pre-train/fine-tune)</td>
          <td>83.28</td>
          <td>19.24</td>
          <td>80.88</td>
          <td>71.36</td>
          <td>26.98</td>
          <td>39.82</td>
          <td>27.65</td>
      </tr>
      <tr>
          <td>Equal</td>
          <td>76.13</td>
          <td>19.02</td>
          <td>76.51</td>
          <td>63.37</td>
          <td>23.89</td>
          <td>34.31</td>
          <td>26.78</td>
      </tr>
      <tr>
          <td>Examples-proportional, $K=2^{18}$</td>
          <td>81.67</td>
          <td>19.07</td>
          <td>78.17</td>
          <td>67.94</td>
          <td>24.57</td>
          <td>35.19</td>
          <td>27.39</td>
      </tr>
      <tr>
          <td>Examples-proportional, $K=2^{19}$</td>
          <td>81.42</td>
          <td>19.24</td>
          <td>79.78</td>
          <td>67.30</td>
          <td>25.21</td>
          <td>36.30</td>
          <td>27.76</td>
      </tr>
      <tr>
          <td>Temperature-scaled, $T=2$</td>
          <td>81.90</td>
          <td>19.28</td>
          <td>79.42</td>
          <td>69.92</td>
          <td>25.42</td>
          <td>36.72</td>
          <td>27.20</td>
      </tr>
  </tbody>
</table>
<p><strong>Key findings on mixing:</strong></p>
<ol>
<li>
<p><strong>Multi-task training underperforms pre-train-then-fine-tune on most tasks.</strong> No mixing strategy matches the baseline of unsupervised pre-training followed by task-specific fine-tuning.</p>
</li>
<li>
<p><strong>Equal mixing is worst.</strong> It dramatically degrades performance, confirming that proportions matter.</p>
</li>
<li>
<p><strong>There exists a task-specific sweet spot for the cap $K$.</strong> Most tasks have an optimal $K$ value; larger or smaller values hurt. The exception is very high-resource tasks (WMT English-French) that always benefit from higher mixing proportions.</p>
</li>
<li>
<p><strong>Temperature scaling at $T=2$ provides the best single compromise.</strong> It achieves reasonable performance across all tasks without requiring per-task tuning of $K$.</p>
</li>
<li>
<p><strong>Multi-task pre-training followed by fine-tuning closes the gap.</strong> When multi-task training is used as pre-training (not as the final training stage), followed by task-specific fine-tuning, performance becomes comparable to unsupervised pre-training alone. This suggests that multi-task exposure during pre-training provides useful early signal without the negative effects of forcing a single model to perform all tasks simultaneously.</p>
</li>
<li>
<p><strong>&ldquo;Leave-one-out&rdquo; training works.</strong> Pre-training on a multi-task mixture that excludes a target task, then fine-tuning on it, produces only slightly worse results. This indicates that multi-task pre-training builds general capabilities that transfer to unseen tasks without dramatic task interference.</p>
</li>
</ol>
<h2 id="data-repetition-degrades-performance">Data repetition degrades performance</h2>
<p>The paper also systematically tests the effect of pre-training data set size by truncating C4 and training over repeated data:</p>
<table>
  <thead>
      <tr>
          <th>Unique tokens</th>
          <th>Repeats</th>
          <th>GLUE</th>
          <th>SQuAD</th>
          <th>SuperGLUE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Full dataset</td>
          <td>0</td>
          <td>83.28</td>
          <td>80.88</td>
          <td>71.36</td>
      </tr>
      <tr>
          <td>$2^{29}$</td>
          <td>64</td>
          <td>82.87</td>
          <td>80.97</td>
          <td>72.03</td>
      </tr>
      <tr>
          <td>$2^{27}$</td>
          <td>256</td>
          <td>82.62</td>
          <td>79.78</td>
          <td>69.97</td>
      </tr>
      <tr>
          <td>$2^{25}$</td>
          <td>1,024</td>
          <td>79.55</td>
          <td>76.27</td>
          <td>64.76</td>
      </tr>
      <tr>
          <td>$2^{23}$</td>
          <td>4,096</td>
          <td>76.34</td>
          <td>70.92</td>
          <td>59.29</td>
      </tr>
  </tbody>
</table>
<p>Performance degrades as data shrinks, with 64 repeats showing limited effects but 1,024+ repeats causing significant degradation. Training loss curves confirm memorization at high repetition counts. The paper recommends using large, diverse pre-training datasets whenever possible.</p>
<h2 id="scaling-and-final-configuration">Scaling and final configuration</h2>
<p>The paper compares scaling strategies: more data, larger models, and ensembles. Training a larger model for fewer steps generally outperforms training a smaller model on more data. Ensembles of independently pre-trained and fine-tuned models provide orthogonal gains.</p>
<p>The final T5-11B model combines the best choices from all ablations: encoder-decoder architecture, span corruption objective, C4 pre-training data, multi-task pre-training followed by fine-tuning, and scaling to 11B parameters trained on over 1 trillion tokens. It achieves state-of-the-art results on GLUE (90.3 average), SuperGLUE (88.9, near human performance of 89.8), SQuAD, and CNN/Daily Mail. It does not achieve state-of-the-art on WMT translation tasks, where methods using backtranslation and cross-lingual pre-training retain the lead.</p>
<h2 id="implications-and-limitations">Implications and limitations</h2>
<p>The T5 paper&rsquo;s multi-task mixing findings are its most enduring contribution beyond the model itself. The core lessons: proportions matter enormously (equal mixing fails), examples-proportional mixing with a cap is a reasonable default, temperature scaling provides a single-knob alternative, and multi-task pre-training followed by fine-tuning can match pure unsupervised pre-training.</p>
<p><strong>Limitations:</strong></p>
<ul>
<li>All ablations use the same encoder-decoder architecture. Findings may not transfer to decoder-only models that dominate current practice.</li>
<li>The multi-task mixing experiments treat each task as a separate &ldquo;domain.&rdquo; Interactions between similar tasks (e.g., multiple classification tasks) are not isolated.</li>
<li>The paper does not provide a principled method for choosing $K$ or $T$; both require empirical search.</li>
<li>C4 has known quality issues (templated text, noisy content) that have been addressed in later datasets.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Status: Highly Reproducible.</strong> Code, pre-trained models, and the C4 dataset are all publicly released.</p>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>C4 (Colossal Clean Crawled Corpus)</td>
          <td>~750 GB</td>
          <td>Heuristically cleaned Common Crawl</td>
      </tr>
      <tr>
          <td>Downstream</td>
          <td>GLUE, SuperGLUE, SQuAD, CNN/DM, WMT (EnDe, EnFr, EnRo)</td>
          <td>Standard splits</td>
          <td>Text-to-text format</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<p>Encoder-decoder Transformer. Sizes: Base (220M), Small (60M), Large (770M), 3B, 11B. Baseline uses Base size. SentencePiece vocabulary with 32K tokens. Pre-trained for $2^{19}$ steps, fine-tuned for $2^{18}$ steps on individual tasks.</p>
<h3 id="algorithms">Algorithms</h3>
<p>Multi-task mixing: examples-proportional with cap $K \in {2^{16}, \ldots, 2^{21}}$, temperature-scaled with $T \in {2, 4, 8}$, and equal mixing. Unsupervised objective: span corruption (mean span length 3, 15% corruption rate). Training with Adafactor optimizer, inverse square root learning rate schedule.</p>
<h3 id="hardware">Hardware</h3>
<p>All models trained using Mesh TensorFlow on TPU slices. T5-11B pre-trained for 1M steps with batch size $2^{11}$ sequences of length 512 (~1 trillion tokens total). Exact TPU pod configurations per experiment not detailed.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/google-research/text-to-text-transfer-transformer">T5 Code</a></td>
          <td>Code</td>
          <td>Apache 2.0</td>
          <td>Official TensorFlow implementation (JAX successor: T5X)</td>
      </tr>
      <tr>
          <td><a href="https://github.com/google-research/text-to-text-transfer-transformer#released-model-checkpoints">T5 Models</a></td>
          <td>Model</td>
          <td>Apache 2.0</td>
          <td>Pre-trained checkpoints (Small through 11B)</td>
      </tr>
      <tr>
          <td><a href="https://www.tensorflow.org/datasets/catalog/c4">C4 Dataset</a></td>
          <td>Dataset</td>
          <td>-</td>
          <td>~750 GB cleaned Common Crawl, via TensorFlow Datasets</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{raffel2020exploring,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Raffel, Colin and Shazeer, Noam and Roberts, Adam and Lee, Katherine and Narang, Sharan and Matena, Michael and Zhou, Yanqi and Li, Wei and Liu, Peter J.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Machine Learning Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{21}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{140}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1--67}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SlimPajama-DC: Data Combinations for LLM Training</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/slimpajama-dc-data-combinations/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/slimpajama-dc-data-combinations/</guid><description>Shen et al. study how global deduplication and domain combinations in SlimPajama affect LLM training, finding diversity after dedup is key.</description><content:encoded><![CDATA[<h2 id="an-empirical-study-of-data-domain-combinations">An empirical study of data domain combinations</h2>
<p>This is a <strong>discovery paper</strong> that empirically investigates how different combinations and proportions of data domains affect language model pretraining. Using the SlimPajama dataset (a globally deduplicated, 627B token refinement of RedPajama), the study trains seven 1.3B model configurations with varying domain mixtures to identify which combinations and deduplication strategies produce the best downstream performance.</p>
<h2 id="why-data-combination-strategy-matters">Why data combination strategy matters</h2>
<p>Multi-source pretraining datasets combine data from web crawls, code repositories, books, academic papers, and other sources. Two underexplored questions drive this work: (1) Does deduplication within each source (local) versus across all sources (global) meaningfully affect model quality? (2) When sources are thoroughly deduplicated, how does the combination and proportion of domains affect downstream performance? Most open-source LLM training datasets (RedPajama, The Pile) perform only local deduplication, leaving cross-source redundancy unaddressed.</p>
<h2 id="global-deduplication-and-the-slimpajama-dataset">Global deduplication and the SlimPajama dataset</h2>
<p>SlimPajama applies global MinHashLSH deduplication (Jaccard similarity threshold 0.8, 13-gram signatures) across all seven data sources simultaneously. This reduces RedPajama&rsquo;s 1.2T tokens to 627B tokens, a roughly 48% reduction. The heaviest deduplication hits CommonCrawl and GitHub, which had the most cross-source overlap.</p>
<p>The key processing steps:</p>
<ol>
<li><strong>Low-length document filtering</strong>: Remove documents below a minimum length threshold.</li>
<li><strong>Global deduplication</strong>: MinHashLSH across all sources simultaneously, requiring 64 CPU cores and 1.4TB peak memory. This removes both within-source and between-source duplicates.</li>
</ol>
<p>The resulting dataset composition:</p>
<table>
  <thead>
      <tr>
          <th>Source</th>
          <th>SlimPajama</th>
          <th>RedPajama</th>
          <th>LLaMA 1</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CommonCrawl</td>
          <td>52.2% (333B)</td>
          <td>72.6% (878B)</td>
          <td>67.0%</td>
      </tr>
      <tr>
          <td>C4</td>
          <td>26.7% (170B)</td>
          <td>14.4% (175B)</td>
          <td>15.0%</td>
      </tr>
      <tr>
          <td>GitHub</td>
          <td>5.2% (33B)</td>
          <td>4.9% (59B)</td>
          <td>4.5%</td>
      </tr>
      <tr>
          <td>Books</td>
          <td>4.2% (27B)</td>
          <td>2.1% (26B)</td>
          <td>4.5%</td>
      </tr>
      <tr>
          <td>ArXiv</td>
          <td>4.6% (29B)</td>
          <td>2.3% (28B)</td>
          <td>2.5%</td>
      </tr>
      <tr>
          <td>Wikipedia</td>
          <td>3.8% (24B)</td>
          <td>2.0% (24B)</td>
          <td>4.5%</td>
      </tr>
      <tr>
          <td>StackExchange</td>
          <td>3.3% (21B)</td>
          <td>1.7% (20B)</td>
          <td>2.0%</td>
      </tr>
  </tbody>
</table>
<h2 id="seven-domain-combination-configurations">Seven domain combination configurations</h2>
<p>All configurations train 1.3B parameter models on 330B tokens with identical architecture and hyperparameters. The configurations systematically vary domain diversity:</p>
<ul>
<li><strong>DC-1</strong>: CommonCrawl only (single source)</li>
<li><strong>DC-2</strong>: CommonCrawl + C4 (two web sources)</li>
<li><strong>DC-3</strong>: CommonCrawl + C4 with adjusted proportions</li>
<li><strong>DC-4</strong>: Wikipedia + Books + GitHub + ArXiv + StackExchange (no web crawl)</li>
<li><strong>DC-5</strong>: CommonCrawl + C4 + Wikipedia + Books (four sources, no code/academic)</li>
<li><strong>DC-6</strong>: All seven SlimPajama sources (maximum diversity)</li>
<li><strong>DC-7</strong>: RefinedWeb CommonCrawl (external single-source baseline)</li>
</ul>
<p>The experimental design probes: incremental diversity (DC-1 to DC-2 to DC-5 to DC-6), proportion sensitivity (DC-2 vs DC-3), source importance (DC-3 vs DC-4), and specialization vs generalization (individual vs combined).</p>
<h2 id="diversity-after-global-deduplication-drives-performance">Diversity after global deduplication drives performance</h2>
<h3 id="hugging-face-leaderboard-results">Hugging Face leaderboard results</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Average</th>
          <th>ARC</th>
          <th>HellaSwag</th>
          <th>MMLU</th>
          <th>TruthfulQA</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RedPajama-1.3B</td>
          <td>38.0</td>
          <td>37.2</td>
          <td>55.8</td>
          <td>24.9</td>
          <td>34.3</td>
      </tr>
      <tr>
          <td>DC-1 (CC only)</td>
          <td>38.5</td>
          <td>36.3</td>
          <td>56.0</td>
          <td>27.0</td>
          <td>34.8</td>
      </tr>
      <tr>
          <td>DC-4 (no web)</td>
          <td>37.6</td>
          <td>33.4</td>
          <td>53.3</td>
          <td>26.0</td>
          <td>37.6</td>
      </tr>
      <tr>
          <td>DC-6 (all sources)</td>
          <td>40.0</td>
          <td>33.7</td>
          <td>61.0</td>
          <td>26.9</td>
          <td>38.4</td>
      </tr>
      <tr>
          <td>DC-7 (RefinedWeb)</td>
          <td>41.0</td>
          <td>35.1</td>
          <td>64.7</td>
          <td>26.2</td>
          <td>37.9</td>
      </tr>
  </tbody>
</table>
<p><strong>Key patterns:</strong></p>
<ol>
<li>
<p><strong>More domain diversity improves average performance.</strong> The progression DC-1 (38.5) to DC-2 (38.4) to DC-5 (38.6) to DC-6 (40.0) shows that adding domains consistently lifts average accuracy once global deduplication has removed cross-source redundancy.</p>
</li>
<li>
<p><strong>Global deduplication enables clean combination.</strong> All SlimPajama configurations except DC-4 outperform RedPajama-1.3B (38.0), which uses local deduplication only. The elimination of cross-source overlap means adding sources contributes genuinely new information.</p>
</li>
<li>
<p><strong>Removing web crawl data hurts.</strong> DC-4 (no CommonCrawl/C4) scores lowest (37.6), demonstrating that web text provides essential breadth even when specialized sources are included.</p>
</li>
<li>
<p><strong>Individual domains excel at specific tasks.</strong> DC-1 (CC only) achieves the highest ARC and MMLU scores. DC-4 leads on Winogrande. DC-5 leads on WSC273. No single combination dominates all tasks, reinforcing that diversity trades specialization for generalization.</p>
</li>
<li>
<p><strong>Findings transfer to 7B scale.</strong> The best 1.3B configuration insights were applied to a 7B model trained with large batch sizes, achieving 63.4 average accuracy across the extended benchmark suite.</p>
</li>
</ol>
<h3 id="training-loss-patterns">Training loss patterns</h3>
<p>DC-6 (all sources) achieves the lowest training loss among SlimPajama configurations, consistent with the downstream results. DC-4 (no web crawl) shows the highest training loss, confirming that the large, diverse web crawl data is the most important single component.</p>
<h2 id="implications-and-limitations">Implications and limitations</h2>
<p>The central finding is that <strong>diversity matters most after deduplication</strong>. When cross-source redundancy is removed, each additional source contributes genuinely new signal. Without global deduplication, adding sources may just increase redundancy without proportional benefit.</p>
<p><strong>Limitations:</strong></p>
<ul>
<li>Only seven fixed configurations are tested. No systematic search over continuous mixture proportions (contrast with <a href="/notes/natural-language-processing/language-models/doremi-data-mixture-optimization/">DoReMi</a> or <a href="/notes/natural-language-processing/language-models/data-mixing-laws-pretraining/">Data Mixing Laws</a>).</li>
<li>The configurations are not independent: DC-6 includes all sources from DC-1 through DC-5, making it difficult to isolate the contribution of any single addition.</li>
<li>Only 1.3B and 7B scales tested. Whether the diversity benefit continues scaling is unverified.</li>
<li>English-only. Cross-lingual diversity effects are not studied.</li>
<li>The paper is a technical report without formal peer review.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Status: Highly Reproducible.</strong> All 1.3B models and datasets are publicly released under MIT license on HuggingFace.</p>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>SlimPajama</td>
          <td>627B tokens</td>
          <td>Globally deduplicated from 1.2T RedPajama</td>
      </tr>
      <tr>
          <td>Training</td>
          <td>RefinedWeb</td>
          <td>600B tokens</td>
          <td>External CC-only baseline</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>HF Leaderboard (ARC, HellaSwag, MMLU, TruthfulQA)</td>
          <td>Standard</td>
          <td>4 benchmarks</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Extended suite</td>
          <td>12 additional benchmarks</td>
          <td>Zero and few-shot</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<p>1.3B parameter Cerebras-GPT architecture with ALiBi positional encoding and SwiGLU activation. All configurations trained on 330B tokens. 7B model trained with large batch-size (LBS) strategy on Cerebras 16x CS-2 cluster (80 PFLOP/s in bf16).</p>
<h3 id="hardware">Hardware</h3>
<p>Cerebras 16x CS-2 cluster, 80 PFLOP/s in bf16 mixed precision.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://huggingface.co/MBZUAI-LLM/SlimPajama-DC">SlimPajama-DC Models</a></td>
          <td>Model</td>
          <td>MIT</td>
          <td>All 1.3B DC configurations (select via revision)</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/datasets/MBZUAI-LLM/SlimPajama-627B-DC">SlimPajama-627B-DC Dataset</a></td>
          <td>Dataset</td>
          <td>-</td>
          <td>Source-split version of SlimPajama-627B</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{shen2023slimpajamadc,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{SlimPajama-DC: Understanding Data Combinations for LLM Training}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Shen, Zhiqiang and Tao, Tianhua and Ma, Liqun and Neiswanger, Willie and Liu, Zhengzhong and Wang, Hongyi and Tan, Bowen and Hestness, Joel and Vassilieva, Natalia and Soboleva, Daria and Xing, Eric}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2309.10818}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Scaling Data-Constrained Language Models</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/scaling-data-constrained-language-models/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/scaling-data-constrained-language-models/</guid><description>Muennighoff et al. extend Chinchilla scaling laws to repeated data, finding up to 4 epochs cause negligible loss and 16 epochs mark diminishing returns.</description><content:encoded><![CDATA[<h2 id="an-empirical-study-of-scaling-under-data-constraints">An empirical study of scaling under data constraints</h2>
<p>This is a <strong>discovery paper</strong> that systematically investigates what happens when language models are trained for multiple epochs on repeated data. It extends the Chinchilla scaling laws to the data-constrained regime by proposing a new scaling formula that accounts for the diminishing value of repeated tokens, validated across 400+ training runs ranging from 10M to 9B parameters and up to 1500 epochs.</p>
<h2 id="running-out-of-unique-training-data">Running out of unique training data</h2>
<p>The Chinchilla scaling laws assume unlimited unique data: for a given compute budget, there exists an optimal balance of model parameters and training tokens. But extrapolating these laws to larger models implies data requirements that exceed what is available. Villalobos et al. estimated that high-quality English text would be exhausted by 2024 under Chinchilla-optimal scaling. Most prior large language models trained for a single epoch, and some work explicitly warned against data reuse. The Galactica models (trained for 4.25 epochs) showed that multi-epoch training could work, but no systematic study had quantified the tradeoff between repeated data and fresh data, or how to allocate compute optimally when data is finite.</p>
<h2 id="effective-data-with-exponential-decay-for-repetition">Effective data with exponential decay for repetition</h2>
<p>The paper generalizes the Chinchilla scaling law by replacing raw token count $D$ with an effective data term $D&rsquo;$ that accounts for the diminishing value of repeated tokens:</p>
<p>$$
L(N, D) = \frac{A}{N&rsquo;^{\alpha}} + \frac{B}{D&rsquo;^{\beta}} + E
$$</p>
<p>where the effective data is:</p>
<p>$$
D&rsquo; = U_{D} + U_{D} R_{D}^{<em>} \left(1 - e^{-R_{D}/R_{D}^{</em>}}\right)
$$</p>
<p>Here $U_{D}$ is the number of unique tokens, $R_{D}$ is the number of repetitions (epochs minus 1), and $R_{D}^{<em>}$ is a learned constant representing the &ldquo;half-life&rdquo; of data repetition. When $R_{D} = 0$ (single epoch), $D&rsquo; = U_{D} = D$ and the formula reduces to standard Chinchilla. When $R_{D} \ll R_{D}^{</em>}$, repeated data is worth almost the same as fresh data. As $R_{D}$ grows large, the value of repeated tokens decays to zero, and $D&rsquo;$ saturates at $U_{D}(1 + R_{D}^{<em>})$, meaning no amount of repetition can substitute for more than $R_{D}^{</em>}$ epochs&rsquo; worth of fresh data.</p>
<p>A symmetric formula handles excess parameters:</p>
<p>$$
N&rsquo; = U_{N} + U_{N} R_{N}^{<em>} \left(1 - e^{-R_{N}/R_{N}^{</em>}}\right)
$$</p>
<p>where $U_{N}$ is the compute-optimal parameter count for $U_{D}$ unique tokens and $R_{N}$ measures how much the model exceeds that count. The fitted values are $R_{D}^{<em>} \approx 15.0$ (data repetition half-life at ~16 epochs) and $R_{N}^{</em>} \approx 5.3$ (excess parameters decay faster than repeated data).</p>
<h2 id="experiments-across-400-models">Experiments across 400+ models</h2>
<p><strong>Scale.</strong> Models from 10M to 9B parameters, trained for up to 1500 epochs. Three experimental protocols: fixed unique data (100M, 400M, 1.5B tokens), fixed FLOPs, and parametric fitting across all runs. Training on C4 (English web text) with GPT-2 architecture decoder-only transformers.</p>
<h3 id="resource-allocation-epochs-scale-faster-than-parameters">Resource allocation: epochs scale faster than parameters</h3>
<p>With fixed unique data, results show that more than 50% loss reduction is possible by training beyond one epoch and increasing model size beyond the single-epoch optimum. The data-constrained efficient frontier recommends allocating most additional compute to more epochs rather than more parameters, because excess parameters decay faster ($R_{N}^{<em>} &lt; R_{D}^{</em>}$). This contrasts with Chinchilla, which recommends scaling both equally.</p>
<p>A concrete validation: training the data-constrained compute-optimal model for $9.3 \times 10^{21}$ FLOPs with 25B unique tokens, the recommended allocation (27% fewer parameters, more epochs) achieves better loss and downstream performance than the Chinchilla-optimal allocation.</p>
<h3 id="resource-return-the-4-epoch-safe-zone-and-16-epoch-half-life">Resource return: the 4-epoch safe zone and 16-epoch half-life</h3>
<table>
  <thead>
      <tr>
          <th>Epochs</th>
          <th>Loss impact</th>
          <th>Downstream impact</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1 (baseline)</td>
          <td>Optimal</td>
          <td>Optimal</td>
      </tr>
      <tr>
          <td>Up to 4</td>
          <td>Negligible (+0.5% loss)</td>
          <td>No significant difference</td>
      </tr>
      <tr>
          <td>~16 ($R_{D}^{*}$)</td>
          <td>Diminishing returns begin sharply</td>
          <td>Measurable degradation</td>
      </tr>
      <tr>
          <td>Beyond 16</td>
          <td>Returns decay to near zero</td>
          <td>Significant degradation</td>
      </tr>
      <tr>
          <td>Extreme (44+)</td>
          <td>Training can diverge</td>
          <td>Failure</td>
      </tr>
  </tbody>
</table>
<p>The 8.7B parameter model trained for 4 epochs ($D_{C} = 44$B unique tokens) finishes with only 0.5% higher validation loss than the single-epoch model ($D_{C} = 178$B unique tokens). Beyond 16 epochs, each repeated token retains only $1 - 1/e \approx 63%$ of the value of a fresh token, meaning roughly 37% of value is lost per repetition cycle at the half-life point.</p>
<h3 id="complementary-strategies-code-augmentation-and-filtering">Complementary strategies: code augmentation and filtering</h3>
<p>When data is limited, two strategies can extend the effective dataset:</p>
<p><strong>Code augmentation.</strong> Mixing Python code from The Stack with natural language data. Up to 50% code (42B tokens) shows no degradation on natural language benchmarks, effectively providing a 2x increase in useful training data. Some tasks (WebNLG generation, bAbI reasoning) actually improve with code, possibly because code trains long-range state-tracking capabilities.</p>
<p><strong>Filtering relaxation.</strong> Perplexity filtering (keeping the 25% lowest-perplexity samples) is effective on noisy datasets, but deduplication filtering does not improve downstream performance (though it may reduce memorization). The recommendation: reserve aggressive filtering for noisy data sources; for clean datasets, more data through reduced filtering is better than less data through strict filtering.</p>
<p><strong>Combined strategy</strong>: doubling available data with code and then repeating for 4 epochs yields 8x more training tokens with performance expected to match 8x more unique data.</p>
<h2 id="key-findings-and-limitations">Key findings and limitations</h2>
<p><strong>Key findings:</strong></p>
<ul>
<li>Multi-epoch training is beneficial, not harmful, up to moderate repetition counts.</li>
<li>The data-constrained scaling law accurately predicts loss under repetition using an exponential decay formulation.</li>
<li>Compute should be allocated to epochs faster than parameters when data is constrained.</li>
<li>Code augmentation and selective filtering extend effective data without quality degradation.</li>
</ul>
<p><strong>Limitations:</strong></p>
<ul>
<li>All experiments use the GPT-2 transformer architecture; applicability to other architectures or modalities is untested.</li>
<li>Only the entire dataset is repeated uniformly. Selectively repeating subsets (e.g., high-value data for more epochs) is not modeled.</li>
<li>Hyperparameter sensitivity (learning rate, dropout) to epoch count is unexplored. Higher learning rates may cause earlier onset of diminishing returns.</li>
<li>Focused on English text. Cross-lingual augmentation effects are not studied.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Status: Highly Reproducible.</strong> Code, models, datasets, and hyperparameters are all publicly released under Apache 2.0.</p>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>C4 (English)</td>
          <td>Varies by experiment</td>
          <td>Fixed unique data: 100M, 400M, 1.5B tokens</td>
      </tr>
      <tr>
          <td>Code augmentation</td>
          <td>The Stack (Python)</td>
          <td>Up to 42B tokens</td>
          <td>Mixed with natural language</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>19 NL tasks</td>
          <td>Standard splits</td>
          <td>Zero to five-shot, 114 scores per model</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Data-constrained scaling law: $D&rsquo; = U_{D} + U_{D} R_{D}^{<em>}(1 - e^{-R_{D}/R_{D}^{</em>}})$ with $R_{D}^{<em>} \approx 15.0$, $R_{N}^{</em>} \approx 5.3$. Fitted using the methodology of Hoffmann et al. (2022) adapted for the repetition terms. 400+ training runs used for fitting.</p>
<h3 id="models">Models</h3>
<p>GPT-2 architecture decoder-only transformers with GPT-2 tokenizer. Sizes: 10M to 8.7B parameters. Cosine learning rate schedule (max 2e-4, decay to 2e-5), Adam optimizer ($\beta_2 = 0.999$), dropout 0.1, weight decay 0.1, gradient clipping at 1.0. bfloat16 precision. Trained using Megatron-DeepSpeed.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Data-Constrained Optimal</th>
          <th>Chinchilla Optimal</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validation loss (9.3e21 FLOPs, 25B unique)</td>
          <td>Lower</td>
          <td>Higher</td>
          <td>27% fewer parameters</td>
      </tr>
      <tr>
          <td>Downstream (4 epochs vs 1)</td>
          <td>No significant difference</td>
          <td>Baseline</td>
          <td>8.7B params, 44B unique tokens</td>
      </tr>
      <tr>
          <td>Code augmentation (50% code)</td>
          <td>No NL degradation</td>
          <td>Baseline</td>
          <td>Some tasks improve</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Trained on the LUMI supercomputer (Finland) using AMD Instinct MI250X GPUs with data, tensor, and pipeline parallelism. Up to 256 GPUs (64 nodes) per run, with up to 2,200 nodes (~8,800 GPUs) used in parallel across all concurrent runs. Total compute: approximately 3 million GPU hours. The cluster runs on 100% renewable hydroelectric energy.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/huggingface/datablations">datablations</a></td>
          <td>Code + Models + Data</td>
          <td>Apache 2.0</td>
          <td>All 400+ models, datasets, and training code</td>
      </tr>
      <tr>
          <td><a href="https://github.com/TurkuNLP/Megatron-DeepSpeed">Megatron-DeepSpeed fork</a></td>
          <td>Code</td>
          <td>-</td>
          <td>Training framework adapted for AMD ROCm</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{muennighoff2023scaling,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Scaling Data-Constrained Language Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Muennighoff, Niklas and Rush, Alexander M. and Barak, Boaz and Le Scao, Teven and Piktus, Aleksandra and Tazi, Nouamane and Pyysalo, Sampo and Wolf, Thomas and Raffel, Colin}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{36}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DoReMi: Optimizing Data Mixtures for LM Pretraining</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/doremi-data-mixture-optimization/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/doremi-data-mixture-optimization/</guid><description>DoReMi uses a small proxy model with distributionally robust optimization to learn domain weights that speed up large-scale language model pretraining by 2.6x.</description><content:encoded><![CDATA[<h2 id="a-method-for-automatic-domain-reweighting">A method for automatic domain reweighting</h2>
<p>This is a <strong>method paper</strong> that introduces Domain Reweighting with Minimax Optimization (DoReMi), an algorithm for automatically tuning the mixture proportions of pretraining data domains. Rather than relying on heuristics or expensive downstream-task-based tuning, DoReMi uses a small proxy model trained with <a href="https://en.wikipedia.org/wiki/Robust_optimization">group distributionally robust optimization (Group DRO)</a> to produce domain weights that transfer to much larger models.</p>
<h2 id="why-data-mixture-proportions-matter">Why data mixture proportions matter</h2>
<p>Language model pretraining datasets combine text from many domains: web crawls, Wikipedia, books, code, academic papers, and others. The mixture proportions (how much of each domain to include) significantly affect downstream performance, but existing approaches either set them by hand (<a href="https://en.wikipedia.org/wiki/The_Pile_(dataset)">The Pile</a> uses heuristic weights) or tune them against downstream tasks (GLaM/PaLM), which is expensive and risks overfitting to a specific evaluation set. No principled, task-agnostic method existed for determining mixture proportions.</p>
<h2 id="minimax-optimization-over-domain-excess-loss">Minimax optimization over domain excess loss</h2>
<p>DoReMi&rsquo;s core insight is to frame data mixture optimization as a minimax problem: find domain weights that minimize the worst-case excess loss across all domains. The algorithm has three steps.</p>
<p><strong>Step 1</strong>: Train a small reference model (280M parameters) on some default domain weights $\alpha_{\text{ref}}$ (e.g., proportional to raw token count).</p>
<p><strong>Step 2</strong>: Train a small proxy model $p_{\theta}$ using Group DRO, which solves the minimax objective:</p>
<p>$$
\min_{\theta} \max_{\alpha \in \Delta^{k}} \sum_{i=1}^{k} \alpha_{i} \cdot \left[ \frac{1}{\sum_{x \in D_{i}} |x|} \sum_{x \in D_{i}} \ell_{\theta}(x) - \ell_{\text{ref}}(x) \right]
$$</p>
<p>where $\ell_{\theta}(x) = -\log p_{\theta}(x)$ and $\ell_{\text{ref}}(x) = -\log p_{\text{ref}}(x)$. The excess loss $\ell_{\theta}(x) - \ell_{\text{ref}}(x)$ measures how much headroom the proxy has to improve on each example relative to the reference. The inner maximization upweights domains with high excess loss via exponentiated gradient ascent, while the outer minimization trains the proxy on those upweighted domains.</p>
<p>At each training step, the domain weights update as:</p>
<p>$$
\alpha_{t}&rsquo; \leftarrow \alpha_{t-1} \exp(\eta \lambda_{t})
$$</p>
<p>where $\lambda_{t}[i]$ is the per-domain excess loss (clipped at zero), followed by renormalization and smoothing with a uniform component: $\alpha_{t} \leftarrow (1-c)\frac{\alpha_{t}&rsquo;}{\sum_{i} \alpha_{t}&rsquo;[i]} + cu$, with $c = 10^{-3}$.</p>
<p>The final domain weights are the average over all training steps: $\bar{\alpha} = \frac{1}{T}\sum_{t=1}^{T} \alpha_{t}$.</p>
<p><strong>Step 3</strong>: Resample data according to $\bar{\alpha}$ and train the full-scale model using standard procedures.</p>
<p><strong>Iterated DoReMi</strong> extends this by running multiple rounds, using the previous round&rsquo;s optimized weights as the next round&rsquo;s reference weights. This converges within 3 rounds on the GLaM dataset.</p>
<h2 id="experiments-across-the-pile-and-glam-datasets">Experiments across The Pile and GLaM datasets</h2>
<p><strong>Datasets.</strong> The Pile (22 domains, 800GB) and the GLaM dataset (8 domains, also used for PaLM). On The Pile, baseline weights come from the dataset defaults. On GLaM, baseline weights are uniform, with downstream-tuned oracle weights available for comparison.</p>
<p><strong>Setup.</strong> Transformer decoder-only LMs trained with next-token prediction. All models use batch size 512 and sequence length 1024. Proxy and reference models are 280M parameters. Main models are 8B parameters (30x larger). Training runs: 200K steps (Pile) or 300K steps (GLaM). The domain weight optimization cost (training two 280M models) is 8% of the compute for the 8B main model.</p>
<p><strong>Evaluation.</strong> Per-domain held-out perplexity and one-shot generative accuracy on five tasks: TriviaQA, NaturalQuestions, WebQuestions, SQuADv2, and LAMBADA.</p>
<h3 id="key-domain-weight-shifts">Key domain weight shifts</h3>
<p>On The Pile, DoReMi (280M) dramatically upweights diverse web text (Pile-CC: 0.112 to 0.606) while downweighting specialized domains like ArXiv (0.105 to 0.004), PubMed Central (0.107 to 0.005), and StackExchange (0.093 to 0.015). Smaller, underrepresented domains like YouTubeSubtitles and PhilPapers receive proportionally large increases.</p>
<h3 id="scaling-behavior">Scaling behavior</h3>
<p>DoReMi was tested with matched proxy/main model sizes (280M through 1B) and with varying proxy sizes (70M through 1B) feeding into an 8B main model.</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>Speedup to baseline accuracy</th>
          <th>Downstream improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>DoReMi (280M to 280M)</td>
          <td>4x</td>
          <td>+2% avg accuracy</td>
      </tr>
      <tr>
          <td>DoReMi (280M to 8B)</td>
          <td>2.6x</td>
          <td>+6.5% avg accuracy</td>
      </tr>
      <tr>
          <td>DoReMi (150M to 8B)</td>
          <td>~2x</td>
          <td>Significant</td>
      </tr>
      <tr>
          <td>DoReMi (1B to 8B)</td>
          <td>~2x</td>
          <td>Significant</td>
      </tr>
  </tbody>
</table>
<p>Improvements are consistent across all tested model scales (280M to 1B matched), with no sign of diminishing returns at larger sizes.</p>
<h2 id="perplexity-improves-everywhere-even-on-downweighted-domains">Perplexity improves everywhere, even on downweighted domains</h2>
<p>The most striking finding is that DoReMi improves perplexity on all 22 domains in The Pile, including domains it downweights. The proposed explanation: the lowest-entropy domains need few samples to learn (they&rsquo;re statistically simple), while the highest-entropy domains have token distributions close to the uniform initialization and also need fewer samples. Reallocating weight to medium-entropy domains generates positive transfer that lifts all domains.</p>
<p>On The Pile, DoReMi reaches the baseline&rsquo;s downstream accuracy in 75K steps versus 200K for the baseline (2.6x speedup) and achieves a 6.5% absolute improvement in average one-shot accuracy at 200K steps.</p>
<p>On the GLaM dataset, iterated DoReMi (round 2) matches the performance of domain weights that were tuned directly on downstream task performance, despite having no knowledge of downstream tasks. Domain weights converge within 3 iterations.</p>
<h3 id="ablations">Ablations</h3>
<p>Using only the proxy model&rsquo;s loss (prefer hardest domains) or only the negative reference loss (prefer easiest domains) both underperform the full excess loss formulation. Both components are necessary: the excess loss identifies domains where the proxy has room to improve relative to what is learnable.</p>
<p>The proxy model itself typically underperforms the main model trained on its weights, and this gap grows at larger proxy scales. A 1B proxy model underperforms the 1B baseline, yet its domain weights still improve 1B main model training by over 2x. This suggests the domain weight signal is robust even when the proxy model itself is not well-trained.</p>
<h3 id="limitations">Limitations</h3>
<p>The domain weight landscape may have multiple local optima: a 280M proxy puts most weight on Pile-CC, while a 1B proxy favors OpenWebText2 instead. Both configurations improve over baseline, but the optimal weights are not unique.</p>
<p>The granularity of &ldquo;domains&rdquo; matters. DoReMi works better with more domains (22 on The Pile versus 8 on GLaM). Domains are defined by data provenance, which is coarse-grained. Fine-grained domain definitions (e.g., via clustering) could improve results but also risk DRO putting all weight on a small set of worst-case examples.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>The Pile</td>
          <td>800 GB, 22 domains</td>
          <td>Default heuristic weights as baseline</td>
      </tr>
      <tr>
          <td>Pretraining</td>
          <td>GLaM dataset</td>
          <td>8 domains</td>
          <td>Uniform weights as baseline; downstream-tuned oracle available</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>TriviaQA, NaturalQuestions, WebQuestions, SQuADv2, LAMBADA</td>
          <td>Standard splits</td>
          <td>One-shot generative evaluation</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Group DRO with exponentiated gradient ascent for domain weight updates. Step size $\eta = 1$, smoothing $c = 10^{-3}$. Per-token excess loss clipped at zero. Domain weights averaged over all training steps. Iterated DoReMi converges when $|\bar{\alpha} - \alpha_{\text{ref}}|_{\infty} &lt; 10^{-3}$.</p>
<h3 id="models">Models</h3>
<p>Vanilla Transformer decoder-only models with 256K vocabulary. Sizes: 70M (3 layers), 150M (6 layers), 280M (12 layers), 510M (12 layers), 760M (12 layers), 1B (16 layers), 8B (32 layers). All use 64-dim attention heads except 8B (128-dim).</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>DoReMi (280M to 8B)</th>
          <th>Baseline (8B)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Avg one-shot accuracy</td>
          <td>+6.5% over baseline</td>
          <td>Reference</td>
          <td>5 generative tasks</td>
      </tr>
      <tr>
          <td>Worst-case log-perplexity</td>
          <td>1.46</td>
          <td>1.71</td>
          <td>Across 22 Pile domains</td>
      </tr>
      <tr>
          <td>Avg log-perplexity</td>
          <td>1.40</td>
          <td>1.64</td>
          <td>Across 22 Pile domains</td>
      </tr>
      <tr>
          <td>Domains beating baseline</td>
          <td>22/22</td>
          <td>0/22</td>
          <td>Per-domain perplexity</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Proxy and reference models (under 1B) trained on TPUv3. Models at 1B and 8B trained on TPUv4. Domain weight optimization (two 280M runs) costs 8% of 8B training FLOPs.</p>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{xie2023doremi,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Xie, Sang Michael and Pham, Hieu and Dong, Xuanyi and Du, Nan and Liu, Hanxiao and Lu, Yifeng and Liang, Percy and Le, Quoc V. and Ma, Tengyu and Yu, Adams Wei}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{36}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Data Mixing Laws for LM Pretraining Optimization</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/data-mixing-laws-pretraining/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/data-mixing-laws-pretraining/</guid><description>Ye et al. discover that LM loss follows an exponential law over domain mixture proportions, enabling cheap prediction and optimization of data mixtures.</description><content:encoded><![CDATA[<h2 id="an-empirical-discovery-of-predictable-mixture-loss-relationships">An empirical discovery of predictable mixture-loss relationships</h2>
<p>This is a <strong>discovery paper</strong> that identifies a quantitative, functional relationship between pretraining data mixture proportions and language model loss. The key finding is that domain-specific validation loss follows an exponential law over the linear combination of training domain proportions, and this law composes with standard scaling laws to enable cheap prediction of large-model performance under arbitrary mixtures.</p>
<h2 id="the-missing-quantitative-link-between-data-mixtures-and-performance">The missing quantitative link between data mixtures and performance</h2>
<p>Pretraining data for large language models combines text from many domains (web, code, academic, books, etc.), and mixture proportions significantly affect model quality. Existing approaches either set proportions by hand without disclosed criteria (LLaMA, Baichuan) or use algorithmic methods like <a href="/notes/natural-language-processing/language-models/doremi-data-mixture-optimization/">DoReMi</a> that optimize qualitatively but cannot predict the quantitative effect of a specific mixture before training. Scaling laws exist for model size and data quantity, but no equivalent existed for mixture proportions. This paper fills that gap.</p>
<h2 id="the-exponential-data-mixing-law">The exponential data mixing law</h2>
<p>The core finding: for a model of fixed size trained for a fixed number of steps, the validation loss on domain $i$ as a function of the training mixture proportions $r_{1 \dots M}$ follows:</p>
<p>$$
L_{i}(r_{1 \dots M}) = c_{i} + k_{i} \exp\left(\sum_{j=1}^{M} t_{ij} r_{j}\right)
$$</p>
<p>where $c_{i}$, $k_{i}$, and $t_{ij}$ are fitted parameters. The constant $c_{i}$ represents the irreducible loss (not affected by mixture changes). The interaction coefficients $t_{ij}$ capture how training domain $j$ affects validation loss on domain $i$: negative $t_{ij}$ means domain $j$ helps domain $i$, positive means it hurts.</p>
<p>This was discovered progressively:</p>
<ol>
<li><strong>Two domains</strong>: Log-reducible-loss is linear in domain proportion (univariate exponential).</li>
<li><strong>Three domains</strong>: The exponential generalizes to a linear combination over all domain proportions (Eq. above), outperforming alternatives with comparable parameter count.</li>
<li><strong>General validation</strong>: For a validation set composed of $K$ domains with proportions $s_{1 \dots K}$, the overall loss is:</li>
</ol>
<p>$$
L(r_{1 \dots M}) = \sum_{i=1}^{K} s_{i} \left[ c_{i} + k_{i} \exp\left(\sum_{j=1}^{M} t_{ij} r_{j}\right) \right]
$$</p>
<p>When the validation set composition is unknown, implicit domain aggregation treats $s_{i}$ as learnable parameters. Setting the number of implicit domains larger than the true number works well and is robust to overestimation.</p>
<h3 id="domain-interaction-patterns">Domain interaction patterns</h3>
<p>Visualizing the fitted $t_{ij}$ coefficients across 5 coarse Pile domains reveals three relationship types: most domain pairs are <strong>unrelated</strong> (sparse interaction matrix where each domain&rsquo;s loss is dominated by its own training proportion), some show <strong>facilitation</strong> (e.g., dialogue data helps internet text), and some show <strong>conflict</strong> (e.g., symbolic data hurts prose). This sparsity explains why the law can be fitted with fewer samples than the quadratic parameter count would suggest.</p>
<h2 id="nested-scaling-pipeline-for-cheap-prediction">Nested scaling pipeline for cheap prediction</h2>
<p>Fitting data mixing laws directly at target scale is too expensive (requires many full training runs at different mixtures). The paper proposes nesting three scaling laws:</p>
<p><strong>Step 1</strong>: For each mixture $r_{i}$ and each small model size $N_{j}$, train for $S_{0}$ steps. Fit a <a href="https://en.wikipedia.org/wiki/Power_law">power law</a> $L(S) = E_{1} + B/S^{\beta}$ over steps to extrapolate to the target step count $S_{\text{target}}$.</p>
<p><strong>Step 2</strong>: With the step-extrapolated losses for each mixture, fit a power law $L(N) = E_{2} + A/N^{\alpha}$ over model sizes to extrapolate to the target model size $N_{\text{target}}$.</p>
<p><strong>Step 3</strong>: With the predicted losses at $(N_{\text{target}}, S_{\text{target}})$ for all sampled mixtures, fit the data mixing law and search for the optimal mixture.</p>
<p>This pipeline requires only training small models (70M to 410M) for short runs (30B tokens) to predict performance of a 1B model trained for 100B tokens.</p>
<h3 id="mixture-sampling-strategy">Mixture sampling strategy</h3>
<p>To get informative samples efficiently, the paper uses double-diminishing proportions: for each domain, enumerate proportions by halving from the maximum available. This distributes losses evenly across the exponential law&rsquo;s range. From 40 candidate mixtures trained at the smallest scale (70M), 20 are selected based on which subset minimizes data mixing law fitting error.</p>
<h2 id="experiments-on-redpajama-and-continual-pretraining">Experiments on RedPajama and continual pretraining</h2>
<p><strong>Main experiment.</strong> Models trained on RedPajama, validated on the Pile (mimicking the common scenario where validation data comes from a different distribution than training). Small models: 70M, 160M, 305M, 410M trained for 30B tokens. Target: 1B model for 100B tokens.</p>
<p>The optimized mixture dramatically redistributes weight compared to RedPajama defaults:</p>
<table>
  <thead>
      <tr>
          <th>Domain</th>
          <th>Default</th>
          <th>Optimized</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CommonCrawl</td>
          <td>0.670</td>
          <td>0.125</td>
      </tr>
      <tr>
          <td>C4</td>
          <td>0.150</td>
          <td>0.250</td>
      </tr>
      <tr>
          <td>GitHub</td>
          <td>0.045</td>
          <td>0.141</td>
      </tr>
      <tr>
          <td>ArXiv</td>
          <td>0.045</td>
          <td>0.250</td>
      </tr>
      <tr>
          <td>Books</td>
          <td>0.045</td>
          <td>0.094</td>
      </tr>
      <tr>
          <td>StackExchange</td>
          <td>0.025</td>
          <td>0.125</td>
      </tr>
      <tr>
          <td>Wikipedia</td>
          <td>0.020</td>
          <td>0.016</td>
      </tr>
  </tbody>
</table>
<p>The optimized mixture reaches the default mixture&rsquo;s final performance in 73% of the training steps and eventually achieves performance equivalent to 48% more training on the default mixture.</p>
<p><strong>Comparison to DoReMi and DoGE.</strong> Data mixing laws outperform both: the predicted-optimal mixture achieves lower validation loss than DoReMi and DoGE (both universal and OOD settings) for 1B models trained for 100B tokens on RedPajama.</p>
<p><strong>Continual pretraining.</strong> The law extends to continual pretraining (Pythia-70M on Pile + Python code). It accurately predicts the critical mixture proportion that avoids <a href="https://en.wikipedia.org/wiki/Catastrophic_interference">catastrophic forgetting</a> on the original domain while improving the target domain. This suggests data mixing laws could guide dynamic data schedules across multi-stage pretraining.</p>
<h2 id="implications-and-limitations">Implications and limitations</h2>
<p>The data mixing law provides a predictive framework rather than just an optimization algorithm. Key implications:</p>
<ul>
<li>The interaction coefficients $t_{ij}$ make domain relationships quantitatively observable before full-scale training, identifying facilitation and conflict pairs.</li>
<li>The nested pipeline&rsquo;s cost is dominated by the small-model training runs (40 mixtures at 70M scale), which is orders of magnitude cheaper than even a single target-scale run.</li>
<li>The continual pretraining application opens the door to optimizing dynamic data schedules, where mixture proportions change across training stages.</li>
</ul>
<p><strong>Limitations</strong>: The &ldquo;domain&rdquo; concept remains loosely defined (provenance-based). The nested scaling laws introduce compounding errors at each step, and predictions tend to slightly underestimate actual loss. The number of required fitting samples, while subquadratic in practice due to sparsity, still scales with the number of domains. No theoretical justification for the exponential form is provided; it is a purely empirical finding.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training (pilot)</td>
          <td>The Pile (GitHub, Pile-CC, Books3)</td>
          <td>30B tokens</td>
          <td>2-domain and 3-domain experiments</td>
      </tr>
      <tr>
          <td>Training (main)</td>
          <td>RedPajama</td>
          <td>100B tokens</td>
          <td>7 domains</td>
      </tr>
      <tr>
          <td>Validation</td>
          <td>The Pile validation set</td>
          <td>Standard split</td>
          <td>Out-of-distribution relative to RedPajama</td>
      </tr>
      <tr>
          <td>Continual pretraining</td>
          <td>Pile + Python code</td>
          <td>10B tokens</td>
          <td>Pythia-70M base model</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Data mixing law: $L_{i}(r_{1 \dots M}) = c_{i} + k_{i} \exp(\sum_{j} t_{ij} r_{j})$. Fitted via AdaBoost Regressor on sampled mixtures. Step scaling law: $L(S) = E_{1} + B/S^{\beta}$. Model size scaling law: $L(N) = E_{2} + A/N^{\alpha}$. Both fitted via Huber loss minimization with LBFGS. Decomposed Chinchilla-style (separate fits for stability). 40 candidate mixtures sampled via double-diminishing proportions, 20 selected for the final pipeline.</p>
<h3 id="models">Models</h3>
<p>Transformer decoder-only LMs. Pilot: 70M, 160M. Main pipeline: 70M, 160M, 305M, 410M (for fitting), 1B (target). Batch size: 1M tokens. Cosine learning rate decay with 2K step warmup, decaying to 0.1x at 100K steps.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Optimized Mixture</th>
          <th>Default Mixture</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Steps to match default final loss</td>
          <td>73K (73%)</td>
          <td>100K (100%)</td>
          <td>27% training reduction</td>
      </tr>
      <tr>
          <td>Equivalent extra training</td>
          <td>+48%</td>
          <td>Baseline</td>
          <td>Estimated via step scaling law</td>
      </tr>
      <tr>
          <td>Validation loss (1B, 100B)</td>
          <td>Lowest</td>
          <td>Higher than optimized</td>
          <td>Also beats DoReMi and DoGE</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>8 A100 GPUs. Training times per 30B-token run: 3.5 hours (70M), 8 hours (160M), 16 hours (305M), 21 hours (410M).</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://pile.eleuther.ai/">The Pile</a></td>
          <td>Dataset</td>
          <td>MIT</td>
          <td>Pilot and validation data</td>
      </tr>
      <tr>
          <td><a href="https://github.com/togethercomputer/RedPajama-Data">RedPajama</a></td>
          <td>Dataset</td>
          <td>Apache 2.0</td>
          <td>Main training data</td>
      </tr>
      <tr>
          <td><a href="https://github.com/EleutherAI/pythia">Pythia Suite</a></td>
          <td>Model</td>
          <td>Apache 2.0</td>
          <td>Model architecture configs; Pythia-70M checkpoint for continual pretraining</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status: Partially Reproducible.</strong> Datasets and base model checkpoints are public. No official code release for the data mixing law fitting pipeline, mixture sampling, or the nested scaling law prediction workflow.</p>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{ye2025datamixinglaws,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Data Mixing Laws: Optimizing Data Mixtures by Predicting Language Modeling Performance}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ye, Jiasheng and Liu, Peiju and Sun, Tianxiang and Zhan, Jun and Zhou, Yunhua and Qiu, Xipeng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>RWKV: Linear-Cost RNN with Transformer Training</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/rwkv-rnn-transformer-architecture/</link><pubDate>Tue, 07 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/rwkv-rnn-transformer-architecture/</guid><description>RWKV combines parallelizable transformer training with constant-cost RNN inference using linear attention and channel-wise decay.</description><content:encoded><![CDATA[<h2 id="a-new-architecture-bridging-rnns-and-transformers">A New Architecture Bridging RNNs and Transformers</h2>
<p>This is a <strong>Method</strong> paper that introduces RWKV (Receptance Weighted Key Value), a novel sequence model architecture that combines the parallelizable training of Transformers with the efficient $O(Td)$ inference of RNNs. RWKV can be formulated equivalently as either a Transformer (for parallel training) or an RNN (for sequential inference), achieving the lowest computational and memory complexity among comparable architectures while matching Transformer-level performance. The authors scale RWKV to 14 billion parameters, making it the largest dense RNN ever trained at the time of publication.</p>
<h2 id="the-quadratic-cost-of-self-attention">The Quadratic Cost of Self-Attention</h2>
<p>Transformers have become the dominant architecture for NLP, powering models like GPT-3, LLaMA, and Chinchilla. Their self-attention mechanism captures both local and long-range dependencies while supporting parallelized training. However, self-attention scales quadratically with sequence length in both time ($O(T^2d)$) and space ($O(T^2 + Td)$), making it computationally and memory intensive for long sequences and resource-constrained deployment.</p>
<p>RNNs, by contrast, offer linear scaling in memory and computation, but suffer from the vanishing gradient problem and cannot parallelize across the time dimension during training. This limits their scalability and makes them unable to match Transformer performance in practice.</p>
<p>Prior work on efficient Transformers (Reformer, Performer, Linformer, AFT, MEGA) has attempted to reduce this quadratic cost, often at the expense of model expressivity. RWKV aims to combine the best of both worlds: Transformer-grade training efficiency with RNN-grade inference cost, without any approximation to the attention mechanism.</p>
<h2 id="linear-attention-via-channel-wise-decay">Linear Attention via Channel-Wise Decay</h2>
<p>RWKV is built on four core vectors that interact multiplicatively at each timestep:</p>
<ul>
<li><strong>R</strong> (Receptance): receives past information, acting as a gating signal</li>
<li><strong>W</strong> (Weight): a trainable positional weight decay vector</li>
<li><strong>K</strong> (Key): analogous to keys in standard attention</li>
<li><strong>V</strong> (Value): analogous to values in standard attention</li>
</ul>
<p>The architecture consists of stacked residual blocks, each containing a <strong>time-mixing</strong> sub-block and a <strong>channel-mixing</strong> sub-block.</p>
<h3 id="token-shift">Token Shift</h3>
<p>All linear projection vectors are produced by interpolating between the current input $x_t$ and the previous input $x_{t-1}$, creating a token shift mechanism:</p>
<p>$$
r_t = W_r \cdot (\mu_r \odot x_t + (1 - \mu_r) \odot x_{t-1})
$$</p>
<p>$$
k_t = W_k \cdot (\mu_k \odot x_t + (1 - \mu_k) \odot x_{t-1})
$$</p>
<p>$$
v_t = W_v \cdot (\mu_v \odot x_t + (1 - \mu_v) \odot x_{t-1})
$$</p>
<p>where $\mu_r$, $\mu_k$, $\mu_v$ are learnable interpolation parameters. This is implemented efficiently as a simple offset in the temporal dimension.</p>
<h3 id="the-wkv-operator">The WKV Operator</h3>
<p>The core attention-like computation replaces standard dot-product attention with a channel-wise weighted sum using exponential decay:</p>
<p>$$
wkv_t = \frac{\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i} \odot v_i + e^{u + k_t} \odot v_t}{\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i} + e^{u + k_t}}
$$</p>
<p>Here $w$ is the channel-wise time decay vector and $u$ is a separate bonus vector that attends specifically to the current token. Unlike AFT where $W$ is a pairwise matrix, RWKV treats $W$ as a channel-wise vector modified by relative position, enabling the recurrent formulation.</p>
<h3 id="output-gating">Output Gating</h3>
<p>The receptance vector gates the WKV output through a sigmoid:</p>
<p>$$
o_t = W_o \cdot (\sigma(r_t) \odot wkv_t)
$$</p>
<p>The channel-mixing block uses a similar gating mechanism with squared ReLU activation:</p>
<p>$$
o&rsquo;_t = \sigma(r&rsquo;_t) \odot (W&rsquo;_v \cdot \max(k&rsquo;_t, 0)^2)
$$</p>
<h3 id="dual-mode-operation">Dual-Mode Operation</h3>
<p>During <strong>training</strong>, RWKV operates in time-parallel mode. The matrix multiplications ($W_\lambda$ for $\lambda \in {r, k, v, o}$) dominate at $O(BTd^2)$ and parallelize identically to standard Transformers. The element-wise WKV computation is $O(BTd)$ and parallelizes along batch and channel dimensions.</p>
<p>During <strong>inference</strong>, RWKV switches to time-sequential mode. Each timestep updates a fixed-size state vector, giving constant $O(d)$ memory and $O(Td)$ total time for generating $T$ tokens, compared to $O(T^2d)$ for standard Transformers.</p>
<h3 id="optimizations">Optimizations</h3>
<p>Three additional design choices improve training:</p>
<ol>
<li><strong>Custom CUDA kernels</strong> for the sequential WKV computation, fusing it into a single kernel on training accelerators</li>
<li><strong>Small init embedding</strong>: initializing the embedding matrix with small values plus an additional LayerNorm, accelerating convergence</li>
<li><strong>Custom initialization</strong>: most weights initialized to zero with no biases, following identity-mapping principles from residual network design</li>
</ol>
<h2 id="scaling-to-14b-parameters-and-benchmark-evaluation">Scaling to 14B Parameters and Benchmark Evaluation</h2>
<h3 id="model-scaling">Model Scaling</h3>
<p>The authors train six RWKV models from 169M to 14B parameters, all for one epoch (330B tokens) on the Pile:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Layers</th>
          <th>Dimension</th>
          <th>Parameters</th>
          <th>FLOP/Token</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>169M</td>
          <td>12</td>
          <td>768</td>
          <td>$1.69 \times 10^8$</td>
          <td>$2.61 \times 10^8$</td>
      </tr>
      <tr>
          <td>430M</td>
          <td>24</td>
          <td>1024</td>
          <td>$4.30 \times 10^8$</td>
          <td>$7.57 \times 10^8$</td>
      </tr>
      <tr>
          <td>1.5B</td>
          <td>24</td>
          <td>2048</td>
          <td>$1.52 \times 10^9$</td>
          <td>$2.82 \times 10^9$</td>
      </tr>
      <tr>
          <td>3B</td>
          <td>32</td>
          <td>2560</td>
          <td>$2.99 \times 10^9$</td>
          <td>$5.71 \times 10^9$</td>
      </tr>
      <tr>
          <td>7B</td>
          <td>32</td>
          <td>4096</td>
          <td>$7.39 \times 10^9$</td>
          <td>$1.44 \times 10^{10}$</td>
      </tr>
      <tr>
          <td>14B</td>
          <td>40</td>
          <td>5120</td>
          <td>$1.42 \times 10^{10}$</td>
          <td>$2.78 \times 10^{10}$</td>
      </tr>
  </tbody>
</table>
<p>The parameter count follows: $\text{params} = 2VD + 13D^2L + D(11L + 4)$, where $V = 50277$ is vocabulary size, $D$ is model dimension, and $L$ is layers. FLOPs match the standard transformer formula: $\text{FLOP} = 6 \cdot [\text{tokens}] \cdot [\text{params}]$.</p>
<h3 id="scaling-laws">Scaling Laws</h3>
<p>Training 45 RWKV models across varied (dataset, parameters) pairs, the authors find that RWKV follows the same log-log linear scaling law established for Transformers. The linear fit to Pareto-optimal points achieves $r^2 = 0.994$, and extrapolation an additional order of magnitude still yields $r^2 = 0.875$. This contrasts with prior claims that LSTMs do not follow transformer-like scaling.</p>
<h3 id="nlp-benchmarks">NLP Benchmarks</h3>
<p>RWKV is compared against similarly-sized models trained on comparable token budgets: Pythia, OPT, and BLOOM (all FLOP-matched). Results span twelve benchmarks: ARC (Easy/Challenge), BoolQ, COPA, HeadQA, HellaSwag, LAMBADA, OpenBookQA, PIQA, ReCoRD, SciQ, and Winogrande.</p>
<p>RWKV performs competitively with Transformers across all model sizes. On average across benchmarks, RWKV tracks closely with Pythia and outperforms OPT and BLOOM at comparable scales.</p>
<h3 id="long-context-and-extended-finetuning">Long Context and Extended Finetuning</h3>
<p>RWKV can extend its context length after pretraining through progressive finetuning: doubling from 1024 to 2048 (10B tokens), then to 4096 (100B tokens), and finally to 8192 (100B tokens). Each doubling reduces test loss on the Pile, indicating effective use of longer context.</p>
<p>On the Long Range Arena (LRA) benchmark, which tests sequences from 1,000 to 16,000 tokens, RWKV performs second only to S4 across the five datasets.</p>
<h3 id="inference-efficiency">Inference Efficiency</h3>
<p>Benchmarking text generation on CPU (x86) and GPU (NVIDIA A100 80GB) at float32 precision shows that RWKV exhibits linear scaling in generation time, while Transformers scale quadratically. This advantage grows with sequence length: for long outputs, RWKV completes generation substantially faster at equivalent model sizes.</p>
<h2 id="competitive-performance-with-key-caveats">Competitive Performance with Key Caveats</h2>
<p>RWKV demonstrates that RNN-class models can match Transformer performance at scale, while maintaining $O(Td)$ time and $O(d)$ memory during inference. The key findings are:</p>
<ol>
<li><strong>Scaling laws hold</strong>: RWKV follows the same compute-optimal scaling as Transformers ($r^2 = 0.994$), contradicting earlier claims about RNN scaling behavior</li>
<li><strong>Competitive NLP performance</strong>: Across twelve benchmarks, RWKV matches similarly-sized Transformers trained on comparable data</li>
<li><strong>Linear inference cost</strong>: Generation time scales linearly rather than quadratically, with constant memory regardless of sequence length</li>
<li><strong>Context extension</strong>: Progressive finetuning effectively extends the context window post-training</li>
</ol>
<h3 id="limitations">Limitations</h3>
<p>The authors identify two primary limitations:</p>
<p><strong>Information compression</strong>: Linear attention funnels all past information through a single fixed-size state vector. For tasks requiring recall of specific details over very long contexts, this is mechanistically more constrained than full self-attention, which maintains direct access to all previous tokens.</p>
<p><strong>Prompt sensitivity</strong>: RWKV is more sensitive to prompt engineering than standard Transformers. The linear attention mechanism limits how much prompt information carries forward, making the order of information in the prompt particularly important. Reordering prompts improved F1 from 44.2% to 74.8% on one task.</p>
<h3 id="future-directions">Future Directions</h3>
<p>The authors suggest several avenues: applying parallel scan to reduce WKV cost to $O(B \log(T) d)$, extending RWKV to encoder-decoder and multimodal architectures, leveraging hidden states for interpretability and safety, and increasing internal state size to improve long-range recall.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/BlinkDL/RWKV-LM">BlinkDL/RWKV-LM</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official PyTorch training and inference implementation</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/BlinkDL/rwkv-4-pile-14b">Pre-trained weights (169M to 14B)</a></td>
          <td>Model</td>
          <td>Apache-2.0</td>
          <td>All six Pile-trained sizes on HuggingFace (<code>BlinkDL/rwkv-4-pile-*</code>)</td>
      </tr>
      <tr>
          <td><a href="https://pile.eleuther.ai/">The Pile</a></td>
          <td>Dataset</td>
          <td>Mixed</td>
          <td>825 GiB pretraining corpus; component licenses vary by source</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility classification</strong>: Highly Reproducible. Training code (Apache-2.0), pre-trained weights for all six model sizes, the full training corpus, and complete hyperparameters (Appendix G) are all publicly available. The only missing detail is the specific GPU cluster configuration used for pretraining.</p>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>The Pile</td>
          <td>330B tokens</td>
          <td>One full epoch for all model sizes</td>
      </tr>
      <tr>
          <td>Context extension</td>
          <td>The Pile</td>
          <td>210B additional tokens</td>
          <td>Progressive doubling: 1024 to 8192</td>
      </tr>
      <tr>
          <td>NLP evaluation</td>
          <td>ARC, BoolQ, COPA, HeadQA, HellaSwag, LAMBADA, OpenBookQA, PIQA, ReCoRD, SciQ, Winogrande</td>
          <td>Various</td>
          <td>Zero-shot evaluation</td>
      </tr>
      <tr>
          <td>Long-range evaluation</td>
          <td>Long Range Arena (LRA)</td>
          <td>1K-16K tokens</td>
          <td>Five sub-tasks</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Optimizer: Adam ($\beta = (0.9, 0.99)$), no weight decay</li>
<li>Precision: bfloat16</li>
<li>Training context length: 1024 tokens</li>
<li>Learning rate: constant warmup, then exponential decay</li>
<li>Auxiliary loss from PaLM (softmax normalizer regularization)</li>
<li>Batch size: 128 or 256 sequences (dynamically switched)</li>
<li>Training organized into mini-epochs of 40,320 samples each (8,043 mini-epochs per Pile epoch)</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Init LR</th>
          <th>Warmup Mini-Epochs</th>
          <th>End LR</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>169M</td>
          <td>6e-4</td>
          <td>361</td>
          <td>1e-5</td>
      </tr>
      <tr>
          <td>430M</td>
          <td>4e-4</td>
          <td>411</td>
          <td>1e-5</td>
      </tr>
      <tr>
          <td>1.5B</td>
          <td>3e-4</td>
          <td>443</td>
          <td>1e-5</td>
      </tr>
      <tr>
          <td>3B</td>
          <td>1.5e-4</td>
          <td>451</td>
          <td>1e-5</td>
      </tr>
      <tr>
          <td>7B</td>
          <td>1.5e-4</td>
          <td>465</td>
          <td>1e-5</td>
      </tr>
      <tr>
          <td>14B</td>
          <td>1e-4</td>
          <td>544</td>
          <td>7e-6</td>
      </tr>
  </tbody>
</table>
<p>All pretrained models (169M to 14B) are publicly released on HuggingFace (<code>BlinkDL/rwkv-4-pile-*</code>) under Apache-2.0. Training code is at <a href="https://github.com/BlinkDL/RWKV-LM">BlinkDL/RWKV-LM</a> (Apache-2.0).</p>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>All NLP benchmarks evaluated in zero-shot setting</li>
<li>FLOP-matched comparison against Pythia, OPT, BLOOM</li>
<li>Inference benchmarked on CPU (x86) and GPU (NVIDIA A100 80GB) at float32</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Inference experiments: NVIDIA A100 80GB GPU</li>
<li>Training hardware details not fully specified; FLOP budgets reported per model</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Peng, B., Alcaide, E., Anthony, Q., Albalak, A., Arcadinho, S., Biderman, S., &hellip; &amp; Zhu, R.-J. (2023). RWKV: Reinventing RNNs for the Transformer Era. In <em>Findings of the Association for Computational Linguistics: EMNLP 2023</em>, pp. 14048-14077.</p>
<p><strong>Publication</strong>: Findings of EMNLP 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/BlinkDL/RWKV-LM">GitHub Repository (Apache-2.0)</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{peng2023rwkv,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{RWKV: Reinventing RNNs for the Transformer Era}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Peng, Bo and Alcaide, Eric and Anthony, Quentin and Albalak, Alon and Arcadinho, Samuel and Biderman, Stella and Cao, Huanqi and Cheng, Xin and Chung, Michael and Derczynski, Leon and Du, Xingjian and Grella, Matteo and GV, Kranthi Kiran and He, Xuzheng and Hou, Haowen and Kazienko, Przemys{\l}aw and Koco{\&#39;n}, Jan and Kong, Jiaming and Koptyra, Bart{\l}omiej and Lau, Hayden and Lin, Jiaju and Mantri, Krishna Sri Ipsit and Mom, Ferdinand and Saito, Atsushi and Song, Guangyu and Tang, Xiangru and Wind, Johan S. and Wo{\&#39;z}niak, Stanis{\l}aw and Zhang, Zhenyuan and Zhou, Qinghua and Zhu, Jian and Zhu, Rui-Jie}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Findings of the Association for Computational Linguistics: EMNLP 2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{14048--14077}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.18653/v1/2023.findings-emnlp.936}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Liquid-S4: Input-Dependent State-Space Models</title><link>https://hunterheidenreich.com/notes/machine-learning/model-architectures/liquid-s4-state-space-models/</link><pubDate>Tue, 07 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/model-architectures/liquid-s4-state-space-models/</guid><description>Liquid-S4 combines liquid time-constant networks with structured state-space models, adding input-dependent kernels for long-range sequence modeling.</description><content:encoded><![CDATA[<h2 id="a-method-for-input-adaptive-sequence-modeling">A Method for Input-Adaptive Sequence Modeling</h2>
<p>This is a <strong>Method</strong> paper that introduces Liquid-S4, a new state-space model combining the structured state-space framework (S4) with liquid time-constant (LTC) networks. The primary contribution is an input-dependent state transition mechanism that allows the model to adapt its dynamics based on incoming inputs, while retaining the efficient convolutional kernel computation of S4.</p>
<h2 id="scaling-liquid-networks-to-long-sequences">Scaling Liquid Networks to Long Sequences</h2>
<p>Liquid time-constant (LTC) networks are continuous-time neural networks with input-dependent state transitions, giving them strong generalization and causal modeling properties. However, LTCs rely on ODE solvers that limit their scalability to long sequences. Structured state-space models (S4) solve this scalability problem through HiPPO initialization, diagonal plus low-rank (DPLR) parameterization, and efficient Cauchy kernel computation in the frequency domain, but they use fixed (input-independent) state transitions.</p>
<p>The key question this paper addresses: can the expressivity of LTC networks be combined with the efficiency and scalability of S4 to improve long-range sequence modeling?</p>
<h2 id="the-liquid-kernel-input-dependent-convolutions">The Liquid Kernel: Input-Dependent Convolutions</h2>
<p>The core innovation is a linearized LTC state-space model that replaces the standard SSM dynamics:</p>
<p>$$\dot{x}(t) = \mathbf{A}x(t) + \mathbf{B}u(t)$$</p>
<p>with an input-dependent formulation:</p>
<p>$$\dot{x}(t) = \left[\mathbf{A} + \mathbf{B}u(t)\right]x(t) + \mathbf{B}u(t)$$</p>
<p>where $u(t)$ now modulates the state transition matrix itself. After discretization via the <a href="https://en.wikipedia.org/wiki/Bilinear_transform">bilinear transform</a>, the recurrence becomes:</p>
<p>$$x_{k} = \left(\overline{\mathbf{A}} + \overline{\mathbf{B}}u_{k}\right)x_{k-1} + \overline{\mathbf{B}}u_{k}$$</p>
<p>Unrolling this recurrence reveals that the output $y_{k}$ decomposes into two parts:</p>
<p>$$y = \overline{\mathbf{K}} * u + \overline{\mathbf{K}}_{\text{liquid}} * u_{\text{correlations}}$$</p>
<p>The first term is the standard S4 convolutional kernel $\overline{\mathbf{K}}$, mapping individual input time steps independently. The second term is a new &ldquo;liquid kernel&rdquo; $\overline{\mathbf{K}}_{\text{liquid}}$ that operates on <a href="https://en.wikipedia.org/wiki/Autocorrelation">auto-correlation</a> terms of the input signal (products $u_{i}u_{j}$, $u_{i}u_{j}u_{k}$, etc., up to a chosen order $\mathcal{P}$).</p>
<p><strong>Proposition 1</strong> shows that each liquid kernel of order $p$ can be computed from the precomputed S4 kernel via a <a href="https://en.wikipedia.org/wiki/Hadamard_product_(matrices)">Hadamard product</a> with $\overline{\mathbf{B}}^{p-1}$ followed by an anti-diagonal transformation (flip):</p>
<p>$$\overline{\mathbf{K}}_{\text{liquid}=p} = \left[\overline{\mathbf{K}}_{(L-\tilde{L},L)} \odot \overline{\mathbf{B}}_{(L-\tilde{L},L)}^{p-1}\right] * \mathbf{J}_{\tilde{L}}$$</p>
<p>This is the KB (Kernel $\times$ B) mode. The authors also propose a simplified PB (Powers of B) mode that sets the transition matrix $\overline{\mathbf{A}}$ to identity for the correlation terms:</p>
<p>$$\overline{\mathbf{K}}_{\text{liquid}=p} = \overline{\mathbf{C}} \odot \overline{\mathbf{B}}^{p-1}$$</p>
<p>The PB kernel is cheaper to compute and performs equally well or better in practice.</p>
<p>The computational complexity is $\tilde{\mathcal{O}}(N + L + p_{\text{max}}\tilde{L})$, where $N$ is the state size, $L$ the sequence length, $p_{\text{max}}$ the maximum liquid order, and $\tilde{L}$ the liquid kernel length (typically two orders of magnitude smaller than $L$).</p>
<h2 id="benchmarks-across-long-range-sequence-tasks">Benchmarks Across Long-Range Sequence Tasks</h2>
<p>Liquid-S4 is evaluated on four benchmark suites with the PB kernel using the S4-LegS (scaled <a href="https://en.wikipedia.org/wiki/Legendre_polynomials">Legendre</a>) parameterization.</p>
<h3 id="long-range-arena-lra">Long Range Arena (LRA)</h3>
<p>The LRA benchmark contains six tasks with sequence lengths from 1K to 16K. Liquid-S4 achieves state-of-the-art on all six tasks with an average accuracy of 87.32%:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Input Length</th>
          <th>Liquid-S4</th>
          <th>S4-LegS</th>
          <th>Improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ListOps</td>
          <td>2048</td>
          <td>62.75%</td>
          <td>59.60%</td>
          <td>+3.15%</td>
      </tr>
      <tr>
          <td>Text (IMDB)</td>
          <td>2048</td>
          <td>89.02%</td>
          <td>86.82%</td>
          <td>+2.20%</td>
      </tr>
      <tr>
          <td>Retrieval (AAN)</td>
          <td>4000</td>
          <td>91.20%</td>
          <td>90.90%</td>
          <td>+0.30%</td>
      </tr>
      <tr>
          <td>Image (CIFAR)</td>
          <td>1024</td>
          <td>89.50%</td>
          <td>88.65%</td>
          <td>+0.85%</td>
      </tr>
      <tr>
          <td>Pathfinder</td>
          <td>1024</td>
          <td>94.80%</td>
          <td>94.20%</td>
          <td>+0.60%</td>
      </tr>
      <tr>
          <td>Path-X</td>
          <td>16384</td>
          <td>96.66%</td>
          <td>96.35%</td>
          <td>+0.31%</td>
      </tr>
      <tr>
          <td><strong>Average</strong></td>
          <td></td>
          <td><strong>87.32%</strong></td>
          <td><strong>86.09%</strong></td>
          <td><strong>+1.23%</strong></td>
      </tr>
  </tbody>
</table>
<p>Liquid orders $p$ range from 2 to 6 across tasks.</p>
<h3 id="bidmc-vital-signs">BIDMC Vital Signs</h3>
<p>On medical time-series regression (heart rate, respiratory rate, <a href="https://en.wikipedia.org/wiki/Oxygen_saturation_(medicine)">SpO2</a> prediction from length-4000 biomarker signals):</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Liquid-S4 (RMSE)</th>
          <th>S4-LegS (RMSE)</th>
          <th>Improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Heart Rate</td>
          <td>0.303</td>
          <td>0.332</td>
          <td>8.7%</td>
      </tr>
      <tr>
          <td>Respiratory Rate</td>
          <td>0.158</td>
          <td>0.247</td>
          <td>36.0%</td>
      </tr>
      <tr>
          <td>SpO2</td>
          <td>0.066</td>
          <td>0.090</td>
          <td>26.7%</td>
      </tr>
  </tbody>
</table>
<h3 id="sequential-cifar-scifar">Sequential CIFAR (sCIFAR)</h3>
<p>Liquid-S4 with $p=3$ achieves 92.02% accuracy on 1-D pixel-level image classification, improving over S4-LegS (91.80%).</p>
<h3 id="speech-commands-full-35-labels">Speech Commands (Full 35 Labels)</h3>
<p>On the raw 16kHz speech recognition task, Liquid-S4 achieves 96.78% accuracy with only 224K parameters, a 30% reduction compared to S4&rsquo;s 307K. On the zero-shot 8kHz experiment, performance drops to 90.00% (vs. 91.32% for S4-LegS), which the authors attribute to the liquid kernel&rsquo;s sensitivity to input covariance structure at different sampling rates.</p>
<h2 id="consistent-improvements-with-smaller-models">Consistent Improvements with Smaller Models</h2>
<p>Liquid-S4 achieves state-of-the-art performance on every benchmark evaluated: all six LRA tasks (87.32% average), all three BIDMC vital signs tasks, sCIFAR, and full Speech Commands recognition. The gains are particularly large on tasks where input correlation structure matters (ListOps +3.15%, IMDB +2.20%, respiratory rate RMSE improvement of 36%).</p>
<p>A practical advantage is that Liquid-S4 works well with smaller state sizes (as low as 7 units for some tasks), reducing parameter counts. The PB kernel is recommended over KB for its simplicity and competitive performance. Higher liquid orders ($p$) consistently improve performance, though $p=3$ is recommended as a default.</p>
<p>Limitations include degraded performance in zero-shot frequency transfer (8kHz Speech Commands), suggesting the liquid kernel&rsquo;s input covariance terms may not generalize well across sampling rate changes. The paper also does not compare against non-SSM approaches beyond the LRA benchmark. The causal (unidirectional) configuration works better than bidirectional for Liquid-S4, which may limit applicability to tasks that benefit from bidirectional context.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Classification: Partially Reproducible.</strong> Code and all benchmark datasets are publicly available, with complete hyperparameters documented. No pre-trained weights are released and hardware requirements are not specified.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/raminmh/liquid-s4">raminmh/liquid-s4</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official PyTorch implementation; fork of the S4 repo with KB and PB kernels added</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Evaluation</td>
          <td>Long Range Arena (LRA)</td>
          <td>6 tasks, 1K-16K seq length</td>
          <td>ListOps, IMDB, AAN, CIFAR, Pathfinder, Path-X</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>BIDMC Vital Signs</td>
          <td>4000-length biomarker signals</td>
          <td>Heart rate, respiratory rate, SpO2</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>sCIFAR</td>
          <td>1024-length flattened images</td>
          <td>10-class classification</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Speech Commands</td>
          <td>16kHz raw audio, 35 labels</td>
          <td>Full dataset with zero-shot 8kHz test</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The Liquid-S4 kernel computation builds on the S4 kernel pipeline:</p>
<ol>
<li>Initialize $\mathbf{A}$ with HiPPO (scaled Legendre) matrix in DPLR form</li>
<li>Compute S4 kernel $\overline{\mathbf{K}}$ via Cauchy kernel and iFFT</li>
<li>For each liquid order $p \in {2, \ldots, \mathcal{P}}$, compute $\overline{\mathbf{K}}_{\text{liquid}=p}$ using either KB or PB mode</li>
<li>Convolve $\overline{\mathbf{K}}_{\text{liquid}}$ with input correlation vector $u_{\text{correlations}}$</li>
</ol>
<p>The PB kernel mode is used in all reported experiments. The PyKeops package is used for large tensor computations.</p>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Depth</th>
          <th>Features</th>
          <th>State Size</th>
          <th>Norm</th>
          <th>LR</th>
          <th>Epochs</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ListOps</td>
          <td>9</td>
          <td>128</td>
          <td>7</td>
          <td>BN</td>
          <td>0.002</td>
          <td>30</td>
      </tr>
      <tr>
          <td>IMDB</td>
          <td>4</td>
          <td>128</td>
          <td>7</td>
          <td>BN</td>
          <td>0.003</td>
          <td>50</td>
      </tr>
      <tr>
          <td>AAN</td>
          <td>6</td>
          <td>256</td>
          <td>64</td>
          <td>BN</td>
          <td>0.005</td>
          <td>20</td>
      </tr>
      <tr>
          <td>CIFAR (LRA)</td>
          <td>6</td>
          <td>512</td>
          <td>512</td>
          <td>LN</td>
          <td>0.01</td>
          <td>200</td>
      </tr>
      <tr>
          <td>Pathfinder</td>
          <td>6</td>
          <td>256</td>
          <td>64</td>
          <td>BN</td>
          <td>0.0004</td>
          <td>200</td>
      </tr>
      <tr>
          <td>Path-X</td>
          <td>6</td>
          <td>320</td>
          <td>64</td>
          <td>BN</td>
          <td>0.001</td>
          <td>60</td>
      </tr>
      <tr>
          <td>Speech Commands</td>
          <td>6</td>
          <td>128</td>
          <td>7</td>
          <td>BN</td>
          <td>0.008</td>
          <td>50</td>
      </tr>
      <tr>
          <td>BIDMC (HR)</td>
          <td>6</td>
          <td>128</td>
          <td>256</td>
          <td>LN</td>
          <td>0.005</td>
          <td>500</td>
      </tr>
      <tr>
          <td>BIDMC (RR)</td>
          <td>6</td>
          <td>128</td>
          <td>256</td>
          <td>LN</td>
          <td>0.01</td>
          <td>500</td>
      </tr>
      <tr>
          <td>BIDMC (SpO2)</td>
          <td>6</td>
          <td>128</td>
          <td>256</td>
          <td>LN</td>
          <td>0.01</td>
          <td>500</td>
      </tr>
      <tr>
          <td>sCIFAR</td>
          <td>6</td>
          <td>512</td>
          <td>512</td>
          <td>LN</td>
          <td>0.01</td>
          <td>200</td>
      </tr>
  </tbody>
</table>
<p>Liquid-S4 generally requires smaller learning rates than S4/S4D. $\Delta t_{\text{max}} = 0.2$ for all experiments; $\Delta t_{\text{min}} \propto 1/\text{seq_length}$.</p>
<h3 id="evaluation">Evaluation</h3>
<p>All results report validation accuracy (except BIDMC, which reports test RMSE). Experiments use 2-3 random seeds with standard deviations reported.</p>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Hasani, R., Lechner, M., Wang, T.-H., Chahine, M., Amini, A., &amp; Rus, D. (2022). Liquid Structural State-Space Models. <em>arXiv preprint arXiv:2209.12951</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{hasani2022liquid,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Liquid Structural State-Space Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Hasani, Ramin and Lechner, Mathias and Wang, Tsun-Hsuan and Chahine, Makram and Amini, Alexander and Rus, Daniela}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2209.12951}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span>=<span style="color:#e6db74">{cs.LG}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Lagrangian Neural Networks for Physics</title><link>https://hunterheidenreich.com/notes/machine-learning/model-architectures/lagrangian-neural-networks/</link><pubDate>Tue, 07 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/model-architectures/lagrangian-neural-networks/</guid><description>LNNs parameterize arbitrary Lagrangians with neural networks, learning energy-conserving dynamics without requiring canonical coordinates.</description><content:encoded><![CDATA[<h2 id="a-method-for-learning-arbitrary-lagrangians">A Method for Learning Arbitrary Lagrangians</h2>
<p>This is a <strong>Method</strong> paper that introduces Lagrangian Neural Networks (LNNs), a neural network architecture that parameterizes arbitrary Lagrangians to learn energy-conserving dynamics from data. The key contribution is showing that neural networks can learn Lagrangian functions directly, and that the Euler-Lagrange equation can be solved numerically using automatic differentiation to produce physically consistent dynamics. The approach is strictly more general than prior methods: it does not require canonical coordinates (unlike Hamiltonian Neural Networks) and does not restrict the functional form of kinetic energy (unlike Deep Lagrangian Networks).</p>
<h2 id="why-standard-neural-networks-fail-at-conservation-laws">Why Standard Neural Networks Fail at Conservation Laws</h2>
<p>Neural networks struggle to learn fundamental symmetries and conservation laws from data. A standard neural network trained on trajectories of a <a href="https://en.wikipedia.org/wiki/Double_pendulum">double pendulum</a> will gradually dissipate energy over long rollouts, producing physically implausible behavior. This happens because unconstrained function approximators have no inductive bias toward conservation.</p>
<p>Hamiltonian Neural Networks (HNNs) addressed this by learning a Hamiltonian function, which automatically enforces energy conservation. However, the <a href="https://en.wikipedia.org/wiki/Hamiltonian_mechanics">Hamiltonian formalism</a> requires inputs in <a href="https://en.wikipedia.org/wiki/Canonical_coordinates">canonical coordinates</a> $(q, p)$ satisfying strict <a href="https://en.wikipedia.org/wiki/Poisson_bracket">Poisson bracket</a> relations:</p>
<p>$$
p_i \equiv \frac{\partial \mathcal{L}}{\partial \dot{q}_i} \quad \Longleftrightarrow \quad {q_i, q_j} = 0, \quad {p_i, p_j} = 0, \quad {q_i, p_j} = \delta_{ij}
$$</p>
<p>In many real-world settings, the canonical momenta are unknown or difficult to compute. For example, in special relativity the canonical momentum $\dot{q}(1 - \dot{q}^2)^{-3/2}$ is a complex nonlinear function of velocity. Deep Lagrangian Networks (DeLaNs) partially addressed this by learning Lagrangians, but they assumed kinetic energy takes the rigid-body form $T = \dot{q}^T M \dot{q}$, which excludes relativistic and other non-standard systems.</p>
<h2 id="solving-euler-lagrange-for-a-black-box-lagrangian">Solving Euler-Lagrange for a Black-Box Lagrangian</h2>
<p>The core innovation of LNNs is a method for computing accelerations from a neural network that represents an arbitrary Lagrangian $\mathcal{L}(q, \dot{q})$. Starting from the <a href="https://en.wikipedia.org/wiki/Euler%E2%80%93Lagrange_equation">Euler-Lagrange equation</a>:</p>
<p>$$
\frac{d}{dt} \nabla_{\dot{q}} \mathcal{L} = \nabla_{q} \mathcal{L}
$$</p>
<p>The authors expand the time derivative using the chain rule, yielding:</p>
<p>$$
\left(\nabla_{\dot{q}} \nabla_{\dot{q}}^{\top} \mathcal{L}\right) \ddot{q} + \left(\nabla_{q} \nabla_{\dot{q}}^{\top} \mathcal{L}\right) \dot{q} = \nabla_{q} \mathcal{L}
$$</p>
<p>Solving for the accelerations gives:</p>
<p>$$
\ddot{q} = \left(\nabla_{\dot{q}} \nabla_{\dot{q}}^{\top} \mathcal{L}\right)^{-1} \left[ \nabla_{q} \mathcal{L} - \left(\nabla_{q} \nabla_{\dot{q}}^{\top} \mathcal{L}\right) \dot{q} \right]
$$</p>
<p>This requires computing the Hessian of the neural network with respect to $\dot{q}$ and then inverting it (using a pseudoinverse for numerical stability). JAX&rsquo;s automatic differentiation makes this feasible in just a few lines of code, despite the seemingly complex chain of second-order derivatives. The matrix inverse scales as $\mathcal{O}(d^3)$ with the number of coordinates $d$.</p>
<p>A critical implementation detail is the choice of activation function. Since the method takes second-order derivatives of the network, ReLU is unsuitable (its second derivative is zero everywhere). After a hyperparameter search over ReLU$^2$, ReLU$^3$, tanh, sigmoid, and softplus, the authors found <a href="https://en.wikipedia.org/wiki/Softplus">softplus</a> performed best.</p>
<p>The authors also developed a custom initialization scheme, using symbolic regression to find initialization variances that maintain well-conditioned gradients through the Hessian computation:</p>
<p>$$
\sigma = \frac{1}{\sqrt{n}} \begin{cases} 2.2 &amp; \text{First layer} \\ 0.58i &amp; \text{Hidden layer } i \\ n &amp; \text{Output layer} \end{cases}
$$</p>
<h2 id="extension-to-graphs-and-continuous-systems">Extension to Graphs and Continuous Systems</h2>
<p>LNNs extend naturally to graph-structured and continuous systems via Lagrangian <a href="/notes/machine-learning/model-architectures/relational-inductive-biases-deep-learning-graph-networks/">Graph Networks</a>. For a system with $n$ gridpoints, the total Lagrangian is decomposed into local densities:</p>
<p>$$
\mathcal{L} = \sum_{i=1}^{n} \mathcal{L}_i, \quad \text{where} \quad \mathcal{L}_i = \mathcal{L}_{\text{density}}\left({\phi_j, \dot{\phi}_j}_{j \in \mathcal{I}_i}\right)
$$</p>
<p>Here $\mathcal{I}_i$ defines the neighborhood of node $i$ (e.g., ${i-1, i, i+1}$ for a 1D grid). The Lagrangian density is modeled as an MLP. The resulting Hessian matrix is sparse, with non-zero entries only at &ldquo;neighbor of neighbor&rdquo; positions, enabling efficient computation: in 1D, only 5 forward-over-backward autodiff passes are needed, and the tridiagonal inverse runs in linear time.</p>
<h2 id="experiments-double-pendulum-relativity-and-waves">Experiments: Double Pendulum, Relativity, and Waves</h2>
<p>All models used 4-layer MLPs with 500 hidden units, softplus activations, a decaying learning rate starting at $10^{-3}$, and batch size 32.</p>
<h3 id="double-pendulum">Double Pendulum</h3>
<p>The LNN and baseline achieved similar instantaneous acceleration losses ($7.3$ vs. $7.4 \times 10^{-2}$). The key difference appeared in long-term energy conservation: averaged over 40 random initial conditions with 100 time steps, the mean energy discrepancy was 8% of max potential energy for the baseline but only 0.4% for the LNN.</p>
<h3 id="relativistic-particle">Relativistic Particle</h3>
<p>For a particle with Lagrangian $\mathcal{L} = ((1 - \dot{q}^2)^{-1/2} - 1) + gq$, the canonical momenta $\dot{q}(1 - \dot{q}^2)^{-3/2}$ are non-trivial. An HNN trained on non-canonical coordinates $(q, \dot{q})$ failed to learn the dynamics. The LNN succeeded using the same non-canonical coordinates, matching the performance of an HNN given the correct canonical coordinates.</p>
<h3 id="1d-wave-equation">1D Wave Equation</h3>
<p>The Lagrangian Graph Network learned the wave equation dynamics ($\ddot{\phi} = \frac{\partial^2 \phi}{\partial x^2}$ with $c = 1$) on a 100-gridpoint domain with periodic boundary conditions. The network learned the Lagrangian density corresponding to the continuum form $\mathcal{L} = \int (\dot{\phi}^2 - (\partial \phi / \partial x)^2) dx$, accurately modeling wave propagation and conserving energy across the material.</p>
<table>
  <thead>
      <tr>
          <th>Experiment</th>
          <th>Model</th>
          <th>Energy Error (% of max PE)</th>
          <th>Canonical Coords Required</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Double Pendulum</td>
          <td>Baseline</td>
          <td>8%</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>Double Pendulum</td>
          <td>LNN</td>
          <td>0.4%</td>
          <td>No</td>
      </tr>
      <tr>
          <td>Relativistic Particle</td>
          <td>HNN (non-canonical)</td>
          <td>Failed</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>Relativistic Particle</td>
          <td>HNN (canonical)</td>
          <td>Succeeded</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>Relativistic Particle</td>
          <td>LNN</td>
          <td>Succeeded</td>
          <td>No</td>
      </tr>
      <tr>
          <td>1D Wave Equation</td>
          <td>LGN</td>
          <td>Energy conserved</td>
          <td>No</td>
      </tr>
  </tbody>
</table>
<h2 id="findings-and-comparison-to-prior-approaches">Findings and Comparison to Prior Approaches</h2>
<p>LNNs combine several desirable properties that no single prior method offers:</p>
<table>
  <thead>
      <tr>
          <th>Property</th>
          <th>Neural Net</th>
          <th>Neural ODE</th>
          <th>HNN</th>
          <th>DeLaN</th>
          <th>LNN</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Models dynamical systems</td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>Learns differential equations</td>
          <td></td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>Learns exact conservation laws</td>
          <td></td>
          <td></td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>Learns from arbitrary coordinates</td>
          <td>Yes</td>
          <td>Yes</td>
          <td></td>
          <td>Yes</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>Learns arbitrary Lagrangians</td>
          <td></td>
          <td></td>
          <td></td>
          <td></td>
          <td>Yes</td>
      </tr>
  </tbody>
</table>
<p>The main limitation is computational cost: the Hessian computation and inversion scale as $\mathcal{O}(d^3)$ in the number of coordinates. The Lagrangian Graph Network partially mitigates this for spatially extended systems through the sparsity of the resulting Hessian. The method also assumes access to state derivatives ($\dot{q}$) during training, which may not always be directly available from observations.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>Double pendulum</td>
          <td>600,000 random initial conditions</td>
          <td>Simulated with masses and lengths set to 1</td>
      </tr>
      <tr>
          <td>Training</td>
          <td>Relativistic particle</td>
          <td>Random initial conditions and $g$ values</td>
          <td>$c = 1$, mass = 1, uniform potential</td>
      </tr>
      <tr>
          <td>Training</td>
          <td>1D wave equation</td>
          <td>100 gridpoints</td>
          <td>Periodic boundary conditions, $c = 1$</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Forward model: Euler-Lagrange equation solved via Equation 6 using JAX autodiff</li>
<li>Pseudoinverse used for Hessian inversion to handle potential singular matrices</li>
<li>Custom initialization scheme (Equation 16) derived via symbolic regression with eureqa</li>
<li>Softplus activation selected via hyperparameter search</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>4-layer MLP with 500 hidden units for all experiments</li>
<li>Softplus activation function</li>
<li>Code: <a href="https://github.com/MilesCranmer/lagrangian_nns">github.com/MilesCranmer/lagrangian_nns</a> (Apache-2.0)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>LNN</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Acceleration loss (double pendulum)</td>
          <td>$7.3 \times 10^{-2}$</td>
          <td>$7.4 \times 10^{-2}$</td>
          <td>Similar short-term accuracy</td>
      </tr>
      <tr>
          <td>Energy error (double pendulum)</td>
          <td>0.4%</td>
          <td>8%</td>
          <td>Percentage of max potential energy</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper. JAX-based implementation supports CPU and GPU execution.</p>
<hr>
<p><strong>Reproducibility Status</strong>: Highly Reproducible</p>
<h2 id="artifacts">Artifacts</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/MilesCranmer/lagrangian_nns">lagrangian_nns</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official JAX implementation with notebooks for all experiments</td>
      </tr>
      <tr>
          <td>Training data</td>
          <td>Dataset</td>
          <td>N/A</td>
          <td>Generated procedurally; simulation code included in repository</td>
      </tr>
      <tr>
          <td>Trained models</td>
          <td>Model</td>
          <td>N/A</td>
          <td>Not provided</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Cranmer, M., Greydanus, S., Hoyer, S., Battaglia, P., Spergel, D., &amp; Ho, S. (2020). Lagrangian Neural Networks. <em>ICLR 2020 Workshop on Integration of Deep Neural Models and Differential Equations</em>. arXiv: <a href="https://arxiv.org/abs/2003.04630">2003.04630</a></p>
<p><strong>Publication</strong>: ICLR 2020 Workshop on Integration of Deep Neural Models and Differential Equations</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{cranmer2020lagrangian,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Lagrangian Neural Networks}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Cranmer, Miles and Greydanus, Sam and Hoyer, Stephan and Battaglia, Peter and Spergel, David and Ho, Shirley}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2003.04630}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span>=<span style="color:#e6db74">{cs.LG}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Ewald Message Passing for Molecular Graphs</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/ewald-message-passing-molecular-graphs/</link><pubDate>Tue, 07 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/ewald-message-passing-molecular-graphs/</guid><description>Ewald message passing augments GNNs with Fourier-space long-range interactions, improving energy predictions by 10-16% on OC20 and OE62 benchmarks.</description><content:encoded><![CDATA[<h2 id="a-fourier-space-long-range-correction-for-molecular-gnns">A Fourier-Space Long-Range Correction for Molecular GNNs</h2>
<p>This is a <strong>Method</strong> paper that introduces Ewald message passing (Ewald MP), a general framework for incorporating long-range interactions into message passing neural networks (MPNNs) for molecular <a href="/notes/chemistry/molecular-simulation/ml-potentials/learning-smooth-interatomic-potentials/">potential energy surface</a> prediction. The key contribution is a nonlocal Fourier-space message passing scheme, grounded in the classical <a href="https://en.wikipedia.org/wiki/Ewald_summation">Ewald summation</a> technique from computational physics, that complements the short-range message passing of existing GNN architectures.</p>
<h2 id="the-long-range-interaction-problem-in-molecular-gnns">The Long-Range Interaction Problem in Molecular GNNs</h2>
<p>Standard MPNNs for molecular property prediction rely on a spatial distance cutoff to define atomic neighborhoods. While this locality assumption enables favorable scaling with system size and provides a useful inductive bias, it fundamentally limits the model&rsquo;s ability to capture long-range interactions such as electrostatic forces and van der Waals (<a href="https://en.wikipedia.org/wiki/London_dispersion_force">London dispersion</a>) interactions. These interactions decay slowly with distance (e.g., electrostatic energy follows a $1/r$ power law), and truncating them with a distance cutoff can introduce severe artifacts in thermochemical predictions.</p>
<p>This problem is well-known in molecular dynamics, where empirical force fields explicitly separate bonded (short-range) and non-bonded (long-range) energy terms. The Ewald summation technique addresses this by decomposing interactions into a short-range part that converges quickly with a distance cutoff and a long-range part whose Fourier transform converges quickly with a frequency cutoff. The authors propose bringing this same strategy into the GNN paradigm.</p>
<h2 id="from-ewald-summation-to-learnable-fourier-space-messages">From Ewald Summation to Learnable Fourier-Space Messages</h2>
<p>The core insight is a formal analogy between the continuous-filter convolution used in MPNNs and the electrostatic potential computation in Ewald summation. In a standard continuous-filter convolution, the message sum for atom $i$ is:</p>
<p>$$
M_i^{(l+1)} = \sum_{j \in \mathcal{N}(i)} h_j^{(l)} \cdot \Phi^{(l)}(| \mathbf{x}_i - \mathbf{x}_j |)
$$</p>
<p>where $h_j^{(l)}$ are atom embeddings and $\Phi^{(l)}$ is a learned radial filter. Comparing this to the electrostatic potential $V_i^{\text{es}}(\mathbf{x}_i) = \sum_{j \neq i} q_j \cdot \Phi^{\text{es}}(| \mathbf{x}_i - \mathbf{x}_j |)$ reveals a direct correspondence: atom embeddings play the role of partial charges, and learned filters replace the $1/r$ kernel.</p>
<p>Ewald MP decomposes the learned filter into short-range and long-range components. The short-range part is handled by any existing GNN architecture with a distance cutoff. The long-range part is computed as a sum over Fourier frequencies:</p>
<p>$$
M^{\text{lr}}(\mathbf{x}_i) = \sum_{\mathbf{k}} \exp(i \mathbf{k}^T \mathbf{x}_i) \cdot s_{\mathbf{k}} \cdot \hat{\Phi}^{\text{lr}}(| \mathbf{k} |)
$$</p>
<p>where $s_{\mathbf{k}}$ are <strong><a href="https://en.wikipedia.org/wiki/Structure_factor">structure factor</a> embeddings</strong>, computed as:</p>
<p>$$
s_{\mathbf{k}} = \sum_{j \in \mathcal{S}} h_j \exp(-i \mathbf{k}^T \mathbf{x}_j)
$$</p>
<p>These structure factor embeddings are a Fourier-space representation of the atom embedding distribution, and truncating to low frequencies effectively coarse-grains the hidden model state while preserving long-range information. The frequency filters $\hat{\Phi}^{\text{lr}}$ are learned, making the entire scheme data-driven rather than tied to a fixed physical functional form.</p>
<p>The method handles both <strong>periodic</strong> systems (where the <a href="https://en.wikipedia.org/wiki/Reciprocal_lattice">reciprocal lattice</a> provides a natural frequency discretization) and <strong>aperiodic</strong> systems (where the Fourier domain is discretized using a cubic voxel grid with SVD-based rotation alignment to preserve rotation invariance). The combined embedding update becomes:</p>
<p>$$
h_i^{(l+1)} = \frac{1}{\sqrt{3}} \left[ h_i^{(l)} + f_{\text{upd}}^{\text{sr}}(M_i^{\text{sr}}) + f_{\text{upd}}^{\text{lr}}(M_i^{\text{lr}}) \right]
$$</p>
<p>The computational complexity is $\mathcal{O}(N_{\text{at}} N_{\text{k}})$, and by fixing the number of frequency vectors $N_{\text{k}}$, linear scaling $\mathcal{O}(N_{\text{at}})$ is achievable.</p>
<h2 id="experiments-across-four-gnn-architectures-and-two-datasets">Experiments Across Four GNN Architectures and Two Datasets</h2>
<p>The authors test Ewald MP as an augmentation on four baseline architectures: <a href="/notes/chemistry/datasets/marcel/">SchNet, PaiNN, DimeNet++, and GemNet-T</a>. Two datasets are used:</p>
<ul>
<li><strong>OC20</strong> (Chanussot et al., 2021): ~265M periodic structures of adsorbate-catalyst systems with DFT-computed energies and forces. The OC20-2M subsplit is used for training.</li>
<li><strong>OE62</strong> (Stuke et al., 2020): ~62,000 large aperiodic organic molecules with DFT-computed energies that include a DFT-D3 dispersion correction for London dispersion interactions.</li>
</ul>
<p>All baselines use a 6 Å distance cutoff and 50 maximum neighbors. The Ewald modification is minimal: the long-range message sum is added as an additional skip connection term in each interaction block. Comparison studies include: (1) increasing the distance cutoff to match the computational cost of Ewald MP, (2) replacing the Ewald block with a SchNet interaction block at increased cutoff, and (3) increasing atom embedding dimensions to match Ewald MP&rsquo;s parameter count.</p>
<h3 id="key-energy-mae-results-on-oe62">Key Energy MAE Results on OE62</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Baseline (meV)</th>
          <th>Ewald MP (meV)</th>
          <th>Improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SchNet</td>
          <td>133.5</td>
          <td>79.2</td>
          <td>40.7%</td>
      </tr>
      <tr>
          <td>PaiNN</td>
          <td>61.4</td>
          <td>57.9</td>
          <td>5.7%</td>
      </tr>
      <tr>
          <td>DimeNet++</td>
          <td>51.2</td>
          <td>46.5</td>
          <td>9.2%</td>
      </tr>
      <tr>
          <td>GemNet-T</td>
          <td>51.5</td>
          <td>47.4</td>
          <td>8.0%</td>
      </tr>
  </tbody>
</table>
<h3 id="key-energy-mae-results-on-oc20-averaged-across-test-splits">Key Energy MAE Results on OC20 (Averaged Across Test Splits)</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Baseline (meV)</th>
          <th>Ewald MP (meV)</th>
          <th>Improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SchNet</td>
          <td>895</td>
          <td>830</td>
          <td>7.3%</td>
      </tr>
      <tr>
          <td>PaiNN</td>
          <td>448</td>
          <td>393</td>
          <td>12.3%</td>
      </tr>
      <tr>
          <td>DimeNet++</td>
          <td>496</td>
          <td>445</td>
          <td>10.4%</td>
      </tr>
      <tr>
          <td>GemNet-T</td>
          <td>346</td>
          <td>307</td>
          <td>11.3%</td>
      </tr>
  </tbody>
</table>
<h2 id="robust-long-range-improvements-and-dispersion-recovery">Robust Long-Range Improvements and Dispersion Recovery</h2>
<p>Ewald MP achieves consistent improvements across all models and both datasets, averaging 16.1% on OE62 and 10.3% on OC20. Several findings stand out:</p>
<ol>
<li>
<p><strong>Robustness</strong>: Unlike the increased-cutoff and SchNet-LR alternatives, Ewald MP never produces detrimental effects in any tested configuration. The increased cutoff setting hurts SchNet and PaiNN on OE62, and the SchNet-LR block fails to improve DimeNet++ and GemNet-T.</p>
</li>
<li>
<p><strong>Long-range specificity</strong>: A binning analysis on OE62 groups molecules by the magnitude of their DFT-D3 dispersion correction. Ewald MP shows an outsize improvement for structures with large long-range energy contributions. It recovers or surpasses a &ldquo;cheating&rdquo; baseline that receives the exact DFT-D3 ground truth as an additional input.</p>
</li>
<li>
<p><strong>Efficiency on periodic systems</strong>: Ewald MP achieves similar relative improvements on OC20 at roughly half the relative computational cost compared to OE62, suggesting periodic structures as a particularly attractive application domain.</p>
</li>
<li>
<p><strong>Force predictions</strong>: Improvements in <a href="/notes/chemistry/molecular-simulation/ml-potentials/dark-side-of-forces/">force MAEs</a> are consistent but small, which is expected since the frequency truncation removes high-frequency contributions to the potential energy surface.</p>
</li>
<li>
<p><strong>Ablation studies</strong>: Results are robust across different frequency cutoffs, voxel resolutions, and filtering strategies, with the non-radial periodic filtering scheme outperforming radial alternatives on out-of-distribution generalization.</p>
</li>
</ol>
<p>Limitations include the current focus on scalar (invariant) embeddings only (PaiNN&rsquo;s equivariant vector embeddings are not augmented), and the potential for a &ldquo;gap&rdquo; of medium-range interactions when $N_{\text{k}}$ is fixed for linear scaling. The authors suggest adapting more efficient Ewald summation variants (e.g., particle mesh Ewald with $\mathcal{O}(N \log N)$ scaling) as future work.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training (periodic)</td>
          <td>OC20-2M</td>
          <td>~2M structures</td>
          <td>Subsplit of OC20; PBC; DFT energies and forces</td>
      </tr>
      <tr>
          <td>Training (aperiodic)</td>
          <td>OE62</td>
          <td>~62,000 molecules</td>
          <td>Large organic molecules; DFT energies with D3 correction</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>OC20-test (4 splits: ID, OOD-ads, OOD-cat, OOD-both)</td>
          <td>Varies</td>
          <td>Evaluated via submission to OC20 evaluation server</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>OE62-val, OE62-test</td>
          <td>~6,000 each</td>
          <td>Direct evaluation</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Ewald message passing is integrated as an additional skip connection term in each interaction block</li>
<li>For periodic systems: non-radial filtering with fixed reciprocal lattice positions ($N_x, N_y, N_z$ hyperparameters)</li>
<li>For aperiodic systems: radial Gaussian basis function filtering with frequency cutoff $c_k$ and voxel resolution $\Delta = 0.2$ Å$^{-1}$</li>
<li>SVD-based coordinate alignment for rotation invariance in the aperiodic case</li>
<li>Bottleneck dimension $N_\downarrow = 16$ (GemNet-T) or $N_\downarrow = 8$ (others)</li>
<li>Update function: dense layer + $N_{\text{hidden}}$ residual layers ($N_{\text{hidden}} = 3$, except PaiNN with $N_{\text{hidden}} = 0$)</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Embedding Size (OE62)</th>
          <th>Interaction Blocks</th>
          <th>Ewald Params (OE62)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SchNet</td>
          <td>512</td>
          <td>4</td>
          <td>12.2M total</td>
      </tr>
      <tr>
          <td>PaiNN</td>
          <td>512</td>
          <td>4</td>
          <td>15.7M total</td>
      </tr>
      <tr>
          <td>DimeNet++</td>
          <td>256</td>
          <td>3</td>
          <td>4.8M total</td>
      </tr>
      <tr>
          <td>GemNet-T</td>
          <td>256</td>
          <td>3</td>
          <td>16.1M total</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>Primary metric: Energy mean absolute error (EMAE) in meV</li>
<li>Secondary metric: Force MAE in meV/Å (OC20 only)</li>
<li>Loss: Linear combination of energy and force MAEs (Eq. 15) with model-specific force multipliers</li>
<li>Optimizer: Adam with weight decay ($\lambda = 0.01$)</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>All runtime measurements on NVIDIA A100 GPUs</li>
<li>Runtimes measured after 50 warmup batches, averaged over 500 batches, minimum of 3 repetitions</li>
<li>Code: <a href="https://github.com/arthurkosmala/EwaldMP">EwaldMP</a> (Hippocratic License 3.0)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/arthurkosmala/EwaldMP">EwaldMP</a></td>
          <td>Code</td>
          <td>Hippocratic License 3.0 (new files) / MIT (OC20 base)</td>
          <td>Official implementation built on the Open Catalyst Project codebase</td>
      </tr>
      <tr>
          <td><a href="https://github.com/Open-Catalyst-Project/ocp/blob/main/DATASET.md">OC20</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>~265M periodic adsorbate-catalyst structures with DFT energies and forces</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.1038/s41597-020-0385-y">OE62</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>~62,000 large organic molecules with DFT energies including D3 correction</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status</strong>: Highly Reproducible. Source code, both datasets, and detailed hyperparameters (including per-model learning rates, batch sizes, and Ewald-specific settings) are all publicly available. Pre-trained model weights are not provided.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Kosmala, A., Gasteiger, J., Gao, N., &amp; Günnemann, S. (2023). Ewald-based Long-Range Message Passing for Molecular Graphs. In <em>Proceedings of the 40th International Conference on Machine Learning (ICML 2023)</em>.</p>
<p><strong>Publication</strong>: ICML 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{kosmala2023ewald,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Ewald-based Long-Range Message Passing for Molecular Graphs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Kosmala, Arthur and Gasteiger, Johannes and Gao, Nicholas and G{\&#34;u}nnemann, Stephan}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 40th International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">series</span>=<span style="color:#e6db74">{PMLR}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{202}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Materials Representations for ML Review</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/materials-representations-ml-review/</link><pubDate>Mon, 06 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/materials-representations-ml-review/</guid><description>Review of representation strategies for encoding solid-state materials as ML inputs, covering structural descriptors, crystal graphs, and generative models.</description><content:encoded><![CDATA[<h2 id="a-systematization-of-material-representations">A Systematization of Material Representations</h2>
<p>This paper is a <strong>Systematization</strong> that organizes and categorizes the strategies researchers use to convert solid-state materials into numerical representations suitable for machine learning models. Rather than proposing a new method, the review provides a structured taxonomy of existing approaches, connecting each to the practical constraints of data availability, computational cost, and prediction targets. It covers structural descriptors, graph-based learned representations, compositional features, transfer learning, and generative models for inverse design.</p>
<h2 id="why-material-representations-matter">Why Material Representations Matter</h2>
<p>Machine learning has enabled rapid property prediction for materials, but every ML pipeline depends on how the material is encoded as a numerical input. The authors identify three guiding principles for effective representations:</p>
<ol>
<li><strong>Similarity preservation</strong>: Similar materials should have similar representations, and dissimilar materials should diverge in representation space.</li>
<li><strong>Domain coverage</strong>: The representation should be constructable for every material in the target domain.</li>
<li><strong>Cost efficiency</strong>: Computing the representation should be cheaper than computing the target property directly (e.g., via <a href="https://en.wikipedia.org/wiki/Density_functional_theory">DFT</a>).</li>
</ol>
<p>In practice, materials scientists face several barriers. Atomistic structures span diverse space groups, supercell sizes, and disorder parameters. Real material performance depends on defects, microstructure, and interfaces. Structural information often requires expensive experimental or computational effort to obtain. Datasets in materials science tend to be small, sparse, and biased toward well-studied systems.</p>
<h2 id="structural-descriptors-local-global-and-topological">Structural Descriptors: Local, Global, and Topological</h2>
<p>The review covers three families of hand-crafted structural descriptors that encode atomic positions and types.</p>
<h3 id="local-descriptors">Local Descriptors</h3>
<p>Local descriptors characterize the environment around each atom. Atom-centered symmetry functions (ACSF), introduced by Behler and Parrinello, define radial and angular functions:</p>
<p>$$
G_{i}^{1} = \sum_{j \neq i}^{\text{neighbors}} e^{-\eta(R_{ij} - R_{s})^{2}} f_{c}(R_{ij})
$$</p>
<p>$$
G_{i}^{2} = 2^{1-\zeta} \sum_{j,k \neq i}^{\text{neighbors}} (1 + \lambda \cos \theta_{ijk})^{\zeta} e^{-\eta(R_{ij}^{2} + R_{ik}^{2} + R_{jk}^{2})} f_{c}(R_{ij}) f_{c}(R_{ik}) f_{c}(R_{jk})
$$</p>
<p>The Smooth Overlap of Atomic Positions (SOAP), proposed by Bartók et al., defines atomic neighborhood density as a sum of Gaussians and computes a rotationally invariant kernel through expansion in radial functions and <a href="https://en.wikipedia.org/wiki/Spherical_harmonics">spherical harmonics</a>:</p>
<p>$$
\rho_{i}(\mathbf{r}) = \sum_{j} \exp\left(-\frac{|\mathbf{r} - \mathbf{r}_{ij}|^{2}}{2\sigma^{2}}\right) = \sum_{nlm} c_{nlm} g_{n}(\mathbf{r}) Y_{lm}(\hat{\mathbf{r}})
$$</p>
<p>The power spectrum $\mathbf{p}(\mathbf{r}) \equiv \sum_{m} c_{nlm}(c_{n&rsquo;lm})^{*}$ serves as a vector descriptor of the local environment. SOAP has seen wide adoption both as a similarity metric and as input to ML models.</p>
<p><a href="https://en.wikipedia.org/wiki/Voronoi_diagram">Voronoi tessellation</a> provides another local approach, segmenting space into cells and extracting features like effective coordination numbers, cell volumes, and neighbor properties.</p>
<h3 id="global-descriptors">Global Descriptors</h3>
<p>Global descriptors encode the full structure. The Coulomb matrix models electrostatic interactions between atoms:</p>
<p>$$
M_{i,j} = \begin{cases} Z_{i}^{2.4} &amp; \text{for } i = j \\ \frac{Z_{i}Z_{j}}{|r_{i} - r_{j}|} &amp; \text{for } i \neq j \end{cases}
$$</p>
<p>Other global methods include partial radial distribution functions (PRDF), the many-body tensor representation (MBTR), and cluster expansions. The Atomic Cluster Expansion (ACE) framework generalizes cluster expansions to continuous environments and has become a foundation for modern deep learning potentials.</p>
<h3 id="topological-descriptors">Topological Descriptors</h3>
<p><a href="https://en.wikipedia.org/wiki/Persistent_homology">Persistent homology</a> from topological data analysis (TDA) identifies geometric features at multiple length scales. Topological descriptors capture pore geometries in porous materials and have outperformed traditional structural descriptors for predicting CO$_{2}$ adsorption in metal-organic frameworks and methane storage in <a href="https://en.wikipedia.org/wiki/Zeolite">zeolites</a>. A caveat is the $O(N^{3})$ worst-case computational cost per filtration.</p>
<h2 id="crystal-graph-neural-networks">Crystal Graph Neural Networks</h2>
<p>Graph neural networks bypass manual feature engineering by learning representations directly from structural data. Materials are converted to graphs $G(V, E)$ where nodes represent atoms and edges connect neighbors within a cutoff radius, with periodic boundary conditions.</p>
<p>Key architectures discussed include:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Key Innovation</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CGCNN</td>
          <td>Crystal graph convolutions for broad property prediction</td>
      </tr>
      <tr>
          <td>MEGNet</td>
          <td>Materials graph networks with global state attributes</td>
      </tr>
      <tr>
          <td>ALIGNN</td>
          <td>Line graph neural networks incorporating three-body angular features</td>
      </tr>
      <tr>
          <td>Equivariant GNNs</td>
          <td>E(3)-equivariant message passing for tensorial properties</td>
      </tr>
  </tbody>
</table>
<p>The review identifies several limitations. Graph convolutions based on local neighborhoods can fail to capture long-range interactions or periodicity-dependent properties (e.g., lattice parameters, phonon spectra). Strategies to address this include concatenation with hand-tuned descriptors, plane-wave periodic basis modulation, and reciprocal-space features.</p>
<p>A major practical restriction is the requirement for relaxed atomic positions. Graphs built from unrelaxed crystal prototypes lose information about geometric distortions, degrading accuracy. Approaches to mitigate this include data augmentation with perturbed structures, Bayesian optimization of prototypes, and surrogate force-field relaxation.</p>
<p>Equivariant models that introduce higher-order tensors to node and edge features, constrained to transform correctly under E(3) operations, achieve state-of-the-art accuracy and can match structural descriptor performance even in low-data (~100 datapoints) regimes.</p>
<h2 id="compositional-descriptors-without-structure">Compositional Descriptors Without Structure</h2>
<p>When crystal structures are unavailable, representations can be built purely from stoichiometry and tabulated atomic properties (radii, electronegativity, valence electrons). Despite their simplicity, these methods have distinct advantages: zero computational overhead, accessibility to non-experts, and robustness for high-throughput screening.</p>
<p>Key methods include:</p>
<ul>
<li><strong>MagPie</strong>: 145 input features derived from elemental properties</li>
<li><strong>SISSO</strong>: Compressive sensing over algebraic combinations of atomic properties, capable of discovering interpretable descriptors (e.g., a new tolerance factor $\tau$ for perovskite stability)</li>
<li><strong>ElemNet</strong>: Deep neural network using only fractional stoichiometry as input, outperforming MagPie with &gt;3,000 training points</li>
<li><strong>ROOST</strong>: Fully-connected compositional graph with attention-based message passing, achieving strong performance with only hundreds of examples</li>
<li><strong>CrabNet</strong>: Self-attention on element embeddings with fractional encoding, handling dopant-level concentrations via log-scale inputs</li>
</ul>
<p>Compositional models cannot distinguish polymorphs and generally underperform structural approaches. They are most valuable when atomistic resolution is unavailable.</p>
<h2 id="defects-surfaces-and-grain-boundaries">Defects, Surfaces, and Grain Boundaries</h2>
<p>The review extends beyond idealized unit cells to practical materials challenges:</p>
<p><strong>Point defects</strong>: Representations of the pristine bulk can predict vacancy formation energies through linear relationships with band structure descriptors. Frey et al. proposed using relative differences between defect and parent structure properties, requiring no DFT on the defect itself.</p>
<p><strong>Surfaces and catalysis</strong>: Binding energy prediction for catalysis requires representations beyond the bulk unit cell. The d-band center for metals and oxygen 2p-band center for metal oxides serve as simple electronic descriptors, following the <a href="https://en.wikipedia.org/wiki/Sabatier_principle">Sabatier principle</a> that optimal catalytic activity requires intermediate binding strength. Graph neural networks trained on the Open Catalyst 2020 dataset (&gt;1 million DFT energies) have enabled broader screening, though errors remain high for certain adsorbates and non-metallic surfaces.</p>
<p><strong>Grain boundaries</strong>: SOAP descriptors computed for atoms near grain boundaries and clustered into local environment classes can predict grain boundary energy, mobility, and shear coupling. This approach provides interpretable structure-property relationships.</p>
<h2 id="transfer-learning-across-representations">Transfer Learning Across Representations</h2>
<p>When target datasets are small, transfer learning leverages representations learned from large, related datasets. The standard procedure involves: (1) pretraining on a large dataset (e.g., all Materials Project formation energies), (2) freezing parameters up to a chosen depth, and (3) either fine-tuning remaining layers or extracting features for a separate model.</p>
<p>Key findings from the review:</p>
<ul>
<li>Transfer learning is most effective when the source dataset is orders of magnitude larger than the target</li>
<li>Physically related tasks transfer better (e.g., Open Catalyst absorption energies transfer well to new adsorbates, less so to unrelated small molecules)</li>
<li>Earlier neural network layers learn more general representations and transfer better across properties</li>
<li>Multi-depth feature extraction, combining activations from multiple layers, can improve transfer</li>
<li>Predictions from surrogate models can serve as additional descriptors, expanding screening domains by orders of magnitude</li>
</ul>
<h2 id="generative-models-for-crystal-inverse-design">Generative Models for Crystal Inverse Design</h2>
<p>Generative models for solid-state materials face challenges beyond molecular generation: more diverse atomic species, the need to specify both positions and lattice parameters, non-unique definitions (rotations, translations, supercell scaling), and large unit cells (&gt;100 atoms for zeolites and MOFs).</p>
<p>The review traces the progression of approaches:</p>
<ol>
<li><strong>Voxel representations</strong>: Discretize unit cells into volume elements. Early work (iMatGen, Court et al.) demonstrated feasibility but was restricted to specific chemistries or cubic systems.</li>
<li><strong>Continuous coordinate models</strong>: Point cloud and invertible representations allowed broader chemical spaces but lacked symmetry invariances.</li>
<li><strong>Symmetry-aware models</strong>: Crystal Diffusion <a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">VAE</a> (CDVAE) uses periodic graphs and SE(3)-equivariant message passing for translationally and rotationally invariant generation, establishing benchmark tasks for the field.</li>
<li><strong>Constrained models for porous materials</strong>: Approaches like SmVAE represent MOFs through their topological building blocks (RFcodes), ensuring all generated structures are physically valid.</li>
</ol>
<h2 id="open-problems-and-future-directions">Open Problems and Future Directions</h2>
<p>The review highlights four high-impact open questions:</p>
<ol>
<li><strong>Local vs. global descriptor trade-offs</strong>: Local descriptors (SOAP) excel for short-range interactions but struggle with long-range physics. Global descriptors model periodicity but lack generality across space groups. Combining local and long-range features could provide more universal models.</li>
<li><strong>Prediction from unrelaxed prototypes</strong>: ML force fields can relax structures at a fraction of DFT cost, potentially expanding screening domains. Key questions remain about required training data scale and generalizability.</li>
<li><strong>Applicability of compositional descriptors</strong>: The performance gap between compositional and structural models may be property-dependent, being smaller for properties like band gap that depend on global features rather than local site energies.</li>
<li><strong>Extensions of generative models</strong>: Diffusion-based architectures have improved on voxel approaches for small unit cells, but extending to microstructure, dimensionality, and surface generation remains open.</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>This paper is a review and does not present new experimental results or release any novel code, data, or models. The paper is open-access (hybrid OA at Annual Reviews) and the arXiv preprint is freely available. The following artifacts table covers key publicly available resources discussed in the review.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://arxiv.org/abs/2301.08813">arXiv preprint (2301.08813)</a></td>
          <td>Other</td>
          <td>arXiv (open access)</td>
          <td>Free preprint version</td>
      </tr>
      <tr>
          <td><a href="https://materialsproject.org">Materials Project</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>DFT energies, band gaps, structures for &gt;100,000 compounds</td>
      </tr>
      <tr>
          <td><a href="https://oqmd.org">OQMD</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>Open Quantum Materials Database, &gt;600,000 DFT entries</td>
      </tr>
      <tr>
          <td><a href="https://github.com/Open-Catalyst-Project/ocp">Open Catalyst 2020 (OC20)</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>&gt;1,000,000 DFT surface adsorption energies</td>
      </tr>
      <tr>
          <td><a href="https://aflowlib.org">AFLOW</a></td>
          <td>Dataset</td>
          <td>Public</td>
          <td>High-throughput ab initio library, &gt;3,000,000 entries</td>
      </tr>
      <tr>
          <td><a href="https://github.com/hackingmaterials/matminer">Matminer</a></td>
          <td>Code</td>
          <td>BSD</td>
          <td>Open-source toolkit for materials data mining and featurization</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The review covers: ACSF, SOAP, Voronoi tessellation, Coulomb matrices, PRDF, MBTR, cluster expansions, ACE, persistent homology, CGCNN, MEGNet, ALIGNN, E(3)-equivariant GNNs, MagPie, SISSO, ElemNet, ROOST, CrabNet, VAE, GAN, and diffusion-based crystal generators.</p>
<h3 id="hardware">Hardware</h3>
<p>No new experiments are conducted. Hardware requirements vary by the referenced methods (DFT calculations require HPC; GNN training typically requires 1-8 GPUs).</p>
<h3 id="reproducibility-status">Reproducibility Status</h3>
<p><strong>Partially Reproducible</strong>: The review paper itself is open-access. All major datasets discussed (Materials Project, OQMD, OC20, AFLOW) are publicly available under permissive licenses. Most referenced model implementations (CGCNN, MEGNet, ALIGNN, ROOST, CDVAE) have open-source code. No novel artifacts are released by the authors.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Damewood, J., Karaguesian, J., Lunger, J. R., Tan, A. R., Xie, M., Peng, J., &amp; Gómez-Bombarelli, R. (2023). Representations of Materials for Machine Learning. <em>Annual Review of Materials Research</em>, 53. <a href="https://doi.org/10.1146/annurev-matsci-080921-085947">https://doi.org/10.1146/annurev-matsci-080921-085947</a></p>
<p><strong>Publication</strong>: Annual Review of Materials Research, 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{damewood2023representations,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Representations of Materials for Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Damewood, James and Karaguesian, Jessica and Lunger, Jaclyn R. and Tan, Aik Rui and Xie, Mingrou and Peng, Jiayu and G{\&#39;o}mez-Bombarelli, Rafael}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Annual Review of Materials Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{53}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1146/annurev-matsci-080921-085947}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Transformers and LLMs for Chemistry Drug Discovery</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/transformers-llms-chemistry-drug-discovery/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/transformers-llms-chemistry-drug-discovery/</guid><description>Bran and Schwaller review transformer architectures for chemistry, from task-specific SMILES models to multimodal LLMs and chemistry agents.</description><content:encoded><![CDATA[<h2 id="a-systematization-of-transformers-in-chemistry">A Systematization of Transformers in Chemistry</h2>
<p>This book chapter by Bran and Schwaller is a <strong>Systematization</strong> paper that organizes the growing body of work applying transformer architectures to chemistry and drug discovery. Rather than proposing a new method, the authors trace a three-stage evolution: (1) task-specific single-modality models operating on SMILES and reaction strings, (2) multimodal models bridging molecular representations with spectra, synthesis actions, and natural language, and (3) large language models and LLM-powered agents capable of general chemical reasoning.</p>
<h2 id="why-transformers-for-chemistry">Why Transformers for Chemistry?</h2>
<p>The authors motivate the review by drawing analogies between natural language and chemical language. Just as text can be decomposed into subwords and tokens, molecules can be linearized into <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> or <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> strings, and chemical reactions can be encoded as reaction SMILES. This structural parallel enabled direct transfer of transformer architectures, originally designed for machine translation, to chemical prediction tasks.</p>
<p>Several factors accelerated this adoption:</p>
<ul>
<li>The publication of open chemical databases and benchmarks (e.g., <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>, Open Reaction Database, Therapeutics Data Commons)</li>
<li>Improvements in compute infrastructure and training algorithms</li>
<li>The success of attention mechanisms at capturing context-dependent relationships, which proved effective for learning chemical grammar and atom-level correspondences</li>
</ul>
<p>The review positions the transformer revolution in chemistry as a natural extension of NLP advances, noting that the gap between chemical and natural language is progressively closing.</p>
<h2 id="molecular-representations-as-language">Molecular Representations as Language</h2>
<p>A key section of the review covers text-based molecular representations that make transformer applications possible:</p>
<ul>
<li><strong>SMILES</strong> (Simplified Molecular Input Line Entry System): The dominant linearization scheme since the 1980s, encoding molecular graphs as character sequences with special symbols for bonds, branches, and rings.</li>
<li><strong>SELFIES</strong> (Self-Referencing Embedded Strings): A newer representation that guarantees every string maps to a valid molecule, addressing the robustness issues of SMILES in generative settings.</li>
<li><strong>Reaction SMILES</strong>: Extends molecular representations to encode full chemical reactions in the format &ldquo;A.B &gt; catalyst.reagent &gt; C.D&rdquo;, enabling reaction prediction as a sequence-to-sequence task.</li>
</ul>
<p>The authors note that while IUPAC names, InChI, and <a href="/notes/chemistry/molecular-representations/notations/deepsmiles-adaptation-for-ml/">DeepSMILES</a> exist as alternatives, SMILES and SELFIES dominate practical applications.</p>
<h2 id="stage-1-task-specific-transformer-models">Stage 1: Task-Specific Transformer Models</h2>
<p>The first stage of transformer adoption focused on clearly defined chemical tasks, with models trained on a single data modality (molecular strings).</p>
<h3 id="chemical-translation-tasks">Chemical Translation Tasks</h3>
<p>The encoder-decoder architecture was directly applied to tasks framed as translation:</p>
<ul>
<li><strong><a href="/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/">Molecular Transformer</a></strong> (Schwaller et al.): Treated reaction prediction as translation from reactant SMILES to product SMILES, becoming a leading method for forward synthesis prediction.</li>
<li><strong>Retrosynthetic planning</strong>: The reverse task, predicting reactants from products, with iterative application to construct full retrosynthetic trees mapping to commercially available building blocks.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">Chemformer</a></strong> (Irwin et al.): A pre-trained model across multiple chemical tasks, offering transferability to new applications with improved performance.</li>
<li><strong>Graph-to-sequence models</strong> (Tu and Coley): Used a custom graph encoder with a transformer decoder, achieving improvements through permutation-invariant molecular graph encoding.</li>
</ul>
<h3 id="representation-learning-and-feature-extraction">Representation Learning and Feature Extraction</h3>
<p>Encoder-only transformers proved valuable for generating molecular and reaction embeddings:</p>
<ul>
<li><strong>Reaction representations</strong> (Wang et al., SMILES-BERT): Trained models to generate reaction vectors that outperformed hand-engineered features on downstream regression tasks.</li>
<li><strong>Reaction classification</strong> (Schwaller et al.): Replaced the decoder with a classification layer to map chemical reactions by class, revealing clustering patterns by reaction type, data source, and molecular properties.</li>
<li><strong>Yield prediction</strong>: Regression heads attached to encoders achieved strong results on high-throughput experimentation datasets.</li>
<li><strong>Protein language models</strong> (Rives et al., ESM): Trained on 250 million protein sequences using unsupervised learning, achieving strong performance on protein property prediction and structure forecasting.</li>
<li><strong>RXNMapper</strong> (Schwaller et al.): A notable application where attention weight analysis revealed that transformers internally learn atom-to-atom mappings in chemical reactions, leading to an open-source atom mapping algorithm that outperformed existing approaches.</li>
</ul>
<h2 id="stage-2-multimodal-chemical-models">Stage 2: Multimodal Chemical Models</h2>
<p>The second stage extended transformers beyond molecular strings to incorporate additional data types:</p>
<ul>
<li><strong>Molecular captioning</strong>: Describing molecules in natural language, covering scaffolds, sources, drug interactions, and other features (Edwards et al.).</li>
<li><strong>Bidirectional molecule-text conversion</strong>: Models capable of generating molecules from text queries and performing molecule-to-molecule tasks (Christofidellis et al.).</li>
<li><strong>Experimental procedure prediction</strong>: Generating actionable synthesis steps from reaction SMILES (Vaucher et al.), bridging the gap between retrosynthetic planning and laboratory execution.</li>
<li><strong>Structural elucidation from IR spectra</strong>: Encoding IR spectra as text sequences alongside chemical formulas, then predicting SMILES from these inputs (Alberts et al.), achieving 45% accuracy in structure prediction and surpassing prior approaches for functional group identification.</li>
</ul>
<h2 id="stage-3-large-language-models-and-chemistry-agents">Stage 3: Large Language Models and Chemistry Agents</h2>
<p>The most recent stage builds on foundation models pre-trained on vast text corpora, adapted for chemistry through fine-tuning and in-context learning.</p>
<h3 id="scaling-laws-and-emergent-capabilities">Scaling Laws and Emergent Capabilities</h3>
<p>The authors discuss how model scaling leads to emergent capabilities relevant to chemistry:</p>
<ul>
<li>Below certain compute thresholds, model performance on chemistry tasks appears random.</li>
<li>Above critical sizes, sudden improvements emerge, along with capabilities like chain-of-thought (CoT) reasoning and instruction following.</li>
<li>These emergent abilities enable chemistry tasks that require multi-step reasoning without explicit training on chemical data.</li>
</ul>
<h3 id="llms-as-chemistry-tools">LLMs as Chemistry Tools</h3>
<p>Key applications of LLMs in chemistry include:</p>
<ul>
<li><strong><a href="/notes/chemistry/llm-applications/fine-tuning-gpt3-molecular-properties/">Fine-tuning for low-data chemistry</a></strong> (Jablonka et al.): GPT-3 fine-tuned on limited chemistry datasets performed comparably to, and sometimes exceeded, specialized models with engineered features for tasks like predicting transition wavelengths and phase classification.</li>
<li><strong>In-context learning</strong>: Providing LLMs with a few examples enables prediction on chemistry tasks without any parameter updates, particularly valuable when data is scarce.</li>
<li><strong>Bayesian optimization with LLMs</strong> (Ramos et al.): Using GPT models for uncertainty-calibrated regression, enabling catalyst and molecular optimization directly from synthesis procedures without feature engineering.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/autoregressive/3d-chemical-language-models-xyz-cif-pdb/">3D structure generation</a></strong> (Flam-Shepherd and Aspuru-Guzik): Using language models to generate molecular structures with three-dimensional atomic positions in XYZ, CIF, and PDB formats, matching graph-based algorithms while overcoming representation limitations.</li>
</ul>
<h3 id="llm-powered-chemistry-agents">LLM-Powered Chemistry Agents</h3>
<p>The review highlights the agent paradigm as the most impactful recent development:</p>
<ul>
<li><strong>14 LLM use-cases</strong> (Jablonka et al.): A large-scale collaborative effort demonstrating applications from computational tool wrappers to reaction optimization assistants and scientific question answering.</li>
<li><strong><a href="/notes/chemistry/llm-applications/chemcrow-augmenting-llms-chemistry-tools/">ChemCrow</a></strong> (Bran, Cox et al.): An LLM-powered agent equipped with curated computational chemistry tools, capable of planning and executing tasks across drug design, materials design, and synthesis. ChemCrow demonstrated that tool integration overcomes LLM hallucination issues by grounding responses in reliable data sources.</li>
<li><strong>Autonomous scientific research</strong> (Boiko et al.): Systems with focus on cloud laboratory operability.</li>
</ul>
<p>The agent paradigm offers tool composability through natural language interfaces, allowing users to chain multiple computational tools into custom pipelines.</p>
<h2 id="outlook-and-limitations">Outlook and Limitations</h2>
<p>The authors identify several themes for the future:</p>
<ul>
<li>The three stages represent increasing generality, from task-specific single-modality models to open-ended agents.</li>
<li>Natural language interfaces are progressively closing the gap between chemical and human language.</li>
<li>Tool integration through agents provides grounding that mitigates hallucination, a known limitation of direct LLM application to chemistry.</li>
<li>The review acknowledges that LLMs have a &ldquo;high propensity to generate false and inaccurate content&rdquo; on chemical tasks, making tool-augmented approaches preferable to direct application.</li>
</ul>
<p>The chapter does not provide quantitative benchmarks or systematic comparisons across the methods discussed, as its goal is to organize the landscape rather than evaluate individual methods.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>This is a review/survey chapter and does not introduce new models, datasets, or experiments. The reproducibility assessment applies to the referenced works rather than the review itself.</p>
<h3 id="key-referenced-resources">Key Referenced Resources</h3>
<p>Several open-source tools and datasets discussed in the review are publicly available:</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/rxn4chemistry/rxnmapper">RXNMapper</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Attention-based atom mapping</td>
      </tr>
      <tr>
          <td><a href="https://github.com/ur-whitelab/chemcrow-public">ChemCrow</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>LLM-powered chemistry agent</td>
      </tr>
      <tr>
          <td><a href="https://moleculenet.org/">MoleculeNet</a></td>
          <td>Dataset</td>
          <td>Various</td>
          <td>Molecular ML benchmarks</td>
      </tr>
      <tr>
          <td><a href="https://open-reaction-database.org/">Open Reaction Database</a></td>
          <td>Dataset</td>
          <td>CC-BY-SA-4.0</td>
          <td>Curated reaction data</td>
      </tr>
      <tr>
          <td><a href="https://tdcommons.ai/">Therapeutics Data Commons</a></td>
          <td>Dataset</td>
          <td>MIT</td>
          <td>Drug discovery ML datasets</td>
      </tr>
  </tbody>
</table>
<h3 id="reproducibility-classification">Reproducibility Classification</h3>
<p><strong>Not applicable</strong> (review paper). Individual referenced works range from Highly Reproducible (open-source models like RXNMapper, ChemCrow) to Partially Reproducible (some models without released code) to Closed (proprietary LLMs like GPT-3/GPT-4 used in fine-tuning studies).</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Bran, A. M., &amp; Schwaller, P. (2024). Transformers and Large Language Models for Chemistry and Drug Discovery. In <em>Drug Development Supported by Informatics</em> (pp. 143-163). Springer Nature Singapore. <a href="https://doi.org/10.1007/978-981-97-4828-0_8">https://doi.org/10.1007/978-981-97-4828-0_8</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@incollection</span>{bran2024transformers,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Transformers and Large Language Models for Chemistry and Drug Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Bran, Andres M. and Schwaller, Philippe}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Drug Development Supported by Informatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{143--163}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer Nature Singapore}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1007/978-981-97-4828-0_8}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ReactionT5: Pre-trained T5 for Reaction Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/reactiont5-pretrained-limited-reaction-data/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/reactiont5-pretrained-limited-reaction-data/</guid><description>ReactionT5 uses two-stage pretraining on ZINC and the Open Reaction Database to enable competitive reaction and yield prediction with minimal fine-tuning data.</description><content:encoded><![CDATA[<h2 id="a-two-stage-pre-trained-transformer-for-chemical-reactions">A Two-Stage Pre-trained Transformer for Chemical Reactions</h2>
<p>ReactionT5 is a <strong>Method</strong> paper that proposes a T5-based pre-trained model for chemical reaction tasks, specifically product prediction and yield prediction. The primary contribution is a two-stage pretraining pipeline: first on a compound library (ZINC, 23M molecules) to learn molecular representations, then on a large-scale reaction database (the Open Reaction Database, 1.5M reactions) to learn reaction-level patterns. The key result is that this pre-trained model can be fine-tuned with very limited target-domain data (as few as 30 reactions) and still achieve competitive performance against models trained on full datasets.</p>
<h2 id="bridging-the-gap-between-single-molecule-and-multi-molecule-pretraining">Bridging the Gap Between Single-Molecule and Multi-Molecule Pretraining</h2>
<p>While transformer-based models pre-trained on compound libraries (e.g., <a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a>, MolGPT) have seen substantial development, most focus on single-molecule inputs and outputs. Pretraining for multi-molecule contexts, such as chemical reactions involving reactants, reagents, catalysts, and products, remains underexplored. T5Chem supports multi-task reaction prediction but focuses on building a single multi-task model rather than investigating the effectiveness of pre-trained models for fine-tuning on limited in-house data.</p>
<p>The authors identify two key gaps:</p>
<ol>
<li>Most pre-trained chemical models do not account for reaction-level interactions between multiple molecules.</li>
<li>In practical settings, target-domain reaction data is often scarce, making transfer learning from large public datasets essential.</li>
</ol>
<h2 id="two-stage-pretraining-with-compound-restoration">Two-Stage Pretraining with Compound Restoration</h2>
<p>The core innovation is a two-stage pretraining procedure built on the <a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5 (text-to-text transfer transformer)</a> architecture:</p>
<p><strong>Stage 1: Compound Pretraining (CompoundT5)</strong>. An initialized T5 model is trained on 23M <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> from the ZINC database using span-masked language modeling. The model learns to predict masked subsequences of SMILES tokens. A SentencePiece unigram tokenizer is trained on this compound library, allowing more compact representations than character-level or atom-level tokenizers. After this stage, new tokens are added to the tokenizer to cover metal atoms and other characters present in the reaction database but absent from ZINC.</p>
<p><strong>Stage 2: Reaction Pretraining (ReactionT5)</strong>. CompoundT5 is further pretrained on 1.5M reactions from the Open Reaction Database (ORD) on both product prediction and yield prediction tasks. Reactions are formulated as text-to-text tasks using special tokens:</p>
<ul>
<li><code>REACTANT:</code>, <code>REAGENT:</code>, and <code>PRODUCT:</code> tokens delimit the role of each molecule in the reaction string.</li>
<li>For product prediction, the model takes reactants and reagents as input and generates product SMILES.</li>
<li>For yield prediction, the model takes the full reaction (including products) and outputs a numerical yield value.</li>
</ul>
<p><strong>Compound Restoration</strong>. A notable methodological detail is the handling of uncategorized compounds in the ORD. About 31.8% of ORD reactions contain compounds with unknown roles. Simply discarding these reactions introduces severe product bias (only 447 unique products remain vs. 439,898 with uncategorized data included). The authors develop RestorationT5, a binary classifier built from CompoundT5, that assigns uncategorized compounds to either reactant or reagent roles. This classifier uses a sigmoid output layer and achieves an F1 score of 0.1564 at a threshold of 0.97, outperforming a random forest baseline (F1 = 0.1136). The restored dataset (&ldquo;ORD(restored)&rdquo;) is then used for reaction pretraining.</p>
<p>For yield prediction, the loss function is mean squared error:</p>
<p>$$L = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2$$</p>
<p>where $y_i$ is the true yield (normalized to [0, 1]) and $\hat{y}_i$ is the predicted yield.</p>
<h2 id="experimental-setup-product-and-yield-prediction-benchmarks">Experimental Setup: Product and Yield Prediction Benchmarks</h2>
<h3 id="product-prediction">Product Prediction</h3>
<p>The USPTO dataset (479K reactions) is used for evaluation, with standard train/val/test splits (409K/30K/40K). Reactions overlapping with the ORD (18%) are removed during evaluation. Beam search with beam size 10 is used for decoding, and minimum/maximum output length constraints are set based on the training data distribution. Top-k accuracy (k = 1, 2, 3, 5) and invalidity rate are reported.</p>
<p>Baselines include Seq-to-seq, WLDN (graph neural network), <a href="/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/">Molecular Transformer</a>, and T5Chem.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Train</th>
          <th>Top-1</th>
          <th>Top-2</th>
          <th>Top-3</th>
          <th>Top-5</th>
          <th>Invalidity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Seq-to-seq</td>
          <td>USPTO</td>
          <td>80.3</td>
          <td>84.7</td>
          <td>86.2</td>
          <td>87.5</td>
          <td>-</td>
      </tr>
      <tr>
          <td>WLDN</td>
          <td>USPTO</td>
          <td>85.6</td>
          <td>90.5</td>
          <td>92.8</td>
          <td>93.4</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Molecular Transformer</td>
          <td>USPTO</td>
          <td>88.8</td>
          <td>92.6</td>
          <td>-</td>
          <td>94.4</td>
          <td>-</td>
      </tr>
      <tr>
          <td>T5Chem</td>
          <td>USPTO</td>
          <td>90.4</td>
          <td>94.2</td>
          <td>-</td>
          <td>96.4</td>
          <td>-</td>
      </tr>
      <tr>
          <td>CompoundT5</td>
          <td>USPTO</td>
          <td>88.0</td>
          <td>92.4</td>
          <td>93.9</td>
          <td>95.0</td>
          <td>7.5</td>
      </tr>
      <tr>
          <td>ReactionT5 (restored ORD)</td>
          <td>USPTO200</td>
          <td>85.5</td>
          <td>91.7</td>
          <td>93.5</td>
          <td>94.9</td>
          <td>12.0</td>
      </tr>
  </tbody>
</table>
<p>A critical finding: ReactionT5 pre-trained on ORD achieves 0% accuracy on USPTO without fine-tuning due to domain mismatch (ORD includes byproducts; USPTO lists only the main product). Fine-tuning on just 200 USPTO reactions with the restored ORD model produces competitive results.</p>
<p>The few-shot fine-tuning analysis shows rapid performance scaling:</p>
<table>
  <thead>
      <tr>
          <th>Samples</th>
          <th>Top-1</th>
          <th>Top-2</th>
          <th>Top-3</th>
          <th>Top-5</th>
          <th>Invalidity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>10</td>
          <td>9.0</td>
          <td>12.5</td>
          <td>15.3</td>
          <td>19.1</td>
          <td>12.4</td>
      </tr>
      <tr>
          <td>30</td>
          <td>80.5</td>
          <td>87.3</td>
          <td>89.8</td>
          <td>92.0</td>
          <td>17.2</td>
      </tr>
      <tr>
          <td>50</td>
          <td>83.7</td>
          <td>89.9</td>
          <td>92.2</td>
          <td>94.0</td>
          <td>14.8</td>
      </tr>
      <tr>
          <td>100</td>
          <td>85.1</td>
          <td>91.0</td>
          <td>92.8</td>
          <td>94.4</td>
          <td>14.0</td>
      </tr>
      <tr>
          <td>200</td>
          <td>85.5</td>
          <td>91.7</td>
          <td>93.5</td>
          <td>94.9</td>
          <td>12.0</td>
      </tr>
  </tbody>
</table>
<h3 id="yield-prediction">Yield Prediction</h3>
<p>The <a href="https://en.wikipedia.org/wiki/Buchwald%E2%80%93Hartwig_amination">Buchwald-Hartwig</a> C-N cross-coupling dataset (3,955 reactions) is used with random 7:3 splits (repeated 10 times) plus four out-of-sample test sets (Tests 1-4) designed so that similar reactions do not appear in both train and test.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Random 7:3</th>
          <th>Test 1</th>
          <th>Test 2</th>
          <th>Test 3</th>
          <th>Test 4</th>
          <th>Avg. Tests 1-4</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>DFT</td>
          <td>0.92</td>
          <td>0.80</td>
          <td>0.77</td>
          <td>0.64</td>
          <td>0.54</td>
          <td>0.69</td>
      </tr>
      <tr>
          <td>MFF</td>
          <td>0.927</td>
          <td>0.851</td>
          <td>0.713</td>
          <td>0.635</td>
          <td>0.184</td>
          <td>0.596</td>
      </tr>
      <tr>
          <td>Yield-BERT</td>
          <td>0.951</td>
          <td>0.838</td>
          <td>0.836</td>
          <td>0.738</td>
          <td>0.538</td>
          <td>0.738</td>
      </tr>
      <tr>
          <td>T5Chem</td>
          <td>0.970</td>
          <td>0.811</td>
          <td>0.907</td>
          <td>0.789</td>
          <td>0.627</td>
          <td>0.785</td>
      </tr>
      <tr>
          <td>CompoundT5</td>
          <td>0.971</td>
          <td>0.855</td>
          <td>0.852</td>
          <td>0.712</td>
          <td>0.547</td>
          <td>0.741</td>
      </tr>
      <tr>
          <td>ReactionT5</td>
          <td>0.966</td>
          <td>0.914</td>
          <td>0.940</td>
          <td>0.819</td>
          <td>0.896</td>
          <td>0.892</td>
      </tr>
      <tr>
          <td>ReactionT5 (zero-shot)</td>
          <td>0.904</td>
          <td>0.919</td>
          <td>0.927</td>
          <td>0.847</td>
          <td>0.909</td>
          <td>0.900</td>
      </tr>
  </tbody>
</table>
<p>ReactionT5 achieves the highest average $R^2$ across Tests 1-4 (0.892), with the zero-shot variant performing even better (0.900). The improvement is most dramatic on Test 4, the hardest split, where ReactionT5 achieves $R^2 = 0.896$ versus T5Chem&rsquo;s 0.627 and Yield-BERT&rsquo;s 0.538.</p>
<p>In a low-data regime (30% train / 70% test), ReactionT5 ($R^2 = 0.927$) substantially outperforms a random forest baseline ($R^2 = 0.853$), and even zero-shot ReactionT5 ($R^2 = 0.898$) exceeds the random forest.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="key-findings">Key Findings</h3>
<ol>
<li><strong>Two-stage pretraining is effective</strong>: Compound pretraining followed by reaction pretraining produces models with strong generalization, particularly on out-of-distribution test sets.</li>
<li><strong>Few-shot transfer works</strong>: With as few as 30 fine-tuning reactions, ReactionT5 achieves over 80% Top-1 accuracy on product prediction, competitive with models trained on the full USPTO dataset.</li>
<li><strong>Compound restoration matters</strong>: Restoring uncategorized compounds in the ORD is essential for product prediction. Without restoration, fine-tuning on 200 USPTO reactions yields 0% accuracy; with restoration, the same fine-tuning yields 85.5% Top-1.</li>
<li><strong>Zero-shot yield prediction is surprisingly effective</strong>: ReactionT5 achieves $R^2 = 0.900$ on the out-of-sample yield tests without any task-specific fine-tuning, outperforming all fine-tuned baselines.</li>
</ol>
<h3 id="limitations">Limitations</h3>
<ul>
<li>Product prediction shows a high invalidity rate (12.0% for the best ReactionT5 variant) compared to CompoundT5 (7.5%), suggesting the reaction pretraining may introduce some noise.</li>
<li>The 0% accuracy without fine-tuning on product prediction reveals a significant domain gap between ORD and USPTO annotation conventions (byproducts vs. main products).</li>
<li>The RestorationT5 classifier has low precision (0.0878) despite high recall (0.7212), meaning many compounds are incorrectly assigned roles. The paper does not investigate how this impacts downstream performance.</li>
<li>The paper does not report training times, computational costs, or model sizes, making resource requirements unclear.</li>
<li>Only two downstream tasks (product prediction on USPTO, yield prediction on Buchwald-Hartwig) are evaluated.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Compound pretraining</td>
          <td>ZINC</td>
          <td>22,992,522 compounds</td>
          <td>SMILES canonicalized with <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a></td>
      </tr>
      <tr>
          <td>Reaction pretraining</td>
          <td>ORD (restored)</td>
          <td>1,505,916 reactions</td>
          <td>Atom mapping removed, compounds canonicalized</td>
      </tr>
      <tr>
          <td>Product prediction eval</td>
          <td>USPTO</td>
          <td>479,035 reactions</td>
          <td>409K/30K/40K train/val/test split</td>
      </tr>
      <tr>
          <td>Yield prediction eval</td>
          <td>Buchwald-Hartwig C-N</td>
          <td>3,955 reactions</td>
          <td>Random 7:3 split (10 repeats) + 4 OOS tests</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Base architecture: T5 (text-to-text transfer transformer)</li>
<li>Tokenizer: SentencePiece unigram, trained on ZINC, extended with special reaction tokens</li>
<li>Compound pretraining: Span-masked language modeling (15% masking rate, average span length 3)</li>
<li>Beam search: size 10 for product prediction</li>
<li>Output length constraints: min/max from training data distribution</li>
<li>Yield normalization: clipped to [0, 100], then scaled to [0, 1]</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>CompoundT5: T5 pretrained on ZINC</li>
<li>RestorationT5: CompoundT5 fine-tuned for binary classification (reactant vs. reagent)</li>
<li>ReactionT5: CompoundT5 pretrained on ORD for product and yield prediction</li>
<li>Pre-trained weights available on Hugging Face</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Best Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Top-1 accuracy</td>
          <td>Product prediction</td>
          <td>85.5%</td>
          <td>ReactionT5 with 200 fine-tuning reactions</td>
      </tr>
      <tr>
          <td>Top-5 accuracy</td>
          <td>Product prediction</td>
          <td>94.9%</td>
          <td>ReactionT5 with 200 fine-tuning reactions</td>
      </tr>
      <tr>
          <td>$R^2$</td>
          <td>Yield prediction (random)</td>
          <td>0.966</td>
          <td>ReactionT5 fine-tuned</td>
      </tr>
      <tr>
          <td>$R^2$</td>
          <td>Yield prediction (OOS avg.)</td>
          <td>0.900</td>
          <td>ReactionT5 zero-shot</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper. Training times and GPU requirements are not reported.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/sagawatatsuya/ReactionT5v2">ReactionT5v2 (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/sagawa">ReactionT5 models (Hugging Face)</a></td>
          <td>Model</td>
          <td>MIT</td>
          <td>Pre-trained weights</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Sagawa, T. &amp; Kojima, R. (2023). ReactionT5: a large-scale pre-trained model towards application of limited reaction data. <em>arXiv preprint arXiv:2311.06708</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{sagawa2023reactiont5,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{ReactionT5: a large-scale pre-trained model towards application of limited reaction data}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Sagawa, Tatsuya and Kojima, Ryosuke}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2311.06708}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.48550/arxiv.2311.06708}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>PharMolixFM: Multi-Modal All-Atom Molecular Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/pharmolixfm-all-atom-foundation-models/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/pharmolixfm-all-atom-foundation-models/</guid><description>PharMolixFM unifies diffusion, flow matching, and Bayesian flow networks for all-atom molecular modeling and generation with task-specific denoising priors.</description><content:encoded><![CDATA[<h2 id="a-unified-framework-for-all-atom-molecular-foundation-models">A Unified Framework for All-Atom Molecular Foundation Models</h2>
<p>PharMolixFM is a <strong>Method</strong> paper that introduces a unified framework for constructing all-atom foundation models for molecular modeling and generation. The primary contribution is the systematic implementation of three multi-modal generative model variants (diffusion, flow matching, and Bayesian flow networks) within a single architecture, along with a task-unifying denoising formulation that enables training on multiple structural biology tasks simultaneously. The framework achieves competitive performance on protein-small-molecule docking and structure-based drug design while providing the first empirical analysis of inference scaling laws for molecular generative models.</p>
<h2 id="challenges-in-multi-modal-atomic-modeling">Challenges in Multi-Modal Atomic Modeling</h2>
<p>Existing all-atom foundation models such as AlphaFold3, RoseTTAFold All-Atom, and ESM-AA face two core challenges that limit their generalization across molecular modeling and generation tasks.</p>
<p>First, atomic data is inherently multi-modal: each atom comprises both a discrete atom type and continuous 3D coordinates. This poses challenges for structure models that need to jointly capture and predict both modalities. Unlike text or image data that exhibit a single modality, molecular structures require generative models that can handle discrete categorical variables (atom types, bond types) and continuous variables (coordinates) simultaneously.</p>
<p>Second, there has been no comprehensive analysis of how different training objectives and sampling strategies impact the performance of all-atom foundation models. Prior work has focused on individual model architectures without systematically comparing generative frameworks or studying how inference-time compute scaling affects prediction quality.</p>
<p>PharMolixFM addresses both challenges by providing a unified framework that implements three state-of-the-art multi-modal generative models and formulates all downstream tasks as a generalized denoising process with task-specific priors.</p>
<h2 id="multi-modal-denoising-with-task-specific-priors">Multi-Modal Denoising with Task-Specific Priors</h2>
<p>The core innovation of PharMolixFM is the formulation of molecular tasks as a generalized denoising process where task-specific priors control which parts of the molecular system are noised during training. The framework decomposes a biomolecular system into $N$ atoms represented as a triplet $\bar{\mathbf{S}}_0 = \langle \mathbf{X}_0, \mathbf{A}_0, \mathbf{E}_0 \rangle$, where $\mathbf{X}_0 \in \mathbb{R}^{N \times 3}$ are atom coordinates, $\mathbf{A}_0 \in \mathbb{Z}^{N \times D_1}$ are one-hot atom types, and $\mathbf{E}_0 \in \mathbb{Z}^{N \times N \times D_2}$ are one-hot bond types.</p>
<p>The generative model estimates the density $p_\theta(\langle \mathbf{X}_0, \mathbf{A}_0, \mathbf{E}_0 \rangle)$ subject to SE(3) invariance:</p>
<p>$$
p_\theta(\langle \mathbf{R}\mathbf{X}_0 + \mathbf{t}, \mathbf{A}_0, \mathbf{E}_0 \rangle) = p_\theta(\langle \mathbf{X}_0, \mathbf{A}_0, \mathbf{E}_0 \rangle)
$$</p>
<p>The variational lower bound is optimized over latent variables $S_1, \ldots, S_T$ obtained by adding independent noise to different modalities and atoms:</p>
<p>$$
q(S_{1:T} \mid S_0) = \prod_{i=1}^{T} \prod_{j=1}^{N} q(\mathbf{X}_{i,j} \mid \mathbf{X}_{0,j}, \sigma_{i,j}^{(\mathbf{X})}) , q(\mathbf{A}_{i,j} \mid \mathbf{A}_{0,j}, \sigma_{i,j}^{(\mathbf{A})}) , q(\mathbf{E}_{i,j} \mid \mathbf{E}_{0,j}, \sigma_{i,j}^{(\mathbf{E})})
$$</p>
<p>A key design choice is the noise schedule $\sigma_{i,j}^{(\mathcal{M})} = \frac{i}{T} \cdot \text{fix}_j^{(\mathcal{M})}$, where $\text{fix}_j^{(\mathcal{M})}$ is a scaling factor between 0 and 1 that controls which atoms and modalities receive noise. This &ldquo;Fix&rdquo; mechanism enables multiple training tasks:</p>
<ul>
<li><strong>Docking</strong> ($\text{Fix} = 1$ for protein and molecular graph, $\text{Fix} = 0$ for molecule coordinates): predicts binding pose given known atom/bond types.</li>
<li><strong>Structure-based drug design</strong> ($\text{Fix} = 1$ for protein, $\text{Fix} = 0$ for all molecule properties): generates novel molecules for a given pocket.</li>
<li><strong>Robustness augmentation</strong> ($\text{Fix} = 0.7$ for 15% randomly selected atoms, $\text{Fix} = 0$ for rest): simulates partial structure determination.</li>
</ul>
<h3 id="three-generative-model-variants">Three Generative Model Variants</h3>
<p><strong>Multi-modal diffusion (PharMolixFM-Diff)</strong> uses a Markovian forward process. Continuous coordinates follow Gaussian diffusion while discrete variables use a D3PM categorical transition:</p>
<p>$$
q(\mathbf{X}_{i,j} \mid \mathbf{X}_{0,j}) = \mathcal{N}(\sqrt{\alpha_{i,j}} , \mathbf{X}_{0,j}, (1 - \alpha_{i,j}) \mathbf{I}), \quad \alpha_{i,j} = \prod_{k=1}^{i}(1 - \sigma_{i,j}^{(\mathbf{X})})
$$</p>
<p>$$
q(\mathbf{A}_{i,j} \mid \mathbf{A}_{0,j}) = \text{Cat}(\mathbf{A}_{0,j} \bar{Q}_{i,j}^{(\mathbf{A})}), \quad Q_{i,j}^{(\mathbf{A})} = (1 - \sigma_{i,j}^{(\mathbf{A})}) \mathbf{I} + \frac{\sigma_{i,j}^{(\mathbf{A})}}{D_1} \mathbb{1}\mathbb{1}^T
$$</p>
<p>The training loss combines coordinate MSE with cross-entropy for discrete variables:</p>
<p>$$
\mathcal{L} = \mathbb{E}_{S_0, i, S_i} \left[ \lambda_i^{(\mathbf{X})} | \tilde{\mathbf{X}}_0 - \mathbf{X}_0 |_2^2 + \lambda_i^{(\mathbf{A})} \mathcal{L}_{CE}(\tilde{\mathbf{A}}_0, \mathbf{A}_0) + \lambda_i^{(\mathbf{E})} \mathcal{L}_{CE}(\tilde{\mathbf{E}}_0, \mathbf{E}_0) \right]
$$</p>
<p><strong>Multi-modal flow matching (PharMolixFM-Flow)</strong> constructs a direct mapping between data and prior distributions using conditional vector fields. For coordinates, the conditional flow uses a Gaussian path $q(\mathbf{X}_{i,j} \mid \mathbf{X}_{0,j}) = \mathcal{N}((1 - \sigma_{i,j}^{(\mathbf{X})}) \mathbf{X}_{0,j}, (\sigma_{i,j}^{(\mathbf{X})})^2 \mathbf{I})$, while discrete variables use the same D3PM Markov chain. Sampling proceeds by solving an ODE via Euler integration.</p>
<p><strong>Bayesian flow networks (PharMolixFM-BFN)</strong> perform generative modeling in the parameter space of the data distribution rather than the data space. The Bayesian flow distribution for coordinates is:</p>
<p>$$
p_F(\tilde{\mathbf{X}}_{i,j}^{(\theta)} \mid \mathbf{X}_{0,j}) = \mathcal{N}(\gamma_{i,j} \mathbf{X}_{0,j}, \gamma_{i,j}(1 - \gamma_{i,j}) \mathbf{I}), \quad \gamma_{i,j} = 1 - \alpha^{2(1 - \sigma_{i,j}^{(\mathbf{X})})}
$$</p>
<h3 id="network-architecture">Network Architecture</h3>
<p>The architecture follows PocketXMol with a dual-branch SE(3)-equivariant graph neural network. A protein branch (4-layer GNN with kNN graph) processes pocket atoms, then representations are passed to a molecule branch (6-layer GNN) that captures protein-molecule interactions. Independent prediction heads reconstruct atom coordinates, atom types, and bond types, with additional confidence heads for self-ranking during inference.</p>
<h2 id="docking-and-drug-design-experiments">Docking and Drug Design Experiments</h2>
<h3 id="protein-small-molecule-docking">Protein-Small-Molecule Docking</h3>
<p>PharMolixFM is evaluated on the PoseBusters benchmark (428 protein-small-molecule complexes) using the holo docking setting with a known protein structure and 10 Angstrom binding pocket. The metric is the ratio of predictions with RMSD &lt; 2 Angstrom.</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Self-Ranking (%)</th>
          <th>Oracle-Ranking (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>DiffDock</td>
          <td>38.0</td>
          <td>-</td>
      </tr>
      <tr>
          <td>RFAA</td>
          <td>42.0</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Vina</td>
          <td>52.3</td>
          <td>-</td>
      </tr>
      <tr>
          <td>UniMol-Docking V2</td>
          <td>77.6</td>
          <td>-</td>
      </tr>
      <tr>
          <td>SurfDock</td>
          <td>78.0</td>
          <td>-</td>
      </tr>
      <tr>
          <td>AlphaFold3</td>
          <td>90.4</td>
          <td>-</td>
      </tr>
      <tr>
          <td>PocketXMol (50 repeats)</td>
          <td>82.2</td>
          <td>95.3</td>
      </tr>
      <tr>
          <td>PharMolixFM-Diff (50 repeats)</td>
          <td>83.4</td>
          <td>96.0</td>
      </tr>
      <tr>
          <td>PharMolixFM-Flow (50 repeats)</td>
          <td>73.4</td>
          <td>93.7</td>
      </tr>
      <tr>
          <td>PharMolixFM-BFN (50 repeats)</td>
          <td>78.5</td>
          <td>93.5</td>
      </tr>
      <tr>
          <td>PharMolixFM-Diff (500 repeats)</td>
          <td>83.9</td>
          <td>98.1</td>
      </tr>
  </tbody>
</table>
<p>PharMolixFM-Diff achieves the second-best self-ranking result (83.4%), outperforming PocketXMol by 1.7% absolute but trailing AlphaFold3 (90.4%). The key advantage is inference speed: approximately 4.6 seconds per complex on a single A800 GPU compared to approximately 249.0 seconds for AlphaFold3 (a 54x speedup). Under oracle-ranking with 500 repeats, PharMolixFM-Diff reaches 98.1%, suggesting that better ranking strategies could further improve practical performance.</p>
<h3 id="structure-based-drug-design">Structure-Based Drug Design</h3>
<p>Evaluation uses the CrossDocked test set (100 protein pockets, 100 molecules generated per pocket), measuring Vina binding affinity scores and drug-likeness properties (QED and SA).</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Vina Score (Avg/Med)</th>
          <th>QED</th>
          <th>SA</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pocket2Mol</td>
          <td>-5.14 / -4.70</td>
          <td>0.57</td>
          <td>0.76</td>
      </tr>
      <tr>
          <td>TargetDiff</td>
          <td>-5.47 / -6.30</td>
          <td>0.48</td>
          <td>0.58</td>
      </tr>
      <tr>
          <td>DecompDiff</td>
          <td>-5.67 / -6.04</td>
          <td>0.45</td>
          <td>0.61</td>
      </tr>
      <tr>
          <td>MolCRAFT</td>
          <td>-6.61 / -8.14</td>
          <td>0.46</td>
          <td>0.62</td>
      </tr>
      <tr>
          <td>PharMolixFM-Diff</td>
          <td>-6.18 / -6.44</td>
          <td>0.50</td>
          <td>0.73</td>
      </tr>
      <tr>
          <td>PharMolixFM-Flow</td>
          <td>-6.34 / -6.47</td>
          <td>0.49</td>
          <td>0.74</td>
      </tr>
      <tr>
          <td>PharMolixFM-BFN</td>
          <td>-6.38 / -6.45</td>
          <td>0.48</td>
          <td>0.64</td>
      </tr>
  </tbody>
</table>
<p>PharMolixFM achieves a better balance between binding affinity and drug-like properties compared to baselines. While MolCRAFT achieves the best Vina scores, PharMolixFM-Diff and Flow variants show notably higher QED (0.49-0.50 vs. 0.45-0.48) and SA (0.73-0.74 vs. 0.58-0.62), which are important for downstream validation and in-vivo application.</p>
<h3 id="inference-scaling-law">Inference Scaling Law</h3>
<p>The paper explores whether inference-time scaling holds for molecular generative models, fitting the relationship:</p>
<p>$$
\text{Acc} = a \log(bR + c) + d
$$</p>
<p>where $R$ is the number of sampling repeats. All three PharMolixFM variants exhibit logarithmic improvement in docking accuracy with increased sampling repeats, analogous to inference scaling laws observed in NLP. Performance plateaus eventually due to distributional differences between training and test sets.</p>
<h2 id="competitive-docking-with-faster-inference-but-limited-task-scope">Competitive Docking with Faster Inference, but Limited Task Scope</h2>
<p>PharMolixFM demonstrates that multi-modal generative models can achieve competitive all-atom molecular modeling with substantial inference speed advantages over AlphaFold3. The key findings are:</p>
<ol>
<li><strong>Diffusion outperforms flow matching and BFN</strong> for docking under standard sampling budgets. The stochastic nature of diffusion sampling appears beneficial compared to the deterministic ODE integration of flow matching.</li>
<li><strong>Oracle-ranking reveals untapped potential</strong>: the gap between self-ranking (83.4%) and oracle-ranking (98.1%) at 500 repeats indicates that confidence-based ranking is a bottleneck. Better ranking methods could close the gap with AlphaFold3.</li>
<li><strong>The three variants show similar performance for drug design</strong>, suggesting that model architecture and training data may matter more than the generative framework for generation tasks.</li>
<li><strong>Inference scaling laws hold</strong> for molecular generative models, paralleling findings in NLP.</li>
</ol>
<p>Limitations include that the framework is only evaluated on two tasks (docking and SBDD), and the paper does not address protein structure prediction, protein-protein interactions, or nucleic acid modeling, which are part of AlphaFold3&rsquo;s scope. The BFN variant underperforms the diffusion model, which the authors attribute to smaller noise scales at early sampling steps making training less challenging. The paper also does not compare against concurrent work on inference-time scaling for molecular models.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>PDBBind, Binding MOAD, CrossDocked2020, PepBDB</td>
          <td>Not specified</td>
          <td>Filtered by PocketXMol criteria</td>
      </tr>
      <tr>
          <td>Docking eval</td>
          <td>PoseBusters benchmark</td>
          <td>428 complexes</td>
          <td>Holo docking with known protein</td>
      </tr>
      <tr>
          <td>SBDD eval</td>
          <td>CrossDocked test set</td>
          <td>100 pockets</td>
          <td>100 molecules per pocket</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Three generative variants: multi-modal diffusion (D3PM), flow matching, Bayesian flow networks</li>
<li>Task-specific noise via Fix mechanism (0, 0.7, or 1.0)</li>
<li>Training tasks selected with equal probability per sample</li>
<li>AdamW optimizer: weight decay 0.001, $\beta_1 = 0.99$, $\beta_2 = 0.999$</li>
<li>Linear warmup to learning rate 0.001 over 1000 steps</li>
<li>180K training steps with batch size 40</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Dual-branch SE(3)-equivariant GNN (protein: 4-layer, molecule: 6-layer)</li>
<li>kNN graph construction for protein and protein-molecule interactions</li>
<li>Independent prediction heads for coordinates, atom types, bond types</li>
<li>Confidence heads for self-ranking during inference</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>PharMolixFM-Diff</th>
          <th>AlphaFold3</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RMSD &lt; 2A self-ranking</td>
          <td>83.4% (50 rep)</td>
          <td>90.4%</td>
          <td>PoseBusters docking</td>
      </tr>
      <tr>
          <td>RMSD &lt; 2A oracle-ranking</td>
          <td>98.1% (500 rep)</td>
          <td>-</td>
          <td>PoseBusters docking</td>
      </tr>
      <tr>
          <td>Inference time (per complex)</td>
          <td>~4.6s</td>
          <td>~249.0s</td>
          <td>Single A800 GPU</td>
      </tr>
      <tr>
          <td>Vina score (avg)</td>
          <td>-6.18</td>
          <td>-</td>
          <td>CrossDocked SBDD</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Training: 4x 80GB A800 GPUs</li>
<li>Inference benchmarked on single A800 GPU</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/PharMolix/OpenBioMed">OpenBioMed (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Luo, Y., Wang, J., Fan, S., &amp; Nie, Z. (2025). PharMolixFM: All-Atom Foundation Models for Molecular Modeling and Generation. <em>arXiv preprint arXiv:2503.21788</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{luo2025pharmolixfm,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{PharMolixFM: All-Atom Foundation Models for Molecular Modeling and Generation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Luo, Yizhen and Wang, Jiashuo and Fan, Siqi and Nie, Zaiqing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2503.21788}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>PharmaGPT: Domain-Specific LLMs for Pharma and Chem</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/pharmagpt-domain-specific-llms-biopharmaceutical/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/pharmagpt-domain-specific-llms-biopharmaceutical/</guid><description>PharmaGPT introduces 13B and 70B parameter LLMs trained on biopharmaceutical and chemical corpora, outperforming GPT-3.5 and rivaling GPT-4 on pharmacy exams.</description><content:encoded><![CDATA[<h2 id="a-domain-specific-llm-suite-for-biopharmaceuticals-and-chemistry">A Domain-Specific LLM Suite for Biopharmaceuticals and Chemistry</h2>
<p>This is a <strong>Method</strong> paper that introduces PharmaGPT, a suite of domain-specific large language models with 13 billion and 70 billion parameters. The models are built on the LLaMA architecture and undergo continued pretraining on a curated corpus of biopharmaceutical and chemical literature, followed by instruction fine-tuning and reinforcement learning from human feedback (RLHF). The primary contribution is demonstrating that domain-specific continued pretraining on a general-purpose LLM backbone can produce models that outperform much larger general-purpose models on pharmaceutical knowledge tasks, using only a fraction of the parameters.</p>
<h2 id="bridging-the-gap-between-general-purpose-llms-and-specialized-pharmaceutical-knowledge">Bridging the Gap Between General-Purpose LLMs and Specialized Pharmaceutical Knowledge</h2>
<p>General-purpose LLMs like GPT-3.5 and GPT-4 show impressive broad capabilities but often fall short in specialized domains requiring precise terminology, deep domain knowledge, and high accuracy. The biopharmaceutical and chemical sectors present particular challenges: intricate terminologies, specialized regulatory knowledge, and a demand for precision that general models cannot consistently deliver. Most state-of-the-art LLMs are proprietary, English-centric, and lack depth in vertical domains. The authors identify a gap in the availability of domain-specific LLMs for biomedicine and chemistry, particularly multilingual models that can handle both English and Chinese pharmaceutical content.</p>
<h2 id="continued-pretraining-with-domain-specific-data-and-weighted-instruction-tuning">Continued Pretraining with Domain-Specific Data and Weighted Instruction Tuning</h2>
<p>PharmaGPT&rsquo;s core innovation lies in its training pipeline, which adapts the LLaMA backbone through three stages:</p>
<p><strong>Extended Tokenizer</strong>: The authors develop a new tokenizer using <a href="https://en.wikipedia.org/wiki/Byte-pair_encoding">byte-pair encoding (BPE)</a> from SentencePiece, trained on their pretraining data and merged with the LLaMA2 tokenizer. This extends the vocabulary from 32,000 to 55,296 tokens, improving compression efficiency for Chinese text and specialized domain terminology. The embedding and output layers are resized from $V \times H$ to $V&rsquo; \times H$ where $V = 32{,}000$ and $V&rsquo; = 55{,}296$.</p>
<p><strong>Two-Stage Continued Pretraining</strong>: The models consume 153 billion tokens in Stage 1 (primarily web, news, patents, and papers) and 43 billion tokens in Stage 2 (research reports, exams, books, chats, code, and supervised data). The data distribution shifts between stages to move from general domain knowledge toward specialized biopharmaceutical tasks.</p>
<p><strong>Weighted Instruction Fine-tuning</strong>: Inspired by OpenChat, the authors use a weighted autoregressive objective that zeros out loss on user instruction tokens. The loss function is:</p>
<p>$$\mathcal{L}_{SFT}(\Theta) = \mathbb{E}_{x \sim \mathcal{D}_{SFT}} \left[ -\alpha \sum_{i \in \text{output}} \log p(x_i \mid x_0, x_1, \dots, x_{i-1}; \Theta) \right]$$</p>
<p>where the weight $\alpha$ is set to 1 for expert-curated domain-specific instructions ($\mathcal{D}_{\exp}$) and 0.1 for generic instructions ($\mathcal{D}_{\text{gen}}$). This differential weighting ensures domain-relevant instructions receive higher priority during training.</p>
<p><strong>RLHF with PPO</strong>: A reward model is initialized from the pretrained PharmaGPT-70B and enhanced with two MLPs to output a scalar preference score. The reward model is trained with a binary ranking loss:</p>
<p>$$\mathcal{L}_{\text{ranking}} = -\log\left(\sigma\left(r_\theta(x, y_c) - r_\theta(x, y_r)\right)\right)$$</p>
<p>where $r_\theta(x, y_c)$ is the score for the preferred response and $r_\theta(x, y_r)$ is the score for the rejected response. The RLHF dataset consists of 50,000 human preference expert-annotated instructions with responses from PharmaGPT variants and commercial LLMs (GPT-4, ChatGPT-3.5). <a href="https://en.wikipedia.org/wiki/Proximal_policy_optimization">Proximal Policy Optimization (PPO)</a> is used for the RL training, selecting the highest-scoring response from four generated candidates at each step.</p>
<h2 id="evaluation-on-pharmacy-licensing-exams-translation-and-mmlu">Evaluation on Pharmacy Licensing Exams, Translation, and MMLU</h2>
<p>The evaluation covers four main benchmarks:</p>
<p><strong>NAPLEX (North American Pharmacist Licensure Examination)</strong>: PharmaGPT is tested across three NAPLEX sections. Results show consistent improvement across model iterations:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>NAPLEX I</th>
          <th>NAPLEX II</th>
          <th>NAPLEX III</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PharmaGPT 0.1</td>
          <td>5.0</td>
          <td>2.5</td>
          <td>3.5</td>
      </tr>
      <tr>
          <td>PharmaGPT 0.3</td>
          <td>42.0</td>
          <td>48.0</td>
          <td>46.5</td>
      </tr>
      <tr>
          <td>PharmaGPT 0.5</td>
          <td>57.0</td>
          <td>59.0</td>
          <td>58.0</td>
      </tr>
      <tr>
          <td>PharmaGPT 0.7</td>
          <td>66.0</td>
          <td>68.0</td>
          <td>76.0</td>
      </tr>
  </tbody>
</table>
<p>PharmaGPT 0.7 scores in the 66-76% range across all three NAPLEX sections, outperforming GPT-3.5-turbo by considerable margins.</p>
<p><strong>Chinese Pharmacist Examination</strong>: PharmaGPT achieves scores in the 70% range across all four exam categories, outperforming both GPT-3.5-turbo and GPT-4 in all categories. This result is notable given GPT-4&rsquo;s much larger scale.</p>
<p><strong>Biomedical Translation</strong>: PharmaGPT 0.7 outperforms GPT-3.5, Claude 3, and Google Translate on biomedical paper translation (English-Chinese), achieving <a href="https://en.wikipedia.org/wiki/BLEU">BLEU</a> scores of 30 (paragraph-level), 18 (sentence-level), and 10 (word-level).</p>
<p><strong>MMLU</strong>: On the general Multitask Multilingual Language Understanding benchmark, PharmaGPT achieves scores in the 80% range across most biomedical and life science tasks, surpassing GPT-3.5-turbo and performing comparably to GPT-4 in areas such as physiology, health sciences, and biology.</p>
<h2 id="strong-domain-performance-with-smaller-scale-but-limited-reproducibility">Strong Domain Performance with Smaller Scale, but Limited Reproducibility</h2>
<p><strong>Key findings</strong>:</p>
<ul>
<li>Domain-specific continued pretraining enables a 70B parameter model to match or exceed GPT-4 on pharmaceutical knowledge tasks, despite having a fraction of GPT-4&rsquo;s parameters</li>
<li>Iterative post-training (versions 0.1 through 0.7) shows consistent improvement, with the largest gains occurring between versions 0.3 and 0.5</li>
<li>The two-stage pretraining strategy, shifting from general domain data to more specialized exam and report data, appears effective for building domain expertise</li>
<li>Scaling laws hold within the PharmaGPT family: larger parameter counts consistently produce better performance on both NAPLEX and Chinese pharmaceutical exams</li>
</ul>
<p><strong>Limitations acknowledged by the authors</strong>:</p>
<ul>
<li>Potential biases in the training data</li>
<li>Model dependency on the quality and diversity of input prompts</li>
<li>Challenges in accurately assessing performance on highly specialized tasks without domain expert evaluation</li>
<li>Interpretability concerns for use in sensitive healthcare and pharmaceutical applications</li>
<li>The 3B model is trained from scratch while the 13B and 70B models use LLaMA as a backbone, making direct comparison across model sizes less straightforward</li>
</ul>
<p><strong>Missing details</strong>: The paper does not release model weights, training code, or the proprietary training dataset. No ablation studies isolate the contribution of each training stage (continued pretraining vs. instruction tuning vs. RLHF). The evaluation is limited to multiple-choice exams and translation, without testing on molecular property prediction, reaction prediction, or other computational chemistry tasks common in this domain.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining Stage 1</td>
          <td>Web, News, Patents, Papers</td>
          <td>153B tokens</td>
          <td>Proprietary corpus; not publicly available</td>
      </tr>
      <tr>
          <td>Pretraining Stage 2</td>
          <td>Research Reports, Exams, Books, Chats, Code</td>
          <td>43B tokens</td>
          <td>Proprietary corpus; not publicly available</td>
      </tr>
      <tr>
          <td>Instruction Tuning</td>
          <td>Manually labeled + synthesized data</td>
          <td>Several hundred thousand instructions</td>
          <td>Includes expert Q&amp;A, patent data, ShareGPT</td>
      </tr>
      <tr>
          <td>RLHF</td>
          <td>Human preference annotations</td>
          <td>50,000 annotated instructions</td>
          <td>Expert annotators ranked responses</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>NAPLEX, Chinese Pharmacist Exam, MMLU, MT</td>
          <td>Not specified</td>
          <td>Exam datasets sourced from public exams</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Base architecture</strong>: LLaMA (13B and 70B variants); 3B model trained from scratch</li>
<li><strong>Tokenizer</strong>: Extended BPE tokenizer (55,296 vocab size) merged with LLaMA2 tokenizer</li>
<li><strong>Training objective</strong>: Standard autoregressive LM (pretraining), weighted autoregressive with $\alpha \in {0.1, 1.0}$ (SFT), PPO (RLHF)</li>
<li><strong>Reward model</strong>: Initialized from PharmaGPT-70B with two additional MLPs</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Parameters</th>
          <th>Base</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PharmaGPT-3B</td>
          <td>3B</td>
          <td>Trained from scratch</td>
          <td>Not evaluated in main results</td>
      </tr>
      <tr>
          <td>PharmaGPT-13B</td>
          <td>13B</td>
          <td>LLaMA-13B</td>
          <td>Post-trained</td>
      </tr>
      <tr>
          <td>PharmaGPT-70B</td>
          <td>70B</td>
          <td>LLaMA-70B</td>
          <td>Primary model; versions 0.1-0.7 reported</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>PharmaGPT 0.7</th>
          <th>GPT-3.5</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>NAPLEX I</td>
          <td>66%</td>
          <td>~50%</td>
          <td>Estimated from figures</td>
      </tr>
      <tr>
          <td>NAPLEX II</td>
          <td>68%</td>
          <td>~50%</td>
          <td>Estimated from figures</td>
      </tr>
      <tr>
          <td>NAPLEX III</td>
          <td>76%</td>
          <td>~50%</td>
          <td>Estimated from figures</td>
      </tr>
      <tr>
          <td>Chinese Pharmacist Exam</td>
          <td>~70% range</td>
          <td>Lower</td>
          <td>Outperforms GPT-4</td>
      </tr>
      <tr>
          <td>Biomedical Translation (paragraph BLEU)</td>
          <td>30</td>
          <td>27</td>
          <td>English-Chinese</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify the hardware used for training. Training hyperparameters for the 70B model include tensor parallelism (TP=8) and pipeline parallelism (PP=16) during pretraining, suggesting multi-node GPU training, likely on at least 128 GPUs.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PharmaGPT models</td>
          <td>Model</td>
          <td>Not released</td>
          <td>No public weights or API access</td>
      </tr>
      <tr>
          <td>Training data</td>
          <td>Dataset</td>
          <td>Proprietary</td>
          <td>PatSnap internal data</td>
      </tr>
      <tr>
          <td>Training code</td>
          <td>Code</td>
          <td>Not released</td>
          <td>No public repository</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status</strong>: <strong>Closed</strong>. Neither the model weights, training data, nor training code are publicly available. The proprietary nature of both the data pipeline and the models makes independent reproduction infeasible.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chen, L., Wang, W., Bai, Z., Xu, P., Fang, Y., Fang, J., &hellip; &amp; Tu, C. (2024). PharmaGPT: Domain-Specific Large Language Models for Bio-Pharmaceutical and Chemistry. <em>arXiv preprint arXiv:2406.18045</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{chen2024pharmagpt,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{PharmaGPT: Domain-Specific Large Language Models for Bio-Pharmaceutical and Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Chen, Linqing and Wang, Weilei and Bai, Zilong and Xu, Peng and Fang, Yan and Fang, Jie and Wu, Wentao and Zhou, Lizhi and Zhang, Ruiji and Xia, Yubin and Xu, Chaobo and Hu, Ran and Xu, Licong and Cai, Qijun and Hua, Haoran and Sun, Jing and Liu, Jin and Qiu, Tian and Liu, Haowen and Hu, Meng and Li, Xiuwen and Gao, Fei and Wang, Yufu and Tie, Lin and Wang, Chaochao and Lu, Jianping and Sun, Cheng and Wang, Yixin and Yang, Shengjie and Li, Yuancheng and Jin, Lu and Zhang, Lisha and Bian, Fu and Ye, Zhongkai and Pei, Lidong and Tu, Changyang}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2406.18045}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.48550/arXiv.2406.18045}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Neural Machine Translation for Reaction Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/nmt-organic-reaction-prediction/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/nmt-organic-reaction-prediction/</guid><description>Nam and Kim apply a GRU-based seq2seq model with attention to predict organic reaction products from SMILES, pioneering the NMT approach to chemistry.</description><content:encoded><![CDATA[<h2 id="pioneering-seq2seq-translation-for-reaction-prediction">Pioneering Seq2Seq Translation for Reaction Prediction</h2>
<p>This is a <strong>Method</strong> paper. It introduces the idea of applying neural machine translation (NMT) to organic chemistry reaction prediction by framing product prediction as a sequence-to-sequence translation problem from reactant/reagent <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> to product SMILES. This was one of the earliest works to demonstrate that a data-driven encoder-decoder model could predict reaction products without any hand-coded reaction rules or SMARTS transformations.</p>
<h2 id="limitations-of-existing-reaction-prediction-methods">Limitations of Existing Reaction Prediction Methods</h2>
<p>Prior computational approaches to reaction prediction fell into three categories, each with significant drawbacks:</p>
<ol>
<li>
<p><strong>Rule-based methods</strong> (e.g., CAMEO, EROS) relied on manually encoded reaction rules. They performed well on reactions covered by the rules but required continuous manual encoding as new reaction types were discovered. Many older systems became outdated for this reason.</p>
</li>
<li>
<p><strong>Physical calculation methods</strong> computed energies of transition states from plausible reaction pathways using quantum mechanics. While principled, these approaches carried high computational cost. Simplified approaches (ToyChem, ROBIA) traded accuracy for speed.</p>
</li>
<li>
<p><strong>Machine learning methods</strong> at the time either predicted individual mechanistic steps (requiring tree search for multi-step reactions) or classified reaction types and applied SMARTS transformations to generate products. The classification-based approach of Wei et al. still required manual encoding of SMARTS transformations for new reaction types and struggled with ambiguous reaction classes.</p>
</li>
</ol>
<p>The key gap was the absence of a method that could predict reaction products directly from input molecules, learn from data alone, and generalize to new reaction types without manual rule encoding.</p>
<h2 id="core-innovation-reactions-as-machine-translation">Core Innovation: Reactions as Machine Translation</h2>
<p>The central insight is that SMILES strings can be treated as a language with grammatical specifications. Predicting reaction products then becomes a problem of translating &ldquo;reactant and reagent&rdquo; sentences into &ldquo;product&rdquo; sentences.</p>
<p>The model uses a <a href="https://en.wikipedia.org/wiki/Gated_recurrent_unit">GRU</a>-based encoder-decoder architecture with attention:</p>
<ul>
<li><strong>Encoder</strong>: 3 layers of GRU cells that process the reversed, tokenized SMILES string of reactants and reagents</li>
<li><strong>Decoder</strong>: 3 layers of GRU cells that generate product SMILES tokens autoregressively</li>
<li><strong>Attention mechanism</strong>: allows the decoder to attend to relevant encoder states at each generation step</li>
<li><strong>Embedding dimension</strong>: 600</li>
<li><strong>Vocabulary</strong>: 311 input tokens (reactants/reagents), 180 output tokens (products)</li>
<li><strong>Bucketed sequences</strong>: four bucket sizes handle variable-length inputs and outputs: (54, 54), (70, 60), (90, 65), (150, 80)</li>
</ul>
<p>The SMILES tokenization uses a <a href="https://en.wikipedia.org/wiki/Parsing_expression_grammar">PEG</a>-based parser that splits SMILES strings into atoms, bonds, branching symbols, and ring closure numbers. Input sequences are reversed before feeding to the encoder, following standard practice in NMT at the time.</p>
<p>The translation objective finds the product sequence $\mathbf{y}$ that maximizes the conditional probability:</p>
<p>$$p(\mathbf{y} \mid \mathbf{x}) = \prod_{t=1}^{T} p(y_t \mid y_1, \ldots, y_{t-1}, \mathbf{x})$$</p>
<p>where $\mathbf{x}$ is the tokenized reactant/reagent sequence and $T$ is the product sequence length.</p>
<h2 id="training-data-and-experimental-evaluation">Training Data and Experimental Evaluation</h2>
<h3 id="training-sets">Training Sets</h3>
<p>Two training sets were constructed:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Source</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">Patent reactions (&ldquo;real&rdquo;)</td>
          <td style="text-align: left">1,094,235</td>
          <td style="text-align: left">USPTO patent applications (2001-2013), filtered by length</td>
      </tr>
      <tr>
          <td style="text-align: left">Generated reactions (&ldquo;gen&rdquo;)</td>
          <td style="text-align: left">865,118</td>
          <td style="text-align: left">75 reaction types from Wade&rsquo;s organic chemistry textbook, applied to <a href="/notes/chemistry/datasets/gdb-11/">GDB-11</a> molecules (1-10 atoms)</td>
      </tr>
  </tbody>
</table>
<p>The &ldquo;real&rdquo; set was filtered to exclude reactions with reactant/reagent strings longer than 150 characters, product strings longer than 80 characters, or more than four products. The &ldquo;gen&rdquo; set was composed by iterating reaction templates (as SMARTS) over small molecules from GDB-11, covering five substrate types: acid derivatives, alcohols, aldehydes/ketones, alkenes, and alkynes.</p>
<p>Two models were compared: a &ldquo;gen&rdquo; model (trained only on generated reactions) and a &ldquo;real+gen&rdquo; model (trained on both sets).</p>
<h3 id="textbook-problem-evaluation">Textbook Problem Evaluation</h3>
<p>The models were tested on 10 problem sets from Wade&rsquo;s textbook, following the evaluation approach of Wei et al. Each problem set contained 6-15 reactions. Evaluation metrics included the ratio of fully correct predictions and the average <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> between Morgan fingerprints of predicted and actual products.</p>
<p>The &ldquo;real+gen&rdquo; model outperformed the &ldquo;gen&rdquo; model on most problem sets. On problem set 17-44 (aromatic compound reactions, only present in the &ldquo;real&rdquo; training set), the &ldquo;real+gen&rdquo; model correctly answered 4 out of 11 problems while the &ldquo;gen&rdquo; model answered 2. The &ldquo;gen&rdquo; model&rsquo;s ability to correctly predict some aromatic reactions despite never being trained on them suggests the model can extrapolate to unseen reaction patterns.</p>
<p>For <a href="https://en.wikipedia.org/wiki/Diels%E2%80%93Alder_reaction">Diels-Alder reactions</a> (problem set 15-30), neither model achieved fully correct predictions for all problems, though the &ldquo;real+gen&rdquo; model showed better Tanimoto scores, indicating partially correct structural predictions even when the exact product was missed.</p>
<h3 id="scalability-testing">Scalability Testing</h3>
<p>A scalability test used generated reactions with substrate molecules containing 11-16 atoms (larger than the training set molecules with fewer than 11 atoms). Results showed:</p>
<ul>
<li>The &ldquo;real+gen&rdquo; model maintained Tanimoto scores around 0.7 and error rates around 0.4 as substrate atom count increased</li>
<li>The ratio of fully correct predictions decreased as atom count increased, revealing that the recurrent network struggled with longer input sequences</li>
<li>The &ldquo;real+gen&rdquo; model produced fewer invalid SMILES strings than the &ldquo;gen&rdquo; model, likely because training on more reactions improved the decoder&rsquo;s ability to generate syntactically valid SMILES</li>
</ul>
<h3 id="attention-analysis">Attention Analysis</h3>
<p>Visualization of attention weights revealed a limitation: the decoder cells predominantly attended to the first few encoder cells rather than distributing attention across the full input sequence. This means the attention mechanism was not learning meaningful &ldquo;alignment&rdquo; between reactant atoms and product atoms. The authors note that if decoder cells generating tokens for unreactive sites could attend to the corresponding encoder cells (analogous to atom mapping), prediction quality on longer sequences could improve.</p>
<h3 id="token-embedding-analysis">Token Embedding Analysis</h3>
<p>t-SNE visualization of the learned token embeddings showed that encoder and decoder tokens clustered primarily by syntactic similarity rather than chemical properties. The model did not learn chemically meaningful embeddings, which the authors identify as an area for future improvement.</p>
<h2 id="key-findings-limitations-and-impact">Key Findings, Limitations, and Impact</h2>
<h3 id="key-findings">Key Findings</h3>
<ul>
<li>Treating reaction prediction as NMT is viable: the seq2seq model can predict products without any hand-coded rules</li>
<li>Training on real patent data significantly improves prediction over generated data alone</li>
<li>The model can extrapolate to reaction types not seen during training (e.g., the &ldquo;gen&rdquo; model predicting aromatic reactions)</li>
<li>Compared to the fingerprint-based approach of Wei et al., this method performed better on textbook problems and eliminated the need for manual SMARTS encoding</li>
</ul>
<h3 id="limitations">Limitations</h3>
<ul>
<li><strong>Invalid SMILES generation</strong>: the token-by-token generation process can produce syntactically invalid SMILES (e.g., mismatched parentheses), which the authors scored as zero</li>
<li><strong>Sequence length degradation</strong>: prediction accuracy dropped for longer SMILES strings, a known limitation of RNN-based seq2seq models at the time</li>
<li><strong>Poor attention alignment</strong>: attention weights collapsed to the first encoder positions rather than learning meaningful reactant-product correspondences</li>
<li><strong>Chemically naive embeddings</strong>: token embeddings did not capture chemical properties</li>
<li><strong>Multiple reaction pathways</strong>: reactions with competing pathways (e.g., substitution vs. elimination) were difficult for the model to handle</li>
</ul>
<h3 id="historical-significance">Historical Significance</h3>
<p>This paper is historically significant as one of the first (alongside concurrent work) to propose the NMT framing for reaction prediction. This framing was later adopted and refined by the <a href="/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/">Molecular Transformer</a> (Schwaller et al., 2019), which replaced GRUs with the Transformer architecture and achieved over 90% top-1 accuracy on standard benchmarks. The conceptual contribution of treating SMILES-to-SMILES translation as machine translation became the foundation of an entire subfield.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">Training (real)</td>
          <td style="text-align: left">USPTO patent reactions</td>
          <td style="text-align: left">1,094,235</td>
          <td style="text-align: left">2001-2013 applications, filtered by length</td>
      </tr>
      <tr>
          <td style="text-align: left">Training (gen)</td>
          <td style="text-align: left">Generated from Wade textbook templates</td>
          <td style="text-align: left">865,118</td>
          <td style="text-align: left">75 reaction types, GDB-11 substrates</td>
      </tr>
      <tr>
          <td style="text-align: left">Testing (textbook)</td>
          <td style="text-align: left">Wade textbook problems</td>
          <td style="text-align: left">~100</td>
          <td style="text-align: left">10 problem sets, 6-15 reactions each</td>
      </tr>
      <tr>
          <td style="text-align: left">Testing (scalability)</td>
          <td style="text-align: left">Generated from <a href="/notes/chemistry/datasets/gdb-17/">GDB-17</a></td>
          <td style="text-align: left">2,400</td>
          <td style="text-align: left">400 per atom count (11-16)</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>GRU-based encoder-decoder with attention mechanism</li>
<li>PEG-based SMILES tokenizer</li>
<li>Input sequence reversal</li>
<li>Bucketed training with four bucket sizes</li>
<li>TensorFlow seq2seq tutorial implementation with default learning rate</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Parameter</th>
          <th style="text-align: left">Value</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">GRU layers</td>
          <td style="text-align: left">3</td>
      </tr>
      <tr>
          <td style="text-align: left">Embedding size</td>
          <td style="text-align: left">600</td>
      </tr>
      <tr>
          <td style="text-align: left">Input vocabulary</td>
          <td style="text-align: left">311 tokens</td>
      </tr>
      <tr>
          <td style="text-align: left">Output vocabulary</td>
          <td style="text-align: left">180 tokens</td>
      </tr>
      <tr>
          <td style="text-align: left">Buckets</td>
          <td style="text-align: left">(54,54), (70,60), (90,65), (150,80)</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">gen Model</th>
          <th style="text-align: left">real+gen Model</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">Textbook correct ratio</td>
          <td style="text-align: left">Variable by set</td>
          <td style="text-align: left">Higher on most sets</td>
          <td style="text-align: left">10 problem sets</td>
      </tr>
      <tr>
          <td style="text-align: left">Average Tanimoto similarity</td>
          <td style="text-align: left">Variable</td>
          <td style="text-align: left">~0.7 on scalability test</td>
          <td style="text-align: left">Morgan fingerprint based</td>
      </tr>
      <tr>
          <td style="text-align: left">Invalid SMILES ratio</td>
          <td style="text-align: left">Higher</td>
          <td style="text-align: left">~0.4 on scalability test</td>
          <td style="text-align: left">Decreases with more training data</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Nam, J. &amp; Kim, J. (2016). Linking the Neural Machine Translation and the Prediction of Organic Chemistry Reactions. <em>arXiv preprint</em>, arXiv:1612.09529. <a href="https://arxiv.org/abs/1612.09529">https://arxiv.org/abs/1612.09529</a></p>
<p><strong>Publication</strong>: arXiv preprint 2016</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{nam2016linking,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Linking the Neural Machine Translation and the Prediction of Organic Chemistry Reactions}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Nam, Juno and Kim, Jurae}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:1612.09529}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2016}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.48550/arxiv.1612.09529}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolFM: Trimodal Molecular Foundation Pre-training</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/molfm-multimodal-molecular-foundation/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/molfm-multimodal-molecular-foundation/</guid><description>MolFM fuses molecular graphs, biomedical text, and knowledge graphs via cross-modal attention for joint molecular representation learning.</description><content:encoded><![CDATA[<h2 id="trimodal-pre-training-for-molecular-understanding">Trimodal Pre-training for Molecular Understanding</h2>
<p>MolFM is a <strong>Method</strong> paper that introduces a multimodal molecular foundation model integrating three distinct sources of molecular knowledge: 2D molecular graphs, biomedical text, and knowledge graphs. The primary contribution is a pre-training framework that uses fine-grained cross-modal attention to fuse information across all three modalities, combined with theoretical justification from a deep metric learning perspective. MolFM achieves the best reported results (at time of publication) on cross-modal retrieval, molecule captioning, text-based molecule generation, and molecular property prediction.</p>
<h2 id="why-existing-molecular-models-fall-short">Why Existing Molecular Models Fall Short</h2>
<p>Prior multimodal molecular foundation models operate on at most two modalities (structures and text) and suffer from two key limitations. First, generative approaches like KV-PLM and MolT5 rely on 1D <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings, which cannot capture complex topological and spatial molecular properties such as macrocycles. Contrastive approaches like <a href="/notes/chemistry/molecular-representations/multimodal/momu-molecular-multimodal-foundation/">MoMu</a> and MoleculeSTM learn global alignment between molecule graphs and text but overlook fine-grained connections between specific substructures and textual descriptions.</p>
<p>Second, and more fundamentally, no prior model incorporates <a href="https://en.wikipedia.org/wiki/Knowledge_graph">knowledge graphs</a> as a third modality. Knowledge graphs encode global-level relationships among molecules, target ligands, diseases, and other biomedical entities. These relationships capture functional and structural similarity patterns that cannot be learned from individual molecule-text pairs alone. MolFM addresses both gaps by introducing cross-modal attention across all three modalities and providing theoretical guarantees about what the pre-training objectives learn.</p>
<h2 id="cross-modal-attention-and-metric-learning-guarantees">Cross-Modal Attention and Metric Learning Guarantees</h2>
<h3 id="architecture">Architecture</h3>
<p>MolFM uses three pre-trained single-modal encoders:</p>
<ul>
<li><strong>Molecular graph encoder</strong>: A 5-layer GIN (1.8M parameters) initialized from GraphMVP, producing atom-level features $h_{SA}$ and a graph-level feature $h_{SM}$</li>
<li><strong>Text encoder</strong>: A 6-layer transformer (61.8M parameters) initialized from KV-PLM&rsquo;s first 6 layers, producing token features $h_T$</li>
<li><strong>Knowledge graph encoder</strong>: A TransE model (12.6M parameters) trained on the knowledge graph for 500 epochs, producing entity features $h_K$</li>
</ul>
<p>A multimodal encoder (61.8M parameters, 6 transformer layers with cross-attention) fuses the three modalities. The cross-attention uses text token features as queries and the concatenation of atom features and knowledge graph neighbor features as keys and values. For each molecule, the knowledge graph input is the molecule&rsquo;s entity and $N=4$ randomly sampled one-hop neighbors.</p>
<h3 id="pre-training-objectives">Pre-training Objectives</h3>
<p>MolFM combines four losses:</p>
<p><strong>Structure-text contrastive (STC)</strong> aligns the global feature spaces of structure and text encoders using a symmetric InfoNCE loss:</p>
<p>$$\mathcal{L}_{stc} = -\frac{1}{2} \left[ \log \frac{\exp(s(z_S, z_T) / \tau)}{\sum_{S&rsquo; \in B} \exp(s(z_{S&rsquo;}, z_T) / \tau)} + \log \frac{\exp(s(z_S, z_T) / \tau)}{\sum_{T&rsquo; \in B} \exp(s(z_S, z_{T&rsquo;}) / \tau)} \right]$$</p>
<p>where $s(\cdot, \cdot)$ is cosine similarity and $\tau = 0.1$ is a temperature parameter.</p>
<p><strong>Cross-modal matching (CMM)</strong> predicts whether a structure-text-knowledge triplet corresponds to the same molecule, using cross-entropy over the multimodal encoder&rsquo;s CLS token:</p>
<p>$$\mathcal{L}_{cmm} = \sum_{(\tilde{S}, \tilde{T}, \tilde{K}) \in \tilde{B}} H\left[y_{cmm}(\tilde{S}, \tilde{T}, \tilde{K}),; p_{cmm}\left(\mathcal{M}_\theta(h_{\tilde{S}}, h_{\tilde{T}}, h_{\tilde{K}})\right)\right]$$</p>
<p><strong>Masked language modeling (MLM)</strong> predicts masked text tokens conditioned on all three modalities:</p>
<p>$$\mathcal{L}_{mlm} = H\left[y_{mlm}(\hat{T}),; p_{mlm}\left(\mathcal{M}_\theta(h_S, h_{\hat{T}}, h_K)\right)\right]$$</p>
<p><strong>Knowledge graph embedding (KGE)</strong> regularizes entity embeddings with a max-margin TransE loss:</p>
<p>$$\mathcal{L}_{kge} = \sum_{h \in K} \left[\max(0, d(h,r,t) - d(h,r,\tilde{t}) + \Delta) + \max(0, d(h,r,t) - d(\tilde{h},r,t) + \Delta)\right]$$</p>
<p>where $d(h,r,t) = | f(h) + g(r) - f(t) |_2$ and $\Delta = 0.2$.</p>
<p>The total pre-training loss is:</p>
<p>$$\mathcal{L} = \mathbb{E}_{(S,T,K)}\left[\mathcal{L}_{stc} + \mathcal{L}_{cmm} + \mathcal{L}_{mlm} + \mathcal{L}_{kge}\right]$$</p>
<h3 id="theoretical-justifications">Theoretical Justifications</h3>
<p>The authors provide metric learning interpretations for each objective. For CMM, they show that the loss is proportional to assigning higher scores to matched triplets and lower scores to unmatched ones, aligning the feature space across all three modalities.</p>
<p>For KGE, two lemmas provide guarantees about structurally and functionally similar molecules:</p>
<p><strong>Lemma 1</strong> (Structural similarity): For a symmetric structural-similarity relation $r_s$, the KGE loss satisfies:</p>
<p>$$\mathcal{L}_{kge}(h, r_s, t) \propto 2|f(h) - f(t)| - \mathbb{E}_{\tilde{t}}|f(h) - f(\tilde{t})| - \mathbb{E}_{\tilde{h}}|f(\tilde{h}) - f(t)|$$</p>
<p>This shows KGE pulls structurally similar molecules closer while pushing dissimilar ones apart.</p>
<p><strong>Lemma 2</strong> (Functional similarity): For molecules $h$ and $t$ that interact with a common entity $o$, the distance between their embeddings is upper-bounded:</p>
<p>$$|f(h) - f(t)| \leq \alpha,\mathbb{E}_{(e_1, r, e_2) \sim \mathcal{I}}\left[\mathcal{L}_{kge}(e_1, r, e_2)\right] + C$$</p>
<p>where $\alpha \approx 1$ and $C \approx 0$. This guarantees that minimizing KGE also brings functionally similar molecules closer in the embedding space.</p>
<h2 id="experiments-across-four-downstream-tasks">Experiments Across Four Downstream Tasks</h2>
<h3 id="pre-training-data">Pre-training Data</h3>
<p>MolFM pre-trains on 15K molecules from <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> paired with 37M paragraphs from S2ORC. The knowledge graph contains 49K entities and 3.2M relations, constructed from <a href="https://en.wikipedia.org/wiki/DrugBank">DrugBank</a>, <a href="https://en.wikipedia.org/wiki/BindingDB">BindingDB</a>, and additional public databases with heuristic augmentation.</p>
<h3 id="cross-modal-retrieval">Cross-Modal Retrieval</h3>
<p>Evaluated on PCdes (paragraph-level) in zero-shot and fine-tuning settings. MolFM uses a re-ranking strategy that linearly combines cosine similarity with CMM logits over the top-$k$ retrieved candidates.</p>
<table>
  <thead>
      <tr>
          <th>Mode</th>
          <th>Model</th>
          <th>S-T MRR</th>
          <th>S-T R@1</th>
          <th>S-T R@10</th>
          <th>T-S MRR</th>
          <th>T-S R@1</th>
          <th>T-S R@10</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Zero-shot</td>
          <td>MoMu</td>
          <td>9.89</td>
          <td>5.08</td>
          <td>18.93</td>
          <td>10.33</td>
          <td>4.90</td>
          <td>20.69</td>
      </tr>
      <tr>
          <td>Zero-shot</td>
          <td>MolFM</td>
          <td>21.42</td>
          <td>13.90</td>
          <td>36.21</td>
          <td>23.63</td>
          <td>16.14</td>
          <td>39.54</td>
      </tr>
      <tr>
          <td>Fine-tune</td>
          <td>MoMu</td>
          <td>34.29</td>
          <td>24.47</td>
          <td>53.84</td>
          <td>34.53</td>
          <td>24.87</td>
          <td>54.25</td>
      </tr>
      <tr>
          <td>Fine-tune</td>
          <td>MolFM</td>
          <td>39.56</td>
          <td>29.76</td>
          <td>58.63</td>
          <td>39.34</td>
          <td>29.39</td>
          <td>58.49</td>
      </tr>
  </tbody>
</table>
<p>MolFM achieves 12.13% and 5.04% absolute gains over MoMu under zero-shot and fine-tuning settings, respectively.</p>
<h3 id="molecule-captioning">Molecule Captioning</h3>
<p>Evaluated on ChEBI-20 using MolT5 decoders. MolFM&rsquo;s structure encoder features are concatenated with the MolT5 encoder outputs.</p>
<table>
  <thead>
      <tr>
          <th>Decoder</th>
          <th>Encoder</th>
          <th>BLEU-4</th>
          <th>ROUGE-L</th>
          <th>METEOR</th>
          <th>Text2Mol</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MolT5-base</td>
          <td>MolT5-base</td>
          <td>0.457</td>
          <td>0.578</td>
          <td>0.569</td>
          <td>0.547</td>
      </tr>
      <tr>
          <td>MolT5-base</td>
          <td>MoMu</td>
          <td>0.462</td>
          <td>0.575</td>
          <td>0.576</td>
          <td>0.558</td>
      </tr>
      <tr>
          <td>MolT5-base</td>
          <td>GraphMVP</td>
          <td>0.491</td>
          <td>0.592</td>
          <td>0.599</td>
          <td>0.570</td>
      </tr>
      <tr>
          <td>MolT5-base</td>
          <td>MolFM</td>
          <td>0.498</td>
          <td>0.594</td>
          <td>0.607</td>
          <td>0.576</td>
      </tr>
  </tbody>
</table>
<h3 id="text-based-molecule-generation">Text-Based Molecule Generation</h3>
<p>Also on ChEBI-20 with MolT5 decoders. MolFM&rsquo;s text features are projected and fed to the decoder.</p>
<table>
  <thead>
      <tr>
          <th>Decoder</th>
          <th>Encoder</th>
          <th>Exact</th>
          <th>Valid</th>
          <th>Morgan FTS</th>
          <th>Text2Mol</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MolT5-base</td>
          <td>MolT5-base</td>
          <td>0.082</td>
          <td>0.786</td>
          <td>0.601</td>
          <td>0.543</td>
      </tr>
      <tr>
          <td>MolT5-base</td>
          <td>MoMu</td>
          <td>0.183</td>
          <td>0.863</td>
          <td>0.678</td>
          <td>0.580</td>
      </tr>
      <tr>
          <td>MolT5-base</td>
          <td>MolFM</td>
          <td>0.210</td>
          <td>0.892</td>
          <td>0.697</td>
          <td>0.583</td>
      </tr>
  </tbody>
</table>
<h3 id="molecular-property-prediction">Molecular Property Prediction</h3>
<p>On <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> (8 classification datasets), MolFM concatenates the structure feature and the multimodal encoder&rsquo;s CLS feature to predict properties.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>BBBP</th>
          <th>Tox21</th>
          <th>ClinTox</th>
          <th>HIV</th>
          <th>BACE</th>
          <th>Avg</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GraphMVP</td>
          <td>72.4</td>
          <td>74.4</td>
          <td>77.5</td>
          <td>77.0</td>
          <td>81.2</td>
          <td>73.07</td>
      </tr>
      <tr>
          <td>DeepEIK</td>
          <td>72.1</td>
          <td>72.4</td>
          <td>89.7</td>
          <td>75.0</td>
          <td>80.5</td>
          <td>73.27</td>
      </tr>
      <tr>
          <td>MolFM (w/o T+K)</td>
          <td>72.2</td>
          <td>76.6</td>
          <td>78.6</td>
          <td>78.2</td>
          <td>82.6</td>
          <td>73.95</td>
      </tr>
      <tr>
          <td>MolFM (w/ T+K)</td>
          <td>72.9</td>
          <td>77.2</td>
          <td>79.7</td>
          <td>78.8</td>
          <td>83.9</td>
          <td>74.62</td>
      </tr>
  </tbody>
</table>
<p>With multimodal inputs, MolFM averages 74.62% ROC-AUC, a 1.55% absolute gain over GraphMVP.</p>
<h3 id="ablation-studies">Ablation Studies</h3>
<p>Zero-shot retrieval ablations reveal that cross-modal attention to atoms and CMM are the most critical components. Removing either causes a sharp drop (approximately 3% on S-T retrieval). Knowledge graph incorporation yields a 1.5% average improvement, with both attention to neighbors and KGE contributing marginally.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<p>MolFM demonstrates that incorporating knowledge graphs as a third modality provides consistent improvements across all evaluated tasks. The theoretical analysis connecting pre-training objectives to deep metric learning provides interpretability for why the model works: STC and CMM align representations of the same molecule across modalities, while KGE pulls structurally and functionally similar molecules closer in the embedding space.</p>
<p>The cross-modal attention visualizations show that MolFM learns to associate specific atom substructures with relevant text tokens and knowledge graph entities. For example, the model correctly attends to functional groups mentioned in textual descriptions.</p>
<p>The authors acknowledge several limitations:</p>
<ol>
<li><strong>Data quality</strong>: The pre-training dataset (15K molecules) is small and may introduce biases</li>
<li><strong>Cold-start problem</strong>: MolFM provides limited benefit for newly emerged molecules lacking text and knowledge graph information</li>
<li><strong>Entity scope</strong>: The model focuses on molecules and does not incorporate proteins, genes, or cell lines, which could further improve biomedical understanding</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training (molecules)</td>
          <td>PubChem</td>
          <td>15K molecules</td>
          <td>Follows MoMu&rsquo;s pre-training data</td>
      </tr>
      <tr>
          <td>Pre-training (text)</td>
          <td>S2ORC</td>
          <td>37M paragraphs</td>
          <td>Biomedical literature paragraphs</td>
      </tr>
      <tr>
          <td>Knowledge graph</td>
          <td>DrugBank, BindingDB, public DBs</td>
          <td>49K entities, 3.2M relations</td>
          <td>Constructed with heuristics from MoCL</td>
      </tr>
      <tr>
          <td>Cross-modal retrieval</td>
          <td>PCdes</td>
          <td>Paragraph-level</td>
          <td>Test split</td>
      </tr>
      <tr>
          <td>Captioning/Generation</td>
          <td>ChEBI-20</td>
          <td>-</td>
          <td>Following MolT5 splits</td>
      </tr>
      <tr>
          <td>Property prediction</td>
          <td>MoleculeNet</td>
          <td>8 datasets</td>
          <td>Classification tasks, ROC-AUC metric</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Optimizer: AdamW with weight decay $1 \times 10^{-4}$</li>
<li>Learning rate: linear warmup to $1 \times 10^{-4}$ over 2,000 iterations, cosine annealing to $1 \times 10^{-5}$</li>
<li>Batch size: 128</li>
<li>Pre-training epochs: 300</li>
<li>Knowledge graph neighbors per molecule: $N = 4$</li>
<li>Temperature: $\tau = 0.1$</li>
<li>Margin: $\Delta = 0.2$</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Component</th>
          <th>Architecture</th>
          <th>Parameters</th>
          <th>Initialization</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Graph encoder</td>
          <td>5-layer GIN</td>
          <td>1.8M</td>
          <td>GraphMVP</td>
      </tr>
      <tr>
          <td>Text encoder</td>
          <td>6-layer Transformer</td>
          <td>61.8M</td>
          <td>KV-PLM (first 6 layers)</td>
      </tr>
      <tr>
          <td>Knowledge encoder</td>
          <td>TransE</td>
          <td>12.6M</td>
          <td>Trained 500 epochs on KG</td>
      </tr>
      <tr>
          <td>Multimodal encoder</td>
          <td>6-layer Transformer + cross-attention</td>
          <td>61.8M</td>
          <td>KV-PLM (last 6 layers)</td>
      </tr>
      <tr>
          <td><strong>Total</strong></td>
          <td></td>
          <td><strong>~138M</strong></td>
          <td></td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Metrics</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Cross-modal retrieval</td>
          <td>MRR, Recall@1/5/10</td>
      </tr>
      <tr>
          <td>Molecule captioning</td>
          <td>BLEU-2/4, ROUGE-1/2/L, METEOR, Text2Mol</td>
      </tr>
      <tr>
          <td>Text-to-molecule generation</td>
          <td>BLEU, Exact ratio, Validity, Levenshtein, Fingerprint Tanimoto (MACCS/RDKit/Morgan), Text2Mol</td>
      </tr>
      <tr>
          <td>Property prediction</td>
          <td>ROC-AUC per dataset</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>4 NVIDIA A100 GPUs for pre-training</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/BioFM/OpenBioMed">OpenBioMed</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation including MolFM</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Luo, Y., Yang, K., Hong, M., Liu, X. Y., &amp; Nie, Z. (2023). MolFM: A Multimodal Molecular Foundation Model. <em>arXiv preprint arXiv:2307.09484</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{luo2023molfm,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{MolFM: A Multimodal Molecular Foundation Model}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Luo, Yizhen and Yang, Kai and Hong, Massimo and Liu, Xing Yi and Nie, Zaiqing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2307.09484}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>LlaSMol: Instruction-Tuned LLMs for Chemistry Tasks</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/llamsmol-instruction-tuning-chemistry/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/llamsmol-instruction-tuning-chemistry/</guid><description>LlaSMol fine-tunes open-source LLMs on SMolInstruct, a 3.3M-sample chemistry instruction dataset spanning 14 tasks, outperforming GPT-4 on all chemistry tasks.</description><content:encoded><![CDATA[<h2 id="a-resource-for-chemistry-instruction-tuning">A Resource for Chemistry Instruction Tuning</h2>
<p>This is a <strong>Resource</strong> paper that contributes both a large-scale instruction tuning dataset (SMolInstruct) and a family of fine-tuned LLMs (LlaSMol) for chemistry tasks. The primary contribution is SMolInstruct, a dataset of 3.3 million samples across 14 chemistry tasks, paired with systematic experiments showing that instruction-tuned open-source LLMs can substantially outperform GPT-4 and Claude 3 Opus on chemistry benchmarks. The dataset construction methodology, quality control pipeline, and careful data splitting are central to the paper&rsquo;s value.</p>
<h2 id="why-llms-struggle-with-chemistry-tasks">Why LLMs Struggle with Chemistry Tasks</h2>
<p>Prior work demonstrated that general-purpose LLMs perform poorly on chemistry tasks. Guo et al. (2023) found that GPT-4, while outperforming other LLMs, falls far short of task-specific deep learning models, particularly on tasks requiring precise understanding of <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> representations. Fang et al. (2023) attempted instruction tuning with Mol-Instructions, but the resulting models still performed well below task-specific baselines.</p>
<p>These results raised a fundamental question: are LLMs inherently limited for chemistry, or is the problem simply insufficient training data? The authors argue it is the latter. Previous instruction tuning datasets suffered from limited scale (Mol-Instructions had 1.3M samples with fewer task types), lower quality (numerous low-quality molecular descriptions, mislabeled reactants/reagents in reaction data), and suboptimal design choices (using <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> instead of canonical SMILES, inconsistent data splitting that allowed leakage).</p>
<h2 id="smolinstruct-a-comprehensive-chemistry-instruction-dataset">SMolInstruct: A Comprehensive Chemistry Instruction Dataset</h2>
<p>The core innovation is the SMolInstruct dataset, which addresses the limitations of prior datasets through three design principles:</p>
<p><strong>Scale and comprehensiveness.</strong> SMolInstruct contains 3.3M samples across 14 tasks organized into four categories:</p>
<ul>
<li><strong>Name conversion</strong> (4 tasks): <a href="https://en.wikipedia.org/wiki/IUPAC_nomenclature_of_organic_chemistry">IUPAC</a>-to-formula, IUPAC-to-SMILES, SMILES-to-formula, SMILES-to-IUPAC, sourced from <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a></li>
<li><strong>Property prediction</strong> (6 tasks): ESOL, Lipo, BBBP, ClinTox, HIV, SIDER, sourced from <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a></li>
<li><strong>Molecule description</strong> (2 tasks): molecule captioning and molecule generation, sourced from <a href="https://en.wikipedia.org/wiki/ChEBI">ChEBI-20</a> and Mol-Instructions</li>
<li><strong>Chemical reactions</strong> (2 tasks): forward synthesis and retrosynthesis, sourced from USPTO-full</li>
</ul>
<p><strong>Quality control.</strong> The authors apply rigorous curation: invalid SMILES are filtered using RDKit, mislabeled reactants/reagents in USPTO-full are corrected by comparing atom mappings with products, low-quality molecular descriptions are removed using pattern-based rules, and duplicates are eliminated.</p>
<p><strong>Careful data splitting.</strong> To prevent data leakage across related tasks (e.g., forward synthesis and retrosynthesis share the same reactions), the authors ensure matched samples across reverse tasks are placed together in either training or evaluation sets. Samples with identical inputs but different outputs are also grouped together to prevent exaggerated performance estimates.</p>
<p>Additionally, all SMILES representations are canonicalized, and special tags (e.g., <code>&lt;SMILES&gt;...&lt;/SMILES&gt;</code>) encapsulate different information types within the instruction templates.</p>
<h2 id="experimental-setup-four-base-models-and-comprehensive-baselines">Experimental Setup: Four Base Models and Comprehensive Baselines</h2>
<p>The authors fine-tune four open-source LLMs using LoRA (applied to all attention and FFN linear layers, with rank and alpha both set to 16):</p>
<ul>
<li><strong><a href="/notes/chemistry/llm-applications/galactica-large-language-model-for-science/">Galactica</a> 6.7B</strong>: pretrained on scientific text including chemistry data</li>
<li><strong>Llama 2 7B</strong>: general-purpose LLM</li>
<li><strong>Code Llama 7B</strong>: code-focused variant of Llama 2</li>
<li><strong>Mistral 7B</strong>: general-purpose LLM</li>
</ul>
<p>Training uses 8-bit AdamW with learning rate 1e-4, cosine scheduler, and 3 epochs. Only 0.58% of parameters are fine-tuned (approximately 41.9M parameters). Beam search is used at inference.</p>
<p><strong>Baselines</strong> include:</p>
<ul>
<li>General LLMs without fine-tuning: GPT-4, Claude 3 Opus, and the four base models</li>
<li>Chemistry-specific LLMs: Molinst (Llama 2 tuned on Mol-Instructions), <a href="/notes/chemistry/llm-applications/chemllm-chemical-large-language-model/">ChemLLM</a></li>
<li>Task-specific non-LLM models: <a href="/notes/chemistry/molecular-representations/name-translation/stout-v2/">STOUT</a> for name conversion, Uni-Mol for property prediction, MolT5 for molecule description, RSMILES and <a href="/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/">Molecular Transformer</a> for reaction prediction</li>
</ul>
<h3 id="main-results">Main Results</h3>
<table>
  <thead>
      <tr>
          <th>Task Category</th>
          <th>Best LlaSMol</th>
          <th>GPT-4</th>
          <th>Improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Name conversion (NC-I2F, EM%)</td>
          <td>87.9 (Mistral)</td>
          <td>8.7</td>
          <td>+79.2</td>
      </tr>
      <tr>
          <td>Name conversion (NC-I2S, EM%)</td>
          <td>70.1 (Mistral)</td>
          <td>3.3</td>
          <td>+66.8</td>
      </tr>
      <tr>
          <td>Property prediction (PP-ESOL, RMSE)</td>
          <td>1.150 (Mistral)</td>
          <td>2.570</td>
          <td>-1.42 (lower is better)</td>
      </tr>
      <tr>
          <td>Property prediction (PP-BBBP, Acc%)</td>
          <td>74.6 (Mistral)</td>
          <td>62.9</td>
          <td>+11.7</td>
      </tr>
      <tr>
          <td>Molecule captioning (<a href="https://en.wikipedia.org/wiki/METEOR">METEOR</a>)</td>
          <td>0.452 (Mistral)</td>
          <td>0.188</td>
          <td>+0.264</td>
      </tr>
      <tr>
          <td>Molecule generation (FTS%)</td>
          <td>61.7 (Mistral)</td>
          <td>42.6</td>
          <td>+19.1</td>
      </tr>
      <tr>
          <td>Forward synthesis (EM%)</td>
          <td>63.3 (Mistral)</td>
          <td>1.6</td>
          <td>+61.7</td>
      </tr>
      <tr>
          <td>Retrosynthesis (EM%)</td>
          <td>32.9 (Mistral)</td>
          <td>0.0</td>
          <td>+32.9</td>
      </tr>
  </tbody>
</table>
<p>LlaSMolMistral consistently outperforms all other LLMs and the other LlaSMol variants. It also surpasses task-specific SoTA models on PP-ClinTox (93.1 vs. 92.4) and PP-SIDER (70.7 vs. 70.0), though it has not yet matched SoTA on most other tasks.</p>
<h3 id="ablation-study">Ablation Study</h3>
<p>The ablation study examines three variants:</p>
<ol>
<li>
<p><strong>Without canonicalization</strong>: Performance drops on most tasks, with substantial decreases on forward synthesis (63.3 to 53.7 EM%) and retrosynthesis (32.9 to 23.8 EM%), confirming that canonicalized SMILES reduce learning difficulty.</p>
</li>
<li>
<p><strong>Using SELFIES instead of SMILES</strong>: While SELFIES achieves slightly higher validity (100% vs. 99.7% on some tasks), it results in worse performance overall. SELFIES strings are typically longer than SMILES, making them harder for models to process accurately. This finding contradicts claims from prior work (Fang et al., 2023) that SELFIES should be preferred.</p>
</li>
<li>
<p><strong>Training on Mol-Instructions instead of SMolInstruct</strong>: Using the same base model (Mistral) and identical training settings, the Mol-Instructions-trained model performs drastically worse, achieving near-zero accuracy on name conversion and property prediction tasks, and much lower performance on shared tasks (MC, MG, FS, RS).</p>
</li>
</ol>
<h3 id="additional-analysis">Additional Analysis</h3>
<p>Multi-task training generally outperforms single-task training, with particularly large improvements on PP-ESOL (RMSE 20.616 to 1.150) and molecule generation (FTS 33.1% to 61.7%). Increasing the number of trainable LoRA parameters from 6.8M (0.09%) to 173.0M (2.33%) leads to consistent performance improvements across most tasks, suggesting further gains are possible with more extensive fine-tuning.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<p>The paper establishes several findings:</p>
<ol>
<li>
<p><strong>LLMs can perform chemistry tasks effectively</strong> when provided with sufficient high-quality instruction tuning data. This refutes the notion that LLMs are fundamentally limited for chemistry.</p>
</li>
<li>
<p><strong>The choice of base model matters considerably.</strong> Mistral 7B outperforms Llama 2, Code Llama, and Galactica despite identical training, suggesting that general language understanding transfers well to chemistry.</p>
</li>
<li>
<p><strong>Canonical SMILES outperform both non-canonical SMILES and SELFIES</strong> for LLM-based chemistry, a practical recommendation for future work.</p>
</li>
<li>
<p><strong>Dataset quality is more important than model architecture.</strong> The same base model trained on SMolInstruct vastly outperforms the same model trained on Mol-Instructions.</p>
</li>
</ol>
<p>The authors acknowledge several limitations. The evaluation metrics for molecule captioning and generation (METEOR, FTS) measure text similarity rather than chemical correctness. The paper does not evaluate generalization to tasks beyond the 14 training tasks. LlaSMol models do not yet outperform task-specific SoTA models on most tasks, though the gap has narrowed substantially with only 0.58% of parameters fine-tuned.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>SMolInstruct</td>
          <td>3.29M samples</td>
          <td>14 tasks, canonical SMILES, publicly available on HuggingFace</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>SMolInstruct test split</td>
          <td>33,061 samples</td>
          <td>Careful splitting to prevent leakage across tasks</td>
      </tr>
      <tr>
          <td>NC tasks</td>
          <td>PubChem</td>
          <td>~300K molecules</td>
          <td>IUPAC names, SMILES, molecular formulas</td>
      </tr>
      <tr>
          <td>PP tasks</td>
          <td>MoleculeNet</td>
          <td>~78K samples</td>
          <td>6 datasets (ESOL, Lipo, BBBP, ClinTox, HIV, SIDER)</td>
      </tr>
      <tr>
          <td>MC/MG tasks</td>
          <td>ChEBI-20 + Mol-Instructions</td>
          <td>~60K samples</td>
          <td>Quality-filtered molecular descriptions</td>
      </tr>
      <tr>
          <td>FS/RS tasks</td>
          <td>USPTO-full</td>
          <td>~1.9M samples</td>
          <td>Cleaned, with corrected reactant/reagent labels</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Fine-tuning</strong>: LoRA with rank=16, alpha=16, applied to all attention and FFN linear layers</li>
<li><strong>Optimizer</strong>: 8-bit AdamW, learning rate 1e-4, cosine scheduler</li>
<li><strong>Training</strong>: 3 epochs, max input length 512 tokens</li>
<li><strong>Inference</strong>: Beam search with beam size = <code>num_return_sequences</code> + 3</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Base</th>
          <th>Parameters</th>
          <th>LoRA Parameters</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>LlaSMolGalactica</td>
          <td>Galactica 6.7B</td>
          <td>6.7B</td>
          <td>41.9M (0.58%)</td>
      </tr>
      <tr>
          <td>LlaSMolLlama2</td>
          <td>Llama 2 7B</td>
          <td>7B</td>
          <td>41.9M (0.58%)</td>
      </tr>
      <tr>
          <td>LlaSMolCodeLlama</td>
          <td>Code Llama 7B</td>
          <td>7B</td>
          <td>41.9M (0.58%)</td>
      </tr>
      <tr>
          <td>LlaSMolMistral</td>
          <td>Mistral 7B</td>
          <td>7B</td>
          <td>41.9M (0.58%)</td>
      </tr>
  </tbody>
</table>
<p>All models and the dataset are publicly released on HuggingFace.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task(s)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Exact Match (EM)</td>
          <td>NC, MG, FS, RS</td>
          <td>Molecular identity comparison via RDKit</td>
      </tr>
      <tr>
          <td>Fingerprint <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto Similarity</a> (FTS)</td>
          <td>MG, FS, RS</td>
          <td>Morgan fingerprints</td>
      </tr>
      <tr>
          <td>METEOR</td>
          <td>MC</td>
          <td>Text similarity metric</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>PP-ESOL, PP-Lipo</td>
          <td>Regression tasks</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td>PP-BBBP, PP-ClinTox, PP-HIV, PP-SIDER</td>
          <td>Binary classification</td>
      </tr>
      <tr>
          <td>Validity</td>
          <td>NC-I2S, MG, FS, RS</td>
          <td>Ratio of valid SMILES outputs</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify exact GPU hardware or training times. Training uses the HuggingFace Transformers library with LoRA, and inference is conducted on the Ohio Supercomputer Center.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/OSU-NLP-Group/LlaSMol">LlaSMol Code</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Training, evaluation, and inference scripts</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/datasets/osunlp/SMolInstruct">SMolInstruct</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>3.3M samples across 14 chemistry tasks</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/osunlp/LlaSMol-Mistral-7B">LlaSMol-Mistral-7B</a></td>
          <td>Model</td>
          <td>CC-BY-4.0</td>
          <td>Best-performing model (LoRA adapters)</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/osunlp/LlaSMol-Galactica-6.7B">LlaSMol-Galactica-6.7B</a></td>
          <td>Model</td>
          <td>CC-BY-4.0</td>
          <td>LoRA adapters for Galactica</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/osunlp/LlaSMol-Llama2-7B">LlaSMol-Llama2-7B</a></td>
          <td>Model</td>
          <td>CC-BY-4.0</td>
          <td>LoRA adapters for Llama 2</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/osunlp/LlaSMol-CodeLlama-7B">LlaSMol-CodeLlama-7B</a></td>
          <td>Model</td>
          <td>CC-BY-4.0</td>
          <td>LoRA adapters for Code Llama</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Yu, B., Baker, F. N., Chen, Z., Ning, X., &amp; Sun, H. (2024). LlaSMol: Advancing large language models for chemistry with a large-scale, comprehensive, high-quality instruction tuning dataset. <em>arXiv preprint arXiv:2402.09391</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{yu2024llamsmol,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{LlaSMol: Advancing Large Language Models for Chemistry with a Large-Scale, Comprehensive, High-Quality Instruction Tuning Dataset}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Yu, Botao and Baker, Frazier N. and Chen, Ziqi and Ning, Xia and Sun, Huan}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2402.09391}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Galactica: A Curated Scientific LLM from Meta AI</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/galactica-large-language-model-for-science/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/galactica-large-language-model-for-science/</guid><description>Galactica is a 120B parameter LLM trained on 106B tokens of curated scientific text, outperforming GPT-3 on scientific knowledge tasks.</description><content:encoded><![CDATA[<h2 id="a-scientific-language-model-trained-on-curated-knowledge">A Scientific Language Model Trained on Curated Knowledge</h2>
<p>Galactica is a <strong>Resource</strong> contribution: a family of decoder-only Transformer language models (125M to 120B parameters) trained on a curated corpus of 106 billion tokens from scientific papers, reference material, knowledge bases, and other sources. The paper also introduces several specialized tokenization schemes for scientific modalities (<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>, amino acid sequences, DNA sequences, LaTeX, citations) and a working memory token (<code>&lt;work&gt;</code>) for step-by-step reasoning. All model weights are open-sourced under the Apache 2.0 license.</p>
<h2 id="information-overload-as-the-motivating-problem">Information Overload as the Motivating Problem</h2>
<p>The volume of scientific literature has grown beyond any individual&rsquo;s capacity to process. An average of 516 papers per day were submitted to arXiv as of May 2022, and databases like <a href="https://en.wikipedia.org/wiki/GenBank">NCBI GenBank</a> contained $1.49 \times 10^{12}$ nucleotide bases as of August 2022. Current search engines point to secondary knowledge layers (Wikipedia, UniProt, PubChem) that require costly human curation, creating a throughput bottleneck.</p>
<p>The authors argue that large language models can serve as a new interface for science by storing, combining, and reasoning about scientific knowledge in weight memory, rather than relying on the traditional store-and-retrieve paradigm. Prior scientific language models (SciBERT, BioLM) were small in scale, and general LLMs (GPT-3, PaLM) trained on uncurated web data that is inefficient for scientific tasks.</p>
<h2 id="curated-corpus-and-specialized-tokenization">Curated Corpus and Specialized Tokenization</h2>
<p>The core innovation has two components: a normative approach to dataset curation and a set of specialized tokens for different scientific modalities.</p>
<h3 id="the-galactica-corpus">The Galactica Corpus</h3>
<p>The training corpus consists of 106 billion tokens with a deliberate focus on quality over quantity:</p>
<table>
  <thead>
      <tr>
          <th>Data Source</th>
          <th>Documents</th>
          <th>Tokens</th>
          <th>Token %</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Papers</td>
          <td>48 million</td>
          <td>88 billion</td>
          <td>83.0%</td>
      </tr>
      <tr>
          <td>Code</td>
          <td>2 million</td>
          <td>7 billion</td>
          <td>6.9%</td>
      </tr>
      <tr>
          <td>Reference Material</td>
          <td>8 million</td>
          <td>7 billion</td>
          <td>6.5%</td>
      </tr>
      <tr>
          <td>Knowledge Bases</td>
          <td>2 million</td>
          <td>2 billion</td>
          <td>2.0%</td>
      </tr>
      <tr>
          <td>Filtered CommonCrawl</td>
          <td>0.9 million</td>
          <td>1 billion</td>
          <td>1.0%</td>
      </tr>
      <tr>
          <td>Prompts</td>
          <td>1.3 million</td>
          <td>0.4 billion</td>
          <td>0.3%</td>
      </tr>
      <tr>
          <td>Other</td>
          <td>0.02 million</td>
          <td>0.2 billion</td>
          <td>0.2%</td>
      </tr>
  </tbody>
</table>
<p>Papers come from arXiv (35B tokens), PMC (23B), <a href="https://en.wikipedia.org/wiki/Semantic_Scholar">Semantic Scholar</a> (18B), and PubMed abstracts (5B), among others. Reference material includes Wikipedia (5B tokens), StackExchange (1B), textbooks, and lecture notes. Knowledge bases include <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> Compound (2M compounds, 1B tokens), <a href="https://en.wikipedia.org/wiki/UniProt">UniProt</a> (552K reviewed Swiss-Prot proteins, 0.6B tokens), and the <a href="https://en.wikipedia.org/wiki/RefSeq">RefSeq</a> Genome.</p>
<p>All data is processed into a common markdown format. Mathematical LaTeX is preserved where available, and papers are citation-processed with title-based identifiers.</p>
<h3 id="specialized-tokenization">Specialized Tokenization</h3>
<p>Galactica introduces several modality-specific tokenization strategies:</p>
<ol>
<li>
<p><strong>Citations</strong>: Wrapped with <code>[START_REF]</code> and <code>[END_REF]</code> tokens using paper titles as identifiers, enabling the model to predict citations in context.</p>
</li>
<li>
<p><strong>Working Memory (<code>&lt;work&gt;</code>)</strong>: Step-by-step reasoning is wrapped in <code>&lt;work&gt;</code> and <code>&lt;/work&gt;</code> tokens that mimic an internal working memory, allowing the model to perform multi-step computation. This differs from chain-of-thought prompting in that it is learned during pre-training rather than elicited through prompt engineering.</p>
</li>
<li>
<p><strong>SMILES</strong>: Wrapped with <code>[START_SMILES]</code>/<code>[END_SMILES]</code> tokens and character-level tokenization.</p>
</li>
<li>
<p><strong>Amino Acid Sequences</strong>: Wrapped with <code>[START_AMINO]</code>/<code>[END_AMINO]</code> tokens with character-level tokenization (one token per residue).</p>
</li>
<li>
<p><strong>DNA Sequences</strong>: Wrapped with <code>[START_DNA]</code>/<code>[END_DNA]</code> tokens with character-level tokenization (one token per nucleotide base).</p>
</li>
<li>
<p><strong>Mathematics</strong>: ASCII operations split into individual characters; digits split into individual tokens.</p>
</li>
</ol>
<h3 id="prompt-pre-training">Prompt Pre-Training</h3>
<p>Rather than using instruction tuning as a separate fine-tuning stage, Galactica includes task-specific prompts (358 million tokens total) directly in pre-training alongside the general corpus. This includes question answering, entity extraction, summarization, dialog, and chemical property prediction prompts. The authors frame this as occupying a middle ground between pure self-supervised pre-training and instruction tuning, providing task signal without degrading general capability.</p>
<h2 id="architecture-training-and-evaluation-setup">Architecture, Training, and Evaluation Setup</h2>
<h3 id="architecture">Architecture</h3>
<p>Galactica uses a standard decoder-only Transformer with several modifications:</p>
<ul>
<li>GeLU activations</li>
<li>2048-token context window</li>
<li>No biases in dense kernels or layer norms</li>
<li>Learned positional embeddings</li>
<li>50K BPE vocabulary</li>
</ul>
<p>Five model sizes were trained:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Parameters</th>
          <th>Layers</th>
          <th>$d_{\text{model}}$</th>
          <th>Heads</th>
          <th>Batch Size</th>
          <th>Max LR</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GAL 125M</td>
          <td>125M</td>
          <td>12</td>
          <td>768</td>
          <td>12</td>
          <td>0.5M</td>
          <td>$6 \times 10^{-4}$</td>
      </tr>
      <tr>
          <td>GAL 1.3B</td>
          <td>1.3B</td>
          <td>24</td>
          <td>2,048</td>
          <td>32</td>
          <td>1.0M</td>
          <td>$2 \times 10^{-4}$</td>
      </tr>
      <tr>
          <td>GAL 6.7B</td>
          <td>6.7B</td>
          <td>32</td>
          <td>4,096</td>
          <td>32</td>
          <td>2.0M</td>
          <td>$1.2 \times 10^{-4}$</td>
      </tr>
      <tr>
          <td>GAL 30B</td>
          <td>30.0B</td>
          <td>48</td>
          <td>7,168</td>
          <td>56</td>
          <td>2.0M</td>
          <td>$1 \times 10^{-4}$</td>
      </tr>
      <tr>
          <td>GAL 120B</td>
          <td>120.0B</td>
          <td>96</td>
          <td>10,240</td>
          <td>80</td>
          <td>2.0M</td>
          <td>$0.7 \times 10^{-5}$</td>
      </tr>
  </tbody>
</table>
<p>Training used AdamW with $\beta_1 = 0.9$, $\beta_2 = 0.95$, weight decay of 0.1, gradient clipping at 1.0, and linear learning rate decay to 10% of peak value. Dropout and attention dropout were set to $p = 0.1$.</p>
<h3 id="training-on-repeated-tokens">Training on Repeated Tokens</h3>
<p>Models were trained for 450 billion tokens, approximately 4.25 epochs of the corpus. Validation loss continued to fall through four epochs for all model sizes, with the 120B model only beginning to overfit at the start of the fifth epoch. This is notable because it challenges the prevailing view that repeated tokens are harmful for LLM training. Performance on out-of-domain BIG-bench tasks also continued to improve through training, suggesting no overfitting on downstream generalization.</p>
<h3 id="key-evaluation-results">Key Evaluation Results</h3>
<p><strong>Knowledge Probes</strong>: On LaTeX equation prediction across 434 equations from chemistry, physics, mathematics, statistics, and economics, GAL 120B achieved 68.2% accuracy versus GPT-3&rsquo;s 49.0% (zero-shot). On chemical reactions, GAL 120B scored 43.1% versus GPT-3&rsquo;s 35.1%.</p>
<p><strong>Mathematical Reasoning</strong>: With the <code>&lt;work&gt;</code> token, GAL 120B achieved 41.3% on mathematical MMLU (average across abstract algebra, elementary, high school, college math, and formal logic), compared to Chinchilla&rsquo;s 35.7% (5-shot). On the MATH benchmark, GAL 120B scored 20.4% (5-shot chain-of-thought) versus PaLM 540B&rsquo;s 8.8%.</p>
<p><strong>Scientific QA</strong>: Galactica set state-of-the-art results on PubMedQA (77.6%) and MedMCQA dev (52.9%), outperforming prior fine-tuned models (72.2% and 41.0% respectively).</p>
<p><strong>Citation Prediction</strong>: GAL 120B achieved 51.9% accuracy on PWC Citations and 69.1% on Extended Citations, outperforming both sparse (ElasticSearch) and dense (Contriever) retrieval baselines.</p>
<p><strong>BIG-bench (57 tasks)</strong>: Despite training only on scientific data, GAL 120B (48.7% weighted accuracy) outperformed OPT 175B (43.4%) and BLOOM 176B (42.6%) on primarily non-scientific tasks.</p>
<p><strong><a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> Classification</strong>: Using SMILES in natural language prompts with weak supervision, GAL 120B achieved an average ROC-AUC of 0.690 across six MoleculeNet classification benchmarks (BACE, BBBP, ClinTox, HIV, SIDER, Tox21). This lagged the specialist Uni-Mol model (0.770), which uses 3D molecular information and 10x more molecules.</p>
<p><strong><a href="https://en.wikipedia.org/wiki/IUPAC_nomenclature_of_organic_chemistry">IUPAC</a> Name Prediction</strong>: GAL 120B achieved 39.2% accuracy on predicting IUPAC names from SMILES in a self-supervised setting, with attention visualization showing the model attends to chemically relevant functional groups (e.g., attending to the $\text{-NH}_2$ group when predicting &ldquo;amino&rdquo;).</p>
<p><strong>Protein Function Prediction</strong>: GAL 120B achieved a ROUGE-L of 0.252 on generating free-form protein function descriptions from amino acid sequences, and an $F_1$ of 48.7% on protein keyword prediction from the UniProt general validation set.</p>
<p><strong>Bias and Toxicity</strong>: On CrowS-Pairs, GAL 120B scored 60.5% (closer to ideal 50%) versus OPT 175B&rsquo;s 69.5%. On StereoSet, GAL 120B achieved an ICAT score of 65.6 versus OPT&rsquo;s 60.0 and GPT-3&rsquo;s 60.8. Toxicity rates on RealToxicityPrompts were substantially lower than comparison models.</p>
<h2 id="findings-limitations-and-future-directions">Findings, Limitations, and Future Directions</h2>
<h3 id="key-findings">Key Findings</h3>
<ol>
<li>
<p><strong>Curated data enables repeated training</strong>: The curated scientific corpus allows training for multiple epochs without overfitting, contrary to prevailing assumptions about repeated token degradation.</p>
</li>
<li>
<p><strong>Scientific LLMs generalize beyond science</strong>: Despite training only on scientific text, Galactica outperforms general LLMs on non-scientific BIG-bench tasks, suggesting data quality matters more than data breadth.</p>
</li>
<li>
<p><strong>Weight memory can outperform retrieval</strong>: For citation prediction, Galactica&rsquo;s weight memory outperforms traditional sparse and dense retrieval methods, demonstrating the context-associative power of language models.</p>
</li>
<li>
<p><strong>Multi-modal learning via text</strong>: SMILES and protein sequences can be learned alongside natural language in a single model, and the model attends to chemically interpretable features.</p>
</li>
</ol>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge several limitations:</p>
<ul>
<li><strong>Corpus constraints</strong>: Restricted to open-access papers; much scientific knowledge in closed-access papers and textbooks is excluded. Only 2M of 110M PubChem compounds and 0.5M of 227M UniProt sequences were included.</li>
<li><strong>Corpus vs. prompt effects</strong>: The paper does not disentangle whether performance gains come from the scientific corpus or from the prompt pre-training strategy.</li>
<li><strong>Citation bias</strong>: The model still shows bias toward predicting more popular papers, though this decreases with scale.</li>
<li><strong>No geometry</strong>: SMILES-based representations lack 3D geometric information, limiting chemical understanding.</li>
<li><strong>Hallucination</strong>: Title-based citation identifiers are more prone to hallucination at smaller scales, though accuracy improves with scale.</li>
<li><strong>No instruction tuning comparison</strong>: The paper does not compare prompt pre-training against instruction tuning as a follow-up step.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The paper identifies retrieval augmentation, extending to images, larger context windows, mixture-of-denoising training objectives, and more diverse <code>&lt;work&gt;</code> reasoning examples as promising directions.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>Galactica Corpus</td>
          <td>106B tokens</td>
          <td>Papers (83%), code (6.9%), reference material (6.5%), knowledge bases (2%), CommonCrawl (1%), prompts (0.3%)</td>
      </tr>
      <tr>
          <td>Training (Molecules)</td>
          <td>PubChem Compound subset</td>
          <td>2M compounds (of 110M available)</td>
          <td>Character-level SMILES tokenization</td>
      </tr>
      <tr>
          <td>Training (Proteins)</td>
          <td>Swiss-Prot (UniProt)</td>
          <td>552K reviewed sequences (of 227M available)</td>
          <td>Character-level amino acid tokenization</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>LaTeX Equations</td>
          <td>434 equations</td>
          <td>Chemistry, physics, math, stats, economics</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>MMLU, MATH</td>
          <td>Standard benchmarks</td>
          <td>Out-of-domain evaluation</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>PubMedQA, MedMCQA, BioASQ</td>
          <td>Standard biomedical QA</td>
          <td>In-domain (training prompts included)</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>MoleculeNet (6 tasks)</td>
          <td>Standard molecular benchmarks</td>
          <td>BACE, BBBP, ClinTox, HIV, SIDER, Tox21</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>BIG-bench (57 tasks)</td>
          <td>Standard NLP benchmark</td>
          <td>Out-of-domain, non-scientific</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Decoder-only Transformer with GeLU activations, no biases</li>
<li>AdamW optimizer: $\beta_1 = 0.9$, $\beta_2 = 0.95$, weight decay 0.1</li>
<li>Gradient clipping at global norm 1.0</li>
<li>Linear LR decay to 10% of peak</li>
<li>Dropout: $p = 0.1$ (attention and residual)</li>
<li><a href="https://en.wikipedia.org/wiki/Byte-pair_encoding">BPE</a> vocabulary: 50K tokens from 2% corpus sample</li>
<li>Training: 450B tokens (~4.25 epochs)</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/paperswithcode/galai">Galactica models (galai)</a></td>
          <td>Code + Model</td>
          <td>Apache-2.0</td>
          <td>Official implementation with 125M, 1.3B, 6.7B, 30B, 120B checkpoints</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>GAL 120B</th>
          <th>Best Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>LaTeX Equations (zero-shot)</td>
          <td>68.2%</td>
          <td>GPT-3: 49.0%</td>
          <td>434 equations across 5 domains</td>
      </tr>
      <tr>
          <td>Math MMLU (<code>&lt;work&gt;</code>)</td>
          <td>41.3%</td>
          <td>Chinchilla (5-shot): 35.7%</td>
          <td>Average over 5 math subjects</td>
      </tr>
      <tr>
          <td>MATH (5-shot CoT)</td>
          <td>20.4%</td>
          <td>PaLM 540B: 8.8%</td>
          <td>Minerva 540B (fine-tuned): 33.6%</td>
      </tr>
      <tr>
          <td>PubMedQA</td>
          <td>77.6%</td>
          <td>Prior SOTA: 72.2%</td>
          <td>In-domain</td>
      </tr>
      <tr>
          <td>MedMCQA dev</td>
          <td>52.9%</td>
          <td>Prior SOTA: 41.0%</td>
          <td>In-domain</td>
      </tr>
      <tr>
          <td>BIG-bench (weighted)</td>
          <td>48.7%</td>
          <td>OPT 175B: 43.4%</td>
          <td>57 non-scientific tasks</td>
      </tr>
      <tr>
          <td>MoleculeNet ROC-AUC (avg)</td>
          <td>0.690</td>
          <td>Uni-Mol (3D): 0.770</td>
          <td>Weak supervision vs. direct fine-tuning</td>
      </tr>
      <tr>
          <td>CrowS-Pairs (lower = less biased)</td>
          <td>60.5%</td>
          <td>OPT 175B: 69.5%</td>
          <td>Ideal: 50%</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>120B model training: 128 NVIDIA A100 80GB nodes</li>
<li>120B model inference: single NVIDIA A100 node</li>
<li>Training library: metaseq (Meta AI)</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Taylor, R., Kardas, M., Cucurull, G., Scialom, T., Hartshorn, A., Saravia, E., Poulton, A., Kerkez, V., &amp; Stojnic, R. (2022). Galactica: A Large Language Model for Science. <em>arXiv preprint arXiv:2211.09085</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{taylor2022galactica,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Galactica: A Large Language Model for Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Taylor, Ross and Kardas, Marcin and Cucurull, Guillem and Scialom, Thomas and Hartshorn, Anthony and Saravia, Elvis and Poulton, Andrew and Kerkez, Viktor and Stojnic, Robert}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2211.09085}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.48550/arxiv.2211.09085}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Fine-Tuning GPT-3 for Predictive Chemistry Tasks</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/leveraging-llms-predictive-chemistry/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/leveraging-llms-predictive-chemistry/</guid><description>Fine-tuned GPT-3 matches or outperforms specialized ML models on molecular, materials, and reaction property prediction, especially in low-data regimes.</description><content:encoded><![CDATA[<h2 id="gpt-3-as-a-general-purpose-chemistry-predictor">GPT-3 as a General-Purpose Chemistry Predictor</h2>
<p>This is an <strong>Empirical</strong> paper that systematically benchmarks fine-tuned GPT-3 against dedicated machine learning models across 15 chemistry and materials science prediction tasks. The primary contribution is demonstrating that a general-purpose large language model, with no chemistry-specific architecture or featurization, can match or outperform specialized ML approaches, particularly when training data is limited. The paper also demonstrates inverse molecular design through simple prompt inversion.</p>
<h2 id="why-general-purpose-llms-for-chemistry">Why General-Purpose LLMs for Chemistry</h2>
<p>Machine learning in chemistry typically requires domain-specific feature engineering: molecular fingerprints, graph neural network architectures, or hand-crafted descriptors tailored to each application. Developing these approaches demands specialized expertise and significant effort for each new problem. The small datasets common in experimental chemistry further complicate matters, as many sophisticated ML approaches require large training sets to learn meaningful representations.</p>
<p>Large language models like GPT-3, trained on vast internet text corpora, had shown surprising capability at tasks they were not explicitly trained for. The key question motivating this work was whether these general-purpose models could also answer scientific questions for which we lack answers, given that most chemistry problems can be represented in text form. For example: &ldquo;If I change the metal in my <a href="https://en.wikipedia.org/wiki/Metal%E2%80%93organic_framework">metal-organic framework</a>, will it be stable in water?&rdquo;</p>
<p>Prior chemical language models (e.g., <a href="/notes/chemistry/molecular-design/property-prediction/transformer-cnn-qsar-modeling/">Transformer-CNN</a>, <a href="/notes/chemistry/molecular-design/property-prediction/regression-transformer/">Regression Transformer</a>, <a href="/notes/chemistry/molecular-representations/encoders/selformer/">SELFormer</a>) were pre-trained on chemistry-specific corpora. In contrast, this work investigates models trained primarily on general internet text, examining whether the implicit chemical knowledge encoded during pre-training, combined with task-specific fine-tuning, can substitute for explicit chemical featurization.</p>
<h2 id="language-interfaced-fine-tuning-for-chemistry">Language-Interfaced Fine-Tuning for Chemistry</h2>
<p>The core innovation is &ldquo;language-interfaced fine-tuning&rdquo; (LIFT): reformulating chemistry prediction tasks as natural language question-answering. Training examples take the form of question-completion pairs, where questions describe the chemical system in text and completions provide the target property. For example:</p>
<ul>
<li><strong>Classification</strong>: &ldquo;What is the phase of Co1Cu1Fe1Ni1V1?&rdquo; with completion &ldquo;0&rdquo; (multi-phase)</li>
<li><strong>Regression</strong>: Property values are rounded to a fixed precision, converting continuous prediction into a text generation problem</li>
<li><strong>Inverse design</strong>: Questions and completions are simply swapped, asking &ldquo;What is a molecule with property X?&rdquo; and expecting a <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> string as completion</li>
</ul>
<p>The fine-tuning uses OpenAI&rsquo;s API with the smallest <code>ada</code> variant of GPT-3, with uniform hyperparameters across all tasks (8 epochs, learning rate multiplier of 0.02). No optimization of prompt structure, tokenization, or training schedule was performed, making the approach deliberately simple.</p>
<p>For regression, since language models generate discrete tokens rather than continuous values, the authors round target values to a fixed precision (e.g., 1% for Henry coefficients). This converts regression into a form of classification over numeric strings, with the assumption that GPT-3 can interpolate between these discretized values.</p>
<p>The approach also extends to open-source models. The authors demonstrate that GPT-J-6B can be fine-tuned using parameter-efficient techniques (LoRA, 8-bit quantization) on consumer hardware, and provide the <code>chemlift</code> Python package for this purpose.</p>
<h2 id="benchmarks-across-molecules-materials-and-reactions">Benchmarks Across Molecules, Materials, and Reactions</h2>
<h3 id="datasets-and-tasks">Datasets and Tasks</h3>
<p>The evaluation spans three chemical domains with 15 total benchmarks:</p>
<p><strong>Molecules:</strong></p>
<ul>
<li><a href="https://en.wikipedia.org/wiki/Photoswitch">Photoswitch</a> transition wavelength prediction (2022)</li>
<li>Free energy of solvation (FreeSolv, 2014)</li>
<li>Aqueous solubility (ESOL, 2004)</li>
<li>Lipophilicity (ChEMBL, 2012)</li>
<li><a href="https://en.wikipedia.org/wiki/HOMO_and_LUMO">HOMO-LUMO gap</a> (QMugs, 2022)</li>
<li><a href="https://en.wikipedia.org/wiki/Organic_solar_cell">Organic photovoltaic</a> power conversion efficiency (2018)</li>
</ul>
<p><strong>Materials:</strong></p>
<ul>
<li>Coarse-grained surfactant adsorption free energy (2021)</li>
<li>CO2 and CH4 <a href="https://en.wikipedia.org/wiki/Henry%27s_law">Henry coefficients</a> in MOFs (2020)</li>
<li>MOF heat capacity (2022)</li>
<li><a href="https://en.wikipedia.org/wiki/High-entropy_alloy">High-entropy alloy</a> phase prediction (2020)</li>
<li><a href="https://en.wikipedia.org/wiki/Amorphous_metal">Bulk metallic glass</a> formation ability (2006)</li>
<li>Metallic behavior prediction (2018)</li>
</ul>
<p><strong>Reactions:</strong></p>
<ul>
<li>C-N cross-coupling yield (<a href="https://en.wikipedia.org/wiki/Buchwald%E2%80%93Hartwig_amination">Buchwald-Hartwig</a>, 2018)</li>
<li>C-C cross-coupling yield (<a href="https://en.wikipedia.org/wiki/Suzuki_reaction">Suzuki</a>, 2022)</li>
</ul>
<h3 id="baselines">Baselines</h3>
<p>The baselines include both traditional ML and deep learning approaches:</p>
<ul>
<li><strong>Non-DL</strong>: XGBoost with molecular descriptors/fragprints, Gaussian Process Regression (GPR), random forests, n-Gram models, Automatminer, differential reaction fingerprints (DRFP)</li>
<li><strong>Deep learning</strong>: MolCLR, ModNet, CrabNet, TabPFN</li>
</ul>
<h3 id="data-efficiency-analysis">Data Efficiency Analysis</h3>
<p>To compare data efficiency, the authors fit power law curves to learning curves for all models and measure the &ldquo;data efficiency factor&rdquo;: how much more (or fewer) data the best baseline needs to match GPT-3&rsquo;s performance in the low-data regime.</p>
<table>
  <thead>
      <tr>
          <th>Domain</th>
          <th>Benchmark</th>
          <th>Data Efficiency vs. Non-DL</th>
          <th>vs. DL Baseline</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Molecules</td>
          <td>Photoswitch wavelength</td>
          <td>1.1x (n-Gram)</td>
          <td>1.2x (TabPFN)</td>
      </tr>
      <tr>
          <td>Molecules</td>
          <td>Solvation free energy</td>
          <td>3.1x (GPR)</td>
          <td>1.3x (TabPFN)</td>
      </tr>
      <tr>
          <td>Molecules</td>
          <td>Solubility</td>
          <td>1.0x (XGBoost)</td>
          <td>0.002x (MolCLR)</td>
      </tr>
      <tr>
          <td>Molecules</td>
          <td>Lipophilicity</td>
          <td>3.43x (GPR)</td>
          <td>0.97x (TabPFN)</td>
      </tr>
      <tr>
          <td>Molecules</td>
          <td>HOMO-LUMO gap</td>
          <td>4.3x (XGBoost)</td>
          <td>0.62x (TabPFN)</td>
      </tr>
      <tr>
          <td>Materials</td>
          <td>HEA phase</td>
          <td>24x (RF)</td>
          <td>9.0x (CrabNet)</td>
      </tr>
      <tr>
          <td>Materials</td>
          <td>CO2 Henry coeff.</td>
          <td>0.40x (XGBoost)</td>
          <td>12x (TabPFN)</td>
      </tr>
      <tr>
          <td>Reactions</td>
          <td>C-N cross-coupling</td>
          <td>2.9x (DRFP)</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p>Values &gt;1 indicate GPT-3 is more data-efficient. For the HEA phase prediction task, GPT-3 achieved comparable accuracy to a random forest model trained on 1,126 data points using only about 50 training examples.</p>
<h3 id="representation-sensitivity">Representation Sensitivity</h3>
<p>An important finding is that GPT-3 performs well regardless of molecular representation format. The authors tested IUPAC names, SMILES, and <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>, finding good results across all representations. IUPAC names often produced the best performance, which is notable because it makes the approach accessible to non-specialists who can simply use chemical names rather than learning specialized encodings.</p>
<h3 id="inverse-design">Inverse Design</h3>
<p>For inverse design, the authors fine-tuned GPT-3 with reversed question-completion pairs. On photoswitches:</p>
<ul>
<li>Generated molecules include both training set members and novel structures (some not in PubChem)</li>
<li>Transition wavelengths matched target values within about 10% mean absolute percentage error (validated using the GPR model from Griffiths et al.)</li>
<li>A temperature parameter controls the diversity-validity tradeoff: low temperatures produce training set copies, high temperatures produce diverse but potentially invalid structures</li>
<li>Across all temperatures, generated molecules showed low synthetic accessibility (SA) scores, suggesting synthesizability</li>
</ul>
<p>The authors also demonstrated iterative inverse design for HOMO-LUMO gap optimization: starting from QMugs data, they iteratively fine-tuned GPT-3 to generate molecules with progressively larger bandgaps (&gt;5 eV), successfully shifting the distribution over four generations. This worked even when extrapolating beyond the training distribution (e.g., training only on molecules with gaps &lt;3.5 eV, then generating molecules with gaps &gt;4.0 eV).</p>
<h3 id="coarse-grained-polymer-design">Coarse-Grained Polymer Design</h3>
<p>A striking test involved coarse-grained dispersant polymers with four monomer types and chain lengths of 16-48 units. GPT-3 had no prior knowledge of these abstract representations, yet it outperformed dedicated models for adsorption free energy prediction and successfully performed inverse design, generating monomer sequences with a mean percentage error of about 22% for the desired property.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="key-findings">Key Findings</h3>
<ol>
<li>
<p><strong>Low-data advantage</strong>: Fine-tuned GPT-3 consistently shows the largest advantages over conventional ML in low-data regimes (tens to hundreds of data points), which is precisely where experimental chemistry datasets typically fall.</p>
</li>
<li>
<p><strong>Representation agnostic</strong>: The model works with IUPAC names, SMILES, SELFIES, and even invented abstract representations, removing the need for chemistry-specific tokenization.</p>
</li>
<li>
<p><strong>No feature engineering</strong>: The approach requires no domain-specific descriptors, fingerprints, or architectural modifications, making it accessible to researchers without ML expertise.</p>
</li>
<li>
<p><strong>Bidirectional design</strong>: Inverse design is achieved by simply reversing the question format, with no architectural changes or separate generative model needed.</p>
</li>
<li>
<p><strong>Extrapolation capability</strong>: The model can generate molecules with properties outside the training distribution, as demonstrated by the HOMO-LUMO gap extrapolation experiments.</p>
</li>
</ol>
<h3 id="limitations">Limitations</h3>
<ul>
<li>In the <strong>high-data regime</strong>, conventional ML models with chemistry-specific features often catch up to or surpass GPT-3, as the inductive biases encoded in GPT-3 become less necessary with sufficient data.</li>
<li><strong>Regression</strong> is inherently limited by the discretization of continuous values into tokens. This requires more data than classification and introduces quantization error.</li>
<li>The approach relies on the <strong>OpenAI API</strong>, introducing cost and reproducibility concerns (model versions may change). The authors partially address this by providing open-source alternatives via <code>chemlift</code>.</li>
<li>The authors acknowledge that <strong>identified correlations may not represent causal relationships</strong>. GPT-3 finding predictive patterns does not guarantee that the patterns are chemically meaningful.</li>
<li>No optimization of prompts, tokenization, or hyperparameters was performed, suggesting room for improvement but also making it difficult to assess the ceiling of this approach.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>All datasets are publicly available and were obtained from published benchmarks.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Classification</td>
          <td>HEA phase (Pei et al.)</td>
          <td>1,252 alloys</td>
          <td>Single-phase vs. multi-phase</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>FreeSolv</td>
          <td>643 molecules</td>
          <td>Hydration free energies</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>ESOL</td>
          <td>1,128 molecules</td>
          <td>Aqueous solubility</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>QMugs</td>
          <td>665,000 molecules</td>
          <td>HOMO-LUMO gaps via GFN2-xTB</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>Lipophilicity (ChEMBL)</td>
          <td>Varies</td>
          <td>LogP classification</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>OPV PCE</td>
          <td>Varies</td>
          <td>Organic photovoltaic efficiency</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>MOF Henry coefficients</td>
          <td>Varies</td>
          <td>CO2/CH4 adsorption</td>
      </tr>
      <tr>
          <td>Inverse design</td>
          <td>Photoswitches (Griffiths et al.)</td>
          <td>392 molecules</td>
          <td>Transition wavelengths</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Fine-tuning via OpenAI API: 8 epochs, learning rate multiplier 0.02</li>
<li>GPT-3 <code>ada</code> variant (smallest model) used for all main results</li>
<li>In-context learning also tested with larger GPT-3 models and GPT-4</li>
<li>Open-source alternative: GPT-J-6B with LoRA + 8-bit quantization</li>
<li>Learning curves fit to power laws $-a \exp(-bx + c)$ for data efficiency comparison</li>
<li>Validity checked using RDKit via <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a>&rsquo;s <code>is\_valid</code> method</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>GPT-3 ada (OpenAI API, proprietary)</li>
<li>GPT-J-6B (open-source, fine-tunable on consumer hardware)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Accuracy</td>
          <td>HEA phase</td>
          <td>Classification</td>
      </tr>
      <tr>
          <td>$F_1$ macro</td>
          <td>All classification tasks</td>
          <td>Class-balanced</td>
      </tr>
      <tr>
          <td>Cohen&rsquo;s $\kappa$</td>
          <td>Classification</td>
          <td>Used for learning curve thresholds</td>
      </tr>
      <tr>
          <td>MAE / MAPE</td>
          <td>Regression, inverse design</td>
          <td>Property prediction accuracy</td>
      </tr>
      <tr>
          <td>Validity rate</td>
          <td>Inverse design</td>
          <td>Fraction of parseable SMILES</td>
      </tr>
      <tr>
          <td>Frechet ChemNet distance</td>
          <td>Inverse design</td>
          <td>Distribution similarity</td>
      </tr>
      <tr>
          <td>SA score</td>
          <td>Inverse design</td>
          <td>Synthetic accessibility</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Fine-tuning via OpenAI API (cloud compute, not user-specified)</li>
<li>Open-source experiments: consumer GPU hardware with 8-bit quantization</li>
<li>Quantum chemistry validation: GFN2-xTB for HOMO-LUMO calculations</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/kjappelbaum/gptchem">gptchem</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>All experiments with OpenAI API</td>
      </tr>
      <tr>
          <td><a href="https://github.com/lamalab-org/chemlift">chemlift</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Open-source LLM fine-tuning support</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.7806672">Zenodo (gptchem)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Archived release</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.10233422">Zenodo (chemlift)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Archived release</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Jablonka, K. M., Schwaller, P., Ortega-Guerrero, A., &amp; Smit, B. (2024). Leveraging large language models for predictive chemistry. <em>Nature Machine Intelligence</em>, 6(2), 161-169. <a href="https://doi.org/10.1038/s42256-023-00788-1">https://doi.org/10.1038/s42256-023-00788-1</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{jablonka2024leveraging,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Leveraging large language models for predictive chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Jablonka, Kevin Maik and Schwaller, Philippe and Ortega-Guerrero, Andres and Smit, Berend}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Machine Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{6}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{161--169}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s42256-023-00788-1}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DrugAssist: Interactive LLM Molecule Optimization</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/drugassist-llm-molecule-optimization/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/drugassist-llm-molecule-optimization/</guid><description>DrugAssist fine-tunes Llama2-7B-Chat for interactive molecule optimization via natural language dialogue, releasing the MolOpt-Instructions dataset.</description><content:encoded><![CDATA[<h2 id="an-interactive-llm-for-molecule-optimization">An Interactive LLM for Molecule Optimization</h2>
<p>DrugAssist is a <strong>Method</strong> paper that proposes an interactive molecule optimization model built by fine-tuning Llama2-7B-Chat with LoRA on a newly constructed instruction dataset. The primary contribution is twofold: (1) the MolOpt-Instructions dataset containing over one million molecule pairs with six molecular properties and three optimization task categories, and (2) a dialogue-based molecule optimization system that allows domain experts to iteratively refine molecular modifications through multi-turn natural language conversations.</p>
<h2 id="why-interactive-molecule-optimization-matters">Why Interactive Molecule Optimization Matters</h2>
<p>Molecule optimization is a core step in the drug discovery pipeline, where lead compounds must be modified to improve specific pharmacological properties while maintaining structural similarity. Existing approaches fall into sequence-based methods (treating <a href="/notes/chemistry/molecular-representations/">SMILES</a> optimization as machine translation) and graph-based methods (graph-to-graph translation), but they share a critical limitation: they are non-interactive. These models learn patterns from chemical structure data without incorporating expert feedback.</p>
<p>The drug discovery process is inherently iterative and requires integrating domain expertise. Medicinal chemists typically refine candidates through repeated cycles of suggestion, evaluation, and adjustment. Prior LLM-based approaches like <a href="/notes/chemistry/llm-applications/chatdrug-conversational-drug-editing/">ChatDrug</a> relied on prompt engineering with general-purpose models (GPT-3.5-turbo) rather than fine-tuning, limiting their optimization accuracy. Additionally, most existing molecule optimization benchmarks focus on single-property optimization with vague objectives (e.g., &ldquo;maximize QED&rdquo;), while real-world drug design requires optimizing property values within specific ranges across multiple properties simultaneously.</p>
<h2 id="instruction-based-fine-tuning-with-molopt-instructions">Instruction-Based Fine-Tuning with MolOpt-Instructions</h2>
<p>The core innovation has two components: the MolOpt-Instructions dataset construction pipeline and the multi-task instruction tuning strategy.</p>
<h3 id="dataset-construction">Dataset Construction</h3>
<p>MolOpt-Instructions is built from one million molecules randomly sampled from the <a href="/notes/chemistry/datasets/zinc-22/">ZINC database</a>. The construction workflow uses mmpdb (an open-source Matched Molecular Pair platform) to generate structurally similar molecule pairs through <a href="https://en.wikipedia.org/wiki/Matched_molecular_pair_analysis">Matched Molecular Pair Analysis (MMPA)</a>. Pairs are filtered to satisfy two criteria: <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> greater than 0.65 and <a href="https://en.wikipedia.org/wiki/Partition_coefficient">logP</a> difference greater than 2.5. Property values for six properties (Solubility, BBBP, <a href="https://en.wikipedia.org/wiki/KCNH2">hERG</a> inhibition, QED, hydrogen bond donor count, and hydrogen bond acceptor count) are computed using Tencent&rsquo;s iDrug platform. The final dataset contains 1,029,949 unique pairs covering 1,595,839 unique molecules, with mean similarity of 0.69 and mean logP difference of 2.82.</p>
<p>Three categories of optimization tasks are defined:</p>
<ul>
<li><strong>Loose</strong>: Increase or decrease a given property value (no threshold)</li>
<li><strong>Strict</strong>: Increase or decrease by at least a specified threshold</li>
<li><strong>Range</strong>: Optimize the property value to fall within a given interval</li>
</ul>
<p>Instruction templates are generated with ChatGPT assistance and manually refined. To ensure balance, source and target molecules are swapped for some pairs to maintain a roughly 1:1 ratio of property increases to decreases.</p>
<p>Murcko scaffold analysis confirms chemical diversity: the average molecules per scaffold is 2.95, and over 93.7% of scaffolds contain no more than five molecules.</p>
<h3 id="multi-task-instruction-tuning">Multi-Task Instruction Tuning</h3>
<p>The model is fine-tuned on Llama2-7B-Chat using LoRA (rank 64, alpha 128). To prevent catastrophic forgetting of general language capabilities, the training data combines MolOpt-Instructions with the Stanford Alpaca dataset (52k instruction-following examples, replicated 5x to balance the mixture). The training objective minimizes the negative log-likelihood over the response tokens:</p>
<p>$$L(R; \boldsymbol{\theta}) = -\sum_{u_i \in R} \log \Phi(u_i \mid u_{&lt;i}, I)$$</p>
<p>where $I$ is the instruction, $R$ is the response, and $\Phi$ is the model&rsquo;s conditional probability.</p>
<p>Training runs for 10 epochs with batch size 512, using AdamW ($\beta = (0.9, 0.999)$), learning rate 1e-4, 3% warm-up steps with cosine decay, and no weight decay. The data is split 90/5/5 for train/validation/test.</p>
<h2 id="experimental-setup-and-multi-property-optimization-results">Experimental Setup and Multi-Property Optimization Results</h2>
<h3 id="comparison-with-traditional-approaches">Comparison with Traditional Approaches</h3>
<p>DrugAssist is compared against Mol-Seq2Seq and Mol-Transformer (He et al., 2021) on simultaneous Solubility and BBBP optimization with range constraints. The evaluation prompt asks the model to generate an optimized molecule with solubility within a given range and BBBP category changed from one level to another.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Solubility</th>
          <th>BBBP</th>
          <th>Both</th>
          <th>Valid Rate</th>
          <th>Similarity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Mol-Seq2Seq</td>
          <td>0.46</td>
          <td>0.55</td>
          <td>0.35</td>
          <td>0.76</td>
          <td>0.61</td>
      </tr>
      <tr>
          <td>Mol-Transformer</td>
          <td>0.70</td>
          <td>0.78</td>
          <td>0.59</td>
          <td>0.96</td>
          <td>0.70</td>
      </tr>
      <tr>
          <td>DrugAssist</td>
          <td>0.74</td>
          <td>0.80</td>
          <td>0.62</td>
          <td>0.98</td>
          <td>0.69</td>
      </tr>
  </tbody>
</table>
<p>DrugAssist achieves the highest success rates in both single-property and multi-property optimization while maintaining high validity (0.98) and comparable structural similarity (0.69).</p>
<h3 id="comparison-with-llms">Comparison with LLMs</h3>
<p>DrugAssist is compared against Llama2-7B-Chat, GPT-3.5-turbo (via ChatDrug), and BioMedGPT-LM-7B on 16 tasks covering all three optimization categories. These comparisons use multi-turn dialogues following the ChatDrug protocol: if the model&rsquo;s output fails to meet requirements, a database-retrieved molecule meeting the criteria and similar to the model&rsquo;s output is provided as a hint for iterative refinement.</p>
<p>Selected results on single-property tasks (valid ratio / correct ratio, loose/strict):</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Llama2-7B-Chat</th>
          <th>GPT-3.5-turbo</th>
          <th>BioMedGPT-LM</th>
          <th>DrugAssist</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>QED+</td>
          <td>0.17 / 0.16</td>
          <td>0.15 / 0.15</td>
          <td>0.15 / 0.09</td>
          <td>0.76 / 0.63</td>
      </tr>
      <tr>
          <td>Acceptor+</td>
          <td>0.08 / 0.08</td>
          <td>0.04 / 0.06</td>
          <td>0.18 / 0.13</td>
          <td>0.71 / 0.67</td>
      </tr>
      <tr>
          <td>Donor+</td>
          <td>0.15 / 0.08</td>
          <td>0.10 / 0.04</td>
          <td>0.17 / 0.09</td>
          <td>0.72 / 0.76</td>
      </tr>
      <tr>
          <td>Solubility+</td>
          <td>0.36 / 0.20</td>
          <td>0.16 / 0.05</td>
          <td>0.18 / 0.09</td>
          <td>0.80 / 0.41</td>
      </tr>
      <tr>
          <td>BBBP+</td>
          <td>0.19 / 0.14</td>
          <td>0.10 / 0.10</td>
          <td>0.16 / 0.07</td>
          <td>0.82 / 0.61</td>
      </tr>
      <tr>
          <td>hERG-</td>
          <td>0.39 / 0.31</td>
          <td>0.13 / 0.15</td>
          <td>0.13 / 0.12</td>
          <td>0.71 / 0.67</td>
      </tr>
  </tbody>
</table>
<p>Multi-property tasks:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Llama2-7B-Chat</th>
          <th>GPT-3.5-turbo</th>
          <th>BioMedGPT-LM</th>
          <th>DrugAssist</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Sol+ &amp; Acc+</td>
          <td>0.15 / 0.04</td>
          <td>0.09 / 0.02</td>
          <td>0.10 / 0.07</td>
          <td>0.50 / 0.27</td>
      </tr>
      <tr>
          <td>QED+ &amp; BBBP+</td>
          <td>0.14 / 0.09</td>
          <td>0.09 / 0.06</td>
          <td>0.16 / 0.11</td>
          <td>0.65 / 0.41</td>
      </tr>
  </tbody>
</table>
<p>DrugAssist outperforms all baselines across every task. BioMedGPT-LM frequently misunderstands the task, generating guidance text rather than molecules. GPT-3.5-turbo achieves high validity but often outputs the input molecule unchanged.</p>
<h2 id="transferability-iterative-refinement-and-limitations">Transferability, Iterative Refinement, and Limitations</h2>
<h3 id="key-findings">Key Findings</h3>
<p><strong>Zero-shot transferability</strong>: Although DrugAssist trains on single-property optimization data, it successfully handles multi-property optimization requests at inference time. In a case study, the model simultaneously increased both BBBP and QED by at least 0.1 while maintaining structural similarity, without any multi-property training examples.</p>
<p><strong>Few-shot generalization</strong>: DrugAssist optimizes properties not seen during training (e.g., logP) when provided with a few in-context examples of successful optimizations, a capability that traditional sequence-based or graph-based models cannot achieve without retraining.</p>
<p><strong>Iterative optimization</strong>: When an initial optimization fails to meet requirements, DrugAssist can incorporate feedback (a database-retrieved hint molecule) and modify different functional groups in a second attempt to produce a compliant molecule.</p>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge that DrugAssist has a relatively lower success rate on the most challenging task category, strict range-constrained solubility optimization (0.41 success rate under strict criteria vs. 0.80 under loose criteria). The model also relies on iDrug for property prediction of Solubility, BBBP, and hERG inhibition, meaning its optimization quality is bounded by the accuracy of these property predictors. The evaluation uses only 500 test molecules for LLM comparisons, which is a relatively small evaluation set. The paper does not report statistical significance tests or confidence intervals for any results.</p>
<h3 id="future-directions">Future Directions</h3>
<p>The authors plan to improve multimodal data handling to reduce hallucination problems and to further enhance DrugAssist&rsquo;s interactive capabilities for better understanding of user needs and feedback.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>MolOpt-Instructions</td>
          <td>1,029,949 molecule pairs</td>
          <td>Sourced from ZINC via mmpdb; 6 properties</td>
      </tr>
      <tr>
          <td>Training (auxiliary)</td>
          <td>Stanford Alpaca</td>
          <td>52k instructions (5x replicated)</td>
          <td>Mitigates catastrophic forgetting</td>
      </tr>
      <tr>
          <td>Evaluation (traditional)</td>
          <td>From He et al. (2021)</td>
          <td>Not specified</td>
          <td>Multi-property optimization test</td>
      </tr>
      <tr>
          <td>Evaluation (LLM)</td>
          <td>ZINC subset</td>
          <td>500 molecules</td>
          <td>Randomly selected</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Base model</strong>: Llama2-7B-Chat</li>
<li><strong>Fine-tuning</strong>: LoRA with rank 64, alpha 128</li>
<li><strong>Optimizer</strong>: AdamW, $\beta = (0.9, 0.999)$, lr = 1e-4, no weight decay</li>
<li><strong>Schedule</strong>: 3% warm-up, cosine decay</li>
<li><strong>Epochs</strong>: 10</li>
<li><strong>Batch size</strong>: 512</li>
<li><strong>Property calculation</strong>: iDrug (Solubility, BBBP, hERG); RDKit (H-bond donors/acceptors, QED)</li>
<li><strong>Molecular pairs</strong>: mmpdb for Matched Molecular Pair Analysis</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Fine-tuned Llama2-7B-Chat with LoRA adapters</li>
<li>No pre-trained weights released (code and data available)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Success rate</td>
          <td>Fraction of molecules meeting optimization criteria</td>
      </tr>
      <tr>
          <td>Valid rate</td>
          <td>Fraction of generated SMILES that parse as valid molecules</td>
      </tr>
      <tr>
          <td>Similarity</td>
          <td>Tanimoto similarity between input and optimized molecules</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>8 NVIDIA Tesla A100-SXM4-40GB GPUs</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/blazerye/DrugAssist">DrugAssist Code</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Training and inference code</td>
      </tr>
      <tr>
          <td><a href="https://github.com/blazerye/DrugAssist">MolOpt-Instructions</a></td>
          <td>Dataset</td>
          <td>Not specified</td>
          <td>1M+ molecule pairs, 6 properties</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ye, G., Cai, X., Lai, H., Wang, X., Huang, J., Wang, L., Liu, W., &amp; Zeng, X. (2024). DrugAssist: A Large Language Model for Molecule Optimization. <em>Briefings in Bioinformatics</em>, 26(1), bbae693.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ye2024drugassist,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{DrugAssist: A Large Language Model for Molecule Optimization}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ye, Geyan and Cai, Xibao and Lai, Houtim and Wang, Xing and Huang, Junhong and Wang, Longyue and Liu, Wei and Zeng, Xiangxiang}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Briefings in Bioinformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{26}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{bbae693}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1093/bib/bbae693}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Coscientist: Autonomous Chemistry with LLM Agents</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/autonomous-chemical-research-coscientist/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/autonomous-chemical-research-coscientist/</guid><description>Coscientist uses GPT-4 to autonomously design, plan, and execute chemical experiments including Pd-catalysed cross-coupling optimization.</description><content:encoded><![CDATA[<h2 id="an-llm-powered-agent-for-autonomous-chemical-experimentation">An LLM-Powered Agent for Autonomous Chemical Experimentation</h2>
<p>This is a <strong>Method</strong> paper that introduces Coscientist, an AI system driven by GPT-4 that autonomously designs, plans, and performs complex chemical experiments. The primary contribution is a modular multi-LLM agent architecture that integrates internet search, documentation retrieval, code execution, and robotic experimentation APIs into a unified system capable of end-to-end experimental chemistry with minimal human intervention.</p>
<h2 id="bridging-llm-capabilities-and-laboratory-automation">Bridging LLM Capabilities and Laboratory Automation</h2>
<p>Transformer-based large language models had demonstrated strong capabilities in natural language processing, biology, chemistry, and code generation by early 2023. Simultaneously, laboratory automation had progressed with autonomous reaction discovery, automated flow systems, and mobile robotic platforms. However, these two threads remained largely separate: LLMs could reason about chemistry in text, but could not act on that reasoning by controlling physical experiments.</p>
<p>The gap this work addresses is the integration of LLM reasoning with laboratory automation in a closed-loop system. Prior automated chemistry systems relied on traditional optimization algorithms or narrow AI components. The question was whether GPT-4&rsquo;s general reasoning capabilities could be combined with tool access to produce a system that autonomously designs experiments, writes instrument code, executes reactions, and interprets results, all from natural language prompts.</p>
<p>This work was developed independently and in parallel with other autonomous agent efforts (AutoGPT, BabyAGI, LangChain), with <a href="/notes/chemistry/llm-applications/chemcrow-augmenting-llms-chemistry-tools/">ChemCrow</a> serving as another chemistry-specific example.</p>
<h2 id="a-modular-multi-llm-architecture-with-tool-access">A Modular Multi-LLM Architecture with Tool Access</h2>
<p>The core innovation is Coscientist&rsquo;s modular architecture, centered on a &ldquo;Planner&rdquo; module (a GPT-4 chat completion instance) that orchestrates four command types:</p>
<ol>
<li><strong>GOOGLE</strong>: A Web Searcher module (itself an LLM) that transforms prompts into search queries, browses results, and funnels answers back to the Planner.</li>
<li><strong>PYTHON</strong>: A Code Execution module running in an isolated Docker container for calculations and data analysis, with no LLM dependency.</li>
<li><strong>DOCUMENTATION</strong>: A Docs Searcher module that retrieves and summarizes technical documentation (e.g., Opentrons Python API, Emerald Cloud Lab Symbolic Lab Language) using ada embeddings and distance-based vector search.</li>
<li><strong>EXPERIMENT</strong>: An Automation module that executes generated code on laboratory hardware or provides synthetic procedures.</li>
</ol>
<p>The system prompts are engineered in a modular fashion, with the Planner receiving initial user input and command outputs as messages. The Planner can iteratively call commands, fix software errors, and refine its approach. This design allows natural language instructions (e.g., &ldquo;perform multiple Suzuki reactions&rdquo;) to be translated into complete experimental protocols.</p>
<p>For documentation retrieval, all sections of the OT-2 API documentation were embedded using OpenAI&rsquo;s ada model, and relevant sections are retrieved via cosine similarity search. For the Emerald Cloud Lab, the system learned to program in a symbolic lab language (SLL) that was completely unknown to GPT-4 at training time, demonstrating effective in-context learning from supplied documentation.</p>
<h2 id="six-tasks-demonstrating-autonomous-chemistry-capabilities">Six Tasks Demonstrating Autonomous Chemistry Capabilities</h2>
<p>The paper evaluates Coscientist across six tasks of increasing complexity.</p>
<h3 id="task-1-chemical-synthesis-planning">Task 1: Chemical Synthesis Planning</h3>
<p>A benchmark of seven compounds was used to compare synthesis planning across models (GPT-4, GPT-3.5, Claude 1.3, Falcon-40B-Instruct) with and without web search. Outputs were scored on a 1-5 scale:</p>
<table>
  <thead>
      <tr>
          <th>Score</th>
          <th>Meaning</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>5</td>
          <td>Very detailed and chemically accurate procedure</td>
      </tr>
      <tr>
          <td>4</td>
          <td>Detailed and accurate but without reagent quantities</td>
      </tr>
      <tr>
          <td>3</td>
          <td>Correct chemistry but no step-by-step procedure</td>
      </tr>
      <tr>
          <td>2</td>
          <td>Extremely vague or unfeasible</td>
      </tr>
      <tr>
          <td>1</td>
          <td>Incorrect or failure to follow instructions</td>
      </tr>
  </tbody>
</table>
<p>The GPT-4-powered Web Searcher achieved maximum scores for acetaminophen, aspirin, nitroaniline, and phenolphthalein. It was the only approach to achieve acceptable scores (3+) for ibuprofen, which all non-browsing models synthesized incorrectly. These results highlight the importance of grounding LLMs to avoid hallucinations.</p>
<h3 id="task-2-documentation-search">Task 2: Documentation Search</h3>
<p>The system correctly identified relevant ECL functions from documentation and generated valid SLL code that was successfully executed at ECL, including an <a href="https://en.wikipedia.org/wiki/High-performance_liquid_chromatography">HPLC</a> experiment on a caffeine standard sample.</p>
<h3 id="task-3-cloud-laboratory-execution">Task 3: Cloud Laboratory Execution</h3>
<p>Using prompt-to-function and prompt-to-SLL pipelines, Coscientist generated executable code for the Emerald Cloud Lab. It also searched a catalogue of 1,110 model samples to identify relevant stock solutions from simple search terms.</p>
<h3 id="task-4-liquid-handler-control">Task 4: Liquid Handler Control</h3>
<p>Using the Opentrons OT-2, Coscientist translated natural language prompts (e.g., &ldquo;colour every other line with one colour of your choice,&rdquo; &ldquo;draw a red cross&rdquo;) into accurate liquid handling protocols.</p>
<h3 id="task-5-integrated-multi-module-experiment">Task 5: Integrated Multi-Module Experiment</h3>
<p>The most complex demonstration combined web search, code execution, documentation retrieval, and hardware control to design and execute <a href="https://en.wikipedia.org/wiki/Suzuki_reaction">Suzuki-Miyaura</a> and <a href="https://en.wikipedia.org/wiki/Sonogashira_coupling">Sonogashira</a> <a href="https://en.wikipedia.org/wiki/Cross-coupling_reaction">cross-coupling</a> reactions. Coscientist:</p>
<ul>
<li>Searched the internet for reaction conditions and stoichiometries</li>
<li>Selected correct coupling partners (never misassigning <a href="https://en.wikipedia.org/wiki/Phenylboronic_acid">phenylboronic acid</a> to Sonogashira)</li>
<li>Calculated reagent volumes and wrote OT-2 protocols</li>
<li>Self-corrected when using an incorrect heater-shaker method by consulting documentation</li>
<li>Successfully produced target products confirmed by <a href="https://en.wikipedia.org/wiki/Gas_chromatography%E2%80%93mass_spectrometry">GC-MS</a> analysis (biphenyl at 9.53 min for Suzuki, diphenylacetylene at 12.92 min for Sonogashira)</li>
</ul>
<h3 id="task-6-reaction-optimization">Task 6: Reaction Optimization</h3>
<p>Coscientist was tested on two fully mapped reaction datasets:</p>
<ol>
<li><strong>Suzuki reaction flow dataset</strong> (Perera et al.): varying ligands, reagents/bases, and solvents</li>
<li><strong><a href="https://en.wikipedia.org/wiki/Buchwald%E2%80%93Hartwig_amination">Buchwald-Hartwig</a> C-N coupling dataset</strong> (Doyle et al.): varying ligands, additives, and bases</li>
</ol>
<p>Performance was evaluated using a normalized advantage metric:</p>
<p>$$\text{Normalized Advantage} = \frac{\text{yield}_i - \overline{\text{yield}}}{\text{yield}_{\max} - \overline{\text{yield}}}$$</p>
<p>A value of 1 indicates maximum yield reached, 0 indicates random performance, and negative values indicate worse than random. The normalized maximum advantage (NMA) tracks the best result achieved up to each iteration.</p>
<p>Key findings from the optimization experiments:</p>
<ul>
<li>GPT-4 with prior information (10 random data points) produced better initial guesses than GPT-4 without prior information</li>
<li>Both GPT-4 approaches converged to similar NMA values at the limit</li>
<li>Both GPT-4 approaches outperformed standard <a href="https://en.wikipedia.org/wiki/Bayesian_optimization">Bayesian optimization</a> in NMA and normalized advantage</li>
<li>GPT-3.5 largely failed due to inability to output correct JSON schemas</li>
<li>On the Buchwald-Hartwig dataset, GPT-4 performed comparably whether given compound names or <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings, and could reason about electronic properties from SMILES representations</li>
</ul>
<p>All experiments used a maximum of 20 iterations (5.2% and 6.9% of the total reaction space for the two datasets).</p>
<h2 id="demonstrated-versatility-with-safety-considerations">Demonstrated Versatility with Safety Considerations</h2>
<p>Coscientist demonstrated that GPT-4, when equipped with appropriate tool access, can autonomously handle the full experimental chemistry workflow from literature search to reaction execution and data interpretation. The system showed chemical reasoning capabilities, including selecting appropriate reagents, providing justifications for choices based on reactivity and selectivity, and using experimental data to guide subsequent iterations.</p>
<p>Several limitations are acknowledged:</p>
<ul>
<li>The experimental setup was not yet fully automated (plates were moved manually between instruments), though no human decision-making was involved</li>
<li>GPT-3.5 consistently underperformed due to inability to follow formatting instructions</li>
<li>The synthesis planning evaluation scale is inherently subjective</li>
<li>It is unclear whether GPT-4&rsquo;s training data contained information from the optimization datasets</li>
<li>The comparison with Bayesian optimization may reflect different exploration/exploitation balances rather than pure capability differences</li>
</ul>
<p>The authors raise safety concerns about dual-use potential and note that full code and prompts were withheld pending development of US AI regulations. A simplified implementation was released for reproducibility purposes.</p>
<p>Future directions include extending the system with reaction databases (Reaxys, SciFinder), implementing advanced prompting strategies (ReAct, Chain of Thought, Tree of Thoughts), and developing automated quality control for cloud laboratory experiments.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Synthesis benchmark</td>
          <td>7 compound set</td>
          <td>7 compounds</td>
          <td>Acetaminophen, aspirin, ibuprofen, nitroaniline, etc.</td>
      </tr>
      <tr>
          <td>Optimization</td>
          <td>Perera et al. Suzuki flow dataset</td>
          <td>Fully mapped condition space</td>
          <td>Varying ligands, bases, solvents</td>
      </tr>
      <tr>
          <td>Optimization</td>
          <td>Doyle Buchwald-Hartwig dataset</td>
          <td>Fully mapped condition space</td>
          <td>Varying ligands, additives, bases</td>
      </tr>
      <tr>
          <td>Reagent selection</td>
          <td><a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> compound database</td>
          <td>Not specified</td>
          <td>Used for computational experiments</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Planner</strong>: GPT-4 chat completion with modular system prompts</li>
<li><strong>Web Searcher</strong>: GPT-4 or GPT-3.5-turbo for query generation and result parsing</li>
<li><strong>Documentation embedding</strong>: OpenAI ada model with distance-based vector search</li>
<li><strong>Code execution</strong>: Isolated Docker container (no LLM dependency)</li>
<li><strong>Baseline</strong>: Bayesian optimization with varying initial sample sizes (1-10)</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>GPT-4 (primary)</li>
<li>GPT-3.5-turbo (baseline)</li>
<li>Claude 1.3 (baseline for synthesis planning)</li>
<li>Falcon-40B-Instruct (baseline for synthesis planning)</li>
<li>OpenAI ada (for documentation embedding)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Context</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Synthesis score (1-5)</td>
          <td>7-compound benchmark</td>
          <td>Subjective expert grading</td>
      </tr>
      <tr>
          <td>Normalized advantage</td>
          <td>Optimization tasks</td>
          <td>Measures improvement over random</td>
      </tr>
      <tr>
          <td>NMA</td>
          <td>Optimization tasks</td>
          <td>Maximum advantage achieved through iteration N</td>
      </tr>
      <tr>
          <td>GC-MS confirmation</td>
          <td>Cross-coupling reactions</td>
          <td>Product formation verified experimentally</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Opentrons OT-2 liquid handler with heater-shaker module</li>
<li>UV-Vis plate reader</li>
<li>Emerald Cloud Lab (cloud-based automation)</li>
<li>Computational requirements not specified (relies on OpenAI API calls)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/gomesgroup/coscientist">gomesgroup/coscientist</a></td>
          <td>Code</td>
          <td>Apache-2.0 with Commons Clause</td>
          <td>Simplified implementation; full code withheld for safety</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Boiko, D. A., MacKnight, R., Kline, B. &amp; Gomes, G. (2023). Autonomous chemical research with large language models. <em>Nature</em>, 624(7992), 570-578. <a href="https://doi.org/10.1038/s41586-023-06792-0">https://doi.org/10.1038/s41586-023-06792-0</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{boiko2023autonomous,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Autonomous chemical research with large language models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Boiko, Daniil A. and MacKnight, Robert and Kline, Ben and Gomes, Gabriel dos Passos}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{624}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{7992}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{570--578}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer Nature}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41586-023-06792-0}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemGE: Molecule Generation via Grammatical Evolution</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/search-based/chemge-grammatical-evolution-molecule-generation/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/search-based/chemge-grammatical-evolution-molecule-generation/</guid><description>ChemGE applies grammatical evolution to SMILES strings for population-based de novo molecule generation with inherent parallelism and diversity.</description><content:encoded><![CDATA[<h2 id="grammatical-evolution-for-de-novo-molecular-design">Grammatical Evolution for De Novo Molecular Design</h2>
<p>This is a <strong>Method</strong> paper that introduces ChemGE, a population-based molecular generation approach built on grammatical evolution. Rather than using deep neural networks, ChemGE evolves populations of <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings through a context-free grammar, enabling concurrent evaluation by multiple molecular simulators and producing diverse molecular libraries. The method represents an alternative paradigm for de novo drug design: evolutionary optimization over formal grammars rather than learned latent spaces or autoregressive neural models.</p>
<h2 id="limitations-of-sequential-deep-learning-generators">Limitations of Sequential Deep Learning Generators</h2>
<p>At the time of publication, the dominant approaches to de novo molecular generation included Bayesian optimization over VAE latent spaces (<a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">CVAE</a>, <a href="/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/">GVAE</a>), reinforcement learning with recurrent neural networks (<a href="/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/">ORGAN</a>, <a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a>), sequential Monte Carlo search, and Monte Carlo tree search (ChemTS). These methods share two practical limitations:</p>
<ol>
<li>
<p><strong>Simulation concurrency</strong>: Most methods generate one molecule at a time, making it difficult to run multiple molecular simulations (e.g., <a href="https://en.wikipedia.org/wiki/Molecular_docking">docking</a>) in parallel. This wastes computational resources in high-throughput virtual screening settings.</p>
</li>
<li>
<p><strong>Molecular diversity</strong>: Deep learning generators tend to exploit narrow regions of chemical space. Deep reinforcement learning methods in particular often generate very similar molecules, requiring special countermeasures to maintain diversity. Since drug discovery is a multi-stage pipeline, limited diversity reduces survival rates in downstream <a href="https://en.wikipedia.org/wiki/ADME">ADMET</a> screening.</p>
</li>
</ol>
<p>ChemGE addresses both problems by maintaining a large population of molecules that are evolved and evaluated concurrently.</p>
<h2 id="core-innovation-chromosome-to-smiles-mapping-via-grammar-rules">Core Innovation: Chromosome-to-SMILES Mapping via Grammar Rules</h2>
<p>ChemGE encodes each molecule as a chromosome: a sequence of $N$ integers that deterministically maps to a SMILES string through a context-free grammar. The mapping process works as follows:</p>
<ol>
<li>Start with the grammar&rsquo;s start symbol</li>
<li>At each step $k$, look up the $k$-th integer $c = C[k]$ from the chromosome</li>
<li>Identify the leftmost non-terminal symbol and count its $r$ applicable production rules</li>
<li>Apply the $((c \bmod r) + 1)$-th rule</li>
<li>Repeat until no non-terminal symbols remain or the chromosome is exhausted</li>
</ol>
<p>The context-free grammar is a subset of the OpenSMILES specification, defined formally as $G = (V, \Sigma, R, S)$ where $V$ is the set of non-terminal symbols, $\Sigma$ is the set of terminal symbols, $R$ is the set of production rules, and $S$ is the start symbol.</p>
<p>Evolution follows the $(\mu + \lambda)$ evolution strategy:</p>
<ol>
<li>Create $\lambda$ new chromosomes by drawing random chromosomes from the population and mutating one integer at a random position</li>
<li>Translate each chromosome to a SMILES string and evaluate fitness (e.g., docking score). Invalid molecules receive fitness $-\infty$</li>
<li>Select the top $\mu$ molecules from the merged pool of $\mu + \lambda$ candidates</li>
</ol>
<p>The authors did not use crossover, as it did not improve performance. Diversity is inherently maintained because a large fraction of molecules are mutated in each generation.</p>
<h2 id="experimental-setup-and-benchmark-comparisons">Experimental Setup and Benchmark Comparisons</h2>
<h3 id="druglikeness-score-benchmark">Druglikeness Score Benchmark</h3>
<p>The first experiment optimized the penalized logP score $J^{\log P}$, an indicator of druglikeness defined as:</p>
<p>$$
J^{\log P}(m) = \log P(m) - \text{SA}(m) - \text{ring-penalty}(m)
$$</p>
<p>where $\log P(m)$ is the <a href="https://en.wikipedia.org/wiki/Octanol-water_partition_coefficient">octanol-water partition coefficient</a>, $\text{SA}(m)$ is the synthetic accessibility score, and ring-penalty$(m)$ penalizes carbon rings larger than size 6. All terms are normalized to zero mean and unit standard deviation. Initial populations were randomly sampled from the ZINC database (35 million compounds), with fitness set to $-\infty$ for molecules with molecular weight above 500 or duplicate structures.</p>
<p>ChemGE was compared against CVAE, GVAE, and ChemTS across population sizes $(\mu, \lambda) \in {(10, 20), (100, 200), (1000, 2000), (10000, 20000)}$.</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>2h</th>
          <th>4h</th>
          <th>6h</th>
          <th>8h</th>
          <th>Mol/Min</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ChemGE (10, 20)</td>
          <td>4.46 +/- 0.34</td>
          <td>4.46 +/- 0.34</td>
          <td>4.46 +/- 0.34</td>
          <td>4.46 +/- 0.34</td>
          <td>14.5</td>
      </tr>
      <tr>
          <td>ChemGE (100, 200)</td>
          <td>5.17 +/- 0.26</td>
          <td>5.17 +/- 0.26</td>
          <td>5.17 +/- 0.26</td>
          <td>5.17 +/- 0.26</td>
          <td>135</td>
      </tr>
      <tr>
          <td>ChemGE (1000, 2000)</td>
          <td>4.45 +/- 0.24</td>
          <td>5.32 +/- 0.43</td>
          <td>5.73 +/- 0.33</td>
          <td>5.88 +/- 0.34</td>
          <td>527</td>
      </tr>
      <tr>
          <td>ChemGE (10000, 20000)</td>
          <td>4.20 +/- 0.33</td>
          <td>4.28 +/- 0.28</td>
          <td>4.40 +/- 0.27</td>
          <td>4.53 +/- 0.26</td>
          <td>555</td>
      </tr>
      <tr>
          <td>CVAE</td>
          <td>-30.18 +/- 26.91</td>
          <td>-1.39 +/- 2.24</td>
          <td>-0.61 +/- 1.08</td>
          <td>-0.006 +/- 0.92</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>GVAE</td>
          <td>-4.34 +/- 3.14</td>
          <td>-1.29 +/- 1.67</td>
          <td>-0.17 +/- 0.96</td>
          <td>0.25 +/- 1.31</td>
          <td>1.38</td>
      </tr>
      <tr>
          <td>ChemTS</td>
          <td>4.91 +/- 0.38</td>
          <td>5.41 +/- 0.51</td>
          <td>5.49 +/- 0.44</td>
          <td>5.58 +/- 0.50</td>
          <td>40.89</td>
      </tr>
  </tbody>
</table>
<p>At $(\mu, \lambda) = (1000, 2000)$, ChemGE achieved the highest final score of 5.88 and generated 527 unique molecules per minute, roughly 13x faster than ChemTS and 3700x faster than CVAE. The small population (10, 20) converged prematurely with insufficient diversity, while the overly large population (10000, 20000) could not run enough generations to optimize effectively.</p>
<h3 id="docking-experiment-with-thymidine-kinase">Docking Experiment with Thymidine Kinase</h3>
<p>The second experiment applied ChemGE to generate molecules with high predicted binding affinity for <a href="https://en.wikipedia.org/wiki/Thymidine_kinase">thymidine kinase</a> (KITH), a well-known antiviral drug target. The authors used rDock for docking simulation, taking the best intermolecular score $S_{\text{inter}}$ from three runs with different initial conformations. Fitness was defined as $-S_{\text{inter}}$ (lower scores indicate higher affinity). The protein structure was taken from PDB ID 2B8T.</p>
<p>With 32 parallel cores and $(\mu, \lambda) = (32, 64)$, ChemGE completed 1000 generations in approximately 26 hours, generating 9466 molecules total. Among these, 349 molecules achieved intermolecular scores better than the best known inhibitor in the DUD-E database.</p>
<h3 id="diversity-analysis">Diversity Analysis</h3>
<p>Molecular diversity was measured using internal diversity based on Morgan fingerprints:</p>
<p>$$
I(A) = \frac{1}{|A|^2} \sum_{(x,y) \in A \times A} T_d(x, y)
$$</p>
<p>where $T_d(x, y) = 1 - \frac{|x \cap y|}{|x \cup y|}$ is the <a href="https://en.wikipedia.org/wiki/Jaccard_index#Tanimoto_similarity_and_distance">Tanimoto distance</a>.</p>
<p>The 349 &ldquo;ChemGE-active&rdquo; molecules (those scoring better than the best known inhibitor) had an internal diversity of 0.55, compared to 0.46 for known inhibitors and 0.65 for the whole ZINC database. This is a substantial improvement over known actives, achieved without any explicit diversity-promoting mechanism.</p>
<p>ISOMAP visualizations showed that ChemGE populations migrated away from known inhibitors over generations, ultimately occupying a completely different region of chemical space by generation 1000. This suggests ChemGE discovered a novel structural class of potential binders.</p>
<h2 id="high-throughput-and-diversity-without-deep-learning">High Throughput and Diversity Without Deep Learning</h2>
<p>ChemGE demonstrates several notable findings:</p>
<ol>
<li>
<p><strong>Deep learning is not required</strong> for competitive de novo molecular generation. Grammatical evolution over SMILES achieves higher throughput and comparable or better optimization scores than VAE- and RNN-based methods.</p>
</li>
<li>
<p><strong>Population size matters significantly</strong>. Too small a population leads to premature convergence. Too large a population prevents sufficient per-molecule optimization within the computational budget. The $(\mu, \lambda) = (1000, 2000)$ setting provided the best balance.</p>
</li>
<li>
<p><strong>Inherent diversity</strong> is a key advantage of evolutionary methods. Without any explicit diversity loss or penalty, ChemGE maintains diversity comparable to the ZINC database and exceeds that of known active molecules.</p>
</li>
<li>
<p><strong>Parallel evaluation</strong> is naturally supported. Each generation produces $\lambda$ independent molecules that can be evaluated by separate docking simulators simultaneously.</p>
</li>
</ol>
<p>The authors acknowledge several limitations. Synthetic routes and ADMET properties were not evaluated for the generated molecules. The docking scores, while favorable, require confirmation through biological assays. The authors also note that incorporating probabilistic or neural models into the evolutionary process might further improve performance.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Initial population</td>
          <td>ZINC</td>
          <td>~35M compounds</td>
          <td>Randomly sampled starting molecules</td>
      </tr>
      <tr>
          <td>Docking target</td>
          <td>PDB 2B8T</td>
          <td>1 structure</td>
          <td>Thymidine kinase-ligand complex</td>
      </tr>
      <tr>
          <td>Baseline actives</td>
          <td>DUD-E (KITH)</td>
          <td>57 inhibitors</td>
          <td>Known thymidine kinase inhibitors</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Grammatical evolution with $(\mu + \lambda)$ evolution strategy</li>
<li>Mutation only (no crossover)</li>
<li>Context-free grammar subset of OpenSMILES specification</li>
<li>Chromosome length: $N$ integers per molecule</li>
<li>Fitness set to $-\infty$ for invalid SMILES, MW &gt; 500, or duplicate molecules</li>
</ul>
<h3 id="models">Models</h3>
<p>No neural network models are used. ChemGE is purely evolutionary.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Max $J^{\log P}$ (8h)</td>
          <td>5.88 +/- 0.34</td>
          <td>ChemTS: 5.58 +/- 0.50</td>
          <td>ChemGE (1000, 2000)</td>
      </tr>
      <tr>
          <td>Molecules/min</td>
          <td>527</td>
          <td>ChemTS: 40.89</td>
          <td>~13x throughput improvement</td>
      </tr>
      <tr>
          <td>Docking hits</td>
          <td>349</td>
          <td>Best DUD-E inhibitor</td>
          <td>Molecules with better $S_{\text{inter}}$</td>
      </tr>
      <tr>
          <td>Internal diversity</td>
          <td>0.55</td>
          <td>Known inhibitors: 0.46</td>
          <td>Morgan fingerprint Tanimoto distance</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>CPU: Intel Xeon E5-2630 v3 (benchmark experiments, single core)</li>
<li>Docking: 32 cores in parallel (thymidine kinase experiment, ~26 hours for 1000 generations)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/tsudalab/ChemGE">ChemGE</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Yoshikawa, N., Terayama, K., Sumita, M., Homma, T., Oono, K., &amp; Tsuda, K. (2018). Population-based de novo molecule generation, using grammatical evolution. <em>Chemistry Letters</em>, 47(11), 1431-1434. <a href="https://doi.org/10.1246/cl.180665">https://doi.org/10.1246/cl.180665</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{yoshikawa2018chemge,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Population-based De Novo Molecule Generation, Using Grammatical Evolution}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Yoshikawa, Naruki and Terayama, Kei and Sumita, Masato and Homma, Teruki and Oono, Kenta and Tsuda, Koji}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Chemistry Letters}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{47}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{11}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1431--1434}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Oxford University Press}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1246/cl.180665}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemCrow: Augmenting LLMs with 18 Chemistry Tools</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/chemcrow-augmenting-llms-chemistry-tools/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/chemcrow-augmenting-llms-chemistry-tools/</guid><description>ChemCrow integrates 18 expert-designed chemistry tools with GPT-4 to enable autonomous synthesis planning, drug discovery, and materials design tasks.</description><content:encoded><![CDATA[<h2 id="an-llm-powered-chemistry-agent">An LLM-Powered Chemistry Agent</h2>
<p>This is a <strong>Method</strong> paper that introduces ChemCrow, an LLM chemistry agent that augments GPT-4 with 18 expert-designed tools to accomplish tasks across organic synthesis, drug discovery, and materials design. Rather than relying on the LLM&rsquo;s internal knowledge (which is often inaccurate for chemistry), ChemCrow uses the LLM as a reasoning engine that iteratively calls specialized tools to gather information, plan actions, and execute experiments. The system successfully planned and executed real-world chemical syntheses on a robotic platform, demonstrating one of the first chemistry-related LLM agent interactions with the physical world.</p>
<h2 id="bridging-llm-reasoning-and-chemical-expertise">Bridging LLM Reasoning and Chemical Expertise</h2>
<p>Large language models have transformed many domains, but they struggle with chemistry-specific problems. GPT-4 cannot reliably perform basic operations like multiplying large numbers, converting <a href="https://en.wikipedia.org/wiki/IUPAC_nomenclature_of_chemistry">IUPAC names</a> to molecular structures, or predicting reaction outcomes. These limitations stem from the models&rsquo; token-prediction design, which does not encode chemical reasoning or factual chemical knowledge reliably.</p>
<p>Meanwhile, the chemistry community has developed numerous specialized computational tools for reaction prediction, <a href="/notes/chemistry/molecular-design/reaction-prediction/">retrosynthesis</a> planning, molecular property prediction, and de novo molecular generation. These tools exist in isolated environments with steep learning curves, making them difficult for experimental chemists to integrate and use together. The gap between LLM reasoning capabilities and specialized chemistry tools presents an opportunity: augmenting LLMs with these tools could compensate for the models&rsquo; chemical knowledge deficiencies while providing a natural language interface to specialized computational chemistry capabilities.</p>
<h2 id="tool-augmented-reasoning-via-react">Tool-Augmented Reasoning via ReAct</h2>
<p>ChemCrow builds on the ReAct (Reasoning and Acting) framework, where the LLM follows an iterative Thought-Action-Action Input-Observation loop. At each step, the model reasons about the current state of the task, selects an appropriate tool, provides input, pauses while the tool executes, and then incorporates the observation before deciding on the next step. This continues until the final answer is reached.</p>
<p>The system integrates 18 tools organized into four categories:</p>
<p><strong>General tools</strong> include web search (via SerpAPI), literature search (using paper-qa with OpenAI embeddings and FAISS), a Python REPL for arbitrary code execution, and a human interaction interface.</p>
<p><strong>Molecule tools</strong> cover Name2SMILES (converting molecule names to <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> via Chem-Space, PubChem, and OPSIN), SMILES2Price (checking purchasability via molbloom and ZINC20), Name2CAS (CAS number lookup via PubChem), molecular Similarity (<a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> with ECFP2 fingerprints), ModifyMol (local chemical space exploration via SynSpace), PatentCheck (bloom filter patent lookup via molbloom), FuncGroups (functional group identification via SMARTS patterns), and SMILES2Weight (molecular weight calculation via RDKit).</p>
<p><strong>Safety tools</strong> include ControlledChemicalCheck (screening against chemical weapons lists from <a href="https://en.wikipedia.org/wiki/Organisation_for_the_Prohibition_of_Chemical_Weapons">OPCW</a> and the Australia Group), ExplosiveCheck (GHS explosive classification via PubChem), and SafetySummary (comprehensive safety overview from PubChem data).</p>
<p><strong>Chemical reaction tools</strong> include NameRXN (reaction classification via NextMove Software), ReactionPredict (product prediction via IBM&rsquo;s RXN4Chemistry API using the <a href="/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/">Molecular Transformer</a>), ReactionPlanner (multi-step synthesis planning via RXN4Chemistry), and ReactionExecute (direct synthesis execution on IBM&rsquo;s RoboRXN robotic platform).</p>
<p>A key design feature is that safety checks are automatically invoked before synthesis execution. If a molecule is flagged as a controlled chemical or precursor, execution stops immediately.</p>
<h2 id="experimental-validation-and-evaluation">Experimental Validation and Evaluation</h2>
<h3 id="autonomous-synthesis">Autonomous Synthesis</h3>
<p>ChemCrow autonomously planned and executed four real-world syntheses on the IBM RoboRXN cloud-connected robotic platform:</p>
<ul>
<li><strong><a href="https://en.wikipedia.org/wiki/DEET">DEET</a></strong> (insect repellent), from the prompt &ldquo;Plan and execute the synthesis of an insect repellent&rdquo;</li>
<li><strong>Three <a href="https://en.wikipedia.org/wiki/Thiourea">thiourea</a> <a href="https://en.wikipedia.org/wiki/Organocatalysis">organocatalysts</a></strong> (Schreiner&rsquo;s, Ricci&rsquo;s, and Takemoto&rsquo;s catalysts), from a prompt asking to find and synthesize a thiourea organocatalyst that accelerates the <a href="https://en.wikipedia.org/wiki/Diels%E2%80%93Alder_reaction">Diels-Alder reaction</a></li>
</ul>
<p>All four syntheses yielded the anticipated compounds. ChemCrow demonstrated the ability to autonomously adapt synthesis procedures when the RoboRXN platform flagged issues (such as insufficient solvent or invalid purification actions), iteratively modifying the procedure until it was valid.</p>
<h3 id="novel-chromophore-discovery">Novel Chromophore Discovery</h3>
<p>In a human-AI collaboration scenario, ChemCrow was instructed to train a machine learning model to screen candidate <a href="https://en.wikipedia.org/wiki/Chromophore">chromophores</a>. The system loaded and cleaned data from a chromophore database, trained and evaluated a random forest model, and suggested a molecule with a target absorption maximum of 369 nm. The proposed molecule was subsequently synthesized and characterized, revealing a measured absorption maximum of 336 nm, confirming the discovery of a new chromophore.</p>
<h3 id="expert-vs-llm-evaluation">Expert vs. LLM Evaluation</h3>
<p>The evaluation used 14 use cases spanning synthesis planning, molecular design, and chemical logic. Both ChemCrow and standalone GPT-4 (without tools) were evaluated by:</p>
<ol>
<li><strong>Expert human evaluators</strong> (n=4): Assessed correctness of chemistry, quality of reasoning, and degree of task completion</li>
<li><strong>EvaluatorGPT</strong>: An LLM evaluator prompted to assess responses</li>
</ol>
<p>Key findings from the evaluation:</p>
<table>
  <thead>
      <tr>
          <th>Evaluator</th>
          <th>Preferred System</th>
          <th>Reasoning</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Human experts</td>
          <td>ChemCrow</td>
          <td>Better chemical accuracy and task completeness, especially on complex tasks</td>
      </tr>
      <tr>
          <td>EvaluatorGPT</td>
          <td>GPT-4</td>
          <td>Favored fluent, complete-sounding responses despite factual errors</td>
      </tr>
  </tbody>
</table>
<p>Human experts preferred ChemCrow across most tasks, with the exception of very simple tasks where GPT-4 could answer from memorized training data (e.g., synthesis of well-known molecules like paracetamol). GPT-4 without tools consistently produced hallucinations that appeared convincing but were factually incorrect upon expert inspection.</p>
<p>An important finding is that LLM-based evaluation (EvaluatorGPT) cannot replace expert human assessment for scientific tasks. The LLM evaluator lacks the domain knowledge needed to distinguish fluent but incorrect answers from accurate ones, rendering it unsuitable for benchmarking factuality in chemistry.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<p>ChemCrow demonstrates that augmenting LLMs with expert-designed tools transforms them from &ldquo;hyperconfident, typically wrong information sources&rdquo; into reasoning engines that can gather and act on accurate chemical information. The system lowers the barrier for non-experts to access computational chemistry tools through natural language while serving as an assistant to expert chemists.</p>
<p>Several limitations are acknowledged:</p>
<ul>
<li><strong>Tool dependency</strong>: ChemCrow&rsquo;s performance is bounded by the quality and coverage of its tools. Improved synthesis engines would directly improve synthesis planning capabilities.</li>
<li><strong>Reasoning failures</strong>: Tools become useless if the LLM&rsquo;s reasoning about when and how to use them is flawed, or if garbage inputs are provided.</li>
<li><strong>Reproducibility</strong>: The API-based approach to closed-source LLMs (GPT-4) limits reproducibility of individual results. The authors note that open-source models could address this, potentially at the cost of reasoning quality.</li>
<li><strong>Evaluation scope</strong>: The 14 evaluation tasks, while diverse, represent a limited test set. Standardized benchmarks for LLM-based chemistry tools did not exist at the time of publication.</li>
<li><strong>Safety considerations</strong>: While safety tools prevent execution of controlled chemical syntheses, risks remain from inaccurate reasoning or tool outputs leading to suboptimal conclusions.</li>
</ul>
<p>The authors emphasize that ChemCrow&rsquo;s modular design allows easy extension with new tools, and that future integration of image-processing tools, additional language-based tools, and other capabilities could substantially enhance the system.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Chromophore screening</td>
          <td>DB for chromophore (Joung et al.)</td>
          <td>Not specified</td>
          <td>Used for training random forest model</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>14 expert-designed tasks</td>
          <td>14 tasks</td>
          <td>Spanning synthesis, molecular design, and chemical logic</td>
      </tr>
      <tr>
          <td>Chemical safety</td>
          <td>OPCW Schedules 1-3, Australia Group lists</td>
          <td>Not specified</td>
          <td>Used for controlled chemical screening</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>LLM</strong>: GPT-4 with temperature 0.1</li>
<li><strong>Framework</strong>: LangChain for tool integration</li>
<li><strong>Reasoning</strong>: ReAct (Reasoning + Acting) framework with chain-of-thought prompting</li>
<li><strong>Synthesis planning</strong>: IBM RXN4Chemistry API (Molecular Transformer-based)</li>
<li><strong>Molecule similarity</strong>: Tanimoto similarity with ECFP2 fingerprints via RDKit</li>
<li><strong>Chemical space exploration</strong>: SynSpace with 50 robust medicinal chemistry reactions</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>GPT-4 (OpenAI, closed-source) for reasoning</li>
<li>Random forest for chromophore screening (trained on the fly)</li>
<li>Molecular Transformer via RXN4Chemistry API for reaction prediction and retrosynthesis</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Human evaluation</strong>: 4 expert chemists rated responses on chemistry correctness, reasoning quality, and task completion</li>
<li><strong>LLM evaluation</strong>: EvaluatorGPT assessed responses (found unreliable for factuality)</li>
<li><strong>Experimental validation</strong>: 4 syntheses on RoboRXN platform, 1 novel chromophore characterization</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Hardware requirements are not specified in the paper. The system relies primarily on API calls to GPT-4 and RXN4Chemistry, so local compute requirements are minimal.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/ur-whitelab/chemcrow-public">chemcrow-public</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Open-source implementation with 12 of 18 tools</td>
      </tr>
      <tr>
          <td><a href="https://github.com/ur-whitelab/chemcrow-runs">chemcrow-runs</a></td>
          <td>Data</td>
          <td>Not specified</td>
          <td>All experiment outputs and evaluation data</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.10884639">Zenodo release (code)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Archived release v0.3.24</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.10884645">Zenodo release (runs)</a></td>
          <td>Data</td>
          <td>Not specified</td>
          <td>Archived experiment runs</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Bran, A. M., Cox, S., Schilter, O., Baldassari, C., White, A. D., &amp; Schwaller, P. (2024). Augmenting large language models with chemistry tools. <em>Nature Machine Intelligence</em>, 6(5), 525-535.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{bran2024augmenting,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Augmenting large language models with chemistry tools}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Bran, Andres M. and Cox, Sam and Schilter, Oliver and Baldassari, Carlo and White, Andrew D. and Schwaller, Philippe}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Machine Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{6}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{525--535}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s42256-024-00832-8}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChatDrug: Conversational Drug Editing with ChatGPT</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/chatdrug-conversational-drug-editing/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/chatdrug-conversational-drug-editing/</guid><description>ChatDrug uses ChatGPT with retrieval and domain feedback for drug editing across small molecules, peptides, and proteins on 39 tasks.</description><content:encoded><![CDATA[<h2 id="a-framework-for-conversational-drug-editing-with-llms">A Framework for Conversational Drug Editing with LLMs</h2>
<p>ChatDrug is a <strong>Method</strong> paper that introduces a parameter-free framework for drug editing using conversational large language models (specifically ChatGPT/GPT-3.5). The primary contribution is a three-module pipeline that combines prompt engineering, retrieval-augmented domain feedback, and iterative conversation to perform text-guided editing of small molecules, peptides, and proteins. The paper also establishes a benchmark of 39 drug editing tasks spanning these three drug types.</p>
<h2 id="bridging-conversational-ai-and-drug-discovery">Bridging Conversational AI and Drug Discovery</h2>
<p>Drug editing (also called <a href="https://en.wikipedia.org/wiki/Hit_to_lead">lead optimization</a> or protein design) is a critical step in the drug discovery pipeline where molecular substructures are modified to achieve desired properties. Traditional approaches rely on domain experts for manual editing, which can be subjective and biased. Recent multi-modal approaches like MoleculeSTM and ProteinDT have started exploring text-guided drug editing, but they are domain-specific (limited to one drug type) and lack conversational capabilities for iterative refinement.</p>
<p>The authors identify three properties of conversational LLMs that make them suitable for drug discovery: (1) pretraining on comprehensive knowledge bases covering drug-related concepts, (2) strong few-shot adaptation and generalization abilities, and (3) interactive communication enabling iterative feedback incorporation. However, directly applying LLMs to drug editing yields suboptimal results because the models do not fully utilize prior domain knowledge. ChatDrug addresses this gap through structured retrieval and feedback mechanisms.</p>
<h2 id="three-module-pipeline-pdds-redf-and-conversation">Three-Module Pipeline: PDDS, ReDF, and Conversation</h2>
<p>ChatDrug consists of three modules that operate sequentially without any parameter learning.</p>
<h3 id="pdds-module-prompt-design-for-domain-specific">PDDS Module (Prompt Design for Domain-Specific)</h3>
<p>The PDDS module constructs domain-specific prompts for ChatGPT. Given an input drug $\pmb{x}_{\text{in}}$ and a text prompt $\pmb{x}_t$ describing the desired property change, the goal is:</p>
<p>$$
\pmb{x}_{\text{out}} = \text{ChatDrug}(\pmb{x}_{\text{in}}, \pmb{x}_t)
$$</p>
<p>The prompts are designed around high-level property descriptions (e.g., &ldquo;more soluble in water&rdquo;) rather than exact substructure replacements. The authors argue that ChatDrug is better suited for &ldquo;fuzzy searching&rdquo; (property-based editing with non-deterministic answers) rather than &ldquo;exact searching&rdquo; (precise substructure replacement that experts can do directly).</p>
<h3 id="redf-module-retrieval-and-domain-feedback">ReDF Module (Retrieval and Domain Feedback)</h3>
<p>The ReDF module retrieves structurally similar examples from a domain-specific database and injects them into the conversation as demonstrations. For an input drug $\pmb{x}_{\text{in}}$, a candidate drug $\tilde{\pmb{x}}$ that failed the desired property change, and a retrieval database, ReDF returns:</p>
<p>$$
\pmb{x}_R = \text{ReDF}(\pmb{x}_{\text{in}}, \tilde{\pmb{x}}; \pmb{x}_t) = \underset{\pmb{x}&rsquo;_R \in \text{RetrievalDB}}{\arg\max} \langle \tilde{\pmb{x}}, \pmb{x}&rsquo;_R \rangle \wedge D(\pmb{x}_{\text{in}}, \pmb{x}&rsquo;_R; \pmb{x}_t)
$$</p>
<p>where $D(\cdot, \cdot; \cdot) \in {\text{True}, \text{False}}$ is a domain feedback function checking whether the retrieved drug satisfies the desired property change, and $\langle \tilde{\pmb{x}}, \pmb{x}&rsquo;_R \rangle$ is a similarity function (<a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> for small molecules, <a href="https://en.wikipedia.org/wiki/Levenshtein_distance">Levenshtein distance</a> for peptides and proteins).</p>
<p>The retrieved example $\pmb{x}_R$ is injected into the prompt as: &ldquo;Your provided sequence [$\tilde{\pmb{x}}$] is not correct. We find a sequence [$\pmb{x}_R$] which is correct and similar to the molecule you provided. Can you give me a new molecule?&rdquo;</p>
<h3 id="conversation-module">Conversation Module</h3>
<p>The conversation module enables iterative refinement over $C$ rounds. At each round $c$, if the edited drug $\pmb{x}_c$ does not satisfy the evaluation condition, ChatDrug retrieves a new example via ReDF using $\tilde{\pmb{x}} = \pmb{x}_c$ and continues the conversation. This aligns with the iterative nature of real drug discovery workflows.</p>
<h2 id="experiments-across-39-drug-editing-tasks">Experiments Across 39 Drug Editing Tasks</h2>
<h3 id="task-design">Task Design</h3>
<p>The benchmark includes 39 tasks across three drug types:</p>
<ul>
<li><strong>Small molecules</strong> (28 tasks): 16 single-objective (tasks 101-108, each with loose and strict thresholds) and 12 multi-objective tasks (tasks 201-206, each with two thresholds). Properties include solubility (<a href="https://en.wikipedia.org/wiki/Partition_coefficient">LogP</a>), drug-likeness (QED), permeability (<a href="https://en.wikipedia.org/wiki/Polar_surface_area">tPSA</a>), <a href="https://en.wikipedia.org/wiki/Hydrogen_bond">hydrogen bond</a> acceptors/donors.</li>
<li><strong>Peptides</strong> (9 tasks): 6 single-objective and 3 multi-objective tasks for editing <a href="https://en.wikipedia.org/wiki/Major_histocompatibility_complex">peptide-MHC binding</a> affinity across different <a href="https://en.wikipedia.org/wiki/Human_leukocyte_antigen">HLA allele</a> types.</li>
<li><strong>Proteins</strong> (2 tasks): Editing protein sequences to increase <a href="https://en.wikipedia.org/wiki/Alpha_helix">alpha-helix</a> or <a href="https://en.wikipedia.org/wiki/Beta_sheet">beta-strand</a> secondary structures.</li>
</ul>
<h3 id="baselines">Baselines</h3>
<p>For small molecules, baselines include Random, PCA, High-Variance, and GS-Mutate (all based on MegaMolBART), plus MoleculeSTM with <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> and Graph representations. For peptides and proteins, random mutation baselines with 1-3 mutated positions are used.</p>
<h3 id="main-results">Main Results</h3>
<p>ChatDrug achieves the best performance on 33 out of 39 tasks. Key results for small molecule editing (hit ratio):</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Property</th>
          <th>ChatDrug (loose)</th>
          <th>Best Baseline (loose)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>101</td>
          <td>More soluble</td>
          <td>94.13</td>
          <td>67.86 (MoleculeSTM-Graph)</td>
      </tr>
      <tr>
          <td>102</td>
          <td>Less soluble</td>
          <td>96.86</td>
          <td>64.79 (MoleculeSTM-Graph)</td>
      </tr>
      <tr>
          <td>106</td>
          <td>Lower permeability</td>
          <td>77.35</td>
          <td>34.13 (MoleculeSTM-SMILES)</td>
      </tr>
      <tr>
          <td>107</td>
          <td>More HBA</td>
          <td>95.35</td>
          <td>54.01 (MoleculeSTM-SMILES)</td>
      </tr>
      <tr>
          <td>108</td>
          <td>More HBD</td>
          <td>96.54</td>
          <td>60.97 (MoleculeSTM-Graph)</td>
      </tr>
  </tbody>
</table>
<p>ChatDrug underperforms on tasks 104 (less like a drug) and 105 (higher permeability) and most multi-objective tasks involving permeability (205), where MoleculeSTM variants perform better.</p>
<p>For peptide editing, ChatDrug achieves 41-69% hit ratios compared to 0.4-14.4% for random mutation baselines. For protein editing, ChatDrug reaches 34.79% and 51.38% hit ratios on helix and strand tasks respectively, compared to 26.90% and 21.44% for the best random mutation baseline.</p>
<h3 id="ablation-studies">Ablation Studies</h3>
<p><strong>Conversation rounds</strong>: Performance increases with more rounds, converging around $C = 2$. For example, on task 101 (loose threshold), zero-shot achieves 78.26%, $C = 1$ reaches 89.56%, and $C = 2$ reaches 93.37%.</p>
<p><strong>ReDF threshold</strong>: Using a stricter threshold in the domain feedback function $D$ (matching the evaluation threshold) yields substantially higher performance than using a loose threshold. For example, on task 107 with strict evaluation, the strict-threshold ReDF achieves 72.60% vs. 14.96% for the loose-threshold ReDF.</p>
<p><strong>Similarity analysis</strong>: Retrieved molecules $\pmb{x}_R$ tend to have lower similarity to input molecules than the intermediate outputs $\pmb{x}_1$, yet they have higher hit ratios. This suggests the ReDF module explores the chemical space effectively, and the conversation module balances similarity preservation with property optimization.</p>
<p><strong>Knowledge extraction</strong>: ChatDrug can articulate domain-specific reasoning for its edits (e.g., summarizing rules for increasing water solubility by introducing polar functional groups), though the extracted knowledge shows some redundancy.</p>
<h2 id="limitations-and-future-directions">Limitations and Future Directions</h2>
<p>ChatDrug demonstrates that conversational LLMs can serve as useful tools for drug editing, achieving strong results across diverse drug types with a parameter-free approach. The framework exhibits open vocabulary and compositional properties, allowing it to handle novel drug concepts and multi-objective tasks through natural language.</p>
<p>The authors acknowledge two main limitations. First, ChatDrug struggles with understanding complex 3D drug geometries, which would require deeper geometric modeling. Second, the framework requires multiple conversation rounds to achieve strong performance, adding computational cost through repeated API calls. The authors suggest that knowledge summarization capabilities of LLMs could help reduce this cost.</p>
<p>The evaluation relies entirely on computational oracles (RDKit for small molecules, MHCflurry2.0 for peptides, ProteinCLAP for proteins) rather than wet-lab validation. The hit ratio metric also excludes invalid outputs from the denominator, so the effective success rate on all attempted edits may be lower than reported.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Small molecule inputs</td>
          <td><a href="/notes/chemistry/datasets/zinc-22/">ZINC</a></td>
          <td>200 molecules</td>
          <td>Sampled SMILES strings</td>
      </tr>
      <tr>
          <td>Small molecule retrieval DB</td>
          <td>ZINC</td>
          <td>10K molecules</td>
          <td>For ReDF similarity search</td>
      </tr>
      <tr>
          <td>Peptide inputs</td>
          <td>Peptide-MHC binding dataset</td>
          <td>500 peptides per task</td>
          <td>From 30 common MHC alleles</td>
      </tr>
      <tr>
          <td>Peptide retrieval DB</td>
          <td>Experimental binding data</td>
          <td>Varies by allele</td>
          <td>Target allele experimental data</td>
      </tr>
      <tr>
          <td>Protein inputs</td>
          <td>TAPE test set</td>
          <td>Varies</td>
          <td>Secondary structure prediction test data</td>
      </tr>
      <tr>
          <td>Protein retrieval DB</td>
          <td>TAPE training set</td>
          <td>Varies</td>
          <td>Secondary structure prediction training data</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>GPT-3.5-turbo via OpenAI ChatCompletion API, temperature=0, frequency_penalty=0.2</li>
<li>System prompt: &ldquo;You are an expert in the field of molecular chemistry.&rdquo;</li>
<li>$C = 2$ conversation rounds for main results</li>
<li>5 random seeds (0-4) for small molecule main results, seed 0 for ablations</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>ChatGPT (GPT-3.5-turbo): used as-is, no fine-tuning</li>
<li>MHCflurry 2.0: pseudo-oracle for peptide binding affinity evaluation</li>
<li>ProteinCLAP-EBM-NCE from ProteinDT: protein secondary structure prediction</li>
<li>ESMFold: protein folding for visualization</li>
<li>RDKit: molecular property calculations for small molecules</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Hit Ratio</td>
          <td>Fraction of valid edits satisfying property requirements</td>
          <td>Invalid sequences excluded from denominator</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>All experiments conducted on a single NVIDIA RTX A6000 GPU (used only for peptide and protein evaluation). Total OpenAI API cost was less than $100.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/chao1224/ChatDrug">ChatDrug GitHub</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Liu, S., Wang, J., Yang, Y., Wang, C., Liu, L., Guo, H., &amp; Xiao, C. (2024). Conversational Drug Editing Using Retrieval and Domain Feedback. <em>ICLR 2024</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{liu2024chatdrug,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Conversational Drug Editing Using Retrieval and Domain Feedback}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Liu, Shengchao and Wang, Jiongxiao and Yang, Yijin and Wang, Chengpeng and Liu, Ling and Guo, Hongyu and Xiao, Chaowei}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>BioT5: Cross-Modal Integration of Biology and Chemistry</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/biot5-cross-modal-biology/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/biot5-cross-modal-biology/</guid><description>BioT5 is a T5-based pretraining framework that jointly models molecules, proteins, and natural language using SELFIES for robust molecular generation.</description><content:encoded><![CDATA[<h2 id="a-unified-pretraining-framework-for-molecules-proteins-and-text">A Unified Pretraining Framework for Molecules, Proteins, and Text</h2>
<p>BioT5 is a <strong>Method</strong> paper that introduces a comprehensive <a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5</a>-based pretraining framework for cross-modal integration of molecules, proteins, and natural language. The primary contribution is a multi-task pretraining approach that uses <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> (instead of <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>) for 100% valid molecular representations, separate tokenization for each modality, and a combination of masked language modeling and translation objectives to connect structured biological data with unstructured scientific text. After fine-tuning, BioT5 (252M parameters) achieves state-of-the-art performance on 10 out of 15 downstream tasks spanning molecule property prediction, protein property prediction, drug-target interaction, protein-protein interaction, molecule captioning, and text-based molecule generation.</p>
<h2 id="bridging-the-gap-between-molecular-sequences-and-scientific-knowledge">Bridging the Gap Between Molecular Sequences and Scientific Knowledge</h2>
<p>Prior cross-modal models in computational biology face three recurring challenges. First, models like MolT5 and MolXPT rely on SMILES to represent molecules, but SMILES strings are syntactically fragile: random perturbations or model-generated sequences frequently produce invalid molecular structures. Edwards et al. (2022) and Li et al. (2023) both highlight this validity problem as a bottleneck for text-to-molecule generation. Second, the contextual information surrounding molecular and protein names in scientific literature (e.g., mentions in <a href="https://en.wikipedia.org/wiki/PubMed">PubMed</a> abstracts that describe properties, interactions, and experimental results) remains underutilized. Most models either ignore this context or treat it identically to structured database entries. Third, existing approaches like MolT5 and <a href="/notes/chemistry/llm-applications/galactica-large-language-model-for-science/">Galactica</a> share a single tokenizer and embedding space across molecules, proteins, and text. This leads to chemically incorrect tokenization: the bromine atom &ldquo;Br&rdquo; in SMILES gets split into &ldquo;B&rdquo; (boron) and &ldquo;r&rdquo;, producing erroneous downstream predictions.</p>
<p>BioT5 addresses all three issues simultaneously by adopting SELFIES for molecular representation, extracting entity-linked contextual knowledge from PubMed, and employing separate vocabularies for each modality.</p>
<h2 id="selfies-separate-tokenization-and-multi-task-pretraining">SELFIES, Separate Tokenization, and Multi-Task Pretraining</h2>
<p>The core innovations of BioT5 center on three design decisions:</p>
<h3 id="selfies-for-robust-molecular-representation">SELFIES for Robust Molecular Representation</h3>
<p>BioT5 replaces SMILES with SELFIES (Self-referencing Embedded Strings) for all molecular representations. Every permutation of symbols within the SELFIES alphabet generates a chemically valid molecular structure, guaranteeing 100% validity in generation tasks. Molecules from ZINC20 are converted from SMILES to SELFIES during data preprocessing.</p>
<h3 id="modality-specific-tokenization">Modality-Specific Tokenization</h3>
<p>Rather than sharing a single SentencePiece vocabulary across modalities, BioT5 maintains three separate dictionaries:</p>
<ul>
<li><strong>Molecules</strong>: Each SELFIES token corresponds to a chemically meaningful atom group enclosed in brackets (e.g., <code>[C]</code>, <code>[=C]</code>, <code>[Br]</code>).</li>
<li><strong>Proteins</strong>: Amino acids are prefixed with a special <code>&lt;p&gt;</code> token to distinguish them from text characters (e.g., <code>&lt;p&gt;M</code>, <code>&lt;p&gt;K</code>, <code>&lt;p&gt;R</code>).</li>
<li><strong>Text</strong>: The standard T5 vocabulary is retained.</li>
</ul>
<p>This prevents semantic conflation across modalities. The total vocabulary size is 35,073, and the model comprises 252M parameters using the T5-v1.1-base architecture.</p>
<h3 id="multi-task-pretraining-objectives">Multi-Task Pretraining Objectives</h3>
<p>BioT5 uses six pretraining tasks organized into three categories:</p>
<ol>
<li><strong>Single-modal T5 objective</strong>: Standard span corruption and recovery applied independently to molecule SELFIES (task 1), protein <a href="https://en.wikipedia.org/wiki/FASTA_format">FASTA</a> (task 2), and general text from C4 (task 3).</li>
<li><strong>Wrapped text T5 objective</strong> (task 4): Applied to PubMed articles where molecular names are replaced with corresponding SELFIES strings and gene names are appended with protein FASTA sequences, using BERN2 for named entity recognition and entity linking.</li>
<li><strong>Bidirectional translation</strong> (tasks 5 and 6): Molecule SELFIES to text description and vice versa (using 339K pairs from <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a>), and protein FASTA to text description and vice versa (using 569K pairs from <a href="https://en.wikipedia.org/wiki/UniProt">Swiss-Prot</a>).</li>
</ol>
<p>The translation direction is randomly sampled with probability 0.5 for each example. For downstream tasks, BioT5 uses prompt-based fine-tuning to cast all tasks into a sequence generation format, reducing the gap between pretraining and fine-tuning.</p>
<h2 id="evaluation-across-15-downstream-tasks">Evaluation Across 15 Downstream Tasks</h2>
<p>BioT5 is evaluated on 15 tasks organized into three categories: single-instance prediction, multi-instance prediction, and cross-modal generation.</p>
<h3 id="molecule-property-prediction-moleculenet">Molecule Property Prediction (MoleculeNet)</h3>
<p>BioT5 is evaluated on six binary classification tasks from <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> using scaffold splitting: BBBP, Tox21, ClinTox, HIV, BACE, and SIDER. Results are averaged over three random runs.</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>GEM</th>
          <th>MolXPT</th>
          <th>BioT5</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>BBBP</td>
          <td>72.4</td>
          <td>80.0</td>
          <td>77.7</td>
      </tr>
      <tr>
          <td>Tox21</td>
          <td>78.1</td>
          <td>77.1</td>
          <td>77.9</td>
      </tr>
      <tr>
          <td>ClinTox</td>
          <td>90.1</td>
          <td>95.3</td>
          <td>95.4</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>80.6</td>
          <td>78.1</td>
          <td><strong>81.0</strong></td>
      </tr>
      <tr>
          <td>BACE</td>
          <td>85.6</td>
          <td>88.4</td>
          <td><strong>89.4</strong></td>
      </tr>
      <tr>
          <td>SIDER</td>
          <td>67.2</td>
          <td>71.7</td>
          <td><strong>73.2</strong></td>
      </tr>
      <tr>
          <td><strong>Avg</strong></td>
          <td>79.0</td>
          <td>81.9</td>
          <td><strong>82.4</strong></td>
      </tr>
  </tbody>
</table>
<p>BioT5 achieves the best average AUROC (82.4) across all six datasets, surpassing both GNN-based methods (GEM) and language model baselines (MolXPT).</p>
<h3 id="protein-property-prediction-peer-benchmark">Protein Property Prediction (PEER Benchmark)</h3>
<p>On the PEER benchmark, BioT5 is evaluated on protein solubility and subcellular localization prediction:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Params</th>
          <th>Solubility (Acc)</th>
          <th>Localization (Acc)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESM-1b</td>
          <td>652.4M</td>
          <td>70.23</td>
          <td><strong>92.40</strong></td>
      </tr>
      <tr>
          <td>ProtBert</td>
          <td>419.9M</td>
          <td>68.15</td>
          <td>91.32</td>
      </tr>
      <tr>
          <td>BioT5</td>
          <td>252.1M</td>
          <td><strong>74.65</strong></td>
          <td>91.69</td>
      </tr>
  </tbody>
</table>
<p>BioT5 achieves the best solubility prediction accuracy (74.65%) despite being 2-3x smaller than dedicated protein language models like ESM-1b and ProtBert.</p>
<h3 id="drug-target-interaction-prediction">Drug-Target Interaction Prediction</h3>
<p>BioT5 is evaluated on three DTI datasets (BioSNAP, Human, BindingDB) with five random runs:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>BioSNAP AUROC</th>
          <th>Human AUROC</th>
          <th>BindingDB AUROC</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>DrugBAN</td>
          <td>0.903</td>
          <td>0.982</td>
          <td>0.960</td>
      </tr>
      <tr>
          <td>BioT5</td>
          <td><strong>0.937</strong></td>
          <td><strong>0.989</strong></td>
          <td><strong>0.963</strong></td>
      </tr>
  </tbody>
</table>
<p>BioT5 consistently outperforms DrugBAN and other specialized DTI models across all three datasets.</p>
<h3 id="molecule-captioning-and-text-based-molecule-generation">Molecule Captioning and Text-Based Molecule Generation</h3>
<p>On the ChEBI-20 dataset, BioT5 outperforms all baselines in molecule captioning:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Params</th>
          <th>BLEU-4</th>
          <th>METEOR</th>
          <th>Text2Mol</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MolT5-large</td>
          <td>783M</td>
          <td>0.508</td>
          <td>0.614</td>
          <td>0.582</td>
      </tr>
      <tr>
          <td>MolXPT</td>
          <td>350M</td>
          <td>0.505</td>
          <td>0.626</td>
          <td>0.594</td>
      </tr>
      <tr>
          <td>BioT5</td>
          <td>252M</td>
          <td><strong>0.556</strong></td>
          <td><strong>0.656</strong></td>
          <td><strong>0.603</strong></td>
      </tr>
  </tbody>
</table>
<p>For text-based molecule generation, BioT5 achieves an exact match score of 0.413 (vs. 0.311 for MolT5-large) while maintaining 100% validity, compared to 90.5% for MolT5-large. This demonstrates the direct benefit of SELFIES: every generated sequence is a valid molecule.</p>
<h3 id="protein-protein-interaction-prediction">Protein-Protein Interaction Prediction</h3>
<p>On the PEER PPI benchmarks (Yeast and Human), BioT5 achieves competitive results, outperforming fully fine-tuned ProtBert and ESM-1b on the Yeast dataset (64.89% vs. 63.72% for ProtBert) and placing second on Human (86.22% vs. 88.06% for ESM-1b with frozen weights).</p>
<h2 id="key-findings-limitations-and-future-directions">Key Findings, Limitations, and Future Directions</h2>
<p>BioT5 demonstrates that integrating molecular, protein, and textual modalities within a single pretraining framework yields consistent improvements across diverse biological tasks. Three factors drive BioT5&rsquo;s performance: (1) SELFIES guarantees 100% molecular validity in generation tasks, eliminating a persistent failure mode of SMILES-based models; (2) separate tokenization preserves the semantic integrity of each modality; (3) wrapped text pretraining on PubMed provides contextual biological knowledge that pure sequence models miss.</p>
<p>The authors acknowledge several limitations. BioT5 requires full-parameter fine-tuning for each downstream task because instruction-tuning does not generalize across tasks, and combining datasets via instructions causes data leakage (the authors note overlaps between BindingDB training data and BioSNAP/Human test sets). The model only handles sequence-format bio-entities and does not incorporate 2D or 3D structural information. Additional biological modalities such as DNA/RNA sequences and cell-level data are also left for future work.</p>
<p>The authors also note risks: BioT5 could potentially be misused to generate dangerous molecules, and it may fail to generate effective therapeutic molecules or produce compounds with adverse side effects.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining (molecules)</td>
          <td>ZINC20</td>
          <td>~300M molecules</td>
          <td>Converted from SMILES to SELFIES</td>
      </tr>
      <tr>
          <td>Pretraining (proteins)</td>
          <td><a href="https://en.wikipedia.org/wiki/UniProt">UniRef50</a></td>
          <td>27M proteins</td>
          <td>Filtered by length</td>
      </tr>
      <tr>
          <td>Pretraining (text)</td>
          <td>C4</td>
          <td>Large</td>
          <td>Standard T5 corpus</td>
      </tr>
      <tr>
          <td>Pretraining (wrapped text)</td>
          <td>PubMed</td>
          <td>33M articles</td>
          <td>Entity linking via BERN2</td>
      </tr>
      <tr>
          <td>Pretraining (molecule-text pairs)</td>
          <td>PubChem</td>
          <td>339K pairs</td>
          <td>Excludes ChEBI-20 molecules</td>
      </tr>
      <tr>
          <td>Pretraining (protein-text pairs)</td>
          <td>Swiss-Prot</td>
          <td>569K pairs</td>
          <td>High-quality annotations</td>
      </tr>
      <tr>
          <td>Evaluation (molecular properties)</td>
          <td>MoleculeNet</td>
          <td>6 datasets</td>
          <td>Scaffold splitting</td>
      </tr>
      <tr>
          <td>Evaluation (protein properties)</td>
          <td>PEER</td>
          <td>2 tasks</td>
          <td>Solubility and localization</td>
      </tr>
      <tr>
          <td>Evaluation (DTI)</td>
          <td>BioSNAP, Human, BindingDB</td>
          <td>3 datasets</td>
          <td>Binary classification</td>
      </tr>
      <tr>
          <td>Evaluation (PPI)</td>
          <td>Yeast, Human</td>
          <td>2 datasets</td>
          <td>From PEER benchmark</td>
      </tr>
      <tr>
          <td>Evaluation (generation)</td>
          <td>ChEBI-20</td>
          <td>33K pairs</td>
          <td>Molecule captioning and text-to-molecule</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Architecture: T5-v1.1-base (encoder-decoder transformer)</li>
<li>Optimizer: AdamW with RMS scaling</li>
<li>Learning rate: cosine annealing, base $1 \times 10^{-2}$, minimum $1 \times 10^{-5}$</li>
<li>Warmup steps: 10,000</li>
<li>Dropout: 0.0</li>
<li>Maximum input length: 512 tokens</li>
<li>Pretraining steps: 350K</li>
<li>Batch size: 96 per GPU (6 data types per batch)</li>
<li>Prompt-based fine-tuning for all downstream tasks</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Parameters</th>
          <th>Vocabulary Size</th>
          <th>Architecture</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>BioT5</td>
          <td>252M</td>
          <td>35,073</td>
          <td>T5-v1.1-base</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>Molecule property prediction: AUROC on 6 MoleculeNet tasks (scaffold split, 3 runs)</li>
<li>Protein property prediction: accuracy on PEER benchmark (3 runs)</li>
<li>Drug-target interaction: AUROC, AUPRC, accuracy on 3 DTI datasets (5 runs)</li>
<li>Protein-protein interaction: accuracy on 2 PPI datasets (3 runs)</li>
<li>Molecule captioning: BLEU, ROUGE, METEOR, Text2Mol on ChEBI-20</li>
<li>Text-based molecule generation: BLEU, exact match, fingerprint similarities, FCD, validity on ChEBI-20</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>8x NVIDIA A100 80GB GPUs for pretraining</li>
<li>Codebase: nanoT5</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/QizhiPei/BioT5">BioT5 Code</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Pei, Q., Zhang, W., Zhu, J., Wu, K., Gao, K., Wu, L., Xia, Y., &amp; Yan, R. (2023). BioT5: Enriching Cross-modal Integration in Biology with Chemical Knowledge and Natural Language Associations. <em>Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing</em>, 1102-1123. <a href="https://doi.org/10.18653/v1/2023.emnlp-main.70">https://doi.org/10.18653/v1/2023.emnlp-main.70</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{pei2023biot5,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{BioT5: Enriching Cross-modal Integration in Biology with Chemical Knowledge and Natural Language Associations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Pei, Qizhi and Zhang, Wei and Zhu, Jinhua and Wu, Kehan and Gao, Kaiyuan and Wu, Lijun and Xia, Yingce and Yan, Rui}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1102--1123}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Association for Computational Linguistics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.18653/v1/2023.emnlp-main.70}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MTL-BERT: Multitask BERT for Property Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/mtl-bert-multitask-smiles-enumeration/</link><pubDate>Fri, 27 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/mtl-bert-multitask-smiles-enumeration/</guid><description>MTL-BERT combines BERT pretraining, multitask learning, and SMILES enumeration for molecular property prediction across 60 drug discovery datasets.</description><content:encoded><![CDATA[<h2 id="a-multitask-bert-framework-for-molecular-property-prediction">A Multitask BERT Framework for Molecular Property Prediction</h2>
<p>MTL-BERT is a <strong>Method</strong> paper that introduces a multitask learning framework built on BERT for predicting molecular properties from <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES strings</a>. The primary contribution is the combination of three strategies to address data scarcity in drug discovery: (1) masked token pretraining on 1.7 million unlabeled molecules from <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a>, (2) multitask fine-tuning across 60 property prediction datasets simultaneously, and (3) <a href="/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/">SMILES enumeration</a> as a data augmentation technique applied during pretraining, fine-tuning, and inference. The model achieves strong performance across 60 <a href="https://en.wikipedia.org/wiki/ADME">ADMET</a> and molecular property datasets (44 classification and 16 regression), outperforming baselines including GNNs, XGBoost with molecular fingerprints, and prior <a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a> approaches.</p>
<h2 id="data-scarcity-in-molecular-property-prediction">Data Scarcity in Molecular Property Prediction</h2>
<p>Deep learning methods for molecular property prediction face a fundamental tension: they require large amounts of labeled data to learn effectively, but labeled bioactivity data is scarce due to the cost and time of laboratory experiments. Existing approaches at the time of publication addressed this in isolation. Graph neural networks (GNNs) learn from molecular graphs but are typically shallow (2-3 layers) and prone to overfitting on small datasets. The original SMILES-BERT model applied masked language modeling to SMILES strings but fine-tuned separately for each task, missing opportunities to share information across related properties. Fixed molecular representations like <a href="/notes/chemistry/molecular-representations/encoders/cddd-translation-molecular-descriptors/">CDDD</a> (continuous and data-driven descriptors) cannot be further optimized for specific downstream tasks.</p>
<p>The authors identify three specific gaps: (1) single-task fine-tuning wastes the correlations between related ADMET properties (e.g., <a href="https://en.wikipedia.org/wiki/Lipophilicity">lipophilicity</a> relates to many ADMET endpoints), (2) using only canonical SMILES limits the model&rsquo;s ability to learn robust molecular features, and (3) no prior work had combined pretraining, multitask learning, and SMILES enumeration into a unified framework.</p>
<h2 id="three-strategies-combined-pretraining-multitask-learning-and-smiles-enumeration">Three Strategies Combined: Pretraining, Multitask Learning, and SMILES Enumeration</h2>
<p>The core innovation of MTL-BERT is the synergistic combination of three strategies in a single pipeline.</p>
<h3 id="masked-smiles-pretraining">Masked SMILES Pretraining</h3>
<p>Following the BERT paradigm, MTL-BERT pretrains on 1.7 million unlabeled molecules from ChEMBL using a masked token recovery task. For each SMILES string, 15% of tokens are randomly selected: 80% are replaced with a [MASK] token, 10% are replaced with a random token, and 10% remain unchanged. The loss is computed only at masked positions. Unlike the original BERT, MTL-BERT omits the next-sentence prediction task since there is no sequential relationship between SMILES strings (following the RoBERTa finding that this task is unnecessary).</p>
<p>SMILES strings are tokenized with a regular expression that captures multi-character tokens (e.g., Si, Br, Cl) and common SMILES syntax. The model uses positional encoding to capture token order.</p>
<h3 id="transformer-architecture">Transformer Architecture</h3>
<p>The model uses a standard Transformer encoder with multihead self-attention. The scaled dot-product attention computes:</p>
<p>$$\mathbf{O}_h = \text{softmax}\left(\frac{\mathbf{Q}_h \mathbf{K}_h^T}{\sqrt{d_k}}\right) \mathbf{V}_h$$</p>
<p>where $\mathbf{Q}_h$, $\mathbf{K}_h$, and $\mathbf{V}_h$ are the query, key, and value matrices for head $h$, and $\sqrt{d_k}$ is a scaling factor. The outputs from all heads are concatenated and projected. Each attention sublayer is followed by a position-wise feedforward network with GELU activation, layer normalization, and residual connections.</p>
<p>Three model sizes were compared:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Layers</th>
          <th>Heads</th>
          <th>Embedding Size</th>
          <th>FFN Size</th>
          <th>Recovery Accuracy</th>
          <th>Fine-tuning Performance</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MTL-BERT_SMALL</td>
          <td>4</td>
          <td>4</td>
          <td>128</td>
          <td>512</td>
          <td>0.931</td>
          <td>0.826</td>
      </tr>
      <tr>
          <td>MTL-BERT_MEDIUM</td>
          <td>8</td>
          <td>8</td>
          <td>256</td>
          <td>1,024</td>
          <td>0.962</td>
          <td>0.852</td>
      </tr>
      <tr>
          <td>MTL-BERT_LARGE</td>
          <td>12</td>
          <td>12</td>
          <td>576</td>
          <td>2,304</td>
          <td>0.974</td>
          <td>0.848</td>
      </tr>
  </tbody>
</table>
<p>The medium model was selected for its best fine-tuning performance with lower computational cost, despite the large model achieving higher pretraining recovery accuracy. The slight performance drop for the large model suggests mild overfitting.</p>
<h3 id="multitask-fine-tuning-with-task-tokens">Multitask Fine-tuning with Task Tokens</h3>
<p>During fine-tuning, task tokens ([T0], [T1], &hellip;) are prepended to each input SMILES string. The Transformer output at each task token position is passed through a task-specific two-layer feedforward network for the corresponding prediction task. An attention mask prevents direct information exchange between task tokens, allowing each task to learn directly from SMILES tokens without interference. This design also reduces the discrepancy between pretraining (no task tokens visible) and fine-tuning.</p>
<p>Cross-entropy loss is used for classification tasks and mean squared error for regression tasks. The total multitask loss is a simple sum of per-task losses without learned weighting.</p>
<h3 id="smiles-enumeration-as-data-augmentation">SMILES Enumeration as Data Augmentation</h3>
<p>A molecule can be represented by multiple valid SMILES strings by varying starting atoms and traversal orders. MTL-BERT applies SMILES enumeration at all three stages:</p>
<ol>
<li><strong>Pretraining</strong>: Enumerated SMILES increase diversity of the self-supervised training data.</li>
<li><strong>Fine-tuning</strong>: Each dataset is augmented 20x with random SMILES variants, increasing data diversity and helping the model learn position-invariant features.</li>
<li><strong>Inference</strong>: Multiple SMILES are generated per test molecule, predictions are fused (averaged) for a more robust final prediction.</li>
</ol>
<p>The 20x augmentation factor was chosen based on prior work showing diminishing returns beyond this level while significantly increasing computational cost.</p>
<h2 id="experimental-evaluation-across-60-datasets">Experimental Evaluation Across 60 Datasets</h2>
<h3 id="setup">Setup</h3>
<p>MTL-BERT was evaluated on 60 datasets (44 classification, 16 regression) covering ADMET properties and common molecular benchmarks. Datasets were sourced from ADMETlab and <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>. Each dataset was split 8:1:1 (train/validation/test), and experiments were repeated 10 times with random splits, reporting mean and standard deviation.</p>
<p>Classification tasks were evaluated with <a href="https://en.wikipedia.org/wiki/Receiver_operating_characteristic">ROC-AUC</a> and accuracy; regression tasks with $R^2$ and RMSE.</p>
<h3 id="baselines">Baselines</h3>
<p>Five baselines were compared:</p>
<ul>
<li><strong>ECFP4-XGBoost</strong>: Extended-connectivity fingerprints (diameter 4) with gradient boosting</li>
<li><strong>Graph Attention Network (GAT)</strong></li>
<li><strong>Graph Convolutional Network (GCN)</strong></li>
<li><strong>AttentiveFP</strong>: A GNN with attention for molecular property prediction</li>
<li><strong>CDDD</strong>: Continuous and data-driven descriptors from a pretrained RNN auto-encoder</li>
</ul>
<h3 id="ablation-study">Ablation Study</h3>
<p>Three model variants were compared to isolate contributions:</p>
<ul>
<li><strong>MTL-BERT</strong>: Full model (pretraining + multitask + SMILES enumeration)</li>
<li><strong>STL-BERT</strong>: Single-task fine-tuning with SMILES enumeration (no multitask)</li>
<li><strong>Cano-BERT</strong>: Canonical SMILES only, single-task fine-tuning (equivalent to SMILES-BERT)</li>
</ul>
<p>Cano-BERT showed more than 10% degradation on several datasets (CL, Fu, LC50DM) compared to STL-BERT, demonstrating the importance of SMILES enumeration. MTL-BERT outperformed STL-BERT on most datasets, with improvements exceeding 5% on $F_{20\%}$, SR-ARE, and SR-ATAD5, confirming that multitask learning provides additional benefit on top of enumeration.</p>
<h3 id="results-vs-baselines">Results vs. Baselines</h3>
<p>MTL-BERT outperformed all baselines on nearly all 60 datasets. Specific findings:</p>
<ul>
<li>ECFP4-XGBoost performed inconsistently, doing well on some tasks (e.g., $F_{30\%}$, BACE, CL) but poorly on others, reflecting the limitation of fixed-length fingerprint representations.</li>
<li>GNNs generally improved over fingerprints but still suffered from data scarcity, falling behind ECFP4-XGBoost by more than 3% on $F_{30\%}$, Carcinogenicity, CL, and VD.</li>
<li>MTL-BERT surpassed all baselines except on CYP2C19-sub and BACE (by less than 1.1%).</li>
<li>On 14 tasks (NR-ER, NR-PPAR-gamma, SR-ARE, SR-ATAD5, SR-HSE, SR-MMP, Bioconcentration Factor, Fu, LC50FM, Lipophilicity, CL, PPB, VD, LC50DM), MTL-BERT exceeded the best baseline by more than 5-10%.</li>
<li>Improvements were statistically significant at the 95% confidence level (paired t-test, $P \leq 0.001$).</li>
</ul>
<h3 id="representation-analysis">Representation Analysis</h3>
<p><a href="https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding">t-SNE</a> visualization of pretrained token embeddings (from 1,000 randomly selected molecules, approximately 35,000 tokens) showed that:</p>
<ul>
<li>Tokens of the same type cluster together (capturing atomic type information).</li>
<li>Within type clusters, sub-groups correspond to different chemical environments (e.g., oxygen atoms in nitrate groups vs. carbonyl groups).</li>
<li>Nearby embeddings share similar molecular neighborhood environments.</li>
</ul>
<h3 id="attention-based-interpretability">Attention-based Interpretability</h3>
<p>The model&rsquo;s attention weights provide interpretability for predictions:</p>
<ul>
<li>For a solubility task (LogS/LogD), attention concentrated on polar groups, which are known determinants of aqueous solubility.</li>
<li>For <a href="https://en.wikipedia.org/wiki/Ames_test">AMES</a> (mutagenicity), attention focused on <a href="https://en.wikipedia.org/wiki/Azide">azide</a>, nitrosamide, <a href="https://en.wikipedia.org/wiki/Acyl_chloride">acylchloride</a>, and nitrite groups, which are known mutagenic structural alerts.</li>
</ul>
<h2 id="performance-gains-from-combined-strategies-with-interpretable-attention">Performance Gains from Combined Strategies with Interpretable Attention</h2>
<p>MTL-BERT demonstrates that the combination of pretraining, multitask learning, and SMILES enumeration is more effective than any individual strategy for molecular property prediction. The ablation study provides clear evidence for the additive benefit of each component.</p>
<p>Key strengths include the breadth of evaluation (60 datasets covering diverse ADMET endpoints), the consistent improvement over multiple baseline types (fingerprints, GNNs, pretrained representations), and the interpretable attention mechanism that highlights chemically meaningful substructures.</p>
<p>Limitations to note: the simple sum of multitask losses (no learned task weighting) may not be optimal when tasks have very different scales or when some tasks are unrelated. The authors observe slight degradation on a few datasets (AMES, CYP1A2-Sub, FreeSolv), suggesting negative transfer in those cases. The 20x SMILES enumeration significantly increases computational cost during fine-tuning and inference. The paper does not report wall-clock training times or GPU hours, making it difficult to assess the practical cost of the enumeration strategy. Hardware details are not specified beyond acknowledgment of the High-Performance Computing Center at Central South University.</p>
<p>The hierarchical clustering of task representations reveals meaningful task groupings (e.g., LogD and LogP cluster together due to their shared relationship with water solubility), supporting the premise that multitask learning captures cross-task correlations.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>ChEMBL</td>
          <td>1.7M molecules</td>
          <td>Unlabeled SMILES; 10% held out for evaluation</td>
      </tr>
      <tr>
          <td>Fine-tuning/Evaluation</td>
          <td>ADMETlab + MoleculeNet</td>
          <td>60 datasets (44 classification, 16 regression)</td>
          <td>8:1:1 train/val/test split</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Pretraining</strong>: Masked token prediction (15% masking rate: 80% [MASK], 10% random, 10% unchanged). Adam optimizer, learning rate 1e-4, batch size 512, 50 epochs.</li>
<li><strong>Fine-tuning</strong>: Adam optimizer, learning rate 5e-5, batch size 64, dropout 0.1. Cross-entropy for classification, MSE for regression. Early stopping with patience 20, max 200 epochs.</li>
<li><strong>SMILES enumeration</strong>: 20x augmentation. Repeated search up to 100 times if enumerated SMILES is identical to a previous one.</li>
<li><strong>Inference fusion</strong>: Predictions from multiple enumerated SMILES are averaged.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>MTL-BERT_MEDIUM (selected model): 8 layers, 8 attention heads, 256 embedding size, 1,024 FFN size</li>
<li>Pretraining recovery accuracy: 0.962</li>
<li>1,000 task tokens pre-allocated for future tasks</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task Type</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ROC-AUC</td>
          <td>Classification</td>
          <td>Primary metric</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td>Classification</td>
          <td>Secondary metric</td>
      </tr>
      <tr>
          <td>$R^2$</td>
          <td>Regression</td>
          <td>Primary metric</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>Regression</td>
          <td>Secondary metric</td>
      </tr>
  </tbody>
</table>
<p>All experiments repeated 10 times with random splits; mean and standard deviation reported.</p>
<h3 id="hardware">Hardware</h3>
<p>Hardware specifications are not reported in the paper. The authors acknowledge the High-Performance Computing Center of Central South University.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/zhang-xuan1314/MTL-BERT">MTL-BERT</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Official implementation</td>
      </tr>
      <tr>
          <td><a href="https://www.ebi.ac.uk/chembl/">ChEMBL</a></td>
          <td>Dataset</td>
          <td>CC BY-SA 3.0</td>
          <td>Pretraining data source</td>
      </tr>
      <tr>
          <td><a href="https://moleculenet.org/">MoleculeNet</a></td>
          <td>Dataset</td>
          <td>MIT</td>
          <td>Fine-tuning benchmark</td>
      </tr>
      <tr>
          <td><a href="https://admetmesh.scbdd.com/">ADMETlab</a></td>
          <td>Dataset</td>
          <td>Free for academic use</td>
          <td>ADMET property datasets</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhang, X.-C., Wu, C.-K., Yi, J.-C., Zeng, X.-X., Yang, C.-Q., Lu, A.-P., Hou, T.-J., &amp; Cao, D.-S. (2022). Pushing the boundaries of molecular property prediction for drug discovery with multitask learning BERT enhanced by SMILES enumeration. <em>Research</em>, 2022, Article 0004. <a href="https://doi.org/10.34133/research.0004">https://doi.org/10.34133/research.0004</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{zhang2022mtlbert,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Pushing the Boundaries of Molecular Property Prediction for Drug Discovery with Multitask Learning BERT Enhanced by SMILES Enumeration}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Zhang, Xiao-Chen and Wu, Cheng-Kun and Yi, Jia-Cai and Zeng, Xiang-Xiang and Yang, Can-Qun and Lu, Ai-Ping and Hou, Ting-Jun and Cao, Dong-Sheng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{Article 0004}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.34133/research.0004}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Association for the Advancement of Science (AAAS)}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Mol2vec: Unsupervised ML with Chemical Intuition</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/mol2vec-unsupervised-chemical-intuition/</link><pubDate>Fri, 27 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/mol2vec-unsupervised-chemical-intuition/</guid><description>Mol2vec applies Word2vec to Morgan substructures, learning dense vector representations of molecules that capture chemical similarity for property prediction.</description><content:encoded><![CDATA[<h2 id="word2vec-meets-cheminformatics">Word2vec Meets Cheminformatics</h2>
<p>Mol2vec is a <strong>Method</strong> paper that introduces an unsupervised approach for learning dense vector representations of molecular substructures. The core idea is a direct analogy to <a href="/notes/machine-learning/model-architectures/distributed-representations/">Word2vec</a> from natural language processing: molecular substructures (derived from the Morgan algorithm) are treated as &ldquo;words,&rdquo; and entire molecules are treated as &ldquo;sentences.&rdquo; By training on a large unlabeled corpus of 19.9 million compounds, Mol2vec produces embeddings where chemically related substructures occupy nearby regions of vector space. Compound-level vectors are then obtained by summing constituent substructure vectors, and these can serve as features for downstream supervised learning tasks.</p>
<h2 id="sparse-fingerprints-and-their-limitations">Sparse Fingerprints and Their Limitations</h2>
<p>Molecular fingerprints, particularly Morgan fingerprints (extended-connectivity fingerprints, ECFP), are among the most widely used molecular representations in cheminformatics. They perform well for similarity searching, virtual screening, and activity prediction. However, they suffer from several practical drawbacks:</p>
<ul>
<li><strong>High dimensionality and sparsity</strong>: Morgan fingerprints are typically hashed to fixed-length binary vectors (e.g., 2048 or 4096 bits), resulting in very sparse representations.</li>
<li><strong>Bit collisions</strong>: The hashing step can map distinct substructures to the same bit position, losing structural information.</li>
<li><strong>No learned relationships</strong>: Each bit is independent, so the representation does not encode any notion of chemical similarity between substructures.</li>
</ul>
<p>At the time of this work (2017), NLP techniques had started to appear in cheminformatics. The <a href="https://en.wikipedia.org/wiki/Tf%E2%80%93idf">tf-idf</a> method had been applied to Morgan fingerprints for compound-protein interaction prediction, and <a href="https://en.wikipedia.org/wiki/Latent_Dirichlet_allocation">Latent Dirichlet Allocation</a> had been used for chemical topic modeling. The Word2vec concept had been adapted for protein sequences (ProtVec) but had not yet been applied to small molecules. Mol2vec fills this gap.</p>
<h2 id="from-substructure-identifiers-to-dense-embeddings">From Substructure Identifiers to Dense Embeddings</h2>
<p>The central insight of Mol2vec is that the Morgan algorithm already produces a natural &ldquo;vocabulary&rdquo; of molecular substructures, and the order in which these substructures appear in a molecule provides local context, analogous to word order in a sentence.</p>
<h3 id="corpus-construction">Corpus Construction</h3>
<p>The training corpus was assembled from <a href="https://en.wikipedia.org/wiki/ZINC_database">ZINC</a> v15 and <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a> v23, merged and deduplicated, then filtered by molecular weight (12-600), heavy atom count (3-50), clogP (-5 to 7), and allowed elements (H, B, C, N, O, F, P, S, Cl, Br). This yielded 19.9 million compounds.</p>
<h3 id="sentence-generation">Sentence Generation</h3>
<p>For each molecule, the Morgan algorithm generates atom identifiers at radius 0 and radius 1. Each atom contributes two identifiers (one per radius), ordered according to the atom order in the canonical <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>. This sequence of identifiers forms a &ldquo;sentence&rdquo; for Word2vec training.</p>
<h3 id="word2vec-training">Word2vec Training</h3>
<p>The model was trained using the gensim implementation of Word2vec. After evaluating both CBOW and Skip-gram architectures with window sizes of 5, 10, and 20, and embedding dimensions of 100 and 300, the best configuration was:</p>
<ul>
<li><strong>Architecture</strong>: Skip-gram</li>
<li><strong>Window size</strong>: 10</li>
<li><strong>Embedding dimension</strong>: 300</li>
</ul>
<p>Rare identifiers appearing fewer than 3 times in the corpus were replaced with a special &ldquo;UNSEEN&rdquo; token, which learns a near-zero vector. This allows the model to handle novel substructures at inference time.</p>
<h3 id="compound-vector-generation">Compound Vector Generation</h3>
<p>The final vector for a molecule is the sum of all its substructure vectors:</p>
<p>$$\mathbf{v}_{\text{mol}} = \sum_{i=1}^{N} \mathbf{v}_{s_i}$$</p>
<p>where $\mathbf{v}_{s_i}$ is the 300-dimensional embedding for the $i$-th substructure identifier in the molecule. This summation implicitly captures substructure counts and importance through vector amplitude.</p>
<h2 id="benchmarking-across-regression-and-classification-tasks">Benchmarking Across Regression and Classification Tasks</h2>
<h3 id="datasets">Datasets</h3>
<p>The authors evaluated Mol2vec on four datasets:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Task</th>
          <th>Size</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESOL</td>
          <td>Regression</td>
          <td>1,144</td>
          <td>Aqueous solubility prediction</td>
      </tr>
      <tr>
          <td>Ames</td>
          <td>Classification</td>
          <td>6,511</td>
          <td><a href="https://en.wikipedia.org/wiki/Mutagen">Mutagenicity</a> (balanced: 3,481 positive, 2,990 negative)</td>
      </tr>
      <tr>
          <td>Tox21</td>
          <td>Classification</td>
          <td>8,192</td>
          <td>12 human toxicity targets (imbalanced)</td>
      </tr>
      <tr>
          <td>Kinase</td>
          <td>Classification</td>
          <td>284 kinases</td>
          <td>Bioactivity from ChEMBL v23</td>
      </tr>
  </tbody>
</table>
<h3 id="machine-learning-methods">Machine Learning Methods</h3>
<p>Three ML methods were compared using both Mol2vec and Morgan FP features:</p>
<ul>
<li><strong>Random Forest (RF)</strong>: scikit-learn, 500 estimators</li>
<li><strong>Gradient Boosting Machine (GBM)</strong>: XGBoost, 2000 estimators, max depth 3, learning rate 0.1</li>
<li><strong>Deep Neural Network (DNN)</strong>: Keras/TensorFlow, 4 hidden layers with 2000 neurons each for Mol2vec; 1 hidden layer with 512 neurons for Morgan FP</li>
</ul>
<p>All models were validated using 20x 5-fold cross-validation with the <a href="https://en.wikipedia.org/wiki/Wilcoxon_signed-rank_test">Wilcoxon signed-rank test</a> for statistical comparison.</p>
<h3 id="esol-regression-results">ESOL Regression Results</h3>
<table>
  <thead>
      <tr>
          <th>Features</th>
          <th>Method</th>
          <th>$R^2_{\text{ext}}$</th>
          <th>MSE</th>
          <th>MAE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Descriptors</td>
          <td>MLR</td>
          <td>0.81 +/- 0.01</td>
          <td>0.82</td>
          <td>0.69</td>
      </tr>
      <tr>
          <td>Molecular Graph</td>
          <td>CNN</td>
          <td>0.93</td>
          <td>0.31 +/- 0.03</td>
          <td>0.40 +/- 0.00</td>
      </tr>
      <tr>
          <td>Morgan FP</td>
          <td>GBM</td>
          <td>0.66 +/- 0.00</td>
          <td>1.43 +/- 0.00</td>
          <td>0.88 +/- 0.00</td>
      </tr>
      <tr>
          <td>Mol2vec</td>
          <td>GBM</td>
          <td>0.86 +/- 0.00</td>
          <td>0.62 +/- 0.00</td>
          <td>0.60 +/- 0.00</td>
      </tr>
  </tbody>
</table>
<p>Mol2vec substantially outperformed Morgan FP ($R^2_{\text{ext}}$ 0.86 vs. 0.66) but did not match the best graph convolution methods ($R^2_{\text{ext}}$ ~0.93).</p>
<h3 id="classification-results-ames-and-tox21">Classification Results (Ames and Tox21)</h3>
<p>On the Ames dataset, Mol2vec and Morgan FP performed comparably (AUC 0.87 vs. 0.88), both matching or exceeding prior SVM and Naive Bayes results. On Tox21, both achieved an average AUC of 0.83, outperforming literature results from graph convolution (0.71) and DNN/SVM approaches (0.71-0.72).</p>
<h3 id="proteochemometric-pcm-extension">Proteochemometric (PCM) Extension</h3>
<p>Mol2vec was combined with ProtVec (protein sequence embeddings using the same Word2vec approach on 3-grams) by concatenating vectors, forming PCM2vec. This was evaluated using a rigorous 4-level cross-validation scheme:</p>
<ul>
<li><strong>CV1</strong>: New compound-target pairs</li>
<li><strong>CV2</strong>: New targets</li>
<li><strong>CV3</strong>: New compounds</li>
<li><strong>CV4</strong>: New compounds and targets</li>
</ul>
<p>On Tox21, PCM2vec improved predictions for new compound-target pairs (CV1: AUC 0.87 vs. 0.79 for Morgan FP) and new compounds (CV3: AUC 0.85 vs. 0.78). On the kinase dataset, PCM2vec approached the performance of classical PCM (Morgan + z-scales) while being alignment-independent, meaning it can be applied to proteins with low sequence similarity.</p>
<h2 id="chemical-intuition-and-practical-value">Chemical Intuition and Practical Value</h2>
<h3 id="embedding-quality">Embedding Quality</h3>
<p>The learned substructure embeddings capture meaningful chemical relationships. Hierarchical clustering of the 25 most common substructures shows expected groupings: aromatic carbons cluster together, aliphatic ring carbons form a separate group, and carbonyl carbons and oxygens are closely related. Similarly, t-SNE projections of amino acid vectors encoded by Mol2vec reproduce known amino acid relationships (e.g., similar distances between Glu/Gln and Asp/Asn pairs, reflecting the carboxylic acid to amide transition).</p>
<h3 id="key-findings">Key Findings</h3>
<ol>
<li><strong>Skip-gram with 300-dimensional embeddings</strong> provides the best Mol2vec representations, consistent with NLP best practices.</li>
<li><strong>Mol2vec excels at regression tasks</strong>, substantially outperforming Morgan FP on ESOL solubility prediction ($R^2_{\text{ext}}$ 0.86 vs. 0.66).</li>
<li><strong>Classification performance is competitive</strong> with Morgan FP across Ames and Tox21 datasets.</li>
<li><strong>PCM2vec enables alignment-independent proteochemometrics</strong>, extending PCM approaches to diverse protein families with low sequence similarity.</li>
<li><strong>Tree-based methods (RF, GBM) outperformed DNNs</strong> on these tasks, though the authors note further DNN tuning could help.</li>
</ol>
<h3 id="limitations">Limitations</h3>
<ul>
<li>The compound vector is a simple sum of substructure vectors, which discards information about substructure arrangement and molecular topology.</li>
<li>Only Morgan identifiers at radii 0 and 1 were used. Larger radii might capture more context but would increase vocabulary size.</li>
<li>DNN architectures were not extensively optimized, leaving open the question of how well Mol2vec pairs with deep learning.</li>
<li>The approach was benchmarked against Morgan FP but not against other learned representations such as graph neural networks in a controlled comparison.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ZINC v15 + ChEMBL v23</td>
          <td>19.9M compounds</td>
          <td>Filtered by MW, atom count, clogP, element types</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>ESOL</td>
          <td>1,144 compounds</td>
          <td>Aqueous solubility regression</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Ames</td>
          <td>6,511 compounds</td>
          <td>Mutagenicity classification</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Tox21</td>
          <td>8,192 compounds</td>
          <td>12 toxicity targets, retrieved via DeepChem</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Kinase (ChEMBL v23)</td>
          <td>284 kinases</td>
          <td>IC50/Kd/Ki binding assays</td>
      </tr>
      <tr>
          <td>Protein corpus</td>
          <td><a href="https://en.wikipedia.org/wiki/UniProt">UniProt</a></td>
          <td>554,241 sequences</td>
          <td>For ProtVec training</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Word2vec</strong>: Skip-gram, window size 10, 300-dimensional embeddings, min count 3</li>
<li><strong>Morgan algorithm</strong>: Radii 0 and 1 (119 and 19,831 unique identifiers respectively)</li>
<li><strong>UNSEEN token</strong>: Replaces identifiers occurring fewer than 3 times</li>
<li><strong>Compound vector</strong>: Sum of all substructure vectors</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>RF</strong>: scikit-learn, 500 estimators, sqrt features, balanced class weights</li>
<li><strong>GBM</strong>: XGBoost, 2000 estimators, max depth 3, learning rate 0.1</li>
<li><strong>DNN</strong>: Keras/TensorFlow, 4 layers x 2000 neurons (Mol2vec) or 1 layer x 512 neurons (Morgan FP), ReLU activation, dropout 0.1</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Mol2vec Best</th>
          <th>Morgan FP Best</th>
          <th>Task</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>$R^2_{\text{ext}}$</td>
          <td>0.86 (GBM)</td>
          <td>0.66 (GBM)</td>
          <td>ESOL regression</td>
      </tr>
      <tr>
          <td>AUC</td>
          <td>0.87 (RF)</td>
          <td>0.88 (RF)</td>
          <td>Ames classification</td>
      </tr>
      <tr>
          <td>AUC</td>
          <td>0.83 (RF)</td>
          <td>0.83 (RF)</td>
          <td>Tox21 classification</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/samoturk/mol2vec">mol2vec</a></td>
          <td>Code</td>
          <td>BSD-3-Clause</td>
          <td>Python package with pre-trained model</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Jaeger, S., Fulle, S., &amp; Turk, S. (2018). Mol2vec: Unsupervised Machine Learning Approach with Chemical Intuition. <em>Journal of Chemical Information and Modeling</em>, 58(1), 27-35. <a href="https://doi.org/10.1021/acs.jcim.7b00616">https://doi.org/10.1021/acs.jcim.7b00616</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{jaeger2018mol2vec,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Mol2vec: Unsupervised Machine Learning Approach with Chemical Intuition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Jaeger, Sabrina and Fulle, Simone and Turk, Samo}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{58}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{27--35}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.7b00616}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MG-BERT: Graph BERT for Molecular Property Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/mg-bert-molecular-graph-bert/</link><pubDate>Fri, 27 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/mg-bert-molecular-graph-bert/</guid><description>MG-BERT integrates graph neural network message passing into BERT with masked atom pretraining on 1.7M molecules for molecular property prediction.</description><content:encoded><![CDATA[<h2 id="a-graph-aware-bert-for-molecular-property-prediction">A Graph-Aware BERT for Molecular Property Prediction</h2>
<p>MG-BERT is a <strong>Method</strong> paper that adapts the BERT pretraining paradigm from NLP to molecular graphs. The primary contribution is a modified Transformer architecture that replaces global self-attention with bond-based local attention, allowing atoms to exchange information only through chemical bonds. This creates a deep message-passing network that avoids the oversmoothing problem of conventional graph neural networks (GNNs). Combined with a masked atom prediction pretraining strategy on 1.7 million unlabeled molecules from ChEMBL, MG-BERT learns context-sensitive atomic representations that transfer effectively to downstream property prediction tasks.</p>
<h2 id="data-scarcity-in-molecular-property-prediction">Data Scarcity in Molecular Property Prediction</h2>
<p><a href="/notes/chemistry/molecular-design/property-prediction/">Molecular property prediction</a> is central to drug discovery, particularly for ADMET (Absorption, Distribution, Metabolism, Excretion, and Toxicity) endpoints. While deep learning has advanced many domains, molecular property prediction faces a persistent challenge: labeled data scarcity. ADMET measurements require expensive, time-consuming experiments, and typical datasets contain only hundreds to thousands of examples.</p>
<p>Prior approaches fall into three categories, each with limitations:</p>
<ol>
<li><strong>Feature engineering</strong> (molecular fingerprints, descriptors): Requires expert design, suffers from low scalability, and fixed representations cannot be optimized for specific tasks.</li>
<li><strong>SMILES-based deep learning</strong> (CNNs, LSTMs, Transformers on SMILES strings): Must learn to parse molecular information from complex string syntax, increasing learning difficulty. Autoencoder-based methods (e.g., <a href="/notes/chemistry/molecular-representations/encoders/cddd-translation-molecular-descriptors/">CDDD</a>) learn fixed representations that cannot be fine-tuned.</li>
<li><strong>Graph neural networks</strong> (GAT, GCN): Can learn directly from molecular topology, but are limited to 2-3 layers due to oversmoothing, restricting their capacity to capture deep-level patterns.</li>
</ol>
<p>The BERT model from NLP demonstrated that self-supervised pretraining on large unlabeled corpora followed by fine-tuning on small labeled datasets can substantially improve downstream performance. <a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a> applied this idea to SMILES strings directly, but suffered from interpretability issues due to auxiliary characters in the SMILES syntax. MG-BERT addresses these limitations by operating directly on molecular graphs.</p>
<h2 id="bond-based-local-attention-and-masked-atom-pretraining">Bond-Based Local Attention and Masked Atom Pretraining</h2>
<p>The core innovation of MG-BERT has two components: a modified Transformer architecture for molecular graphs and a self-supervised pretraining strategy.</p>
<h3 id="architecture-modifications">Architecture Modifications</h3>
<p>The original BERT model uses three components: an embedding layer, Transformer encoder layers, and a task-specific output layer. MG-BERT makes three key modifications:</p>
<ol>
<li>
<p><strong>Atom embeddings replace word embeddings.</strong> The dictionary contains 16 tokens: 13 common atom types ([H], [C], [N], [O], [F], [S], [Cl], [P], [Br], [B], [I], [Si], [Se]), plus [UNK] for rare atoms, [MASK] for pretraining, and [GLOBAL] for graph-level readout.</p>
</li>
<li>
<p><strong>No positional encoding.</strong> Unlike sequential text, atoms in a molecular graph have no inherent ordering, so positional embeddings are removed.</p>
</li>
<li>
<p><strong>Local attention replaces global attention.</strong> The adjacency matrix of the molecular graph is used as a visibility matrix to modulate the attention scores. Each atom can only attend to atoms connected by chemical bonds. Formally, the attention is constrained so that:</p>
</li>
</ol>
<p>$$A&rsquo;_{ij} = \begin{cases} A_{ij} &amp; \text{if bond exists between } i \text{ and } j \\ -\infty &amp; \text{otherwise} \end{cases}$$</p>
<p>where $A_{ij}$ is the standard scaled dot-product attention score. This local message passing makes MG-BERT a variant of GNN, but one that can stack many layers (6 in the medium configuration) without oversmoothing, thanks to the residual connections inherited from the Transformer architecture.</p>
<ol start="4">
<li><strong>Supernode for graph-level readout.</strong> A [GLOBAL] supernode is added to each molecular graph, connected to all atoms. This node aggregates information from the entire molecule and serves as the molecular representation for downstream prediction.</li>
</ol>
<h3 id="masked-atom-prediction">Masked Atom Prediction</h3>
<p>The pretraining strategy mirrors BERT&rsquo;s masked language model but operates on atoms:</p>
<ul>
<li>15% of atoms in each molecule are randomly selected (at least one atom per molecule)</li>
<li>Of selected atoms: 80% are replaced with [MASK], 10% are randomly replaced with another atom type, and 10% remain unchanged</li>
<li>The model is trained to predict the original atom type at masked positions</li>
<li>Loss is computed only at masked positions</li>
</ul>
<h3 id="model-configurations">Model Configurations</h3>
<p>Three model sizes were compared:</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>Layers</th>
          <th>Heads</th>
          <th>Embedding Size</th>
          <th>FFN Size</th>
          <th>Recovery Accuracy</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MG-BERT Small</td>
          <td>3</td>
          <td>2</td>
          <td>128</td>
          <td>256</td>
          <td>95.27%</td>
      </tr>
      <tr>
          <td>MG-BERT Medium</td>
          <td>6</td>
          <td>4</td>
          <td>256</td>
          <td>512</td>
          <td>98.31%</td>
      </tr>
      <tr>
          <td>MG-BERT Large</td>
          <td>12</td>
          <td>8</td>
          <td>576</td>
          <td>1152</td>
          <td>98.35%</td>
      </tr>
  </tbody>
</table>
<p>The medium configuration was selected for all experiments because it achieved the best downstream performance, despite the large model having slightly higher pretraining recovery accuracy. The authors attribute this to overfitting risk with the larger model.</p>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<h3 id="pretraining">Pretraining</h3>
<p>MG-BERT was pretrained on 1.7 million compounds randomly selected from ChEMBL, with 10% held out for evaluation (1.53M training molecules). Molecules were converted to 2D undirected graphs using RDKit, with hydrogen atoms explicitly included. The model was pretrained for 10 epochs using Adam with learning rate 1e-4 and batch size 256.</p>
<h3 id="fine-tuning-datasets">Fine-tuning Datasets</h3>
<p>Sixteen datasets covering ADMET endpoints and common molecular properties were collected from ADMETlab and <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>:</p>
<table>
  <thead>
      <tr>
          <th>Type</th>
          <th>Dataset</th>
          <th>Category</th>
          <th>Size</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Regression</td>
          <td>Caco2</td>
          <td>Absorption</td>
          <td>979</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>logD</td>
          <td>Physicochemical</td>
          <td>10,354</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>logS</td>
          <td>Physicochemical</td>
          <td>5,045</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>PPB</td>
          <td>Distribution</td>
          <td>1,480</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>tox</td>
          <td>Toxicity</td>
          <td>7,295</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>ESOL</td>
          <td>Physicochemical</td>
          <td>1,128</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>FreeSolv</td>
          <td>Physicochemical</td>
          <td>642</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>Lipo</td>
          <td>Physicochemical</td>
          <td>4,200</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>Ames</td>
          <td>Toxicity</td>
          <td>6,719</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BBB</td>
          <td>Distribution</td>
          <td>1,855</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>FDAMDD</td>
          <td>Toxicity</td>
          <td>795</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>H_HT</td>
          <td>Toxicity</td>
          <td>2,170</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>Pgp_inh</td>
          <td>Absorption</td>
          <td>2,125</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>Pgp_sub</td>
          <td>Absorption</td>
          <td>1,210</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BACE</td>
          <td>Biophysics</td>
          <td>1,513</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BBBP</td>
          <td>Physiology</td>
          <td>2,039</td>
      </tr>
  </tbody>
</table>
<p>Datasets were split 8:1:1 (train:validation:test) with stratified sampling by SMILES length. Each experiment was repeated 10 times with random splits, reporting mean and standard deviation. Regression was evaluated by R-squared, classification by ROC-AUC. Early stopping with a maximum of 100 epochs was used.</p>
<h3 id="baselines">Baselines</h3>
<p>Five baselines were compared:</p>
<ol>
<li><strong>ECFP4-XGBoost</strong>: Extended connectivity fingerprints (diameter 4) with gradient-boosted trees</li>
<li><strong>GAT</strong>: Graph Attention Network</li>
<li><strong>GCN</strong>: Graph Convolutional Network</li>
<li><strong>CDDD</strong>: Continuous and Data-Driven Descriptors (pretrained RNN encoder on SMILES with a fully connected network)</li>
<li><strong>SMILES-BERT</strong>: Original BERT applied directly to SMILES strings</li>
</ol>
<h3 id="ablation-studies">Ablation Studies</h3>
<p>Two ablation studies were conducted:</p>
<ol>
<li><strong>Pretraining effectiveness</strong>: Comparing pretrained vs. non-pretrained MG-BERT under identical hyperparameters</li>
<li><strong>Hydrogen atoms</strong>: Comparing MG-BERT with and without explicit hydrogen atoms in the molecular graph</li>
</ol>
<h2 id="consistent-improvements-across-admet-benchmarks">Consistent Improvements Across ADMET Benchmarks</h2>
<h3 id="main-results">Main Results</h3>
<p>MG-BERT consistently outperformed all baselines across all 16 datasets. Key results on the 11 ADMET datasets:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>ECFP4-XGBoost</th>
          <th>GAT</th>
          <th>GCN</th>
          <th>CDDD</th>
          <th>SMILES-BERT</th>
          <th>MG-BERT</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Caco2 (R2)</td>
          <td>61.41</td>
          <td>69.16</td>
          <td>67.15</td>
          <td>73.42</td>
          <td>72.39</td>
          <td><strong>74.68</strong></td>
      </tr>
      <tr>
          <td>logD (R2)</td>
          <td>70.84</td>
          <td>84.62</td>
          <td>86.22</td>
          <td>85.85</td>
          <td>86.31</td>
          <td><strong>87.46</strong></td>
      </tr>
      <tr>
          <td>logS (R2)</td>
          <td>73.73</td>
          <td>84.06</td>
          <td>83.47</td>
          <td>84.01</td>
          <td>85.20</td>
          <td><strong>87.66</strong></td>
      </tr>
      <tr>
          <td>PPB (R2)</td>
          <td>55.11</td>
          <td>59.96</td>
          <td>57.34</td>
          <td>54.12</td>
          <td>62.37</td>
          <td><strong>65.94</strong></td>
      </tr>
      <tr>
          <td>Ames (AUC)</td>
          <td>87.21</td>
          <td>86.38</td>
          <td>87.04</td>
          <td>86.82</td>
          <td>87.69</td>
          <td><strong>89.33</strong></td>
      </tr>
      <tr>
          <td>BBB (AUC)</td>
          <td>94.62</td>
          <td>93.03</td>
          <td>92.67</td>
          <td>94.44</td>
          <td>94.02</td>
          <td><strong>95.41</strong></td>
      </tr>
      <tr>
          <td>BBBP (AUC)</td>
          <td>89.16</td>
          <td>90.33</td>
          <td>90.74</td>
          <td>91.12</td>
          <td>91.32</td>
          <td><strong>92.08</strong></td>
      </tr>
  </tbody>
</table>
<p>The overall improvement across all datasets was 28.1% (7.02% on classification, 21.28% on regression). Improvements were statistically significant at the 95% confidence level (paired t-test, P &lt;= 0.001).</p>
<h3 id="pretraining-ablation">Pretraining Ablation</h3>
<p>Pretraining improved performance by more than 2% on all datasets. The benefit was largest for small datasets: Caco2 improved by approximately 10 percentage points (64.79 to 74.68 R2), and FDAMDD improved by about 7.5 points (80.76 to 88.23 AUC). This confirms that self-supervised pretraining effectively addresses the labeled data scarcity problem.</p>
<h3 id="hydrogen-atom-ablation">Hydrogen Atom Ablation</h3>
<p>Including explicit hydrogen atoms improved pretraining recovery accuracy from 92.25% to 98.31% and consistently improved downstream performance. The authors provide an intuitive explanation: hydrogen atoms help determine bond counts for neighboring atoms, which is critical for the masked atom recovery task. They also show that removing hydrogens can make structurally distinct molecules (e.g., benzene and cyclohexane) indistinguishable at the graph level.</p>
<h3 id="interpretability-via-attention-visualization">Interpretability via Attention Visualization</h3>
<p>The authors provide two forms of interpretability analysis:</p>
<ol>
<li>
<p><strong>t-SNE visualization of atomic representations</strong>: Pretrained atomic representations cluster by atom type and, more specifically, by local chemical environment (e.g., aromatic carbons separate from aliphatic carbons, C-N bonds from C-O bonds). This demonstrates that pretraining captures neighborhood context beyond simple atom identity.</p>
</li>
<li>
<p><strong>Attention weight visualization</strong>: On the logD task, the supernode&rsquo;s attention focuses on polar groups (which govern lipophilicity). On the Ames mutagenicity task, attention concentrates on known mutagenic structural alerts (acylchloride, nitrosamide, azide groups). This provides chemically meaningful explanations for predictions.</p>
</li>
</ol>
<h3 id="limitations">Limitations</h3>
<p>The paper does not extensively discuss limitations, but several can be identified:</p>
<ul>
<li>The model uses only 2D molecular topology (atom types and bonds) without 3D conformational information or bond-type features</li>
<li>The atom dictionary is limited to 13 common types plus [UNK], which may lose information for molecules containing rarer elements</li>
<li>Evaluation is limited to ADMET-focused datasets; broader chemical spaces (e.g., materials, catalysts) are not tested</li>
<li>The comparison baselines do not include other graph-based pretraining methods (e.g., the contemporaneous Strategies for Pre-training Graph Neural Networks by Hu et al.)</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>ChEMBL (random subset)</td>
          <td>1.7M molecules (1.53M train)</td>
          <td>10% held out for evaluation</td>
      </tr>
      <tr>
          <td>Fine-tuning</td>
          <td>ADMETlab + MoleculeNet</td>
          <td>16 datasets (642-10,354 molecules)</td>
          <td>8:1:1 splits, stratified by SMILES length</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Optimizer</strong>: Adam (pretraining: lr=1e-4, batch=256; fine-tuning: lr from {1e-5, 5e-5, 1e-4}, batch from {16, 32, 64})</li>
<li><strong>Pretraining epochs</strong>: 10</li>
<li><strong>Fine-tuning</strong>: Up to 100 epochs with early stopping</li>
<li><strong>Dropout</strong>: Optimized per task in range [0.0, 0.5]</li>
<li><strong>Masking</strong>: 15% of atoms (80% [MASK], 10% random, 10% unchanged)</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: MG-BERT Medium (6 layers, 4 heads, embedding size 256, FFN size 512)</li>
<li><strong>Molecule processing</strong>: RDKit for graph conversion with explicit hydrogens</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task Type</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>R-squared (R2)</td>
          <td>Regression</td>
          <td>Higher is better</td>
      </tr>
      <tr>
          <td>ROC-AUC</td>
          <td>Classification</td>
          <td>Higher is better</td>
      </tr>
      <tr>
          <td>Accuracy, RMSE</td>
          <td>Both</td>
          <td>Reported in supplementary Table S1</td>
      </tr>
  </tbody>
</table>
<p>All results averaged over 10 random splits with standard deviations reported.</p>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify hardware requirements (GPU type, training time, or memory usage).</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/zhang-xuan1314/Molecular-graph-BERT">Molecular-graph-BERT</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Jupyter Notebook implementation; last code push August 2021</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhang, X.-C., Wu, C.-K., Yang, Z.-J., Wu, Z.-X., Yi, J.-C., Hsieh, C.-Y., Hou, T.-J., &amp; Cao, D.-S. (2021). MG-BERT: leveraging unsupervised atomic representation learning for molecular property prediction. <em>Briefings in Bioinformatics</em>, 22(6), bbab152. <a href="https://doi.org/10.1093/bib/bbab152">https://doi.org/10.1093/bib/bbab152</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{zhang2021mgbert,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{{MG-BERT}: leveraging unsupervised atomic representation learning for molecular property prediction}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Zhang, Xiao-Chen and Wu, Cheng-Kun and Yang, Zhi-Jiang and Wu, Zhen-Xing and Yi, Jia-Cai and Hsieh, Chang-Yu and Hou, Ting-Jun and Cao, Dong-Sheng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Briefings in Bioinformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{22}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{6}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{bbab152}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Oxford University Press}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1093/bib/bbab152}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Maxsmi: SMILES Augmentation for Property Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/maxsmi-smiles-augmentation-property-prediction/</link><pubDate>Fri, 27 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/maxsmi-smiles-augmentation-property-prediction/</guid><description>Maxsmi systematically evaluates five SMILES augmentation strategies with CNN and RNN models across solubility, lipophilicity, and bioactivity tasks.</description><content:encoded><![CDATA[<h2 id="systematic-benchmarking-of-smiles-data-augmentation">Systematic Benchmarking of SMILES Data Augmentation</h2>
<p>This is an <strong>Empirical</strong> paper that systematically evaluates how SMILES augmentation affects deep learning molecular property prediction. The primary contribution is a comprehensive comparison of five augmentation strategies across three neural network architectures and four datasets, producing the &ldquo;Maxsmi&rdquo; models that maximize prediction performance. The study also demonstrates that test-time augmentation provides a practical confidence measure for predictions.</p>
<h2 id="the-data-scarcity-problem-in-qsar-modeling">The Data Scarcity Problem in QSAR Modeling</h2>
<p>Deep learning models require large training sets to perform well, but experimental physico-chemical and bioactivity datasets remain small, typically ranging from hundreds to a few thousand compounds. SMILES augmentation, where the non-unique <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES representation</a> of a molecule is exploited to generate multiple training examples per compound, has been shown to help in prior work by Bjerrum (2017), Kimber et al. (2018), and Li and Fourches (2020). However, no prior study had systematically compared different augmentation strategies, analyzed how much augmentation is needed, or examined the relationship between augmentation factor and prediction confidence. Most previous work chose augmentation numbers a priori without justification. Maxsmi fills this gap by providing a systematic analysis and practical guidelines.</p>
<h2 id="five-augmentation-strategies-and-test-time-ensemble-learning">Five Augmentation Strategies and Test-Time Ensemble Learning</h2>
<p>The core insight is twofold. First, the authors define five distinct strategies for generating augmented SMILES:</p>
<ol>
<li><strong>No augmentation</strong>: use only the canonical SMILES (baseline)</li>
<li><strong>Augmentation with duplication</strong>: generate $m$ random SMILES per compound, allowing duplicates; dataset grows to $N \times m$</li>
<li><strong>Augmentation without duplication</strong>: generate $m$ random SMILES and discard exact duplicates</li>
<li><strong>Augmentation with reduced duplication</strong>: keep only $f(m) = \sqrt{m}$ copies of each duplicate, a compromise between the above</li>
<li><strong>Augmentation with estimated maximum</strong>: sample random SMILES until the same string has been generated 10 times, attempting to cover most of the valid SMILES space</li>
</ol>
<p>Second, the authors formalize test-time augmentation as ensemble learning. Given a trained model $M_{\Theta}$, each test compound $C$ is represented by $k$ random SMILES $S_1(C), \ldots, S_k(C)$. The per-SMILES predictions are:</p>
<p>$$
\hat{y}_i(C) = M_{\Theta}(S_i(C))
$$</p>
<p>The compound-level prediction is an aggregation (mean) over these:</p>
<p>$$
\hat{y}(C) = A\big(\hat{y}_1(C), \ldots, \hat{y}_k(C)\big)
$$</p>
<p>The standard deviation of the per-SMILES predictions serves as a confidence measure: high variance indicates the model is uncertain about a compound.</p>
<h2 id="experimental-design-three-architectures-four-datasets">Experimental Design: Three Architectures, Four Datasets</h2>
<h3 id="datasets">Datasets</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Size (after preprocessing)</th>
          <th>Train / Test</th>
          <th>Task</th>
          <th>Provenance</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESOL</td>
          <td>1,128</td>
          <td>902 / 226</td>
          <td>Water solubility</td>
          <td><a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a></td>
      </tr>
      <tr>
          <td>ESOL_small</td>
          <td>1,068</td>
          <td>854 / 214</td>
          <td>Solubility (max 25 heavy atoms)</td>
          <td>MoleculeNet</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>642</td>
          <td>513 / 129</td>
          <td>Hydration free energy</td>
          <td>MoleculeNet</td>
      </tr>
      <tr>
          <td><a href="https://en.wikipedia.org/wiki/Lipophilicity">Lipophilicity</a></td>
          <td>4,199</td>
          <td>3,359 / 840</td>
          <td>Octanol/water distribution</td>
          <td><a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a></td>
      </tr>
      <tr>
          <td>Affinity (EGFR)</td>
          <td>5,849</td>
          <td>4,679 / 1,170</td>
          <td><a href="https://en.wikipedia.org/wiki/IC50">pIC50</a> against <a href="https://en.wikipedia.org/wiki/Epidermal_growth_factor_receptor">EGFR</a> kinase</td>
          <td>Kinodata</td>
      </tr>
  </tbody>
</table>
<h3 id="architectures">Architectures</h3>
<p>Three shallow neural networks are compared:</p>
<ul>
<li><strong>CONV1D</strong>: 1D convolution (kernel size 10, stride 1) followed by two fully connected layers</li>
<li><strong>CONV2D</strong>: 2D convolution on the one-hot encoded SMILES matrix, followed by two fully connected layers</li>
<li><strong>RNN</strong>: LSTM layer followed by two fully connected layers (128 and 64 units)</li>
</ul>
<p>All models are trained for 250 epochs with batch size 16, MSE loss, SGD optimizer, and learning rate 0.001. A Random Forest baseline with Morgan fingerprints (radius 2, length 1024) is also included.</p>
<h3 id="augmentation-sweep">Augmentation sweep</h3>
<p>The augmentation number $m$ is varied from 1 to 20 (step 1) and from 20 to 100 (step 10) for three strategies (with, without, and reduced duplication). The estimated maximum strategy is tested on the smaller datasets. Both training and test sets receive the same augmentation.</p>
<h2 id="key-findings-augmentation-consistently-improves-rmse">Key Findings: Augmentation Consistently Improves RMSE</h2>
<h3 id="augmentation-always-helps">Augmentation always helps</h3>
<p>Across all datasets and architectures, SMILES augmentation reduces test RMSE compared to the no-augmentation baseline. Performance improves sharply in the low augmentation range (1 to 10) and reaches a plateau around 40 to 70, after which additional augmentation provides diminishing returns.</p>
<h3 id="best-models-maxsmi">Best models (Maxsmi)</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Model</th>
          <th>Augmentation Number</th>
          <th>Strategy</th>
          <th>Test RMSE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESOL</td>
          <td>CONV1D</td>
          <td>70</td>
          <td>Reduced duplication</td>
          <td>0.569</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>CONV1D</td>
          <td>70</td>
          <td>With duplication</td>
          <td>1.032</td>
      </tr>
      <tr>
          <td>Lipophilicity</td>
          <td>CONV1D</td>
          <td>80</td>
          <td>Without duplication</td>
          <td>0.593</td>
      </tr>
  </tbody>
</table>
<p>The CONV1D architecture consistently outperforms RNN and CONV2D. For ESOL, the CONV1D model improves from 0.839 RMSE (no augmentation) to 0.569 RMSE (70x reduced duplication), a 32% reduction.</p>
<h3 id="no-single-best-augmentation-strategy">No single best augmentation strategy</h3>
<p>The three main augmentation strategies (with, without, and reduced duplication) perform similarly. Generating the estimated maximum number of unique SMILES does not yield the best results, suggesting a saturation point exists where additional SMILES diversity stops helping.</p>
<h3 id="canonical-smiles-outperform-single-random-smiles">Canonical SMILES outperform single random SMILES</h3>
<p>When augmentation is limited to a single representation ($m = 1$), the canonical SMILES consistently outperforms a single random SMILES. On ESOL with CONV1D, the canonical model achieves 0.839 RMSE versus 0.964 for a random SMILES. The authors attribute this to the simpler, more readable structure of canonical SMILES (fewer branches and brackets).</p>
<h3 id="comparison-to-prior-work">Comparison to prior work</h3>
<table>
  <thead>
      <tr>
          <th>Study</th>
          <th>ESOL</th>
          <th>FreeSolv</th>
          <th>Lipophilicity</th>
          <th>Model</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Maxsmi</td>
          <td>0.569</td>
          <td>1.032</td>
          <td>0.593</td>
          <td>CNN</td>
      </tr>
      <tr>
          <td>MoleculeNet</td>
          <td>0.58 +/- 0.03</td>
          <td>1.15 +/- 0.12</td>
          <td>0.655 +/- 0.036</td>
          <td>GNN</td>
      </tr>
      <tr>
          <td>CNF</td>
          <td>0.62</td>
          <td>1.11</td>
          <td>0.67</td>
          <td>CNN</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/property-prediction/molpmofit-transfer-learning-qsar/">MolPMoFiT</a></td>
          <td>N/A</td>
          <td>1.197 +/- 0.127</td>
          <td>0.565 +/- 0.037</td>
          <td>RNN</td>
      </tr>
  </tbody>
</table>
<p>Maxsmi outperforms or matches MoleculeNet&rsquo;s graph neural networks and the CNF model on all three tasks. MolPMoFiT slightly outperforms Maxsmi on lipophilicity (0.565 vs 0.593) but performs worse on FreeSolv.</p>
<h3 id="confidence-estimation">Confidence estimation</h3>
<p>The standard deviation of per-SMILES predictions correlates with prediction error. Confidence curves show that sequentially removing compounds with the highest uncertainty leads to monotonically decreasing mean prediction error. For ESOL, keeping only the top 10% most confident predictions yields errors below 0.25.</p>
<h3 id="egfr-affinity-test-case">EGFR affinity test case</h3>
<p>Applying the Maxsmi approach (CONV1D, 70x augmentation, reduced duplication) to EGFR kinase affinity prediction yields test RMSE of 0.777 and R2 of 0.712, compared to 1.031 RMSE and 0.494 R2 for the canonical model (a 25% RMSE improvement). The Random Forest baseline (0.758 RMSE, 0.726 R2) performs comparably, which the authors note without further explanation.</p>
<h3 id="limitations">Limitations</h3>
<ul>
<li>All experiments use a single train/test split (80/20) without cross-validation, due to the computational cost of the full augmentation sweep. This means reported RMSE values lack uncertainty estimates for the Maxsmi models.</li>
<li>The study uses shallow networks only. Whether the same augmentation benefits apply to deeper architectures or pre-trained models is untested.</li>
<li>The EGFR test case shows the Random Forest baseline performing comparably to the Maxsmi model, raising questions about when SMILES augmentation provides a meaningful advantage over traditional fingerprint-based methods.</li>
<li>The comparison to prior work uses different splits, preprocessing, and evaluation protocols across studies, which the authors acknowledge limits direct comparability.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Evaluation</td>
          <td>ESOL</td>
          <td>1,128</td>
          <td>MoleculeNet, water solubility</td>
      </tr>
      <tr>
          <td>Training/Evaluation</td>
          <td>FreeSolv</td>
          <td>642</td>
          <td>MoleculeNet, hydration free energy</td>
      </tr>
      <tr>
          <td>Training/Evaluation</td>
          <td>Lipophilicity</td>
          <td>4,199</td>
          <td>ChEMBL, logD</td>
      </tr>
      <tr>
          <td>Test case</td>
          <td>EGFR Affinity</td>
          <td>5,849</td>
          <td>Kinodata (ChEMBL v28), pIC50</td>
      </tr>
  </tbody>
</table>
<p>All datasets are publicly available through MoleculeNet/DeepChem and Kinodata.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>SMILES generation via <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a>&rsquo;s random SMILES enumeration</li>
<li>One-hot encoding of SMILES characters with padding to max length</li>
<li>Five augmentation strategies applied to both training and test sets</li>
<li>Mean aggregation for compound-level predictions</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Architecture</th>
          <th>Parameters</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CONV1D</td>
          <td>1D conv (kernel 10, stride 1) + 2 FC layers</td>
          <td>Not specified</td>
      </tr>
      <tr>
          <td>CONV2D</td>
          <td>2D conv (single channel) + 2 FC layers</td>
          <td>Not specified</td>
      </tr>
      <tr>
          <td>RNN</td>
          <td>LSTM + FC(128) + FC(64)</td>
          <td>Not specified</td>
      </tr>
      <tr>
          <td>RF Baseline</td>
          <td>Random Forest (default sklearn)</td>
          <td>Morgan FP, radius 2, length 1024</td>
      </tr>
  </tbody>
</table>
<p>Training: 250 epochs, batch size 16, MSE loss, SGD, lr=0.001.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RMSE (ESOL)</td>
          <td>0.569</td>
          <td>1.102 (RF)</td>
          <td>CONV1D, 70x reduced dup</td>
      </tr>
      <tr>
          <td>RMSE (FreeSolv)</td>
          <td>1.032</td>
          <td>2.563 (RF)</td>
          <td>CONV1D, 70x with dup</td>
      </tr>
      <tr>
          <td>RMSE (Lipophilicity)</td>
          <td>0.593</td>
          <td>0.860 (RF)</td>
          <td>CONV1D, 80x without dup</td>
      </tr>
      <tr>
          <td>RMSE (EGFR)</td>
          <td>0.777</td>
          <td>0.758 (RF)</td>
          <td>CONV1D, 70x reduced dup</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Training was performed on a GeForce GTX 1080 Ti, provided by the HPC cluster at Freie Universitat Berlin. Training CONV1D on ESOL with 100x augmentation (keeping duplicates, 90,200 data points) takes approximately 3 hours. Training with 19x augmentation achieves RMSE of 0.605 in under 30 minutes.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/volkamerlab/maxsmi">volkamerlab/maxsmi</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Full source code, trained models, CLI for prediction</td>
      </tr>
      <tr>
          <td><a href="https://maxsmi.readthedocs.io/en/latest/">Documentation</a></td>
          <td>Docs</td>
          <td>N/A</td>
          <td>Read the Docs documentation</td>
      </tr>
      <tr>
          <td><a href="https://github.com/openkinome/kinodata">Kinodata</a></td>
          <td>Dataset</td>
          <td>N/A</td>
          <td>Curated kinase bioactivity data from ChEMBL v28</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status</strong>: Highly Reproducible. Code, data, trained models, and a command-line prediction tool are all publicly available under the MIT license.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Kimber, T. B., Gagnebin, M., &amp; Volkamer, A. (2021). Maxsmi: Maximizing molecular property prediction performance with confidence estimation using SMILES augmentation and deep learning. <em>Artificial Intelligence in the Life Sciences</em>, 1, 100014. <a href="https://doi.org/10.1016/j.ailsci.2021.100014">https://doi.org/10.1016/j.ailsci.2021.100014</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{kimber2021maxsmi,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Maxsmi: Maximizing molecular property prediction performance with confidence estimation using SMILES augmentation and deep learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Kimber, Talia B. and Gagnebin, Maxime and Volkamer, Andrea}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Artificial Intelligence in the Life Sciences}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{100014}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Elsevier}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1016/j.ailsci.2021.100014}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MAT: Graph-Augmented Transformer for Molecules (2020)</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/molecule-attention-transformer/</link><pubDate>Fri, 27 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/molecule-attention-transformer/</guid><description>MAT augments the Transformer self-attention mechanism with inter-atomic distances and molecular graph adjacency for molecular property prediction.</description><content:encoded><![CDATA[<h2 id="a-graph-augmented-transformer-for-molecular-property-prediction">A Graph-Augmented Transformer for Molecular Property Prediction</h2>
<p>This is a <strong>Method</strong> paper that proposes the Molecule Attention Transformer (MAT), a Transformer-based architecture adapted for molecular property prediction. The primary contribution is a modified self-attention mechanism that incorporates inter-atomic distances and molecular graph structure alongside the standard query-key attention. Combined with self-supervised pretraining on 2 million molecules from ZINC15, MAT achieves competitive performance across seven diverse molecular property prediction tasks while requiring minimal hyperparameter tuning.</p>
<h2 id="challenges-in-deep-learning-for-molecular-properties">Challenges in Deep Learning for Molecular Properties</h2>
<p>Predicting molecular properties is central to drug discovery and materials design, yet deep neural networks have struggled to consistently outperform shallow methods like random forests and SVMs on these tasks. Wu et al. (2018) demonstrated through the MoleculeNet benchmark that graph neural networks do not reliably beat classical models. Two recurring problems compound this:</p>
<ol>
<li><strong>Underfitting</strong>: Graph neural networks tend to underfit training data, with performance failing to scale with model complexity (Ishiguro et al., 2019).</li>
<li><strong>Hyperparameter sensitivity</strong>: Deep models for molecule property prediction require extensive hyperparameter search (often 500+ configurations) to achieve competitive results, making them impractical for many practitioners.</li>
</ol>
<p>Concurrent work explored using vanilla Transformers on SMILES string representations of molecules (Honda et al., 2019; Wang et al., 2019), but these approaches discard the explicit structural information encoded in molecular graphs and 3D conformations. The motivation for MAT is to combine the flexibility of the Transformer architecture with domain-specific inductive biases from molecular structure.</p>
<h2 id="molecule-self-attention-combining-attention-distance-and-graph-structure">Molecule Self-Attention: Combining Attention, Distance, and Graph Structure</h2>
<p>The core innovation is the Molecule Self-Attention layer, which replaces standard Transformer self-attention. In a standard Transformer, head $i$ computes:</p>
<p>$$
\mathcal{A}^{(i)} = \rho\left(\frac{\mathbf{Q}_{i} \mathbf{K}_{i}^{T}}{\sqrt{d_{k}}}\right) \mathbf{V}_{i}
$$</p>
<p>MAT augments this with two additional information sources. Let $\mathbf{A} \in {0, 1}^{N_{\text{atoms}} \times N_{\text{atoms}}}$ denote the molecular graph adjacency matrix and $\mathbf{D} \in \mathbb{R}^{N_{\text{atoms}} \times N_{\text{atoms}}}$ denote the inter-atomic distance matrix. The modified attention becomes:</p>
<p>$$
\mathcal{A}^{(i)} = \left(\lambda_{a} \rho\left(\frac{\mathbf{Q}_{i} \mathbf{K}_{i}^{T}}{\sqrt{d_{k}}}\right) + \lambda_{d}, g(\mathbf{D}) + \lambda_{g}, \mathbf{A}\right) \mathbf{V}_{i}
$$</p>
<p>where $\lambda_{a}$, $\lambda_{d}$, and $\lambda_{g}$ are scalar hyperparameters weighting each component, and $g$ is either a row-wise softmax or an element-wise exponential decay $g(d) = \exp(-d)$.</p>
<p>Key architectural details:</p>
<ul>
<li><strong>Atom embedding</strong>: Each atom is represented as a 26-dimensional vector encoding atomic identity (one-hot over B, N, C, O, F, P, S, Cl, Br, I, dummy, other), number of heavy neighbors, number of hydrogens, formal charge, ring membership, and aromaticity.</li>
<li><strong>Dummy node</strong>: An artificial disconnected node (distance $10^{6}$ from all atoms) is added to each molecule, allowing the model to &ldquo;skip&rdquo; attention heads when no relevant pattern exists, similar to how BERT uses the separation token.</li>
<li><strong>3D conformers</strong>: Distance matrices are computed from RDKit-generated 3D conformers using the Universal Force Field (UFF).</li>
<li><strong>Pretraining</strong>: Node-level masked atom prediction on 2 million ZINC15 molecules (following Hu et al., 2019), where 15% of atom features are masked and the model predicts them.</li>
</ul>
<h2 id="benchmark-evaluation-and-ablation-studies">Benchmark Evaluation and Ablation Studies</h2>
<h3 id="experimental-setup">Experimental setup</h3>
<p>MAT is evaluated on seven molecular property prediction datasets spanning regression and classification:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Task</th>
          <th>Size</th>
          <th>Metric</th>
          <th>Split</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FreeSolv</td>
          <td>Regression (hydration free energy)</td>
          <td>642</td>
          <td>RMSE</td>
          <td>Random</td>
      </tr>
      <tr>
          <td>ESOL</td>
          <td>Regression (log solubility)</td>
          <td>1,128</td>
          <td>RMSE</td>
          <td>Random</td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td>Classification (BBB permeability)</td>
          <td>2,039</td>
          <td>ROC AUC</td>
          <td>Scaffold</td>
      </tr>
      <tr>
          <td>Estrogen-alpha</td>
          <td>Classification (receptor activity)</td>
          <td>2,398</td>
          <td>ROC AUC</td>
          <td>Scaffold</td>
      </tr>
      <tr>
          <td>Estrogen-beta</td>
          <td>Classification (receptor activity)</td>
          <td>1,961</td>
          <td>ROC AUC</td>
          <td>Scaffold</td>
      </tr>
      <tr>
          <td>MetStab-high</td>
          <td>Classification (metabolic stability)</td>
          <td>2,127</td>
          <td>ROC AUC</td>
          <td>Random</td>
      </tr>
      <tr>
          <td>MetStab-low</td>
          <td>Classification (metabolic stability)</td>
          <td>2,127</td>
          <td>ROC AUC</td>
          <td>Random</td>
      </tr>
  </tbody>
</table>
<p>Baselines include GCN, Weave, EAGCN, Random Forest (RF), and SVM. Each model receives the same hyperparameter search budget (150 or 500 evaluations). Results are averaged over 6 random train/validation/test splits.</p>
<h3 id="main-results">Main results</h3>
<p>MAT achieves the best average rank across all seven tasks:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Avg. Rank (500 budget)</th>
          <th>Avg. Rank (150 budget)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MAT</td>
          <td>2.42</td>
          <td>2.71</td>
      </tr>
      <tr>
          <td>RF</td>
          <td>3.14</td>
          <td>3.14</td>
      </tr>
      <tr>
          <td>SVM</td>
          <td>3.57</td>
          <td>3.28</td>
      </tr>
      <tr>
          <td>GCN</td>
          <td>3.57</td>
          <td>3.71</td>
      </tr>
      <tr>
          <td>Weave</td>
          <td>3.71</td>
          <td>3.57</td>
      </tr>
      <tr>
          <td>EAGCN</td>
          <td>4.14</td>
          <td>4.14</td>
      </tr>
  </tbody>
</table>
<p>With self-supervised pretraining, Pretrained MAT achieves an average rank of 1.57, outperforming both Pretrained EAGCN (4.0) and SMILES Transformer (4.29). Pretrained MAT requires tuning only the learning rate (7 values tested), compared to 500 hyperparameter combinations for the non-pretrained models.</p>
<h3 id="ablation-results">Ablation results</h3>
<p>Ablation studies on BBBP, ESOL, and FreeSolv reveal:</p>
<table>
  <thead>
      <tr>
          <th>Variant</th>
          <th>BBBP (AUC)</th>
          <th>ESOL (RMSE)</th>
          <th>FreeSolv (RMSE)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MAT (full)</td>
          <td>.723</td>
          <td>.286</td>
          <td>.250</td>
      </tr>
      <tr>
          <td>- Graph</td>
          <td>.716</td>
          <td>.316</td>
          <td>.276</td>
      </tr>
      <tr>
          <td>- Distance</td>
          <td>.729</td>
          <td>.281</td>
          <td>.281</td>
      </tr>
      <tr>
          <td>- Attention</td>
          <td>.692</td>
          <td>.306</td>
          <td>.329</td>
      </tr>
      <tr>
          <td>- Dummy node</td>
          <td>.714</td>
          <td>.317</td>
          <td>.249</td>
      </tr>
      <tr>
          <td>+ Edge features</td>
          <td>.683</td>
          <td>.314</td>
          <td>.358</td>
      </tr>
  </tbody>
</table>
<p>Removing any single component degrades performance on at least one task, supporting the value of combining all three information sources. Adding edge features does not help, suggesting the adjacency and distance matrices already capture sufficient bond-level information.</p>
<h3 id="interpretability-analysis">Interpretability analysis</h3>
<p>Individual attention heads in the first layer learn chemically meaningful functions. Six heads were identified that focus on specific chemical patterns: 2-neighbored aromatic carbons, sulfur atoms, non-ring nitrogens, carbonyl oxygens, 3-neighbored aromatic atoms (substitution positions), and aromatic ring nitrogens. Statistical validation using Kruskal-Wallis tests confirmed that atoms matching these SMARTS patterns receive significantly higher attention weights ($p &lt; 0.001$ for all patterns).</p>
<h2 id="findings-limitations-and-future-directions">Findings, Limitations, and Future Directions</h2>
<p>MAT demonstrates that augmenting Transformer self-attention with molecular graph structure and 3D distance information produces a model that performs consistently well across diverse property prediction tasks. The key practical finding is that self-supervised pretraining dramatically reduces the hyperparameter tuning burden: Pretrained MAT matches or exceeds the performance of extensively tuned models while requiring only learning rate selection.</p>
<p>Several limitations are acknowledged:</p>
<ul>
<li><strong>Fingerprint-based models still win on some tasks</strong>: RF and SVM with extended-connectivity fingerprints outperform MAT on metabolic stability and Estrogen-beta tasks, suggesting that incorporating fingerprint representations could improve MAT further.</li>
<li><strong>Single conformer</strong>: Only one pre-computed 3D conformer is used per molecule. More sophisticated conformer sampling or ensemble strategies were not explored.</li>
<li><strong>Limited pretraining exploration</strong>: Only the masked atom prediction task from Hu et al. (2019) was used. The authors note that exploring additional pretraining objectives is a promising direction.</li>
<li><strong>Scalability</strong>: The pretrained model uses 1024-dimensional embeddings with 8 layers and 16 attention heads, fitting the largest model that fits in GPU memory.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>ZINC15</td>
          <td>2M molecules</td>
          <td>Sampled from ZINC database</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>FreeSolv</td>
          <td>642</td>
          <td>Hydration free energy regression</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>ESOL</td>
          <td>1,128</td>
          <td>Log solubility regression</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>BBBP</td>
          <td>2,039</td>
          <td>Blood-brain barrier classification</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Estrogen-alpha/beta</td>
          <td>2,398 / 1,961</td>
          <td>Receptor activity classification</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>MetStab-high/low</td>
          <td>2,127 each</td>
          <td>Metabolic stability classification</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Optimizer: Adam with Noam learning rate scheduler (warmup then inverse square root decay)</li>
<li>Pretraining: 8 epochs, learning rate 0.001, batch size 256, binary cross-entropy loss</li>
<li>Fine-tuning: 100 epochs, batch size 32, learning rate selected from {1e-3, 5e-4, 1e-4, 5e-5, 1e-5, 5e-6, 1e-6}</li>
<li>Distance kernel: exponential decay $g(d) = \exp(-d)$ for pretrained model</li>
<li>Lambda weights: $\lambda_{a} = \lambda_{d} = 0.33$ for pretrained model</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Pretrained MAT: 1024-dim embeddings, 8 layers, 16 attention heads, 1 feed-forward layer per block</li>
<li>Dropout: 0.0, weight decay: 0.0 for pretrained model</li>
<li>Atom featurization: 26-dimensional one-hot encoding (Table 1 in paper)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>Regression: RMSE (FreeSolv, ESOL)</li>
<li>Classification: ROC AUC (BBBP, Estrogen-alpha/beta, MetStab-high/low)</li>
<li>All experiments repeated 6 times with different train/validation/test splits</li>
<li>Scaffold split for BBBP, Estrogen, random split for others</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify exact hardware details. The pretrained model is described as &ldquo;the largest model that still fits the GPU memory.&rdquo;</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/gmum/MAT">gmum/MAT</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation with pretrained weights</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Maziarka, Ł., Danel, T., Mucha, S., Rataj, K., Tabor, J., &amp; Jastrzębski, S. (2020). Molecule Attention Transformer. <em>arXiv preprint arXiv:2002.08264</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{maziarka2020molecule,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Molecule Attention Transformer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Maziarka, {\L}ukasz and Danel, Tomasz and Mucha, S{\l}awomir and Rataj, Krzysztof and Tabor, Jacek and Jastrz{\k{e}}bski, Stanis{\l}aw}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2002.08264}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DMP: Dual-View Molecule Pre-training (SMILES+GNN)</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/dual-view-molecule-pretraining/</link><pubDate>Fri, 27 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/dual-view-molecule-pretraining/</guid><description>DMP pre-trains molecular encoders using both SMILES Transformer and GNN branches with a BYOL-style dual-view consistency loss for property prediction.</description><content:encoded><![CDATA[<h2 id="a-dual-branch-pre-training-method-for-molecular-property-prediction">A Dual-Branch Pre-training Method for Molecular Property Prediction</h2>
<p>DMP (Dual-view Molecule Pre-training) is a <strong>Method</strong> paper that introduces a pre-training framework combining two complementary molecular encoders: a Transformer operating on <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings and a Graph Neural Network (GNN) operating on molecular graphs. The two branches are trained jointly with masked language modeling (MLM) objectives plus a BYOL-style dual-view consistency loss. After pre-training on 10M <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> molecules, either branch (or both) can be fine-tuned for downstream tasks. The authors recommend the Transformer branch based on empirical results. DMP achieves the best reported performance on 7 of 9 <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> classification tasks and 3 retrosynthesis benchmarks (at the time of the 2021 arXiv version).</p>
<h2 id="why-combine-smiles-and-graph-views-for-molecules">Why Combine SMILES and Graph Views for Molecules</h2>
<p>Prior molecule pre-training methods used either graph representations with GNNs or SMILES representations with Transformers, but not both. The authors observe that the two views are complementary: Transformers handle molecules with large atom distances (long chains) well, while GNNs handle molecules with many concatenated rings better. Neither model alone captures the full range of molecular structures effectively.</p>
<p>Existing GNN-based pre-training methods (Hu et al. 2020, MolCLR, GROVER) and SMILES-based methods (<a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a>, <a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a>) each have blind spots dictated by their input representation. DMP addresses this by pre-training both views simultaneously and enforcing representation consistency between them, so each branch benefits from the structural knowledge of the other.</p>
<h2 id="dual-view-consistency-with-byol-style-training">Dual-View Consistency with BYOL-Style Training</h2>
<p>The core innovation is the dual-view consistency objective, inspired by Bootstrap Your Own Latent (BYOL). Given a molecule $M$ with SMILES representation $M_s$ and graph representation $M_g$, DMP obtains high-level features from each branch:</p>
<ul>
<li><strong>Transformer branch</strong>: A RoBERTa-base model encodes the SMILES sequence. The [CLS] token output serves as the molecule representation $f_s$.</li>
<li><strong>GNN branch</strong>: A DeeperGCN network encodes the molecular graph. Mean+max pooling over atom representations yields $f_g$.</li>
</ul>
<p>The dual-view consistency loss uses nonlinear projection heads $\psi_g, \psi_s$ and prediction heads $\rho_g, \rho_s$:</p>
<p>$$
p_g = \psi_g(f_g), \quad q_g = \rho_g(p_g); \quad p_s = \psi_s(f_s), \quad q_s = \rho_s(p_s)
$$</p>
<p>The consistency loss maximizes cross-view <a href="https://en.wikipedia.org/wiki/Cosine_similarity">cosine similarity</a> with stop-gradient (SG) on the target:</p>
<p>$$
\ell_{\text{dual}}(\tilde{M}_g, \tilde{M}_s) = -\cos(q_s, \text{SG}(p_g)) - \cos(q_g, \text{SG}(p_s))
$$</p>
<p>where $\cos(p, q) = \frac{p^\top q}{|p|_2 |q|_2}$ and $\tilde{M}_g, \tilde{M}_s$ are the masked versions of the inputs. The stop-gradient prevents representation collapse without requiring negative samples or a momentum encoder.</p>
<p>The full training objective combines three losses:</p>
<ol>
<li><strong>MLM on Transformer</strong>: Recover masked tokens in SMILES sequences</li>
<li><strong>MLM on GNN</strong>: Recover masked atoms in molecular graphs</li>
<li><strong>Dual-view consistency</strong>: The BYOL-style loss above</li>
</ol>
<p>Both MLM objectives and the consistency loss are necessary. Ablations show that removing MLM (using only dual-view loss) degrades performance, and using two branches of the same type (two Transformers or two GNNs) is less effective than the heterogeneous Transformer+GNN combination.</p>
<h2 id="experiments-on-moleculenet-and-retrosynthesis">Experiments on MoleculeNet and Retrosynthesis</h2>
<h3 id="pre-training-setup">Pre-training Setup</h3>
<p>DMP is pre-trained on 10M molecules from PubChem (matching prior work). The Transformer branch uses RoBERTa-base (12 layers, hidden dim 768, 87M parameters). The GNN branch uses DeeperGCN (12 layers, hidden dim 384, 7.4M parameters). Combined, DMP has 104.1M parameters. Training runs for 200K iterations on 8 V100 GPUs over 3.8 days with Adam optimizer (lr = 5e-4, weight decay 0.01).</p>
<h3 id="molecular-property-prediction-moleculenet">Molecular Property Prediction (MoleculeNet)</h3>
<p>DMP is evaluated on 6 binary classification tasks (BBBP, Tox21, ClinTox, HIV, BACE, SIDER) using official DeepChem splits, and on 3 additional tasks (BBBP, SIDER, ClinTox classification + ESOL, QM7, QM8 regression) using scaffold splits from GROVER.</p>
<p>Key results on DeepChem splits (ROC-AUC %):</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>MolCLR</th>
          <th>TF (MLM)</th>
          <th>DMP_TF</th>
          <th>DMP_TF+GNN</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>BBBP</td>
          <td>73.6</td>
          <td>74.9</td>
          <td><strong>78.1</strong></td>
          <td>77.8</td>
      </tr>
      <tr>
          <td>Tox21</td>
          <td>79.8</td>
          <td>77.6</td>
          <td><strong>78.8</strong></td>
          <td>79.1</td>
      </tr>
      <tr>
          <td>ClinTox</td>
          <td>93.2</td>
          <td>92.9</td>
          <td><strong>95.0</strong></td>
          <td>95.6</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>80.6</td>
          <td>80.2</td>
          <td><strong>81.0</strong></td>
          <td>81.4</td>
      </tr>
      <tr>
          <td>BACE</td>
          <td>89.0</td>
          <td>88.0</td>
          <td><strong>89.3</strong></td>
          <td>89.4</td>
      </tr>
      <tr>
          <td>SIDER</td>
          <td>68.0</td>
          <td>68.4</td>
          <td><strong>69.2</strong></td>
          <td>69.8</td>
      </tr>
  </tbody>
</table>
<p>On scaffold splits (comparison with GROVER and MPG):</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>GROVER</th>
          <th>MPG</th>
          <th>DMP_TF</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>BBBP (AUC)</td>
          <td>0.940</td>
          <td>0.922</td>
          <td><strong>0.945</strong></td>
      </tr>
      <tr>
          <td>SIDER (AUC)</td>
          <td>0.658</td>
          <td>0.661</td>
          <td><strong>0.695</strong></td>
      </tr>
      <tr>
          <td>ClinTox (AUC)</td>
          <td>0.944</td>
          <td>0.963</td>
          <td><strong>0.968</strong></td>
      </tr>
      <tr>
          <td>ESOL (RMSE)</td>
          <td>0.831</td>
          <td>0.741</td>
          <td><strong>0.700</strong></td>
      </tr>
      <tr>
          <td>QM7 (MAE)</td>
          <td>72.6</td>
          <td>-</td>
          <td><strong>69.6</strong></td>
      </tr>
      <tr>
          <td>QM8 (MAE)</td>
          <td>0.0125</td>
          <td>-</td>
          <td><strong>0.0124</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="retrosynthesis">Retrosynthesis</h3>
<p>DMP is tested on USPTO-50K (reaction type known/unknown) and USPTO-full. Using a &ldquo;DMP fusion&rdquo; approach (fusing pre-trained representations into a Transformer encoder-decoder for <a href="/notes/chemistry/molecular-design/reaction-prediction/">retrosynthesis</a>), DMP improves top-1 accuracy by 2-3 points over the baseline Transformer across all settings:</p>
<table>
  <thead>
      <tr>
          <th>Setting</th>
          <th>Transformer</th>
          <th>ChemBERTa fusion</th>
          <th>DMP fusion</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>USPTO-50K (unknown)</td>
          <td>42.3</td>
          <td>43.9</td>
          <td><strong>46.1</strong></td>
      </tr>
      <tr>
          <td>USPTO-50K (known)</td>
          <td>54.2</td>
          <td>56.4</td>
          <td><strong>57.5</strong></td>
      </tr>
      <tr>
          <td>USPTO-full</td>
          <td>42.9</td>
          <td>-</td>
          <td><strong>45.0</strong></td>
      </tr>
  </tbody>
</table>
<p>For GNN-based retrosynthesis, replacing GLN&rsquo;s GNN modules with DMP&rsquo;s pre-trained GNN branch improves top-1 accuracy from 52.5% to 54.2% (unknown type) and from 64.2% to 66.5% (known type).</p>
<h3 id="representation-quality">Representation Quality</h3>
<p><a href="https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding">t-SNE</a> visualization of pre-trained representations shows that DMP produces better scaffold-based clustering than either GNN-only or Transformer-only pre-training. The <a href="https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index">Davies-Bouldin index</a> improves from 3.56 (GNN) and 3.59 (Transformer) to 2.19 (DMP), indicating much tighter within-scaffold clusters.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<p><strong>Key findings:</strong></p>
<ul>
<li>Combining heterogeneous views (SMILES + graph) during pre-training is more effective than using two branches of the same type. TF(x2) and GNN(x2) variants show smaller gains.</li>
<li>Both MLM and dual-view consistency loss contribute. Removing MLM (dual-view only) hurts performance, especially on BBBP (71.1 vs 78.1 with both losses).</li>
<li>The Transformer branch alone is recommended for downstream tasks, as it achieves strong results without adding GNN parameters at inference time.</li>
<li>Scaling pre-training data from 10M to 100M compounds yields marginal additional improvement.</li>
</ul>
<p><strong>Limitations acknowledged by the authors:</strong></p>
<ol>
<li>Training cost is higher than single-branch methods (3.8 days vs 2.5 days for TF-only on 8 V100s), since both branches must be trained jointly.</li>
<li>A fixed branch selection strategy is used at inference time. The authors note that a meta-controller for dynamic branch selection per molecule would be preferable.</li>
<li>The GNN branch uses simple atom masking without bond deletion or subgraph removal, leaving room for stronger graph-level pre-training objectives.</li>
</ol>
<p><strong>Relation to co-training:</strong> The authors clarify that DMP differs from classical <a href="https://en.wikipedia.org/wiki/Co-training">co-training</a> (Blum and Mitchell 1998) in that it does not require conditional independence between views and produces a pre-trained model rather than additional labeled data.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>PubChem subset</td>
          <td>10M compounds</td>
          <td>Same subset as MolCLR and ChemBERTa</td>
      </tr>
      <tr>
          <td>Pre-training (large)</td>
          <td>PubChem subset</td>
          <td>100M compounds</td>
          <td>Additional scale experiment</td>
      </tr>
      <tr>
          <td>Evaluation (classification)</td>
          <td>MoleculeNet (BBBP, Tox21, ClinTox, HIV, BACE, SIDER)</td>
          <td>1.5K-41K molecules</td>
          <td>Official DeepChem splits</td>
      </tr>
      <tr>
          <td>Evaluation (regression)</td>
          <td>MoleculeNet (ESOL, QM7, QM8)</td>
          <td>Varies</td>
          <td>Scaffold splits from GROVER</td>
      </tr>
      <tr>
          <td>Evaluation (retrosynthesis)</td>
          <td>USPTO-50K, USPTO-full</td>
          <td>50K / 950K reactions</td>
          <td>Splits from Dai et al. (2019)</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Transformer branch</strong>: RoBERTa-base with MLM. SMILES tokenized using regex from Schwaller et al. (2019).</li>
<li><strong>GNN branch</strong>: DeeperGCN with 12 layers, atom masking for MLM.</li>
<li><strong>Dual-view loss</strong>: BYOL-style with 3-layer MLP projection heads and 2-layer MLP prediction heads, stop-gradient on targets.</li>
<li><strong>Optimizer</strong>: Adam (lr=5e-4, beta1=0.9, beta2=0.98, epsilon=1e-6), weight decay 0.01, 10K warmup steps, linear decay.</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Component</th>
          <th>Architecture</th>
          <th>Parameters</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Transformer branch</td>
          <td>RoBERTa-base (12L, 768H, 12 heads)</td>
          <td>87M</td>
      </tr>
      <tr>
          <td>GNN branch</td>
          <td>DeeperGCN (12L, 384H)</td>
          <td>7.4M</td>
      </tr>
      <tr>
          <td>DMP (total)</td>
          <td>Transformer + GNN + projection/prediction heads</td>
          <td>104.1M</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>Classification: ROC-AUC, averaged over 3 random seeds</li>
<li>Regression: RMSE (ESOL) or MAE (QM7, QM8)</li>
<li>Retrosynthesis: Top-k exact match accuracy (k=1,3,5,10,20,50)</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Pre-training: 8 NVIDIA V100 GPUs, batch size 12288 tokens, gradient accumulation 16x</li>
<li>Pre-training time: 3.8 days (DMP), 2.5 days (TF-only), 1.7 days (GNN-only)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<p>No public code repository or pre-trained model weights were identified for this paper. The paper references GLN&rsquo;s code repository (<a href="https://github.com/Hanjun-Dai/GLN">https://github.com/Hanjun-Dai/GLN</a>) for the retrosynthesis baseline but does not release DMP-specific code.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Hanjun-Dai/GLN">GLN (baseline)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Retrosynthesis baseline, not DMP code</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhu, J., Xia, Y., Wu, L., Xie, S., Zhou, W., Qin, T., Li, H., &amp; Liu, T.-Y. (2023). Dual-view Molecular Pre-training. In <em>Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining</em> (pp. 3615-3627). <a href="https://doi.org/10.1145/3580305.3599317">https://doi.org/10.1145/3580305.3599317</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{zhu2023dualview,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Dual-view Molecular Pre-training}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Zhu, Jinhua and Xia, Yingce and Wu, Lijun and Xie, Shufang and Zhou, Wengang and Qin, Tao and Li, Houqiang and Liu, Tie-Yan}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{3615--3627}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1145/3580305.3599317}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Transformers for Molecular Property Prediction Review</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/transformers-molecular-property-prediction-review/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/transformers-molecular-property-prediction-review/</guid><description>A systematic review of 16 transformer models for molecular property prediction, analyzing architecture, data, tokenization, and benchmarking gaps.</description><content:encoded><![CDATA[<h2 id="a-systematization-of-transformers-for-molecular-property-prediction">A Systematization of Transformers for Molecular Property Prediction</h2>
<p>This is a <strong>Systematization</strong> paper. Sultan et al. provide the first comprehensive, structured review of sequence-based transformer models applied to molecular property prediction (MPP). The review catalogs 16 models published between 2019 and 2023, organizes them by architecture type (encoder-decoder, encoder-only, decoder-only), and systematically examines seven key design decisions that arise when building a transformer for MPP. The paper&rsquo;s primary contribution is identifying gaps in current evaluation practices and articulating what standardization the field needs for meaningful progress.</p>
<h2 id="the-problem-inconsistent-evaluation-hinders-progress">The Problem: Inconsistent Evaluation Hinders Progress</h2>
<p>Molecular property prediction is essential for drug discovery, crop protection, and environmental science. Deep learning approaches, including transformers, have been increasingly applied to this task by learning molecular representations from string notations like <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> and <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>. However, the field faces several challenges:</p>
<ol>
<li><strong>Small labeled datasets</strong>: Labeled molecular property datasets typically contain only hundreds or thousands of molecules, making supervised learning alone insufficient.</li>
<li><strong>No standardized evaluation protocol</strong>: Different papers use different data splits (scaffold vs. random), different splitting implementations, different numbers of repetitions (3 to 50), and sometimes do not share their test sets. This makes direct comparison across models infeasible.</li>
<li><strong>Unclear design choices</strong>: With many possible configurations for pre-training data, chemical language, tokenization, positional embeddings, model size, pre-training objectives, and fine-tuning approaches, the field lacks systematic analyses to guide practitioners.</li>
</ol>
<p>The authors note that standard machine learning methods with fixed-size molecular fingerprints remain strong baselines for real-world datasets, illustrating that the promise of transformers for MPP has not yet been fully realized.</p>
<h2 id="seven-design-questions-for-molecular-transformers">Seven Design Questions for Molecular Transformers</h2>
<p>The central organizing framework of this review addresses seven questions practitioners must answer when building a transformer for MPP. For each, the authors synthesize findings across the 16 reviewed models.</p>
<h3 id="reviewed-models">Reviewed Models</h3>
<p>The paper catalogs 16 models organized by architecture:</p>
<table>
  <thead>
      <tr>
          <th>Architecture</th>
          <th>Base Model</th>
          <th>Models</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Encoder-Decoder</td>
          <td>Transformer, BART</td>
          <td><a href="/notes/chemistry/molecular-representations/encoders/smiles-transformer/">ST</a>, Transformer-CNN, <a href="/notes/chemistry/molecular-representations/encoders/x-mol-pretraining-molecular-understanding/">X-Mol</a>, <a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">ChemFormer</a></td>
      </tr>
      <tr>
          <td>Encoder-Only</td>
          <td>BERT</td>
          <td><a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a>, MAT, <a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a>, Mol-BERT, Chen et al., K-BERT, FP-BERT, <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MolFormer</a></td>
      </tr>
      <tr>
          <td>Encoder-Only</td>
          <td>RoBERTa</td>
          <td><a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a>, <a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a>, <a href="/notes/chemistry/molecular-representations/encoders/selformer/">SELFormer</a></td>
      </tr>
      <tr>
          <td>Decoder-Only</td>
          <td>XLNet</td>
          <td><a href="/notes/chemistry/molecular-design/property-prediction/regression-transformer/">Regression Transformer</a> (RT)</td>
      </tr>
  </tbody>
</table>
<p>The core attention mechanism shared by all these models is the scaled dot-product attention:</p>
<p>$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^{T}}{\sqrt{d_{k}}}\right)V
$$</p>
<p>where $Q$, $K$, and $V$ are the query, key, and value matrices, and $d_{k}$ is the dimension of the key vectors.</p>
<h3 id="question-1-which-database-and-how-many-molecules">Question 1: Which Database and How Many Molecules?</h3>
<p>Pre-training data sources vary considerably. The three main databases are ZINC (37 billion molecules in ZINC22), ChEMBL (2.4 million unique molecules with 20 million bioactivity measurements), and PubChem (111 million unique molecules). Pre-training set sizes ranged from 900K (ST on ChEMBL) to 1.1B molecules (MolFormer on ZINC + PubChem).</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Database</th>
          <th>Size</th>
          <th>Language</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ST</td>
          <td>ChEMBL</td>
          <td>900K</td>
          <td>SMILES</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a></td>
          <td>ChEMBL (<a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a>)</td>
          <td>1.6M</td>
          <td>SMILES</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a></td>
          <td>PubChem</td>
          <td>100K-10M</td>
          <td>SMILES, SELFIES</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a></td>
          <td>PubChem</td>
          <td>5M-77M</td>
          <td>SMILES</td>
      </tr>
      <tr>
          <td>MAT</td>
          <td>ZINC</td>
          <td>2M</td>
          <td>List of atoms</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/molformer/">MolFormer</a></td>
          <td>ZINC + PubChem</td>
          <td>1.1B</td>
          <td>SMILES</td>
      </tr>
      <tr>
          <td>Chen et al.</td>
          <td>C, CP, CPZ</td>
          <td>2M-775M</td>
          <td>SMILES</td>
      </tr>
  </tbody>
</table>
<p>A key finding is that larger pre-training datasets do not consistently improve downstream performance. MolFormer showed minimal difference between models trained on 100M vs. 1.1B molecules. ChemBERTa-2 found that the model trained on 5M molecules using MLM performed comparably to 77M molecules for BBBP (both around 0.70 ROC-AUC). Chen et al. reported comparable $R^{2}$ values of $0.925 \pm 0.01$, $0.917 \pm 0.012$, and $0.915 \pm 0.01$ for ESOL across datasets of 2M, 103M, and 775M molecules, respectively. The data composition and covered chemical space appear to matter more than raw size.</p>
<h3 id="question-2-which-chemical-language">Question 2: Which Chemical Language?</h3>
<p>Most models use SMILES. ChemBERTa, RT, and SELFormer also explored SELFIES. MAT uses a simple list of atoms with structural features, while Mol-BERT and FP-BERT use circular fingerprints.</p>
<p>Direct comparisons between SMILES and SELFIES (by ChemBERTa on Tox21 SR-p53 and RT for drug-likeness prediction) found no significant performance difference. The RT authors reported that SELFIES models performed approximately $0.004 \pm 0.01$ better on RMSE, while SMILES models performed approximately $0.004 \pm 0.01$ better on Pearson correlation. The choice of chemical language does not appear to be a major factor in prediction performance, and even non-string representations (atom lists in MAT, fingerprints in Mol-BERT) perform competitively.</p>
<h3 id="question-3-how-to-tokenize">Question 3: How to Tokenize?</h3>
<p>Tokenization methods span atom-level (42-66 vocabulary tokens), regex-based (47-2,362 tokens), BPE (509-52K tokens), and substructure-based (3,357-13,325 tokens) approaches. No systematic comparison of tokenization strategies exists in the literature. The vocabulary size varied dramatically, from 42 tokens for MolBERT to over 52K for ChemBERTa. The authors argue that chemically meaningful tokenization (e.g., functional group-based fragmentation) could improve both performance and explainability.</p>
<h3 id="question-4-how-to-add-positional-embeddings">Question 4: How to Add Positional Embeddings?</h3>
<p>Most models inherited the absolute positional embedding from their NLP base models. MolBERT and RT adopted relative positional embeddings. MolFormer combined absolute and Rotary Positional Embedding (RoPE). MAT incorporated spatial information (inter-atomic 3D distances and adjacency) alongside self-attention.</p>
<p>MolFormer&rsquo;s comparison showed that RoPE became superior to absolute embeddings only when the pre-training dataset was very large. The performance difference (MAE on QM9) between absolute and RoPE embeddings for models trained on 111K, 111M, and 1.1B molecules was approximately $-0.20 \pm 0.18$, $-0.44 \pm 0.22$, and $0.27 \pm 0.12$, respectively.</p>
<p>The authors highlight that SMILES and SELFIES are linearizations of a 2D molecular graph, so consecutive tokens in a sequence are not necessarily spatially close. Positional embeddings that reflect 2D or 3D molecular structure remain underexplored.</p>
<h3 id="question-5-how-many-parameters">Question 5: How Many Parameters?</h3>
<p>Model sizes range from approximately 7M (ST, Mol-BERT) to over 100M parameters (MAT). Most chemical language models operate with 100M parameters or fewer, much smaller than NLP models like BERT (110M-330M) or GPT-3 (175B).</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Dimensions</th>
          <th>Heads</th>
          <th>Layers</th>
          <th>Parameters</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ST</td>
          <td>256</td>
          <td>4</td>
          <td>4</td>
          <td>7M</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a></td>
          <td>768</td>
          <td>12</td>
          <td>12</td>
          <td>85M</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/molformer/">MolFormer</a></td>
          <td>768</td>
          <td>12</td>
          <td>6, 12</td>
          <td>43M, 85M</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/selformer/">SELFormer</a></td>
          <td>768</td>
          <td>12, 4</td>
          <td>8, 12</td>
          <td>57M, 85M</td>
      </tr>
      <tr>
          <td>MAT</td>
          <td>1024</td>
          <td>16</td>
          <td>8</td>
          <td>101M</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a></td>
          <td>768</td>
          <td>12</td>
          <td>6</td>
          <td>43M</td>
      </tr>
  </tbody>
</table>
<p>SELFormer and MolFormer both tested different model sizes. SELFormer&rsquo;s larger model (approximately 86M parameters) showed approximately 0.034 better ROC-AUC for BBBP compared to the smaller model. MolFormer&rsquo;s larger model (approximately 87M parameters) performed approximately 0.04 better ROC-AUC on average for BBBP, HIV, BACE, and SIDER. The field lacks the systematic scaling analyses (analogous to Kaplan et al. and Hoffmann et al. in NLP) needed to establish proper scaling laws for chemical language models.</p>
<h3 id="question-6-which-pre-training-objectives">Question 6: Which Pre-training Objectives?</h3>
<p>Pre-training objectives fall into domain-agnostic and domain-specific categories:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Pre-training Objective</th>
          <th>Fine-tuning</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/molformer/">MolFormer</a></td>
          <td>MLM</td>
          <td>Frozen, Update</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a></td>
          <td>MLM</td>
          <td>Update</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a></td>
          <td>MLM, PhysChemPred, SMILES-EQ</td>
          <td>Frozen, Update</td>
      </tr>
      <tr>
          <td>K-BERT</td>
          <td>Atom feature, MACCS prediction, CL</td>
          <td>Update last layer</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a></td>
          <td>MLM, MTR</td>
          <td>Update</td>
      </tr>
      <tr>
          <td>MAT</td>
          <td>MLM, 2D Adjacency, 3D Distance</td>
          <td>Update</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">ChemFormer</a></td>
          <td>Denoising Span MLM, Augmentation</td>
          <td>Update</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/property-prediction/regression-transformer/">RT</a></td>
          <td>PLM (Permutation Language Modeling)</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p>Domain-specific objectives (predicting physico-chemical properties, atom features, or MACCS keys) showed promising but inconsistent results. MolBERT&rsquo;s PhysChemPred performed closely to the full three-objective model (approximately $0.72 \pm 0.06$ vs. $0.71 \pm 0.06$ ROC-AUC in virtual screening). The SMILES-EQ objective (identifying equivalent SMILES) was found to lower performance when combined with other objectives. K-BERT&rsquo;s contrastive learning objective did not significantly change performance (average ROC-AUC of 0.806 vs. 0.807 with and without CL).</p>
<p>ChemBERTa-2&rsquo;s Multi-Task Regression (MTR) objective performed noticeably better than MLM-only for almost all four classification tasks across pre-training dataset sizes.</p>
<h3 id="question-7-how-to-fine-tune">Question 7: How to Fine-tune?</h3>
<p>Fine-tuning through weight updates generally outperforms frozen representations. SELFormer showed this most dramatically, with a difference of 2.187 RMSE between frozen and updated models on FreeSolv. MolBERT showed a much smaller difference (0.575 RMSE on FreeSolv), likely because its domain-specific pre-training objectives already produced representations closer to the downstream tasks.</p>
<h2 id="benchmarking-challenges-and-performance-comparison">Benchmarking Challenges and Performance Comparison</h2>
<h3 id="downstream-datasets">Downstream Datasets</h3>
<p>The review focuses on nine benchmark datasets across three categories from <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Molecules</th>
          <th>Tasks</th>
          <th>Type</th>
          <th>Application</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESOL</td>
          <td>1,128</td>
          <td>1 regression</td>
          <td>Physical chemistry</td>
          <td>Aqueous solubility</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>642</td>
          <td>1 regression</td>
          <td>Physical chemistry</td>
          <td>Hydration free energy</td>
      </tr>
      <tr>
          <td>Lipophilicity</td>
          <td>4,200</td>
          <td>1 regression</td>
          <td>Physical chemistry</td>
          <td>LogD at pH 7.4</td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td>2,050</td>
          <td>1 classification</td>
          <td>Physiology</td>
          <td>Blood-brain barrier</td>
      </tr>
      <tr>
          <td>ClinTox</td>
          <td>1,484</td>
          <td>2 classification</td>
          <td>Physiology</td>
          <td>Clinical trial toxicity</td>
      </tr>
      <tr>
          <td>SIDER</td>
          <td>1,427</td>
          <td>27 classification</td>
          <td>Physiology</td>
          <td>Drug side effects</td>
      </tr>
      <tr>
          <td>Tox21</td>
          <td>7,831</td>
          <td>12 classification</td>
          <td>Physiology</td>
          <td>Nuclear receptor/stress pathways</td>
      </tr>
      <tr>
          <td>BACE</td>
          <td>1,513</td>
          <td>1 classification</td>
          <td>Biophysics</td>
          <td>Beta-secretase 1 binding</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>41,127</td>
          <td>1 classification</td>
          <td>Biophysics</td>
          <td>Anti-HIV activity</td>
      </tr>
  </tbody>
</table>
<h3 id="inconsistencies-in-evaluation">Inconsistencies in Evaluation</h3>
<p>The authors document substantial inconsistencies that prevent fair model comparison:</p>
<ol>
<li><strong>Data splitting</strong>: Models used different splitting methods (scaffold vs. random) and different implementations even when using the same method. Not all models adhered to scaffold splitting for classification tasks as recommended.</li>
<li><strong>Different test sets</strong>: Even models using the same split type may not evaluate on identical test molecules due to different random seeds.</li>
<li><strong>Varying repetitions</strong>: Repetitions ranged from 3 (RT) to 50 (Chen et al.), making some analyses more statistically robust than others.</li>
<li><strong>Metric inconsistency</strong>: Most use ROC-AUC for classification and RMSE for regression, but some models report only averages without standard deviations, while others report standard errors.</li>
</ol>
<h3 id="performance-findings">Performance Findings</h3>
<p>When comparing only models evaluated on the same test sets (Figure 2 in the paper), the authors observe that transformer models show comparable, but not consistently superior, performance to existing ML and DL models. The performance varies considerably across models and datasets.</p>
<p>For BBBP, the Mol-BERT model reported lower ROC-AUC than its corresponding MPNN (approximately 0.88 vs. 0.91), while MolBERT outperformed its corresponding CDDD model (approximately 0.86 vs. 0.76 ROC-AUC) and its SVM baseline (approximately 0.86 vs. 0.70 ROC-AUC). A similar mixed pattern appeared for HIV: ChemBERTa performed worse than its corresponding ML models, while MolBERT performed better than its ML (approximately 0.08 higher ROC-AUC) and DL (approximately 0.03 higher ROC-AUC) baselines. For SIDER, Mol-BERT performed approximately 0.1 better ROC-AUC than its corresponding MPNN. For regression, MAT and MolBERT showed improved performance over their ML and DL baselines on ESOL, FreeSolv, and Lipophilicity. For example, MAT performed approximately 0.2 lower RMSE than an SVM model and approximately 0.03 lower RMSE than the Weave model on ESOL.</p>
<h2 id="key-takeaways-and-future-directions">Key Takeaways and Future Directions</h2>
<p>The review concludes with six main takeaways:</p>
<ol>
<li><strong>Performance</strong>: Transformers using SMILES show comparable but not consistently superior performance to existing ML and DL models for MPP.</li>
<li><strong>Scaling</strong>: No systematic analysis of model parameter scaling relative to data size exists for chemical language models. Such analysis is essential.</li>
<li><strong>Pre-training data</strong>: Dataset size alone is not the sole determinant of downstream performance. Composition and chemical space coverage matter.</li>
<li><strong>Chemical language</strong>: SMILES and SELFIES perform similarly. Alternative representations (atom lists, fingerprints) also work when the architecture is adjusted.</li>
<li><strong>Domain knowledge</strong>: Domain-specific pre-training objectives show promise, but tokenization and positional encoding remain underexplored.</li>
<li><strong>Benchmarking</strong>: The community needs standardized data splitting, fixed test sets, statistical analysis, and consistent reporting to enable meaningful comparison.</li>
</ol>
<p>The authors also highlight the need for attention visualization and explainability analysis, investigation of NLP-originated techniques (pre-training regimes, fine-tuning strategies like LoRA, explainability methods), and adaptation of these techniques to the specific characteristics of chemical data (smaller vocabularies, shorter sequences).</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>This is a review paper. No new data or models are introduced. All analyses use previously reported results from the 16 reviewed papers, with additional visualization and comparison. The authors provide a GitHub repository with the code and data used to generate their comparative figures.</p>
<h3 id="algorithms">Algorithms</h3>
<p>Not applicable (review paper). The paper describes training strategies at a conceptual level, referencing the original publications for implementation details.</p>
<h3 id="models">Models</h3>
<p>Not applicable (review paper). The paper catalogs 16 models with their architecture details, parameter counts, and training configurations across Tables 1, 4, 5, 6, and 7.</p>
<h3 id="evaluation">Evaluation</h3>
<p>The paper compiles performance across nine MoleculeNet datasets. Key comparison figures (Figures 2 and 7) restrict to models evaluated on the same test sets for fair comparison, using ROC-AUC for classification and RMSE for regression.</p>
<h3 id="hardware">Hardware</h3>
<p>Not applicable (review paper).</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/volkamerlab/Transformers4MPP_review">Transformers4MPP_review</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Figure generation code and compiled data</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Sultan, A., Sieg, J., Mathea, M., &amp; Volkamer, A. (2024). Transformers for Molecular Property Prediction: Lessons Learned from the Past Five Years. <em>Journal of Chemical Information and Modeling</em>, 64(16), 6259-6280. <a href="https://doi.org/10.1021/acs.jcim.4c00747">https://doi.org/10.1021/acs.jcim.4c00747</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{sultan2024transformers,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Transformers for Molecular Property Prediction: Lessons Learned from the Past Five Years}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Sultan, Afnan and Sieg, Jochen and Mathea, Miriam and Volkamer, Andrea}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{64}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{16}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{6259--6280}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.4c00747}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Transformer-CNN: SMILES Embeddings for QSAR Modeling</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/transformer-cnn-qsar-modeling/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/transformer-cnn-qsar-modeling/</guid><description>Transformer-CNN uses SMILES embeddings from a canonicalization Transformer with a CNN head for interpretable QSAR property prediction.</description><content:encoded><![CDATA[<h2 id="transformer-based-smiles-embeddings-for-property-prediction">Transformer-Based SMILES Embeddings for Property Prediction</h2>
<p>This is a <strong>Method</strong> paper that introduces Transformer-CNN, a two-stage architecture for <a href="https://en.wikipedia.org/wiki/Quantitative_structure%E2%80%93activity_relationship">QSAR</a> (Quantitative Structure-Activity Relationship) modeling. The primary contribution is a transfer learning approach: a Transformer model is first trained on the task of <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> canonicalization (mapping non-canonical SMILES to canonical forms), and the encoder&rsquo;s internal representations are then used as &ldquo;dynamic SMILES embeddings&rdquo; for downstream property prediction via a convolutional neural network (TextCNN). The authors also contribute an interpretability framework based on Layer-wise Relevance Propagation (LRP) that traces predictions back to individual atom contributions.</p>
<h2 id="from-descriptors-to-learned-embeddings-in-qsar">From Descriptors to Learned Embeddings in QSAR</h2>
<p>Traditional QSAR methods rely on hand-engineered molecular descriptors (fragment counts, physicochemical features) coupled with feature selection and classical ML algorithms. While deep learning approaches that operate on raw SMILES strings or molecular graphs have reduced the need for manual feature engineering, they typically require large training datasets to learn effective representations from scratch. QSAR datasets, in contrast, often contain only hundreds of molecules, making it difficult to train end-to-end deep models.</p>
<p>The authors identify two specific gaps. First, existing SMILES-based autoencoders such as <a href="/notes/chemistry/molecular-representations/encoders/cddd-translation-molecular-descriptors/">CDDD</a> (Continuous and Data-Driven molecular Descriptors) produce fixed-length latent vectors, discarding positional information that could be useful for property prediction and interpretation. Second, QSAR models built on deep architectures generally lack interpretability, making it hard to verify that predictions rely on chemically meaningful structural features rather than spurious correlations.</p>
<h2 id="dynamic-smiles-embeddings-via-canonicalization-pre-training">Dynamic SMILES Embeddings via Canonicalization Pre-training</h2>
<p>The core insight is that training a Transformer to perform SMILES canonicalization (a Seq2Seq task mapping non-canonical SMILES to canonical SMILES) produces an encoder whose internal states serve as information-rich, position-dependent molecular embeddings.</p>
<h3 id="pre-training-on-smiles-canonicalization">Pre-training on SMILES Canonicalization</h3>
<p>The Transformer encoder-decoder is trained on approximately 17.7 million canonicalization pairs derived from the <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a> database (SMILES with length up to 110 characters). Each molecule is augmented 10 times by generating non-canonical SMILES variants, plus one identity pair where both sides are canonical. The training uses character-level tokenization with a 66-symbol vocabulary covering drug-like molecules including stereochemistry, charges, and inorganic ions.</p>
<p>The Transformer architecture follows Vaswani et al. with 3 layers and 10 self-attention heads. The learning rate schedule follows:</p>
<p>$$\lambda = \text{factor} \cdot \min(1.0,; \text{step} / \text{warmup}) / \max(\text{step},; \text{warmup})$$</p>
<p>where factor = 20, warmup = 16,000 steps, and $\lambda$ is clipped at a minimum of $10^{-4}$. Training runs for 10 epochs (275,907 batches per epoch) without early stopping.</p>
<p>On validation with 500,000 generated ChEMBL-like SMILES, the model correctly canonicalizes 83.6% of all samples. Performance drops for stereochemistry (37.2% for @-containing SMILES) and cis/trans notation (73.9%).</p>
<h3 id="from-encoder-states-to-qsar-predictions">From Encoder States to QSAR Predictions</h3>
<p>After pre-training, the encoder&rsquo;s output for a molecule with $N$ characters is a matrix of dimensions $(N, \text{EMBEDDINGS})$. Unlike fixed-length CDDD descriptors, these &ldquo;dynamic embeddings&rdquo; preserve positional information, meaning equivalent characters receive different embedding values depending on their context and position.</p>
<p>To handle variable-length embeddings, the authors use a TextCNN architecture (from DeepChem) with 1D convolutional filters at kernel sizes (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20) producing (100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160) filters respectively. After GlobalMaxPool and concatenation, the features pass through Dropout (rate = 0.25), a Dense layer ($N = 512$), a Highway layer, and finally an output layer (1 neuron for regression, 2 for classification).</p>
<p>The Transformer weights are frozen during QSAR training. The Adam optimizer is used with a fixed learning rate of $10^{-4}$ and early stopping on a 10% held-out validation set. Critically, SMILES augmentation ($n = 10$) is applied during both training and inference, with the final prediction being the average over augmented SMILES for each molecule.</p>
<h3 id="interpretability-via-layer-wise-relevance-propagation">Interpretability via Layer-wise Relevance Propagation</h3>
<p>The LRP algorithm propagates relevance scores from the output back through the CNN layers to the Transformer encoder output (which is position-wise). The relevance conservation property holds:</p>
<p>$$y = R = f(x) = \sum_{l \in (L)} R_{l} = \sum_{l \in (L-1)} R_{l} = \cdots = \sum_{l \in (1)} R_{l}$$</p>
<p>In practice, biases absorb some relevance, so the total propagated to the input is less than the output:</p>
<p>$$\sum_{l \in (L)} R_{l} = \sum_{l \in (L-1)} R_{l} + B$$</p>
<p>For gated connections in the Highway block, the authors implement the signal-take-all redistribution rule. The interpretation algorithm generates one SMILES per non-hydrogen atom (each drawn starting from that atom), runs LRP on each, and averages contributions. If more than 50% of relevance dissipates on biases, the interpretation may be unreliable, serving as an applicability domain indicator.</p>
<h2 id="benchmarks-across-18-regression-and-classification-datasets">Benchmarks Across 18 Regression and Classification Datasets</h2>
<p>The authors evaluate on the same 18 datasets (9 regression, 9 classification) used in their previous SMILES augmentation study, enabling direct comparison. All experiments use five-fold cross-validation.</p>
<h3 id="regression-results-r2">Regression Results ($r^2$)</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th style="text-align: center">Descriptor-based</th>
          <th style="text-align: center">SMILES-based (augm=10)</th>
          <th style="text-align: center">Transformer-CNN (no augm)</th>
          <th style="text-align: center">Transformer-CNN (augm=10)</th>
          <th style="text-align: center">CDDD</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MP (19,104)</td>
          <td style="text-align: center">0.83</td>
          <td style="text-align: center">0.85</td>
          <td style="text-align: center">0.83</td>
          <td style="text-align: center"><strong>0.86</strong></td>
          <td style="text-align: center">0.85</td>
      </tr>
      <tr>
          <td>BP (11,893)</td>
          <td style="text-align: center">0.98</td>
          <td style="text-align: center">0.98</td>
          <td style="text-align: center">0.97</td>
          <td style="text-align: center"><strong>0.98</strong></td>
          <td style="text-align: center">0.98</td>
      </tr>
      <tr>
          <td>BCF (378)</td>
          <td style="text-align: center">0.85</td>
          <td style="text-align: center">0.85</td>
          <td style="text-align: center">0.71</td>
          <td style="text-align: center"><strong>0.85</strong></td>
          <td style="text-align: center">0.81</td>
      </tr>
      <tr>
          <td>FreeSolv (642)</td>
          <td style="text-align: center"><strong>0.94</strong></td>
          <td style="text-align: center">0.93</td>
          <td style="text-align: center">0.72</td>
          <td style="text-align: center">0.91</td>
          <td style="text-align: center">0.93</td>
      </tr>
      <tr>
          <td>LogS (1,311)</td>
          <td style="text-align: center"><strong>0.92</strong></td>
          <td style="text-align: center">0.92</td>
          <td style="text-align: center">0.85</td>
          <td style="text-align: center">0.91</td>
          <td style="text-align: center">0.91</td>
      </tr>
      <tr>
          <td>Lipo (4,200)</td>
          <td style="text-align: center">0.70</td>
          <td style="text-align: center">0.72</td>
          <td style="text-align: center">0.60</td>
          <td style="text-align: center">0.73</td>
          <td style="text-align: center"><strong>0.74</strong></td>
      </tr>
      <tr>
          <td>BACE (1,513)</td>
          <td style="text-align: center">0.73</td>
          <td style="text-align: center">0.72</td>
          <td style="text-align: center">0.66</td>
          <td style="text-align: center"><strong>0.76</strong></td>
          <td style="text-align: center">0.75</td>
      </tr>
      <tr>
          <td>DHFR (739)</td>
          <td style="text-align: center">0.62</td>
          <td style="text-align: center">0.63</td>
          <td style="text-align: center">0.46</td>
          <td style="text-align: center"><strong>0.67</strong></td>
          <td style="text-align: center">0.61</td>
      </tr>
      <tr>
          <td>LEL (483)</td>
          <td style="text-align: center">0.19</td>
          <td style="text-align: center">0.25</td>
          <td style="text-align: center">0.20</td>
          <td style="text-align: center"><strong>0.27</strong></td>
          <td style="text-align: center">0.23</td>
      </tr>
  </tbody>
</table>
<h3 id="classification-results-auc">Classification Results (AUC)</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th style="text-align: center">Descriptor-based</th>
          <th style="text-align: center">SMILES-based (augm=10)</th>
          <th style="text-align: center">Transformer-CNN (no augm)</th>
          <th style="text-align: center">Transformer-CNN (augm=10)</th>
          <th style="text-align: center">CDDD</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>HIV (41,127)</td>
          <td style="text-align: center">0.82</td>
          <td style="text-align: center">0.78</td>
          <td style="text-align: center">0.81</td>
          <td style="text-align: center"><strong>0.83</strong></td>
          <td style="text-align: center">0.74</td>
      </tr>
      <tr>
          <td>AMES (6,542)</td>
          <td style="text-align: center">0.86</td>
          <td style="text-align: center">0.88</td>
          <td style="text-align: center">0.86</td>
          <td style="text-align: center"><strong>0.89</strong></td>
          <td style="text-align: center">0.86</td>
      </tr>
      <tr>
          <td>BACE (1,513)</td>
          <td style="text-align: center">0.88</td>
          <td style="text-align: center">0.89</td>
          <td style="text-align: center">0.89</td>
          <td style="text-align: center"><strong>0.91</strong></td>
          <td style="text-align: center">0.90</td>
      </tr>
      <tr>
          <td>ClinTox (1,478)</td>
          <td style="text-align: center"><strong>0.77</strong></td>
          <td style="text-align: center">0.76</td>
          <td style="text-align: center">0.71</td>
          <td style="text-align: center">0.77</td>
          <td style="text-align: center">0.73</td>
      </tr>
      <tr>
          <td>Tox21 (7,831)</td>
          <td style="text-align: center">0.79</td>
          <td style="text-align: center"><strong>0.83</strong></td>
          <td style="text-align: center">0.81</td>
          <td style="text-align: center">0.82</td>
          <td style="text-align: center">0.82</td>
      </tr>
      <tr>
          <td>BBBP (2,039)</td>
          <td style="text-align: center">0.90</td>
          <td style="text-align: center">0.91</td>
          <td style="text-align: center">0.90</td>
          <td style="text-align: center"><strong>0.92</strong></td>
          <td style="text-align: center">0.89</td>
      </tr>
      <tr>
          <td>JAK3 (886)</td>
          <td style="text-align: center">0.79</td>
          <td style="text-align: center"><strong>0.80</strong></td>
          <td style="text-align: center">0.70</td>
          <td style="text-align: center">0.78</td>
          <td style="text-align: center">0.76</td>
      </tr>
      <tr>
          <td>BioDeg (1,737)</td>
          <td style="text-align: center">0.92</td>
          <td style="text-align: center"><strong>0.93</strong></td>
          <td style="text-align: center">0.91</td>
          <td style="text-align: center">0.93</td>
          <td style="text-align: center">0.92</td>
      </tr>
      <tr>
          <td>RP AR (930)</td>
          <td style="text-align: center">0.85</td>
          <td style="text-align: center"><strong>0.87</strong></td>
          <td style="text-align: center">0.83</td>
          <td style="text-align: center">0.87</td>
          <td style="text-align: center">0.86</td>
      </tr>
  </tbody>
</table>
<h3 id="key-comparisons">Key Comparisons</h3>
<p>Baselines include descriptor-based methods (the best from LibSVM, Random Forest, XGBoost, ASNN, and DNNs), direct SMILES-based models with augmentation, and CDDD descriptors analyzed by the same classical ML methods. CDDD descriptors come from the Sml2canSml autoencoder approach, which produces fixed 512-dimensional vectors.</p>
<p>Transformer-CNN with augmentation matches or exceeds all baselines on 14 of 18 datasets. The effect of augmentation is dramatic: without it, Transformer-CNN underperforms substantially (e.g., BCF drops from 0.85 to 0.71, JAK3 from 0.78 to 0.70). This confirms that the internal consensus from multiple SMILES representations is essential to the method&rsquo;s effectiveness.</p>
<p>A practical advantage over CDDD is that Transformer-CNN imposes no constraints on molecular properties (CDDD requires logP in (-5, 7), molecular weight under 12,600, 3-50 heavy atoms, and organic molecules only), since the Transformer was trained on the full diversity of ChEMBL.</p>
<h3 id="interpretability-case-studies">Interpretability Case Studies</h3>
<p>For <a href="https://en.wikipedia.org/wiki/Ames_test">AMES</a> mutagenicity, the LRP analysis of 1-Bromo-4-nitrobenzene correctly identifies the nitro group and halogen as structural alerts, consistent with known mutagenicity rules. For aqueous solubility of <a href="https://en.wikipedia.org/wiki/Haloperidol">haloperidol</a>, the model assigns positive contributions to hydroxyl, carbonyl, and aliphatic nitrogen groups (which increase solubility) and negative contributions to aromatic carbons (which decrease it). Both cases align with established chemical knowledge, supporting the trustworthiness of the model.</p>
<h2 id="effective-transfer-learning-for-small-qsar-datasets">Effective Transfer Learning for Small QSAR Datasets</h2>
<p>Transformer-CNN achieves competitive or superior QSAR performance across 18 diverse benchmarks by combining three ingredients: (1) Transformer-based pre-training via SMILES canonicalization, (2) SMILES augmentation during training and inference, and (3) a lightweight CNN head. The method requires minimal hyperparameter tuning, as the Transformer weights are frozen and the CNN architecture is fixed.</p>
<p>The authors acknowledge several limitations and future directions:</p>
<ul>
<li>Stereochemistry canonicalization accuracy is low (37.2%), which could impact models for stereo-sensitive properties</li>
<li>The LRP interpretability depends on sufficient relevance propagation (at least 50% reaching the input layer)</li>
<li>The variance among augmented SMILES predictions could serve as a confidence estimate, but this is left to future work</li>
<li>Applicability domain assessment based on SMILES reconstruction quality is proposed but not fully developed</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ChEMBL (SMILES &lt;= 110 chars)</td>
          <td>17.7M pairs</td>
          <td>10x augmentation + 1 identity pair per molecule</td>
      </tr>
      <tr>
          <td>Validation (canon.)</td>
          <td>Generated ChEMBL-like SMILES</td>
          <td>500,000</td>
          <td>From a molecular generator</td>
      </tr>
      <tr>
          <td>QSAR benchmarks</td>
          <td>9 regression + 9 classification</td>
          <td>378-41,127</td>
          <td>Available on OCHEM (<a href="https://ochem.eu">https://ochem.eu</a>)</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Transformer: 3 layers, 10 self-attention heads, character-level tokenization (66 symbols)</li>
<li>TextCNN: 12 kernel sizes (1-10, 15, 20) with 100-200 filters each, GlobalMaxPool, Dense(512), Highway, Dropout(0.25)</li>
<li>Augmentation: n=10 non-canonical SMILES per molecule during training and inference</li>
<li>LRP: signal-take-all redistribution for Highway gates, standard LRP for Dense and Conv layers</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Transformer encoder weights pre-trained on canonicalization task (frozen during QSAR training)</li>
<li>QSAR CNN trained with Adam optimizer, learning rate $10^{-4}$, early stopping</li>
<li>Pre-trained embeddings and standalone prediction models available in the GitHub repository</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>Regression: coefficient of determination $r^2 = 1 - SS_{\text{res}} / SS_{\text{tot}}$</li>
<li>Classification: Area Under the ROC Curve (AUC)</li>
<li>Five-fold cross-validation with bootstrap standard errors</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>NVIDIA Quadro P6000, Titan Xp, and Titan V GPUs (donated by NVIDIA)</li>
<li>TensorFlow v1.12.0, RDKit v2018.09.2</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/bigchem/transformer-cnn">transformer-cnn</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Source code, pre-trained embeddings, standalone prediction models</td>
      </tr>
      <tr>
          <td><a href="https://ochem.eu">OCHEM</a></td>
          <td>Other</td>
          <td>N/A</td>
          <td>Online platform hosting the method, training datasets, and models</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Karpov, P., Godin, G., &amp; Tetko, I. V. (2020). Transformer-CNN: Swiss knife for QSAR modeling and interpretation. <em>Journal of Cheminformatics</em>, 12, 17. <a href="https://doi.org/10.1186/s13321-020-00423-w">https://doi.org/10.1186/s13321-020-00423-w</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{karpov2020transformer,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Transformer-{CNN}: Swiss knife for {QSAR} modeling and interpretation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Karpov, Pavel and Godin, Guillaume and Tetko, Igor V.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{17}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-020-00423-w}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>t-SMILES: Tree-Based Fragment Molecular Encoding</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/t-smiles-fragment-molecular-representation/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/t-smiles-fragment-molecular-representation/</guid><description>t-SMILES encodes fragmented molecules as SMILES-type strings via breadth-first traversal of full binary trees, reducing nesting depth and improving generation.</description><content:encoded><![CDATA[<h2 id="a-fragment-based-molecular-representation-method">A Fragment-Based Molecular Representation Method</h2>
<p>This is a <strong>Method</strong> paper that proposes t-SMILES (tree-based SMILES), a framework for representing molecules as SMILES-type strings derived from fragment-based decompositions. The primary contribution is an encoding algorithm that converts fragmented molecular graphs into full binary trees (FBTs) and then traverses them breadth-first to produce linear strings. Three coding variants are introduced: TSSA (shared atom), TSDY (dummy atom without ID), and TSID (dummy atom with ID). The framework achieves 100% theoretical validity, higher novelty scores, and improved distribution-learning metrics compared to classical <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>, <a href="/notes/chemistry/molecular-representations/notations/deepsmiles-adaptation-for-ml/">DeepSMILES</a>, and <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> across ChEMBL, ZINC, and <a href="/notes/chemistry/datasets/qm9/">QM9</a> benchmarks.</p>
<h2 id="why-fragment-based-representations-matter-for-molecular-generation">Why Fragment-Based Representations Matter for Molecular Generation</h2>
<p>Classical SMILES encodes molecules via depth-first traversal of the molecular graph, requiring parentheses and ring identifiers to appear in matched pairs with deep nesting. When generative models (LSTM, Transformer) are trained on SMILES, they produce chemically invalid strings, particularly on small datasets, because they struggle to learn these long-range pairing constraints. DeepSMILES addresses some syntactical issues but still permits semantic violations (e.g., oxygen with three bonds). SELFIES guarantees 100% valid strings but at the cost of readability and, as the authors show, lower <a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">FCD</a> scores indicating generated molecules diverge from the training distribution.</p>
<p>Fragment-based approaches reduce the search space compared to atom-level methods and can provide insights into molecular recognition (e.g., protein-ligand interactions). However, existing fragment-based deep learning methods rely on fixed dictionaries of candidate fragments, creating in-vocabulary/out-of-vocabulary problems and high-dimensional sparse representations. The encoding of fragments as SMILES-type strings, rather than dictionary IDs, had not been systematically explored before this work.</p>
<p>The authors draw on the observation that fragments in organic molecules follow a <a href="https://en.wikipedia.org/wiki/Zipf's_law">Zipf-like</a> rank distribution similar to words in natural language, motivating the use of NLP techniques for fragment-based molecular modeling.</p>
<h2 id="core-innovation-binary-tree-encoding-of-fragmented-molecules">Core Innovation: Binary Tree Encoding of Fragmented Molecules</h2>
<p>The t-SMILES algorithm proceeds in three steps:</p>
<ol>
<li><strong>Fragmentation</strong>: A molecule is decomposed into valid chemical fragments using a chosen algorithm (JTVAE, BRICS, <a href="https://en.wikipedia.org/wiki/Matched_molecular_pair_analysis">MMPA</a>, or Scaffold), producing a fragmented molecular graph.</li>
<li><strong>Tree construction</strong>: The fragmented graph is converted into an Acyclic Molecular Tree (AMT), which is a reduced graph where nodes represent fragments and edges represent bonds between them. The AMT is then transformed into a Full Binary Tree (FBT), where every internal node has exactly two children.</li>
<li><strong>String generation</strong>: The FBT is traversed using breadth-first search (BFS) to produce the t-SMILES string.</li>
</ol>
<p>The framework introduces only two new symbols beyond standard SMILES: <code>&amp;</code> marks empty tree nodes (branch terminators providing global structural information), and <code>^</code> separates adjacent substructure segments (analogous to spaces between words in English).</p>
<h3 id="three-coding-variants">Three Coding Variants</h3>
<ul>
<li><strong>TSSA</strong> (shared atom): Two fragments share a real atom at their connection point. Produces the highest novelty scores and is recommended for goal-directed tasks.</li>
<li><strong>TSDY</strong> (dummy atom, no ID): Uses dummy atoms (marked with <code>*</code>) to indicate bonding points. Provides a balanced choice between novelty and distribution fidelity.</li>
<li><strong>TSID</strong> (dummy atom with ID): Uses numbered dummy atoms (<code>[n*]</code>) for unambiguous reconstruction. Produces the most faithful distribution reproduction and is recommended for distribution-learning tasks.</li>
</ul>
<h3 id="structural-advantages">Structural Advantages</h3>
<p>The key structural benefit is a dramatic reduction in nesting depth. For TSDY_M on ChEMBL, the proportion of tokens at nesting depth 0-1-2 increases from 68.0% (SMILES) to 99.3%, while depth 3-4-5 drops from 31.9% to 0.7%, and depth 6-11 drops from 0.1% to 0.0002%. The <code>&amp;</code> symbol, which encodes molecular topology, does not need to appear in pairs (unlike parentheses in SMILES), and its high frequency means it does not create a scarcity problem for learning.</p>
<p>The framework also supports a multi-code system where classical SMILES can be integrated as a special case called TS_Vanilla, and multiple fragmentation-based codes can be combined into hybrid models.</p>
<h3 id="reconstruction-and-data-augmentation">Reconstruction and Data Augmentation</h3>
<p>Molecules can be reconstructed from t-SMILES strings by reversing the process: rebuilding the FBT from the string, converting to AMT, and assembling fragments into a molecular graph. This reconstruction process can itself generate novel molecules without any model training by randomly assembling fragments. On ChEMBL, TSSA reconstruction achieves uniqueness above 0.98 and novelty above 0.68 for all four fragmentation algorithms, with 100% validity.</p>
<p>Data augmentation in t-SMILES operates at four levels: (1) different decomposition algorithms, (2) reconstruction, (3) enumeration of fragment strings, and (4) enumeration of FBTs. Unlike <a href="/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/">SMILES enumeration</a> (which only produces different strings for the same molecule), t-SMILES reconstruction generates genuinely different molecules from the same fragment set.</p>
<h2 id="systematic-evaluation-across-multiple-benchmarks">Systematic Evaluation Across Multiple Benchmarks</h2>
<p>All experiments use MolGPT (a Transformer-decoder model) as the primary generative model. Three types of metrics are employed: distribution-learning benchmarks, goal-directed benchmarks, and Wasserstein distance metrics for physicochemical properties.</p>
<h3 id="low-resource-datasets-jnk3-and-aid1706">Low-Resource Datasets (JNK3 and AID1706)</h3>
<p>On <a href="https://en.wikipedia.org/wiki/MAPK10">JNK3</a> (923 active molecules), the authors investigate overfitting behavior across training epochs:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Valid</th>
          <th>Novelty</th>
          <th>FCD</th>
          <th>Active Novel</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SMILES [R200]</td>
          <td>0.795</td>
          <td>0.120</td>
          <td>0.584</td>
          <td>0.072</td>
      </tr>
      <tr>
          <td>SMILES [R2000]</td>
          <td>1.000</td>
          <td>0.001</td>
          <td>0.765</td>
          <td>0.004</td>
      </tr>
      <tr>
          <td>SELFIES [R200]</td>
          <td>1.000</td>
          <td>0.238</td>
          <td>0.544</td>
          <td>0.148</td>
      </tr>
      <tr>
          <td>SELFIES [R2000]</td>
          <td>1.000</td>
          <td>0.008</td>
          <td>0.767</td>
          <td>0.050</td>
      </tr>
      <tr>
          <td>TSSA_S [R300]</td>
          <td>1.000</td>
          <td>0.833</td>
          <td>0.564</td>
          <td>0.582</td>
      </tr>
      <tr>
          <td>TSSA_S [R5000]</td>
          <td>1.000</td>
          <td>0.817</td>
          <td>0.608</td>
          <td>0.564</td>
      </tr>
      <tr>
          <td>TF_TSSA_S [R5]</td>
          <td>1.000</td>
          <td>0.932</td>
          <td>0.483</td>
          <td>0.710</td>
      </tr>
      <tr>
          <td>TSSA_S_Rec50 [R10]</td>
          <td>1.000</td>
          <td>0.962</td>
          <td>0.389</td>
          <td>0.829</td>
      </tr>
  </tbody>
</table>
<p>Key findings: SMILES and DeepSMILES novelty scores collapse to near zero after 200 epochs, while t-SMILES novelty stabilizes around 0.8. The highest active-novel score of 0.829 comes from t-SMILES with reconstruction-based data augmentation. Transfer learning with t-SMILES maintains novelty of 0.710 at 5 epochs versus 0.526 for SMILES, and at 100 epochs the gap widens dramatically (0.569 vs. 0.023).</p>
<h3 id="distribution-learning-on-chembl">Distribution Learning on ChEMBL</h3>
<p>t-SMILES models outperform graph baselines (Graph MCTS, hG2G, MGM) and fragment-based methods (FASMIFRA). TSID_B and TSID_S achieve FCD scores of 0.909 while maintaining novelty of 0.941 and 0.933, surpassing SMILES (FCD 0.906, novelty 0.907) in both dimensions. TSDY and TSID models consistently outperform TSSA on distribution fidelity for larger molecules.</p>
<h3 id="goal-directed-tasks-on-chembl">Goal-Directed Tasks on ChEMBL</h3>
<p>On 20 <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> subtasks, different fragmentation algorithms excel at different tasks. The goal-directed reconstruction algorithm significantly outperforms random reconstruction. On the <a href="https://en.wikipedia.org/wiki/Sitagliptin">Sitagliptin</a> MPO task (T16.SMPO), the TSDY_M model with goal-directed reconstruction achieves a score of 0.930, compared to 0.598 for SMILES and 0.708 for CReM. On <a href="https://en.wikipedia.org/wiki/Valsartan">Valsartan</a> SMARTS (T18.VS), t-SMILES models reach 0.997 versus 0.985 for SMILES.</p>
<h3 id="distribution-learning-on-zinc-and-qm9">Distribution Learning on ZINC and QM9</h3>
<p>On ZINC, t-SMILES models significantly outperform existing fragment-based baselines (JTVAE, FragDgm). Seven t-SMILES models achieve both higher FCD and novelty scores than SELFIES. On QM9 (smaller molecules), all string-based models achieve high FCD scores (above 0.960), with t-SMILES performing better than existing string and graph approaches.</p>
<h3 id="physicochemical-properties">Physicochemical Properties</h3>
<p>Across ChEMBL and ZINC, TSDY and TSID models capture physicochemical property distributions (MolWt, LogP, SAScore, N_Atoms, N_Rings, etc.) more faithfully than TSSA models. Multiple t-SMILES models outperform SMILES in more than four out of nine property categories. Baseline models hG2G and JTVAE show the weakest pattern learning, producing molecules with fewer atoms and rings than the training data.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="main-results">Main Results</h3>
<ol>
<li>t-SMILES achieves 100% theoretical validity by fragmenting molecules into chemically valid pieces before encoding.</li>
<li>The framework avoids the overfitting problem on low-resource datasets, maintaining stable novelty scores where SMILES, DeepSMILES, and SELFIES collapse.</li>
<li>The multi-code system allows different coding algorithms to complement each other, with hybrid models accessing broader chemical space.</li>
<li>Goal-directed reconstruction significantly outperforms all baselines on targeted optimization tasks.</li>
<li>TSDY and TSID provide better distribution fidelity than TSSA on larger molecules, while TSSA excels at novelty generation for goal-directed tasks.</li>
</ol>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge several limitations:</p>
<ul>
<li>Whether the tree structure of t-SMILES can be effectively learned by Large Language Models remains unexplored.</li>
<li>Only published fragmentation algorithms were tested; custom fragmentation schemes were not investigated.</li>
<li>Experiments on more complex (larger) molecules were not performed.</li>
<li>The reconstruction algorithm uses simple rules for fragment assembly; more sophisticated assembly methods (Monte Carlo tree search, CReM) could improve quality.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The authors suggest exploring advanced reconstruction and optimization algorithms, improved generative models, evolutionary techniques, and extending t-SMILES to property prediction, retrosynthesis, and reaction prediction tasks. The framework is also extensible to other string representations (t-DSMILES, t-SELFIES) by changing how fragments are encoded.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Low-resource evaluation</td>
          <td>JNK3</td>
          <td>923 active molecules</td>
          <td>Kinase inhibitors</td>
      </tr>
      <tr>
          <td>Low-resource evaluation</td>
          <td>AID1706</td>
          <td>329 active molecules</td>
          <td>SARS 3CLPro inhibitors</td>
      </tr>
      <tr>
          <td>Distribution learning</td>
          <td>ChEMBL</td>
          <td>Standard split</td>
          <td>Large drug-like molecules</td>
      </tr>
      <tr>
          <td>Distribution learning</td>
          <td>ZINC</td>
          <td>250K subset</td>
          <td>Medium drug-like molecules</td>
      </tr>
      <tr>
          <td>Distribution learning</td>
          <td>QM9</td>
          <td>~134K molecules</td>
          <td>Small organic molecules</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Fragmentation</strong>: JTVAE, BRICS, MMPA, Scaffold (all via <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a>)</li>
<li><strong>Tree construction</strong>: AMT from reduced graph, then FBT transformation</li>
<li><strong>Traversal</strong>: Breadth-first search on FBT</li>
<li><strong>Generative model</strong>: MolGPT (Transformer decoder)</li>
<li><strong>Discriminative model</strong>: AttentiveFP for activity prediction on JNK3/AID1706</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity</td>
          <td>Fraction of generated strings that decode to valid molecules</td>
      </tr>
      <tr>
          <td>Uniqueness</td>
          <td>Fraction of distinct molecules among valid generations</td>
      </tr>
      <tr>
          <td>Novelty</td>
          <td>Fraction of generated molecules not in training set</td>
      </tr>
      <tr>
          <td>KLD</td>
          <td>Kullback-Leibler divergence for physicochemical property distributions</td>
      </tr>
      <tr>
          <td>FCD</td>
          <td>Frechet ChemNet Distance measuring chemical similarity to training set</td>
      </tr>
      <tr>
          <td>Active Novel</td>
          <td>Novel molecules predicted active by AttentiveFP</td>
      </tr>
  </tbody>
</table>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/juanniwu/t-SMILES">t-SMILES GitHub</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation with training/generation scripts</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/ZENODO.10991703">Zenodo deposit</a></td>
          <td>Code + Data</td>
          <td>CC-BY-4.0</td>
          <td>Archived code and data</td>
      </tr>
      <tr>
          <td><a href="https://codeocean.com/capsule/3034546/tree">Code Ocean capsule</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Certified reproducible compute capsule</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper mentions limited computational resources but does not specify exact GPU types or training times.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Wu, J.-N., Wang, T., Chen, Y., Tang, L.-J., Wu, H.-L., &amp; Yu, R.-Q. (2024). t-SMILES: a fragment-based molecular representation framework for de novo ligand design. <em>Nature Communications</em>, 15, 4993.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{wu2024tsmiles,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{t-SMILES: a fragment-based molecular representation framework for de novo ligand design}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Wu, Juan-Ni and Wang, Tong and Chen, Yue and Tang, Li-Juan and Wu, Hai-Long and Yu, Ru-Qin}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Communications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{4993}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41467-024-49388-6}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Survey of Transformer Architectures in Molecular Science</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/transformers-molecular-science-review/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/transformers-molecular-science-review/</guid><description>A comprehensive review of 12 transformer architectures applied to molecular science, covering GPT, BERT, BART, graph transformers, and more.</description><content:encoded><![CDATA[<h2 id="a-systematization-of-transformer-architectures-for-molecular-science">A Systematization of Transformer Architectures for Molecular Science</h2>
<p>This paper is a <strong>Systematization</strong> review. It organizes and taxonomizes 12 families of transformer architectures that have been applied across molecular science, including chemistry, biology, and drug discovery. The primary contribution is not a new method or dataset, but a structured technical overview of the algorithmic internals of each transformer variant and their specific applications to molecular problems. The review covers 201 references and provides a unified treatment of how these architectures capture molecular patterns from sequential, graphical, and image-based data.</p>
<h2 id="bridging-the-gap-between-transformer-variants-and-molecular-applications">Bridging the Gap Between Transformer Variants and Molecular Applications</h2>
<p>Transformer-based models have become widespread in molecular science, yet the authors identify a gap: there is no organized taxonomy linking these diverse techniques in the existing literature. Individual papers introduce specific architectures or applications, but practitioners lack a unified reference that explains the technical differences between GPT, BERT, BART, graph transformers, and other variants in the context of molecular data. The review aims to fill this gap by providing an in-depth investigation of the algorithmic components of each model family, explaining how their architectural innovations contribute to processing complex molecular data. The authors note that the success of transformers in molecular science stems from several factors: the sequential nature of chemical and biological molecules (DNA, RNA, proteins, SMILES strings), the attention mechanism&rsquo;s ability to capture long-range dependencies within molecular structures, and the capacity for transfer learning through pre-training on large chemical and biological datasets.</p>
<h2 id="twelve-transformer-families-and-their-molecular-mechanisms">Twelve Transformer Families and Their Molecular Mechanisms</h2>
<p>The review covers transformer preliminaries before diving into 12 specific architecture families. The core self-attention mechanism computes:</p>
<p>$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$</p>
<p>where $d_k$ is the dimension of the key vectors. The position-wise feed-forward network is:</p>
<p>$$
\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2
$$</p>
<p>The 12 architecture families covered are:</p>
<ol>
<li>
<p><strong>GPT (Generative Pre-trained Transformer)</strong>: Uses the decoder part of the transformer for autoregressive generation. Applications include MolGPT for molecular generation, DrugGPT for protein-ligand binding, and cMolGPT for target-specific de novo molecular generation.</p>
</li>
<li>
<p><strong>BERT (Bidirectional Encoder Representations from Transformers)</strong>: Uses transformer encoders with masked language modeling and next-sentence prediction for pre-training. Molecular applications include FP-BERT for molecular property prediction using composite fingerprint representations, Graph-BERT for protein-protein interaction identification, SMILES-BERT, and Mol-BERT.</p>
</li>
<li>
<p><strong>BART (Bidirectional and Auto-Regressive Transformers)</strong>: Functions as a denoising autoencoder with both encoder and decoder. Molecular applications include Chemformer for sequence-to-sequence chemistry tasks, MS2Mol for mass spectrometry analysis, and MolBART for molecular feature learning.</p>
</li>
<li>
<p><strong>Graph Transformer</strong>: Leverages self-attention on graph-structured data to capture global context. Applications include GraphSite for protein-DNA binding site prediction (using AlphaFold2 structure predictions), KPGT for knowledge-guided molecular graph pre-training, and PAGTN for establishing long-range dependencies in molecular graphs.</p>
</li>
<li>
<p><strong>Transformer-XL</strong>: Incorporates relative positional encoding for modeling long sequences. Used for small molecule retention time prediction, drug design with CHEMBL data (1.27 million molecules), and Heck reaction generation.</p>
</li>
<li>
<p><strong><a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5 (Text-to-Text Transfer Transformer)</a></strong>: Unifies NLP tasks into text-to-text mapping. T5Chem was pre-trained on 97 million molecules from PubChem and achieved 99.5% accuracy on reaction classification (USPTO 500 MT). C5T5 uses IUPAC naming for molecular optimization in drug discovery.</p>
</li>
<li>
<p><strong>Vision Transformer (ViT)</strong>: Applies transformer architecture to image patches. Used for organic molecule classification (97% accuracy with WGAN-generated data), bacterial identification via SERS, and molecular property prediction from mass spectrometry data (TransG-Net).</p>
</li>
<li>
<p><strong>DETR (Detection Transformer)</strong>: End-to-end object detection using transformers. Applied to cryo-EM particle picking (TransPicker), molecular structure image recognition (IMG2SMI), and cell segmentation (Cell-DETR).</p>
</li>
<li>
<p><strong>Conformer</strong>: Integrates convolutional modules into transformer structure. Used for DNA storage error correction (RRCC-DNN), drug-target affinity prediction (NG-DTA with Davis and Kiba datasets).</p>
</li>
<li>
<p><strong>CLIP (Contrastive Language-Image Pre-training)</strong>: Multimodal learning linking text and images. Applied to peptide design (Cut&amp;CLIP for protein degradation), gene identification (pathCLIP), and drug discovery (CLOOME for zero-shot transfer learning).</p>
</li>
<li>
<p><strong>Sparse Transformers</strong>: Use sparse attention matrices to reduce complexity to $O(n\sqrt{n})$. Applied to drug-target interaction prediction with gated cross-attention mechanisms.</p>
</li>
<li>
<p><strong>Mobile and Efficient Transformers</strong>: Compressed variants (TinyBERT, MobileBERT) for resource-constrained environments. Molormer uses ProbSparse self-attention for drug-drug interaction prediction. LOGO is a lightweight pre-trained language model for non-coding genome interpretation.</p>
</li>
</ol>
<h2 id="survey-organization-and-coverage-of-molecular-domains">Survey Organization and Coverage of Molecular Domains</h2>
<p>As a survey paper, this work does not present new experiments. Instead, it catalogues existing applications across multiple molecular domains:</p>
<p><strong>Drug Discovery and Design</strong>: GPT-based ligand design (DrugGPT), BART-based molecular generation (Chemformer, MolBART), graph transformer pre-training for molecular property prediction (KPGT), T5-based chemical reaction prediction (T5Chem), and sparse transformer methods for drug-target interactions.</p>
<p><strong>Protein Science</strong>: BERT-based protein-protein interaction prediction (Graph-BERT), graph transformer methods for protein-DNA binding (GraphSite with AlphaFold2 integration), conformer-based drug-target affinity prediction (NG-DTA), and CLIP-based peptide design (Cut&amp;CLIP).</p>
<p><strong>Molecular Property Prediction</strong>: FP-BERT for fingerprint-based prediction, SMILES-BERT and Mol-BERT for end-to-end prediction from SMILES, KPGT for knowledge-guided graph pre-training, and Transformer-XL for property modeling with relative positional encoding.</p>
<p><strong>Structural Biology</strong>: DETR-based cryo-EM particle picking (TransPicker), vision transformer applications in cell imaging, and Cell-DETR for instance segmentation in microscopy.</p>
<p><strong>Genomics</strong>: Conformer-based DNA storage error correction (RRCC-DNN), LOGO for non-coding genome interpretation, and MetaTransformer for metagenomic sequencing analysis.</p>
<h2 id="future-directions-and-limitations-of-the-survey">Future Directions and Limitations of the Survey</h2>
<p>The review concludes with four future directions:</p>
<ol>
<li>
<p><strong>ChatGPT integration into molecular science</strong>: Using LLMs for data analysis, literature review, and hypothesis generation in chemistry and biology.</p>
</li>
<li>
<p><strong>Multifunction transformers</strong>: Models that extract features across diverse molecular structures and sequences simultaneously.</p>
</li>
<li>
<p><strong>Molecular-aware transformers</strong>: Architectures that handle multiple data types (text, sequence, structure, image, energy, molecular dynamics, function) in a unified framework.</p>
</li>
<li>
<p><strong>Self-assessment transformers and superintelligence</strong>: Speculative discussion of models that learn from seemingly unrelated data sources.</p>
</li>
</ol>
<p>The review has several limitations worth noting. The coverage is broad but shallow: each architecture family receives only 1-2 pages of discussion, and the paper largely describes existing work rather than critically evaluating it. The review does not systematically compare the architectures against each other on common benchmarks. The future directions section (particularly the superintelligence discussion) is speculative and lacks concrete proposals. The paper also focuses primarily on technical architecture descriptions rather than analyzing failure modes, scalability challenges, or reproducibility concerns across the surveyed methods. As a review article, no new data were created or analyzed.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>This is a survey paper. No new datasets were created or used. The paper reviews applications involving datasets such as PubChem (97 million molecules for T5Chem), CHEMBL (1.27 million molecules for Transformer-XL drug design), USPTO 500 MT (reaction classification), ESOL (5,328 molecules for property prediction), and Davis/Kiba (drug-target affinity).</p>
<h3 id="algorithms">Algorithms</h3>
<p>No new algorithms are introduced. The paper provides mathematical descriptions of the core transformer components (self-attention, positional encoding, feed-forward networks, layer normalization) and describes how 12 architecture families modify these components.</p>
<h3 id="models">Models</h3>
<p>No new models are presented. The paper surveys existing models including MolGPT, DrugGPT, FP-BERT, SMILES-BERT, Chemformer, MolBART, GraphSite, KPGT, T5Chem, TransPicker, Cell-DETR, CLOOME, and Molormer, among others.</p>
<h3 id="evaluation">Evaluation</h3>
<p>No new evaluation is performed. Performance numbers cited from the literature include: T5Chem reaction classification accuracy of 99.5%, ViT organic molecule classification at 97%, Transformer-XL property prediction RMSE of 0.6 on ESOL, and Heck reaction generation feasibility rate of 47.76%.</p>
<h3 id="hardware">Hardware</h3>
<p>No hardware requirements are specified, as this is a survey paper.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://onlinelibrary.wiley.com/doi/pdfdirect/10.1002/wcms.1725">Paper (open access)</a></td>
          <td>Paper</td>
          <td>CC-BY-NC-ND</td>
          <td>Open access via Wiley</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Jiang, J., Ke, L., Chen, L., Dou, B., Zhu, Y., Liu, J., Zhang, B., Zhou, T., &amp; Wei, G.-W. (2024). Transformer technology in molecular science. <em>WIREs Computational Molecular Science</em>, 14(4), e1725. <a href="https://doi.org/10.1002/wcms.1725">https://doi.org/10.1002/wcms.1725</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{jiang2024transformer,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Transformer technology in molecular science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Jiang, Jian and Ke, Lu and Chen, Long and Dou, Bozheng and Zhu, Yueying and Liu, Jie and Zhang, Bengong and Zhou, Tianshou and Wei, Guo-Wei}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{WIREs Computational Molecular Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{14}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{e1725}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Wiley}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1002/wcms.1725}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Survey of Scientific LLMs in Bio and Chem Domains</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/scientific-llm-survey-bio-chem/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/scientific-llm-survey-bio-chem/</guid><description>Survey of scientific LLMs covering textual, molecular, protein, genomic, and multimodal models for biological and chemical research.</description><content:encoded><![CDATA[<h2 id="a-systematization-of-scientific-language-models">A Systematization of Scientific Language Models</h2>
<p>This paper is a <strong>Systematization</strong> (survey) that provides a comprehensive review of scientific large language models (Sci-LLMs) designed for biological and chemical domains. The survey covers five main branches of scientific language modeling: textual, molecular, protein, genomic, and multimodal LLMs. For each branch, the authors analyze model architectures, capabilities, training datasets, evaluation benchmarks, and assessment criteria, then identify open challenges and future research directions.</p>
<h2 id="motivation-bridging-scientific-languages-and-llms">Motivation: Bridging Scientific Languages and LLMs</h2>
<p>Large language models have demonstrated strong capabilities in natural language understanding, but scientific research involves specialized &ldquo;languages&rdquo; that differ fundamentally from natural text. Chemical molecules are expressed as <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> or <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> strings, proteins as amino acid sequences, and genomes as nucleotide sequences. Each of these language systems has its own vocabulary and grammar. General-purpose LLMs like ChatGPT and GPT-4 often fail to properly handle these scientific data types because the semantics and grammar of scientific languages diverge substantially from natural language.</p>
<p>Prior surveys have focused on individual modalities (molecules, proteins, or genomes) in isolation. No comprehensive review had unified these language modeling advances into a single framework. This survey fills that gap by systematically covering all five modalities and, notably, the emerging area of multimodal Sci-LLMs that integrate multiple scientific languages.</p>
<h2 id="taxonomy-of-scientific-language-models">Taxonomy of Scientific Language Models</h2>
<p>The survey organizes Sci-LLMs into a clear taxonomic framework built on two axes: the scientific language modality and the model architecture type.</p>
<h3 id="scientific-language-modalities">Scientific Language Modalities</h3>
<p>The authors define five categories of Sci-LLMs:</p>
<ol>
<li>
<p><strong>Text-Sci-LLMs</strong>: LLMs trained on scientific textual corpora (medical, biological, chemical, and comprehensive domains). Examples include BioBERT, BioGPT, ChemBERT, SciBERT, and <a href="/notes/chemistry/llm-applications/galactica-large-language-model-for-science/">Galactica</a>.</p>
</li>
<li>
<p><strong>Mol-LLMs</strong>: Models that process molecular languages (SMILES, SELFIES, <a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a>). These include encoder-only models like <a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a> and <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MolFormer</a> for property prediction, decoder-only models like MolGPT for molecular generation, and encoder-decoder models like Molecular Transformer and <a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">Chemformer</a> for reaction prediction.</p>
</li>
<li>
<p><strong>Prot-LLMs</strong>: Models operating on protein amino acid sequences. The ESM series (ESM-1b, ESM-2) and ProtTrans serve as encoders for function and structure prediction, while ProGen and ProtGPT2 generate novel protein sequences.</p>
</li>
<li>
<p><strong>Gene-LLMs</strong>: Models for DNA and RNA sequences, including DNABERT, Nucleotide Transformer, HyenaDNA, and Evo, covering tasks from variant effect prediction to genome-scale sequence modeling.</p>
</li>
<li>
<p><strong>MM-Sci-LLMs</strong>: Multimodal models integrating multiple scientific data types (molecule-text, protein-text, gene-cell-text, molecule-protein), such as MoleculeSTM, <a href="/notes/chemistry/molecular-representations/multimodal/biot5-cross-modal-biology/">BioT5</a>, Mol-Instructions, and BioMedGPT.</p>
</li>
</ol>
<h3 id="architecture-classification">Architecture Classification</h3>
<p>For each modality, models are categorized into three architecture types:</p>
<ul>
<li><strong>Encoder-only</strong>: Based on BERT/RoBERTa, these models learn fixed-size representations via masked language modeling. They excel at discriminative tasks like property prediction and classification.</li>
<li><strong>Decoder-only</strong>: Based on GPT, these models perform autoregressive generation. They are used for de novo molecule design, protein sequence generation, and DNA sequence generation.</li>
<li><strong>Encoder-decoder</strong>: Based on architectures like <a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5</a> or BART, these handle sequence-to-sequence tasks such as reaction prediction, molecule captioning, and protein sequence-structure translation.</li>
</ul>
<h2 id="comprehensive-catalog-of-models-datasets-and-benchmarks">Comprehensive Catalog of Models, Datasets, and Benchmarks</h2>
<p>A central contribution of the survey is its exhaustive cataloging of resources across all five modalities. The authors compile detailed summary tables covering over 100 Sci-LLMs, their parameter counts, base architectures, training data, and capabilities.</p>
<h3 id="molecular-llms">Molecular LLMs</h3>
<p>The survey documents a rich landscape of Mol-LLMs:</p>
<p><strong>Encoder-only models</strong> for property prediction include <a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a>, ChemBERTa, <a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a>, <a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a>, MolFormer, MG-BERT, GROVER, MAT, Uni-Mol, and others. These models are pre-trained on ZINC, PubChem, or ChEMBL datasets and fine-tuned for molecular property prediction tasks on benchmarks like <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>.</p>
<p><strong>Decoder-only models</strong> for molecular generation include MolGPT, SMILES GPT, iupacGPT, cMolGPT, and Taiga. These generate SMILES strings autoregressively, often combining GPT with reinforcement learning for property optimization.</p>
<p><strong>Encoder-decoder models</strong> for reaction prediction include Molecular Transformer, Retrosynthesis Transformer, Chemformer, <a href="/notes/chemistry/molecular-representations/encoders/bartsmiles-molecular-representations/">BARTSmiles</a>, Graph2SMILES, and MOLGEN. These handle forward reaction prediction and retrosynthesis.</p>
<h3 id="key-datasets-surveyed">Key Datasets Surveyed</h3>
<p>The survey catalogs pre-training datasets and benchmarks for each modality:</p>
<table>
  <thead>
      <tr>
          <th>Modality</th>
          <th>Pre-training Sources</th>
          <th>Key Benchmarks</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Text</td>
          <td>PubMed, PMC, arXiv, Semantic Scholar</td>
          <td>MMLU, MedQA, PubMedQA, SciEval</td>
      </tr>
      <tr>
          <td>Molecule</td>
          <td>ZINC, PubChem, ChEMBL, USPTO, <a href="/notes/chemistry/datasets/gdb-17/">GDB-17</a></td>
          <td>MoleculeNet, <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a>, <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a>, SPECTRA</td>
      </tr>
      <tr>
          <td>Protein</td>
          <td>UniRef50/90/100, BFD, <a href="https://en.wikipedia.org/wiki/Protein_Data_Bank">PDB</a>, <a href="https://en.wikipedia.org/wiki/AlphaFold">AlphaFoldDB</a></td>
          <td><a href="https://en.wikipedia.org/wiki/CASP">CASP</a>, TAPE, ProteinGym, FLIP, PEER</td>
      </tr>
      <tr>
          <td>Genome</td>
          <td>GRCh38, 1000 Genomes, <a href="https://en.wikipedia.org/wiki/ENCODE">ENCODE</a></td>
          <td>NT-Bench, GenBench, BEACON</td>
      </tr>
      <tr>
          <td>Multimodal</td>
          <td>ChEBI-20, PubChemSTM, Mol-Instructions</td>
          <td>Various cross-modal retrieval and generation tasks</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation-metrics">Evaluation Metrics</h3>
<p>For molecular generation, the survey details standard metrics:</p>
<ul>
<li><strong>Validity</strong>: percentage of chemically viable molecules</li>
<li><strong>Uniqueness</strong>: fraction of distinct generated structures</li>
<li><strong>Novelty</strong>: fraction not present in the training set</li>
<li><strong>Internal diversity</strong>: measured as</li>
</ul>
<p>$$
\text{IntDiv}_{p}(G) = 1 - \sqrt[p]{\frac{1}{|G|^{2}} \sum_{m_{1}, m_{2} \in G} T(m_{1}, m_{2})^{p}}
$$</p>
<p>where $T(m_{1}, m_{2})$ is the <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> between molecules $m_{1}$ and $m_{2}$.</p>
<ul>
<li><strong>Frechet ChemNet Distance (FCD)</strong>: comparing distributions of generated and reference molecules</li>
</ul>
<p>$$
\text{FCD}(G, R) = | \mu_{G} - \mu_{R} |^{2} + \text{Tr}\left[\Sigma_{G} + \Sigma_{R} - 2(\Sigma_{G}\Sigma_{R})^{1/2}\right]
$$</p>
<p>For protein generation, analogous metrics include perplexity, Frechet Protein Distance (FPD), foldability (pLDDT), sequence recovery, and novelty (sequence identity).</p>
<h2 id="critical-challenges-and-future-directions">Critical Challenges and Future Directions</h2>
<p>The survey identifies four major challenges and seven future research directions for Sci-LLMs.</p>
<h3 id="challenges">Challenges</h3>
<ol>
<li>
<p><strong>Training data limitations</strong>: Sci-LLM training datasets are orders of magnitude smaller than those for general LLMs. ProGen was trained on 280M protein sequences (tens of billions of tokens), while ChatGPT used approximately 570 billion tokens. Scaling laws suggest larger datasets would improve performance, and advances in sequencing technologies may help close this gap.</p>
</li>
<li>
<p><strong>Architecture mismatch</strong>: Standard Transformer architectures face difficulties with scientific languages. Scientific sequences (proteins with hundreds or thousands of amino acids, DNA with millions of base pairs) are far longer than typical natural language sentences. Additionally, 3D structural information is critical for function prediction but does not naturally map to sequence tokens. Autoregressive generation is also a poor fit since biological sequences function as a whole rather than being read left-to-right.</p>
</li>
<li>
<p><strong>Evaluation gaps</strong>: Computational metrics for generated molecules and proteins provide only indirect quality measures. Wet-lab validation remains the gold standard but is beyond the scope of most AI research teams. Better computational evaluation methods that correlate with experimental outcomes are needed.</p>
</li>
<li>
<p><strong>Ethics</strong>: Sensitive biological data raises privacy concerns. The potential for misuse (e.g., generating harmful substances) requires careful safeguards. Algorithmic bias and equitable access to Sci-LLM benefits also demand attention.</p>
</li>
</ol>
<h3 id="future-directions">Future Directions</h3>
<ol>
<li><strong>Larger-scale, cross-modal training datasets</strong> with strong semantic alignment across modalities</li>
<li><strong>Incorporating 3D structural and temporal information</strong> into language-based modeling, including structural motifs as tokens</li>
<li><strong>Integration with external knowledge sources</strong> such as <a href="https://en.wikipedia.org/wiki/Gene_Ontology">Gene Ontology</a> and chemical knowledge graphs to reduce hallucination</li>
<li><strong>Coupling with physical simulation</strong> (e.g., <a href="/notes/chemistry/molecular-simulation/">molecular dynamics</a>) to ground language models in physical reality</li>
<li><strong>Augmenting Sci-LLMs with specialized tools and agents</strong>, following the success of tool-augmented general LLMs like <a href="/notes/chemistry/llm-applications/chemcrow-augmenting-llms-chemistry-tools/">ChemCrow</a></li>
<li><strong>Development of computational evaluation metrics</strong> that are both fast and accurate, enabling rapid research iteration</li>
<li><strong>Super-alignment with human ethics</strong>, ensuring ethical reasoning is deeply integrated into Sci-LLM behavior</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>This is a survey paper that does not present new experimental results. The authors catalog extensive datasets across five modalities (see tables in the paper for comprehensive listings). The survey itself is maintained as an open resource.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/HICAI-ZJU/Scientific-LLM-Survey">Scientific-LLM-Survey GitHub</a></td>
          <td>Other</td>
          <td>Not specified</td>
          <td>Curated list of papers, models, and resources</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not applicable (survey paper).</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhang, Q., Ding, K., Lyv, T., Wang, X., Yin, Q., Zhang, Y., Yu, J., Wang, Y., Li, X., Xiang, Z., Feng, K., Zhuang, X., Wang, Z., Qin, M., Zhang, M., Zhang, J., Cui, J., Huang, T., Yan, P., Xu, R., Chen, H., Li, X., Fan, X., Xing, H., &amp; Chen, H. (2025). Scientific Large Language Models: A Survey on Biological &amp; Chemical Domains. <em>ACM Computing Surveys</em>, 57(6), 1–38. <a href="https://doi.org/10.1145/3715318">https://doi.org/10.1145/3715318</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{zhang2025scientific,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Scientific Large Language Models: A Survey on Biological \&amp; Chemical Domains}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Zhang, Qiang and Ding, Keyan and Lyv, Tianwen and Wang, Xinda and Yin, Qingyu and Zhang, Yiwen and Yu, Jing and Wang, Yuhao and Li, Xiaotong and Xiang, Zhuoyi and Feng, Kehua and Zhuang, Xiang and Wang, Zeyuan and Qin, Ming and Zhang, Mengyao and Zhang, Jinlu and Cui, Jiyu and Huang, Tao and Yan, Pengju and Xu, Renjun and Chen, Hongyang and Li, Xiaolin and Fan, Xiaohui and Xing, Huabin and Chen, Huajun}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{ACM Computing Surveys}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{57}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{6}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1--38}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1145/3715318}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SPE: Data-Driven SMILES Substructure Tokenization</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/smiles-pair-encoding/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/smiles-pair-encoding/</guid><description>SMILES Pair Encoding adapts byte pair encoding to learn chemically meaningful substructure tokens from SMILES, improving generation and QSAR prediction.</description><content:encoded><![CDATA[<h2 id="a-data-driven-tokenization-method-for-chemical-deep-learning">A Data-Driven Tokenization Method for Chemical Deep Learning</h2>
<p>This is a <strong>Method</strong> paper that introduces SMILES Pair Encoding (SPE), a tokenization algorithm adapted from <a href="https://en.wikipedia.org/wiki/Byte-pair_encoding">byte pair encoding (BPE)</a> in natural language processing. The primary contribution is a data-driven approach that learns a vocabulary of high-frequency SMILES substrings from a large chemical dataset and then uses that vocabulary to tokenize SMILES for downstream deep learning tasks. The authors provide an open-source Python package (SmilesPE) and demonstrate improvements on both molecular generation and <a href="https://en.wikipedia.org/wiki/Quantitative_structure%E2%80%93activity_relationship">QSAR</a> prediction benchmarks.</p>
<h2 id="limitations-of-atom-level-smiles-tokenization">Limitations of Atom-Level SMILES Tokenization</h2>
<p>SMILES-based deep learning models require tokenization to convert molecular strings into sequences of discrete units. The standard approaches have well-known drawbacks:</p>
<ul>
<li><strong>Character-level tokenization</strong> breaks SMILES character by character, splitting chemically meaningful multi-character atoms. For example, <code>[C@@H]</code> becomes six separate tokens (<code>[</code>, <code>C</code>, <code>@</code>, <code>@</code>, <code>H</code>, <code>]</code>), losing the stereochemistry information of a single carbon.</li>
<li><strong>Atom-level tokenization</strong> addresses some of these issues by treating multi-character element symbols (Cl, Br) and bracketed atoms ([nH], [O-]) as single tokens. However, these tokens still encode only individual atoms, not substructures.</li>
<li><strong>k-mer tokenization</strong> (sequences of k consecutive overlapping characters) captures some connectivity information but suffers from the out-of-vocabulary problem: the model cannot represent k-mers not seen during training.</li>
</ul>
<p>All three approaches produce relatively long input sequences (mean ~40 tokens per molecule on ChEMBL at the atom level), which increases computational cost for sequential architectures like RNNs and exacerbates long-range dependency issues.</p>
<h2 id="core-innovation-adapting-byte-pair-encoding-for-smiles">Core Innovation: Adapting Byte Pair Encoding for SMILES</h2>
<p>SPE adapts the byte pair encoding algorithm, originally developed for data compression and later adopted for subword tokenization in NLP, to the domain of chemical strings. The algorithm has two phases:</p>
<p><strong>Vocabulary training:</strong></p>
<ol>
<li>Tokenize SMILES from a large dataset (ChEMBL) at the atom level</li>
<li>Initialize the vocabulary with all unique atom-level tokens</li>
<li>Iteratively count the frequency of all adjacent token pairs, merge the most frequent pair into a new token, and add it to the vocabulary</li>
<li>Stop when either the maximum vocabulary size (MVS) or a minimum frequency threshold (FT) is reached</li>
</ol>
<p><strong>Tokenization:</strong> Given a trained SPE vocabulary, a new SMILES string is first tokenized at the atom level, then token pairs are iteratively merged according to their frequency rank in the vocabulary until no further merges are possible.</p>
<p>The key hyperparameters are MVS and FT. In the reported experiments, MVS was set to 30,000 and FT was set to 2,000. The vocabulary was trained on ~3.4 million SMILES (both canonical and one non-canonical variant per molecule) from ChEMBL25. The resulting vocabulary contained 3,002 unique SMILES substrings with lengths ranging from 1 to 22 atom-level characters.</p>
<p>The trained SPE vocabulary produces tokens that are human-readable and correspond to chemically meaningful substructures and functional groups. SPE tokenization reduces the mean sequence length from approximately 40 tokens (atom-level) to approximately 6 tokens on ChEMBL, a roughly 6-7x compression. This shorter representation directly reduces computational cost for RNN-based and other sequential models.</p>
<p>The algorithm is also compatible with other text-based molecular representations such as <a href="/notes/chemistry/molecular-representations/notations/deepsmiles-adaptation-for-ml/">DeepSMILES</a> and <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>, since these share atom-level character structures that can serve as the starting point for pair merging.</p>
<h2 id="molecular-generation-and-qsar-prediction-experiments">Molecular Generation and QSAR Prediction Experiments</h2>
<h3 id="molecular-generation">Molecular Generation</h3>
<p>The authors trained AWD-LSTM language models with SPE and atom-level tokenization on 9 million SMILES (1 canonical + 5 non-canonical per compound from ChEMBL25). Each model sampled 1 million SMILES for evaluation. The AWD-LSTM architecture used an embedding size of 400, three LSTM layers with 1,152 hidden units each, and various dropout settings (embedding: 0.1, input: 0.6, weight: 0.5, hidden: 0.2). Models were trained for 10 epochs with a base learning rate of 0.008 using one-cycle scheduling.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>SPE</th>
          <th>Atom-level</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity</td>
          <td>0.941</td>
          <td>0.970</td>
      </tr>
      <tr>
          <td>Uniqueness</td>
          <td>0.994</td>
          <td>0.992</td>
      </tr>
      <tr>
          <td>Novelty</td>
          <td>0.983</td>
          <td>0.978</td>
      </tr>
      <tr>
          <td>Internal diversity</td>
          <td>0.897</td>
          <td>0.886</td>
      </tr>
      <tr>
          <td>Nearest neighbor similarity</td>
          <td>0.391</td>
          <td>0.386</td>
      </tr>
  </tbody>
</table>
<p>The SPE model generated a more diverse population of novel molecules at the cost of slightly lower validity (94.1% vs. 97.0%). Internal diversity is defined as:</p>
<p>$$
\text{Internal diversity} = 1 - \frac{1}{|G|} \sum_{(x_1, x_2) \in G \times G} T(x_1, x_2)
$$</p>
<p>where $T(x_1, x_2)$ is the Tanimoto similarity between molecules $x_1$ and $x_2$ using 1024-bit ECFP6 fingerprints. Nearest neighbor similarity (SNN) measures how well the generated set resembles the reference set:</p>
<p>$$
\text{SNN} = \frac{1}{|G|} \sum_{x_G \in G} \max_{x_R \in R} T(x_G, x_R)
$$</p>
<p>Substructure coverage analysis showed both models recovered the same top-1000 BRICS fragments (100% coverage), but SPE consistently outperformed atom-level tokenization on top-5000 coverage across all four substructure types: BRICS fragments (0.997 vs. 0.987), functional groups (0.688 vs. 0.659), scaffolds (0.872 vs. 0.825), and ring systems (0.781 vs. 0.761).</p>
<h3 id="qsar-prediction">QSAR Prediction</h3>
<p>QSAR models were built using the <a href="/notes/chemistry/molecular-design/property-prediction/molpmofit-transfer-learning-qsar/">MolPMoFiT transfer learning framework</a>, which pre-trains a language model on ChEMBL and then fine-tunes it for specific prediction tasks. The evaluation used 24 regression benchmarks (pIC50 values) from Cortes-Ciriano et al., covering targets ranging from 199 molecules (alpha-2a adrenergic receptor) to 5,010 molecules (<a href="https://en.wikipedia.org/wiki/KCNH2">hERG</a>). Models were evaluated on 10 random 80:10:10 splits using RMSE, R-squared, and MAE. Random forest models with 1024-bit ECFP6 were included as baseline comparisons.</p>
<p><a href="https://en.wikipedia.org/wiki/Effect_size">Cohen&rsquo;s d</a> effect sizes were computed to quantify performance differences between tokenization methods. SPE performed comparably or better than atom-level tokenization on 23 out of 24 datasets. Notable results with medium or large effect sizes favoring SPE included <a href="https://en.wikipedia.org/wiki/Cannabinoid_receptor_1">cannabinoid CB1 receptor</a> (large effect), A2a adrenergic receptor, LCK, estrogen receptor, and <a href="https://en.wikipedia.org/wiki/Aurora_kinase_A">Aurora-A kinase</a> (all medium effects). Against k-mer tokenization, SPE matched or outperformed on 22 out of 24 datasets.</p>
<p>Cohen&rsquo;s d is defined as:</p>
<p>$$
\text{Cohen&rsquo;s } d = \frac{\bar{x}_1 - \bar{x}_2}{\sqrt{(\text{SD}_1^2 + \text{SD}_2^2) / 2}}
$$</p>
<p>where $\bar{x}_1, \bar{x}_2$ are the group means and $\text{SD}_1, \text{SD}_2$ are the standard deviations. Thresholds of 0.2 (small), 0.5 (medium), and 0.8 (large) were used following standard recommendations.</p>
<p>SMILES-based deep learning models generally performed on par with or better than the RF baseline, with particularly strong advantages on the four largest datasets (<a href="https://en.wikipedia.org/wiki/Cyclooxygenase-2">COX-2</a>, <a href="https://en.wikipedia.org/wiki/Acetylcholinesterase">acetylcholinesterase</a>, erbB1, and hERG).</p>
<p>In addition to performance gains, SPE-based models trained on average 5 times faster than atom-level models due to the shorter input sequences.</p>
<h2 id="results-summary-and-future-directions">Results Summary and Future Directions</h2>
<p>The main findings of this study are:</p>
<ol>
<li>
<p><strong>SPE produces chemically meaningful tokens.</strong> The learned vocabulary contains human-readable SMILES substrings that correspond to common substructures and functional groups, making model interpretations more accessible.</p>
</li>
<li>
<p><strong>SPE compresses input sequences by ~6-7x.</strong> Mean token sequence length drops from ~40 (atom-level) to ~6 (SPE) on ChEMBL, yielding a ~5x training speedup.</p>
</li>
<li>
<p><strong>SPE improves molecular generation diversity.</strong> The SPE-based generative model produces molecules with higher novelty (98.3% vs. 97.8%), internal diversity (0.897 vs. 0.886), and substructure coverage, at the cost of slightly lower validity (94.1% vs. 97.0%).</p>
</li>
<li>
<p><strong>SPE matches or outperforms atom-level and k-mer tokenization on QSAR prediction.</strong> Across 24 benchmarks, SPE showed comparable or better performance in 23/24 comparisons against atom-level and 22/24 against k-mer tokenization.</p>
</li>
</ol>
<p><strong>Limitations acknowledged by the authors:</strong></p>
<ul>
<li>The SPE vocabulary is trained on a specific dataset (ChEMBL25) and may not optimally represent chemical spaces that differ significantly from drug-like compounds.</li>
<li>The validity rate for molecular generation is slightly lower than atom-level tokenization (94.1% vs. 97.0%), since longer substructure tokens can introduce invalid fragments.</li>
<li>The k-mer tokenization suffers from an out-of-vocabulary problem, which the authors address by replacing unseen 4-mers with <code>[UNK]</code> tokens, but this is a limitation of the comparison rather than of SPE itself.</li>
</ul>
<p><strong>Future directions:</strong> The authors suggest SPE could serve as a general tokenization method for SMILES-based deep learning, applicable to any task where SMILES strings are used as input (<a href="/notes/chemistry/molecular-design/generation/">generation</a>, <a href="/notes/chemistry/molecular-design/property-prediction/">property prediction</a>, <a href="/notes/chemistry/molecular-design/reaction-prediction/">reaction prediction</a>, retrosynthesis). The algorithm can also be applied to DeepSMILES and SELFIES representations without modification.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SPE vocabulary training</td>
          <td>ChEMBL25</td>
          <td>~3.4M SMILES</td>
          <td>1 canonical + 1 non-canonical per molecule</td>
      </tr>
      <tr>
          <td>Language model training</td>
          <td>ChEMBL25 augmented</td>
          <td>~9M SMILES</td>
          <td>1 canonical + 5 non-canonical per molecule</td>
      </tr>
      <tr>
          <td>Molecular generation evaluation</td>
          <td>Sampled from model</td>
          <td>1M SMILES per model</td>
          <td>Validated with RDKit</td>
      </tr>
      <tr>
          <td>QSAR benchmarks</td>
          <td>Cortes-Ciriano et al.</td>
          <td>24 datasets, 199-5010 molecules</td>
          <td>pIC50 regression tasks</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>SPE vocabulary training: iterative pair merging with MVS=30,000 and FT=2,000</li>
<li>Language model: AWD-LSTM with embedding size 400, 3 LSTM layers with 1,152 hidden units</li>
<li>Dropout: embedding=0.1, input=0.6, weight=0.5, hidden=0.2</li>
<li>Training: 10 epochs, base learning rate 0.008, one-cycle policy</li>
<li>QSAR: MolPMoFiT transfer learning with 25x training augmentation and 15x validation augmentation</li>
<li>Test time augmentation: average of canonical + 4 augmented SMILES predictions</li>
<li>RF baseline: 500 trees, 1024-bit ECFP6, default scikit-learn parameters</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>AWD-LSTM architecture from Merity et al. (2018)</li>
<li>MolPMoFiT framework from Li and Fourches (2020) for transfer learning QSAR</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity, Uniqueness, Novelty</td>
          <td>Generation</td>
          <td>Basic quality metrics</td>
      </tr>
      <tr>
          <td>Internal diversity</td>
          <td>Generation</td>
          <td>1 - mean pairwise Tanimoto (ECFP6)</td>
      </tr>
      <tr>
          <td>Nearest neighbor similarity</td>
          <td>Generation</td>
          <td>Mean max Tanimoto to reference set</td>
      </tr>
      <tr>
          <td>Substructure coverage</td>
          <td>Generation</td>
          <td>BRICS, functional groups, scaffolds, ring systems</td>
      </tr>
      <tr>
          <td>RMSE, R-squared, MAE</td>
          <td>QSAR regression</td>
          <td>10 random 80:10:10 splits</td>
      </tr>
      <tr>
          <td>Cohen&rsquo;s d</td>
          <td>QSAR comparison</td>
          <td>Effect size between tokenization methods</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not explicitly specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/XinhaoLi74/SmilesPE">SmilesPE</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>SPE tokenization Python package</td>
      </tr>
      <tr>
          <td><a href="https://github.com/XinhaoLi74/MolPMoFiT">MolPMoFiT</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Transfer learning QSAR framework</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Li, X., &amp; Fourches, D. (2021). SMILES Pair Encoding: A Data-Driven Substructure Tokenization Algorithm for Deep Learning. <em>Journal of Chemical Information and Modeling</em>, 61(4), 1560-1569. <a href="https://doi.org/10.1021/acs.jcim.0c01127">https://doi.org/10.1021/acs.jcim.0c01127</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{li2021smiles,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{SMILES Pair Encoding: A Data-Driven Substructure Tokenization Algorithm for Deep Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Li, Xinhao and Fourches, Denis}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{61}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1560--1569}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.0c01127}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SMILES2Vec: Interpretable Chemical Property Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/smiles2vec-interpretable-property-prediction/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/smiles2vec-interpretable-property-prediction/</guid><description>SMILES2Vec uses a Bayesian-optimized CNN-GRU architecture to predict chemical properties directly from SMILES strings with an interpretable explanation mask.</description><content:encoded><![CDATA[<h2 id="a-general-purpose-rnn-for-chemical-property-prediction-from-smiles">A General-Purpose RNN for Chemical Property Prediction from SMILES</h2>
<p>SMILES2Vec is a <strong>Method</strong> paper that introduces a deep recurrent neural network architecture for predicting chemical properties directly from <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> text representations. The primary contributions are: (1) a Bayesian-optimized CNN-<a href="https://en.wikipedia.org/wiki/Gated_recurrent_unit">GRU</a> architecture that serves as a general-purpose predictor for diverse chemical properties (toxicity, activity, solubility, <a href="https://en.wikipedia.org/wiki/Solvation">solvation</a> energy), (2) an explanation mask mechanism that provides interpretable predictions by identifying which SMILES characters drive the network&rsquo;s decisions, and (3) evidence that representation learning from raw SMILES can match or outperform models using hand-crafted molecular descriptors.</p>
<h2 id="motivation-beyond-engineered-features-in-chemical-modeling">Motivation: Beyond Engineered Features in Chemical Modeling</h2>
<p>At the time of writing (2017), deep learning models in chemistry relied heavily on engineered <a href="https://en.wikipedia.org/wiki/Molecular_descriptor">molecular descriptors</a> and fingerprints as input features. Over 5,000 molecular descriptors had been developed since the late 1940s, and <a href="https://en.wikipedia.org/wiki/Quantitative_structure%E2%80%93activity_relationship">QSAR</a>/QSPR modeling remained the dominant paradigm. The authors identified two key limitations with this approach:</p>
<ol>
<li><strong>Restricted search space</strong>: Engineered features limit the neural network&rsquo;s ability to discover potentially useful representations that domain experts have not anticipated.</li>
<li><strong>Incomplete domain knowledge</strong>: For complex properties where first-principles understanding is incomplete, the lack of appropriate descriptors constrains model performance.</li>
</ol>
<p>In contrast, computer vision and NLP had shown that deep learning models trained on raw data (unaltered images, raw text) could learn powerful representations without feature engineering. The chemical SMILES notation, a text-based encoding of molecular structure that serves as the standard interchange format in cheminformatics, provided a natural analog to text data for NLP-style modeling.</p>
<p>A secondary motivation was interpretability. Most ML and DL models for chemistry operated as black boxes, which posed particular problems for regulated applications like FDA drug approval where mechanistic explanations are required.</p>
<h2 id="core-innovation-cnn-gru-architecture-with-explanation-masks">Core Innovation: CNN-GRU Architecture with Explanation Masks</h2>
<h3 id="architecture-design-via-bayesian-optimization">Architecture Design via <a href="https://en.wikipedia.org/wiki/Bayesian_optimization">Bayesian Optimization</a></h3>
<p>SMILES2Vec treats SMILES strings as character-level text input. The network processes one-hot encoded characters (padded to length 250, covering 99.9% of the <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a> database) through three stages:</p>
<ol>
<li><strong>Embedding layer</strong>: Maps one-hot character vectors to a learned embedding space (size 50)</li>
<li><strong>1D convolutional layer</strong>: 192 filters with kernel size 3, stride 1</li>
<li><strong>Bidirectional GRU layers</strong>: Two layers with 224 and 384 units respectively</li>
</ol>
<p>The authors explored four architectural classes (GRU, LSTM, CNN-GRU, CNN-LSTM) using Bayesian optimization via SigOpt. Each class was evaluated over 60 trials, optimizing embedding size, convolutional filter count, and RNN layer widths. The CNN-GRU class was selected as the best compromise: CNN-LSTM performed best on classification (Tox21), while GRU-based networks excelled at regression (FreeSolv). The final architecture is summarized by the hyperparameters:</p>
<table>
  <thead>
      <tr>
          <th>Component</th>
          <th>Parameter</th>
          <th>Value</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Embedding</td>
          <td>Size</td>
          <td>50</td>
      </tr>
      <tr>
          <td>Conv1D</td>
          <td>Filters</td>
          <td>192</td>
      </tr>
      <tr>
          <td>BiGRU Layer 1</td>
          <td>Units</td>
          <td>224</td>
      </tr>
      <tr>
          <td>BiGRU Layer 2</td>
          <td>Units</td>
          <td>384</td>
      </tr>
  </tbody>
</table>
<h3 id="explanation-mask-for-interpretability">Explanation Mask for Interpretability</h3>
<p>The explanation mask is a post-hoc interpretability mechanism. Given a trained (frozen) SMILES2Vec base model, a separate explanation network learns to produce a per-character mask over the input SMILES string. The mask is trained to preserve the base model&rsquo;s output while masking as much input as possible. The loss function for a single sample is:</p>
<p>$$
\text{Loss}_i = | f(\text{SMILES}_i, \theta) - \text{Sol}(\text{SMILES}_i) |_2 + 10^{-6} | \text{MASK}_i |_2 + 0.05 , H(\text{MASK}_i)
$$</p>
<p>where $f(\text{SMILES}_i, \theta)$ is the base network prediction, $\text{Sol}(\text{SMILES}_i)$ is the ground truth solubility, $H$ is the entropy of the normalized mask, and $\text{MASK}_i$ is the per-character mask vector. The L2 term encourages sparsity and the entropy term penalizes uniform attention distributions.</p>
<p>The explanation network itself is a 20-layer residual network with SELU activations, ending in a 1D convolution of length 1, batch normalization, and a softplus activation. The softplus output ranges from 0 (fully masked) to infinity (amplified attention), allowing the mask to both suppress and emphasize specific SMILES characters.</p>
<h2 id="experimental-setup-and-baseline-comparisons">Experimental Setup and Baseline Comparisons</h2>
<h3 id="datasets">Datasets</h3>
<p>The model was evaluated on four datasets from the <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> benchmark and the ESOL solubility dataset:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Property</th>
          <th>Task</th>
          <th>Size</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Tox21</td>
          <td>Toxicity</td>
          <td>Multi-task classification</td>
          <td>8,014</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>Activity</td>
          <td>Single-task classification</td>
          <td>41,193</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>Solvation energy</td>
          <td>Single-task regression</td>
          <td>643</td>
      </tr>
      <tr>
          <td>ESOL</td>
          <td>Solubility</td>
          <td>Single-task regression</td>
          <td>1,128</td>
      </tr>
  </tbody>
</table>
<p>SMILES strings longer than 250 characters were excluded. Classification datasets (Tox21, HIV) used 1/6 test split with minority class oversampling; regression datasets (FreeSolv, ESOL) used 1/10 test split. All experiments used 5-fold cross-validation.</p>
<h3 id="training-protocol">Training Protocol</h3>
<ul>
<li><strong>Optimizer</strong>: RMSprop with learning rate $10^{-3}$, $\rho = 0.9$, $\epsilon = 10^{-8}$</li>
<li><strong>Batch size</strong>: 32</li>
<li><strong>Epochs</strong>: 250 with early stopping (patience of 25 epochs based on validation loss)</li>
<li><strong>Classification loss</strong>: Binary cross-entropy</li>
<li><strong>Regression loss</strong>: Mean absolute error</li>
<li><strong>Metrics</strong>: AUC for classification, RMSE for regression</li>
</ul>
<h3 id="baselines">Baselines</h3>
<p>SMILES2Vec was compared against:</p>
<ul>
<li><strong>MLP with engineered features</strong>: Standard multi-layer perceptron using molecular fingerprints (from MoleculeNet)</li>
<li><strong>Molecular graph convolutions</strong>: Graph-based neural network from MoleculeNet</li>
<li><strong>Chemception</strong>: CNN operating on 2D chemical images</li>
</ul>
<h3 id="bayesian-optimization-protocol">Bayesian Optimization Protocol</h3>
<p>Only two datasets were used for architecture optimization: the nr-ahr toxicity task from Tox21 (classification) and FreeSolv (regression). The remaining datasets (full Tox21, HIV, ESOL) served purely for generalization evaluation. A fixed test set was held out during optimization, and correlation between validation and test metrics (0.54 for Tox21, 0.78 for FreeSolv) confirmed limited overfitting to the validation set.</p>
<h2 id="results-competitive-accuracy-with-interpretable-predictions">Results: Competitive Accuracy with Interpretable Predictions</h2>
<h3 id="property-prediction-performance">Property Prediction Performance</h3>
<p>SMILES2Vec achieved the following validation metrics (with a pre-training approach from ChemNet improving performance slightly):</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Metric</th>
          <th>SMILES2Vec</th>
          <th>SMILES2Vec + Pre-training</th>
          <th>Graph Conv</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Tox21</td>
          <td>AUC</td>
          <td>0.80</td>
          <td>0.81</td>
          <td>0.81</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>AUC</td>
          <td>0.78</td>
          <td>0.80</td>
          <td>0.80</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>RMSE (kcal/mol)</td>
          <td>1.4</td>
          <td>1.2</td>
          <td>1.3</td>
      </tr>
      <tr>
          <td>ESOL</td>
          <td>RMSE</td>
          <td>0.63</td>
          <td>-</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p>Exact numbers for MLP and Chemception baselines were reported only in a bar chart (Figure 6) and not as precise values. The paper states that MLP with fingerprints performed worst across all tasks, and Chemception fell between MLP and the graph/SMILES methods.</p>
<p>Key findings:</p>
<ul>
<li>SMILES2Vec outperformed MLP models using engineered features across all tasks, despite using no feature engineering.</li>
<li>Against graph convolutions (the state-of-the-art at the time), SMILES2Vec matched on classification (Tox21: 0.81 vs 0.81, HIV: 0.80 vs 0.80) and outperformed on regression (FreeSolv: 1.2 vs 1.3).</li>
<li>SMILES2Vec outperformed Chemception (2D image CNN) on classification tasks but slightly underperformed on regression, which the authors attributed to SMILES lacking explicit atomic number information.</li>
</ul>
<h3 id="interpretability-evaluation">Interpretability Evaluation</h3>
<p>On the ESOL solubility dataset, the explanation mask was evaluated against first-principles chemical knowledge. The authors separated compounds into soluble (&gt; 1.0) and insoluble (&lt; -5.0) categories and defined ground truth: soluble compounds should attend to hydrophilic atoms (O, N) while insoluble compounds should attend to hydrophobic atoms (C, F, Cl, Br, I). The top-3 character accuracy was 88%, confirming that SMILES2Vec learned representations consistent with known functional group chemistry.</p>
<p>Qualitative analysis of the masks showed that for low-solubility molecules, characters corresponding to hydrophobic groups (c, C, Cl) received high attention, while high-solubility molecules showed attention focused on hydrophilic groups (O, N).</p>
<h3 id="limitations">Limitations</h3>
<ul>
<li>The interpretability evaluation was limited to solubility, a well-understood property with simple first-principles rules. The authors acknowledged that quantifying interpretability for complex properties (toxicity, activity) where no simple ground truth exists is nontrivial.</li>
<li>The Bayesian optimization used only a subset of datasets, so the architecture may not be globally optimal across all chemical tasks.</li>
<li>SMILES strings lack explicit atomic number information, which may limit performance on physical property prediction compared to image or graph representations.</li>
<li>The explanation mask approach requires training a separate 20-layer network per property, adding computational overhead.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Architecture optimization</td>
          <td>Tox21 (nr-ahr task)</td>
          <td>8,014</td>
          <td>Single toxicity task for Bayesian optimization</td>
      </tr>
      <tr>
          <td>Architecture optimization</td>
          <td>FreeSolv</td>
          <td>643</td>
          <td>Solvation free energy regression</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Tox21 (full, 12 tasks)</td>
          <td>8,014</td>
          <td>Multi-task classification</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>HIV</td>
          <td>41,193</td>
          <td>Single-task classification</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>ESOL</td>
          <td>1,128</td>
          <td>Solubility regression, also used for interpretability</td>
      </tr>
  </tbody>
</table>
<p>All datasets are publicly available through MoleculeNet. The ESOL dataset is from Delaney (2004).</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Bayesian optimization via SigOpt (60 trials per architectural class, 4 classes, 6 manually seeded initial designs per class)</li>
<li>RMSprop optimizer with standard settings</li>
<li>Explanation mask trained with Adam, learning rate annealed from $10^{-2}$ to $10^{-6}$</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Final architecture: Embedding(50) -&gt; Conv1D(192, kernel=3, stride=1) -&gt; BiGRU(224) -&gt; BiGRU(384)</li>
<li>Explanation network: 20-layer residual network with SELU activations</li>
<li>No pre-trained weights or code were released</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Dataset</th>
          <th>Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>AUC</td>
          <td>Tox21</td>
          <td>0.81</td>
          <td>With pre-training</td>
      </tr>
      <tr>
          <td>AUC</td>
          <td>HIV</td>
          <td>0.80</td>
          <td>With pre-training</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>FreeSolv</td>
          <td>1.2 kcal/mol</td>
          <td>With pre-training</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>ESOL</td>
          <td>0.63</td>
          <td>Base model</td>
      </tr>
      <tr>
          <td>Top-3 accuracy</td>
          <td>ESOL interpretability</td>
          <td>88%</td>
          <td>Explanation mask</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The authors report using TensorFlow with GPU acceleration via NVIDIA cuDNN libraries. Specific GPU models and training times were not reported.</p>
<h3 id="artifacts">Artifacts</h3>
<p>No code, models, or data artifacts were released by the authors. The datasets used are publicly available through MoleculeNet.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Goh, G. B., Hodas, N. O., Siegel, C., &amp; Vishnu, A. (2017). SMILES2Vec: An Interpretable General-Purpose Deep Neural Network for Predicting Chemical Properties. <em>arXiv preprint arXiv:1712.02034</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{goh2017smiles2vec,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{SMILES2Vec: An Interpretable General-Purpose Deep Neural Network for Predicting Chemical Properties}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Goh, Garrett B. and Hodas, Nathan O. and Siegel, Charles and Vishnu, Abhinav}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:1712.02034}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2017}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.48550/arxiv.1712.02034}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SMILES Transformer: Low-Data Molecular Fingerprints</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/smiles-transformer/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/smiles-transformer/</guid><description>SMILES Transformer uses unsupervised Transformer pre-training on SMILES strings to produce molecular fingerprints that excel in low-data drug discovery tasks.</description><content:encoded><![CDATA[<h2 id="a-transformer-approach-to-learned-molecular-fingerprints">A Transformer Approach to Learned Molecular Fingerprints</h2>
<p>This is a <strong>Method</strong> paper that introduces SMILES Transformer (ST), a Transformer-based sequence-to-sequence model pre-trained on unlabeled SMILES strings to produce continuous, data-driven molecular fingerprints. The primary contribution is demonstrating that unsupervised pre-training on chemical text representations yields fingerprints that generalize well under low-data conditions, outperforming both rule-based fingerprints (ECFP) and graph convolution models on several <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> benchmarks. A secondary contribution is the Data Efficiency Metric (DEM), a scalar metric for evaluating model performance across varying training set sizes.</p>
<h2 id="the-low-data-problem-in-molecular-property-prediction">The Low-Data Problem in Molecular Property Prediction</h2>
<p>Machine learning for drug discovery depends on molecular representations, but labeled datasets of experimentally validated properties are typically small. Conventional approaches fall into two camps: rule-based fingerprints like ECFP that hash substructures into sparse binary vectors, and graph-based methods like GraphConv that learn representations end-to-end. Rule-based fingerprints perform poorly with shallow models or limited data, while graph-based methods are designed for large fully-labeled settings.</p>
<p>Pre-training on unlabeled data had shown strong results in NLP (ELMo, BERT, XLNet), and prior work in cheminformatics had explored RNN-based and VAE-based pre-training on SMILES (<a href="/notes/chemistry/molecular-representations/encoders/seq2seq-fingerprint-molecular-embedding/">Seq2Seq fingerprints</a>, <a href="/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/">Grammar VAE</a>, heteroencoders). However, none of these studies systematically evaluated performance in small-data settings. Honda et al. fill this gap by applying Transformer-based pre-training to SMILES and measuring data efficiency explicitly.</p>
<h2 id="transformer-pre-training-on-smiles-with-pooled-fingerprint-extraction">Transformer Pre-training on SMILES with Pooled Fingerprint Extraction</h2>
<p>The core innovation is a Transformer encoder-decoder architecture pre-trained as an autoencoder on SMILES strings, with a specific fingerprint extraction strategy that pools the encoder outputs into a fixed-length vector.</p>
<h3 id="architecture">Architecture</h3>
<p>The model uses 4 Transformer blocks for both the encoder and decoder, each with 4-head attention and 256 embedding dimensions plus 2 linear layers. Input SMILES are tokenized at the symbol level (e.g., &lsquo;c&rsquo;, &lsquo;Br&rsquo;, &lsquo;=&rsquo;, &lsquo;(&rsquo;, &lsquo;2&rsquo;) and one-hot encoded. Following Vaswani et al. (2017), the input uses the sum of token encoding and positional encoding.</p>
<h3 id="pre-training">Pre-training</h3>
<p>The model is pre-trained on 861,000 unlabeled SMILES sampled from <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL24</a> to minimize cross-entropy between input and output SMILES (i.e., reconstruction). <a href="/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/">SMILES enumeration</a> (Bjerrum, 2017) randomly generates non-canonical SMILES at each epoch to reduce representation bias. Training runs for 5 epochs with Adam optimization, reaching a perplexity of 1.0 (perfect decoding).</p>
<h3 id="fingerprint-extraction">Fingerprint Extraction</h3>
<p>Since the Transformer outputs symbol-level (atom-level) representations, a pooling strategy produces molecule-level fingerprints. Four vectors are concatenated:</p>
<ol>
<li>Mean-pooled output of the last encoder layer</li>
<li>Max-pooled output of the last encoder layer</li>
<li>First output token of the last encoder layer</li>
<li>First output token of the penultimate encoder layer</li>
</ol>
<p>This produces a 1024-dimensional fingerprint, matching the dimensionality of ECFP for fair comparison.</p>
<h3 id="data-efficiency-metric">Data Efficiency Metric</h3>
<p>The paper proposes DEM to measure how well a model performs across different training set sizes:</p>
<p>$$
M_{DE}(f, m) = \frac{1}{|I|} \sum_{i \in I} m(f_i, X_i, Y_i)
$$</p>
<p>where $f_i$ is the model trained on the fraction $i$ of training data, $m$ is the task metric, and $I = {0.0125, 0.025, 0.05, 0.1, 0.2, 0.4, 0.8}$ doubles the training percentage at each step. This captures average performance across a range of data availability, giving a single scalar that balances accuracy and data efficiency.</p>
<h2 id="benchmarking-across-moleculenet-with-data-efficiency-focus">Benchmarking Across MoleculeNet with Data Efficiency Focus</h2>
<h3 id="datasets">Datasets</h3>
<p>The evaluation uses 10 datasets from MoleculeNet spanning three categories:</p>
<table>
  <thead>
      <tr>
          <th>Category</th>
          <th>Dataset</th>
          <th>Tasks</th>
          <th>Type</th>
          <th>Molecules</th>
          <th>Metric</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Physical chemistry</td>
          <td>ESOL</td>
          <td>1</td>
          <td>Regression</td>
          <td>1,128</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td>Physical chemistry</td>
          <td>FreeSolv</td>
          <td>1</td>
          <td>Regression</td>
          <td>643</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td>Physical chemistry</td>
          <td><a href="https://en.wikipedia.org/wiki/Lipophilicity">Lipophilicity</a></td>
          <td>1</td>
          <td>Regression</td>
          <td>4,200</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td>Biophysics</td>
          <td>MUV</td>
          <td>17</td>
          <td>Classification</td>
          <td>93,127</td>
          <td>PRC-AUC</td>
      </tr>
      <tr>
          <td>Biophysics</td>
          <td>HIV</td>
          <td>1</td>
          <td>Classification</td>
          <td>41,913</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td>Biophysics</td>
          <td>BACE</td>
          <td>1</td>
          <td>Classification</td>
          <td>1,522</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td>Physiology</td>
          <td>BBBP</td>
          <td>1</td>
          <td>Classification</td>
          <td>2,053</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td>Physiology</td>
          <td>Tox21</td>
          <td>12</td>
          <td>Classification</td>
          <td>8,014</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td>Physiology</td>
          <td>SIDER</td>
          <td>27</td>
          <td>Classification</td>
          <td>1,427</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td>Physiology</td>
          <td>ClinTox</td>
          <td>2</td>
          <td>Classification</td>
          <td>1,491</td>
          <td>ROC-AUC</td>
      </tr>
  </tbody>
</table>
<h3 id="baselines">Baselines</h3>
<ul>
<li><strong>ECFP4</strong>: Rule-based extended-connectivity fingerprint with 1024 dimensions</li>
<li><strong>RNNS2S</strong>: RNN-based Seq2Seq pre-trained fingerprint (3-layer bidirectional GRU, same pre-training data as ST)</li>
<li><strong>GraphConv</strong>: Graph convolution network trained end-to-end on labeled data</li>
</ul>
<h3 id="experimental-setup">Experimental Setup</h3>
<p>All fingerprint methods use a simple MLP classifier/regressor from scikit-learn with default hyperparameters to isolate the fingerprint quality from model capacity. Datasets are randomly split (stratified for classification), and results are averaged over 20 trials. Note that random splits are used rather than scaffold splits for the DEM experiments.</p>
<h3 id="data-efficiency-results-dem">Data Efficiency Results (DEM)</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>ST+MLP</th>
          <th>ECFP+MLP</th>
          <th>RNNS2S+MLP</th>
          <th>GraphConv</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESOL (RMSE, lower is better)</td>
          <td><strong>1.144</strong></td>
          <td>1.741</td>
          <td>1.317</td>
          <td>1.673</td>
      </tr>
      <tr>
          <td>FreeSolv (RMSE, lower is better)</td>
          <td><strong>2.246</strong></td>
          <td>3.043</td>
          <td>2.987</td>
          <td>3.476</td>
      </tr>
      <tr>
          <td>Lipophilicity (RMSE, lower is better)</td>
          <td>1.169</td>
          <td><strong>1.090</strong></td>
          <td>1.219</td>
          <td><strong>1.062</strong></td>
      </tr>
      <tr>
          <td>MUV (PRC-AUC, higher is better)</td>
          <td>0.009</td>
          <td><strong>0.036</strong></td>
          <td>0.010</td>
          <td>0.004</td>
      </tr>
      <tr>
          <td>HIV (ROC-AUC, higher is better)</td>
          <td>0.683</td>
          <td>0.697</td>
          <td>0.682</td>
          <td><strong>0.723</strong></td>
      </tr>
      <tr>
          <td>BACE (ROC-AUC, higher is better)</td>
          <td>0.719</td>
          <td><strong>0.769</strong></td>
          <td>0.717</td>
          <td>0.744</td>
      </tr>
      <tr>
          <td>BBBP (ROC-AUC, higher is better)</td>
          <td><strong>0.900</strong></td>
          <td>0.760</td>
          <td>0.884</td>
          <td>0.795</td>
      </tr>
      <tr>
          <td>Tox21 (ROC-AUC, higher is better)</td>
          <td><strong>0.706</strong></td>
          <td>0.616</td>
          <td>0.702</td>
          <td>0.687</td>
      </tr>
      <tr>
          <td>SIDER (ROC-AUC, higher is better)</td>
          <td>0.559</td>
          <td><strong>0.588</strong></td>
          <td>0.558</td>
          <td>0.557</td>
      </tr>
      <tr>
          <td>ClinTox (ROC-AUC, higher is better)</td>
          <td><strong>0.963</strong></td>
          <td>0.515</td>
          <td>0.904</td>
          <td>0.936</td>
      </tr>
  </tbody>
</table>
<p>ST achieves the best DEM in 5 of 10 datasets (ESOL, FreeSolv, BBBP, Tox21, ClinTox), with particularly strong margins on ClinTox (+0.027 over GraphConv) and BBBP (+0.016 over RNNS2S).</p>
<h3 id="linear-model-experiments">Linear Model Experiments</h3>
<p>To further isolate fingerprint quality, the authors replace MLP with ridge/logistic regression with L2 penalty. On 8 datasets (excluding MUV and SIDER due to class imbalance issues), ST achieves best DEM in 5 of 8, confirming the fingerprint quality holds regardless of downstream model.</p>
<h3 id="stratified-analysis-by-molecule-size">Stratified Analysis by Molecule Size</h3>
<p>On BBBP stratified by SMILES length, ST&rsquo;s ROC-AUC increases with longer SMILES, similar to RNNS2S but unlike GraphConv which shows stable performance across lengths. This suggests text-based models extract richer information from longer sequences.</p>
<h3 id="comparison-with-record-scores-large-data">Comparison with Record Scores (Large Data)</h3>
<p>Under the large-data setting (80/10/10 train/val/test split with hyperparameter tuning via Optuna), ST achieves first place only in ClinTox (0.954) but performs comparably to ECFP and graph-based models on the other datasets. This confirms that ST&rsquo;s main advantage is in the low-data regime.</p>
<h2 id="strong-low-data-performance-with-caveats-on-scalability">Strong Low-Data Performance with Caveats on Scalability</h2>
<h3 id="key-findings">Key Findings</h3>
<ol>
<li>Transformer-based unsupervised pre-training on SMILES produces fingerprints that excel in low-data molecular property prediction, achieving best data efficiency on 5 of 10 MoleculeNet tasks.</li>
<li>The advantage is most pronounced on small datasets (ESOL with 1,128 molecules, FreeSolv with 643, BBBP with 2,053, ClinTox with 1,491) where pre-training enables good generalization.</li>
<li>With sufficient labeled data and hyperparameter tuning, ST fingerprints perform comparably to (but do not surpass) graph-based methods.</li>
<li>Longer SMILES provide richer information for text-based models, as shown by the stratified analysis on BBBP.</li>
</ol>
<h3 id="limitations">Limitations</h3>
<ul>
<li>Random splits are used for most DEM experiments rather than scaffold splits, which may inflate performance estimates for drug discovery applications where training and test molecules are structurally distinct.</li>
<li>The pre-training corpus (861K SMILES from ChEMBL24) is relatively small by modern standards.</li>
<li>MUV performance is poor across all methods (PRC-AUC near zero), suggesting the DEM framework may not be informative for extremely imbalanced or noisy datasets.</li>
<li>No comparison with BERT-style masked language model pre-training, which later work (<a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a>) would show as a viable alternative.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The authors propose three directions: (1) replacing the Transformer with Transformer-XL to handle longer SMILES, (2) multi-task pre-training that jointly predicts molecular descriptors (e.g., molecular weight, <a href="https://en.wikipedia.org/wiki/Partition_coefficient">LogP</a>) alongside SMILES reconstruction, and (3) better exploitation of enumerated SMILES to constrain the latent space.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ChEMBL24</td>
          <td>861,000 SMILES</td>
          <td>Unlabeled, randomly sampled</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>MoleculeNet (10 datasets)</td>
          <td>643 to 93,127 molecules</td>
          <td>See Table 1 for per-dataset details</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Transformer encoder-decoder: 4 blocks each, 4-head attention, 256 embedding dimensions</li>
<li>Pre-training: 5 epochs, Adam optimizer, cross-entropy loss, SMILES enumeration for augmentation</li>
<li>Fingerprint: 1024 dimensions from concatenated mean pool, max pool, and first-token outputs</li>
<li>Downstream: scikit-learn MLP (default hyperparameters) for DEM experiments; ridge/logistic regression for linear model experiments; Optuna for hyperparameter search in large-data comparison</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/DSPsleeporg/smiles-transformer">smiles-transformer</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation (Jupyter notebooks)</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>DEM averaged over 7 training fractions (1.25% to 80%), 20 trials each</li>
<li>Random splits for DEM; scaffold splits for HIV, BACE, BBBP in large-data comparison</li>
<li>Metrics: RMSE (regression), ROC-AUC or PRC-AUC (classification) per MoleculeNet conventions</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify GPU type or training time for the pre-training phase.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Honda, S., Shi, S., &amp; Ueda, H. R. (2019). SMILES Transformer: Pre-trained Molecular Fingerprint for Low Data Drug Discovery. <em>arXiv preprint arXiv:1911.04738</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{honda2019smiles,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{SMILES Transformer: Pre-trained Molecular Fingerprint for Low Data Drug Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Honda, Shion and Shi, Shoi and Ueda, Hiroki R.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:1911.04738}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2019}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SMI+AIS: Hybridizing SMILES with Environment Tokens</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/smi-ais-hybrid-molecular-representation/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/smi-ais-hybrid-molecular-representation/</guid><description>SMI+AIS hybridizes SMILES with Atom-In-SMILES tokens encoding local chemical environments, improving molecular generation binding affinity and synthesizability.</description><content:encoded><![CDATA[<h2 id="a-hybrid-molecular-representation-combining-smiles-and-chemical-environment-tokens">A Hybrid Molecular Representation Combining SMILES and Chemical-Environment Tokens</h2>
<p>This is a <strong>Method</strong> paper that introduces SMI+AIS(N), a hybrid molecular string representation combining standard <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> tokens with <a href="/notes/chemistry/molecular-representations/notations/atom-in-smiles-tokenization/">Atom-In-SMILES (AIS)</a> tokens. AIS tokens encode local chemical environment information (central atom, ring membership, and neighboring atoms) into a single token. The key contribution is a systematic hybridization strategy that selectively replaces the most frequent SMILES tokens with AIS equivalents, preserving SMILES grammar compatibility while enriching token diversity. The method is validated on molecular structure generation via latent space optimization for drug design.</p>
<h2 id="limitations-of-standard-smiles-for-machine-learning">Limitations of Standard SMILES for Machine Learning</h2>
<p>SMILES is the most widely adopted string-based molecular representation, used in major databases like ZINC and PubChem. Despite this ubiquity, SMILES has several well-known limitations for machine learning applications:</p>
<ol>
<li><strong>Non-unique representations</strong>: The same molecule can be encoded as multiple distinct SMILES strings.</li>
<li><strong>Invalid string generation</strong>: Generative models can produce syntactically invalid SMILES that do not correspond to any molecule.</li>
<li><strong>Limited token diversity</strong>: SMILES tokens map one-to-one to atoms or bonds, so the token vocabulary is restricted to the available atom and bond types.</li>
<li><strong>Insufficient chemical context</strong>: Individual SMILES tokens carry no information about the local chemical environment of an atom.</li>
</ol>
<p>Alternative representations like <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> (guaranteeing validity) and <a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a> (guaranteeing uniqueness) address some of these issues but share the same fundamental limitation of low token diversity. The Atom-In-SMILES (AIS) representation (Ucak et al., 2023) enriches tokens with neighboring atom and ring information, but using AIS exclusively produces a large vocabulary with many infrequent tokens that can cause data sparsity problems. The authors aim to find a middle ground: adding chemical context to the most common tokens while keeping the vocabulary manageable.</p>
<h2 id="core-innovation-selective-token-hybridization-with-ais">Core Innovation: Selective Token Hybridization with AIS</h2>
<p>The SMI+AIS(N) representation hybridizes standard SMILES with AIS tokens through a frequency-based selection process:</p>
<h3 id="ais-token-structure">AIS Token Structure</h3>
<p>Each AIS token encodes three pieces of information about an atom, delimited by semicolons:</p>
<p>$$
\lbrack \text{central atom} ; \text{ring info} ; \text{neighbor atoms} \rbrack
$$</p>
<p>For example, the oxygen in a carboxyl group of benzoic acid is represented as <code>[O;!R;C]</code>, meaning: oxygen atom, not in a ring, bonded to carbon. In standard SMILES, this would simply be <code>O</code>.</p>
<h3 id="hybridization-procedure">Hybridization Procedure</h3>
<ol>
<li>Convert all SMILES strings in the <a href="/notes/chemistry/datasets/zinc-22/">ZINC database</a> to their full AIS representations.</li>
<li>Count the frequency of each AIS token across the database.</li>
<li>Select the top-N most frequent AIS tokens to form the hybrid vocabulary.</li>
<li>In the hybrid representation, atoms matching these top-N AIS tokens are written in AIS notation; all other atoms use standard SMILES notation.</li>
</ol>
<p>For benzoic acid, the hybridization produces:</p>
<p>$$
\text{SMI}: \texttt{O=C(O)c1ccccc1}
$$</p>
<p>$$
\text{SMI+AIS}: \texttt{\lbrack O;!R;C\rbrack=\lbrack C;!R;COO\rbrack(\lbrack OH;!R;C\rbrack)c1ccccc1}
$$</p>
<p>The parameter N controls vocabulary size. The authors test N = 50, 100, 150, and 200, finding that N = 100-150 provides the best balance for the ZINC database.</p>
<h3 id="token-frequency-rebalancing">Token Frequency Rebalancing</h3>
<p>A key benefit of hybridization is mitigating the severe token frequency imbalance in standard SMILES. Carbon (C), the most frequent element with ~184 million occurrences in ZINC, is represented by only 16 token types in SMILES. With SMI+AIS(200), carbon is distinguished into 145 token types based on chemical environment, with 74% of carbon occurrences represented by AIS tokens. Less common elements like halogens see minimal change (only 2% AIS representation), which avoids introducing unnecessarily rare tokens.</p>
<table>
  <thead>
      <tr>
          <th>Element</th>
          <th>Frequency</th>
          <th>SMILES Types</th>
          <th>SMI+AIS(100) Types (AIS %)</th>
          <th>SMI+AIS(200) Types (AIS %)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>C</td>
          <td>183,860,954</td>
          <td>16</td>
          <td>78 (73%)</td>
          <td>145 (74%)</td>
      </tr>
      <tr>
          <td>O</td>
          <td>27,270,229</td>
          <td>8</td>
          <td>16 (11%)</td>
          <td>24 (11%)</td>
      </tr>
      <tr>
          <td>N</td>
          <td>26,022,928</td>
          <td>11</td>
          <td>32 (1%)</td>
          <td>46 (10%)</td>
      </tr>
      <tr>
          <td>X (halogens)</td>
          <td>6,137,030</td>
          <td>7</td>
          <td>10 (2%)</td>
          <td>11 (2%)</td>
      </tr>
      <tr>
          <td>S</td>
          <td>4,581,307</td>
          <td>12</td>
          <td>17 (2%)</td>
          <td>24 (2%)</td>
      </tr>
  </tbody>
</table>
<h2 id="latent-space-optimization-for-molecular-generation">Latent Space Optimization for Molecular Generation</h2>
<h3 id="model-architecture">Model Architecture</h3>
<p>The evaluation uses a <a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">conditional variational autoencoder (CVAE)</a> with:</p>
<ul>
<li><strong>Encoder</strong>: BERT-style architecture with entity and positional embeddings, 4 multi-head attention layers (8 heads each), producing mean and standard deviation vectors in latent space.</li>
<li><strong>Decoder</strong>: 4 stacked gated recurrent unit (GRU) layers that transform sampled latent vectors (conditioned) back into token sequences.</li>
<li>Training: 20 epochs on 9 million compounds from the ZINC database (8:1:1 train/valid/test split) under identical conditions for all representations.</li>
</ul>
<h3 id="optimization-setup">Optimization Setup</h3>
<p><a href="https://en.wikipedia.org/wiki/Bayesian_optimization">Bayesian Optimization</a> (BO) via BoTorch is applied to the CVAE <a href="/notes/chemistry/molecular-design/generation/latent-space/">latent space</a>, maximizing a multi-objective function:</p>
<p>$$
\text{Obj} = -\text{BA} - 0.5 \times \text{SA}^2
$$</p>
<p>where BA is binding affinity (docking score from QuickVina 2, lower is stronger) and SA is synthetic accessibility score (from RDKit, lower is more synthesizable). Each BO iteration generates 800 candidate latent vectors. Invalid strings receive a penalty objective value of -100.</p>
<h3 id="protein-targets">Protein Targets</h3>
<p>Four diverse targets were used to assess generalizability:</p>
<ul>
<li><strong>PDK4</strong> (<a href="https://en.wikipedia.org/wiki/Pyruvate_dehydrogenase_kinase">Pyruvate Dehydrogenase Kinase</a> 4): narrow, deep binding pocket</li>
<li><strong>5-HT1B</strong> (<a href="https://en.wikipedia.org/wiki/5-HT1B_receptor">Serotonin Receptor 1B</a>): shallow, open <a href="https://en.wikipedia.org/wiki/G_protein-coupled_receptor">GPCR</a> conformation</li>
<li><strong>PARP1</strong> (<a href="https://en.wikipedia.org/wiki/PARP1">Poly ADP-ribose Polymerase 1</a>): small, flexible molecule binding site</li>
<li><strong>CK1d</strong> (<a href="https://en.wikipedia.org/wiki/Casein_kinase_1">Casein Kinase I</a> Delta): broad, accessible conformation</li>
</ul>
<p>Protein structures were obtained from the <a href="https://en.wikipedia.org/wiki/Protein_Data_Bank">Protein Data Bank</a> (PDB IDs: 4V26, 4IAQ, 6I8M, 4TN6). Each optimization was run 10 times independently from the same 5 initial compounds selected from BindingDB.</p>
<h3 id="key-results">Key Results</h3>
<p>SMI+AIS(100) consistently achieved the highest objective values across protein targets.</p>
<p><strong>PDK4 Optimization</strong> (Top-1 results over 10 independent runs):</p>
<ul>
<li>SMI+AIS(100) achieved approximately 12% improvement over standard SMILES and 28% improvement over SELFIES based on median Top-1 objective values.</li>
<li>Generated structures exhibited BA scores between -10 and -9 and SA scores between 2.0 and 2.3.</li>
<li>Molecular weights clustered around 400 amu, consistent with the CVAE conditioning.</li>
</ul>
<p><strong>Validity Ratios</strong>: Standard SMILES produced approximately 40% valid structures. SMI+AIS representations showed significant improvement as N increased, though SMI+AIS(200) showed slight saturation, likely from insufficiently trained infrequent tokens.</p>
<p><strong>SELFIES</strong>: Despite achieving the highest validity ratio, SELFIES failed to generate chemically meaningful structures with desirable BA and SA scores. The authors attribute this to SELFIES grammar where token meaning is highly context-dependent, causing minor latent space variations to produce large structural changes.</p>
<p><strong>Cross-target consistency</strong>: Improvements were observed across all four protein targets, with slight variation (5-HT1B showed smaller differences between SMI and SMI+AIS(100) for Top-1, while other targets showed significant improvements).</p>
<h2 id="improved-molecular-generation-through-chemical-context-enrichment">Improved Molecular Generation Through Chemical Context Enrichment</h2>
<p>The SMI+AIS(N) representation achieves consistent improvements in molecular generation quality compared to both standard SMILES and SELFIES. The core findings are:</p>
<ol>
<li><strong>Binding affinity improvement</strong>: Approximately 7% improvement over standard SMILES for the PDK4 target.</li>
<li><strong>Synthesizability improvement</strong>: Approximately 6% increase in synthetic accessibility scores.</li>
<li><strong>Target independence</strong>: Performance gains transfer across four structurally diverse protein targets.</li>
<li><strong>Preserved structural motifs</strong>: The generative model retains chemically meaningful fragments (e.g., acetamide and <a href="https://en.wikipedia.org/wiki/Piperidine">piperidine</a>) from initial compounds without explicit fragment constraints.</li>
</ol>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge several limitations:</p>
<ul>
<li><strong>Stereochemistry</strong>: SMI+AIS inherits the limited stereochemistry handling of standard SMILES.</li>
<li><strong>Evaluation scope</strong>: Only molecular generation was tested; property prediction and other ML tasks remain unexplored.</li>
<li><strong>Compute constraints</strong>: The study was limited to molecular generation due to computing power and time.</li>
<li><strong>Single optimization strategy</strong>: Only latent space optimization with Bayesian optimization was evaluated; other generative approaches were not compared.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The authors suggest extending SMI+AIS to diverse benchmarking tests including molecular property prediction, experimental validation, and broader applications of chemical language models.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Vocab</td>
          <td>ZINC Database</td>
          <td>9M compounds</td>
          <td>Canonicalized, deduplicated, split 8:1:1</td>
      </tr>
      <tr>
          <td>Binding targets</td>
          <td>BindingDB</td>
          <td>5 initial compounds per target</td>
          <td>Selected for each protein target</td>
      </tr>
      <tr>
          <td>Protein structures</td>
          <td>PDB</td>
          <td>4 structures</td>
          <td>IDs: 4V26, 4IAQ, 6I8M, 4TN6</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Tokenization</strong>: AIS token frequency counting on full ZINC database, top-N selection</li>
<li><strong>Generative model</strong>: Conditional VAE with BERT encoder (4 layers, 8 heads) and GRU decoder (4 layers)</li>
<li><strong>Optimization</strong>: Bayesian Optimization via BoTorch (800 candidates per iteration)</li>
<li><strong>Docking</strong>: QuickVina 2 with 25 A pocket size, 10 docking simulations per ligand</li>
<li><strong>SA scoring</strong>: RDKit SA score</li>
<li>Training: 20 epochs for all representations under identical conditions</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>CVAE architecture details in supplementary (Fig. S9, Tables S2, S4)</li>
<li>No pre-trained weights released</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>SMI+AIS(100) vs SMILES</th>
          <th>SMI+AIS(100) vs SELFIES</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Median Top-1 Obj. Value</td>
          <td>+12%</td>
          <td>+28%</td>
          <td>PDK4 target</td>
      </tr>
      <tr>
          <td>Validity Ratio</td>
          <td>Higher than ~40% (SMILES)</td>
          <td>Lower than SELFIES</td>
          <td>SMI+AIS improves with N</td>
      </tr>
      <tr>
          <td>BA (binding affinity)</td>
          <td>~7% improvement</td>
          <td>Substantial</td>
          <td>Lower (more negative) is better</td>
      </tr>
      <tr>
          <td>SA (synthesizability)</td>
          <td>~6% improvement</td>
          <td>Substantial</td>
          <td>Lower is more synthesizable</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Hardware details are not specified in the main text. Optimization wall times are reported in supplementary Table S5.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/herim-han/AIS-Drug-Opt">AIS-Drug-Opt</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Source code and datasets for reproduction</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility Status</strong>: Partially Reproducible. Code and processed data are publicly available on GitHub, but no pre-trained model weights are released, the license is unspecified, and hardware requirements are not documented in the main text.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Han, H., Yeom, M. S., &amp; Choi, S. (2025). Hybridization of SMILES and chemical-environment-aware tokens to improve performance of molecular structure generation. <em>Scientific Reports</em>, 15, 16892. <a href="https://doi.org/10.1038/s41598-025-01890-7">https://doi.org/10.1038/s41598-025-01890-7</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{han2025hybridization,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Hybridization of SMILES and chemical-environment-aware tokens to improve performance of molecular structure generation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Han, Herim and Yeom, Min Sun and Choi, Sunghwan}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Scientific Reports}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{16892}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer Nature}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41598-025-01890-7}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Seq2seq Fingerprint: Unsupervised Molecular Embedding</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/seq2seq-fingerprint-molecular-embedding/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/seq2seq-fingerprint-molecular-embedding/</guid><description>Seq2seq fingerprint uses a GRU encoder-decoder trained on SMILES self-translation to produce unsupervised molecular embeddings for property prediction.</description><content:encoded><![CDATA[<h2 id="an-unsupervised-seq2seq-method-for-molecular-fingerprints">An Unsupervised Seq2seq Method for Molecular Fingerprints</h2>
<p>This is a <strong>Method</strong> paper that introduces seq2seq fingerprint, an unsupervised molecular embedding approach based on sequence-to-sequence learning. The core idea is to train a <a href="https://en.wikipedia.org/wiki/Gated_recurrent_unit">GRU</a> encoder-decoder network to translate <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings to themselves, then extract the intermediate fixed-length vector as a molecular fingerprint. These fingerprints are then used with standard supervised classifiers for downstream property prediction tasks such as solubility classification and promiscuity prediction.</p>
<h2 id="the-labeled-data-bottleneck-in-drug-discovery">The Labeled Data Bottleneck in Drug Discovery</h2>
<p>Machine learning approaches to molecular property prediction depend on fixed-length feature vectors as inputs. Traditional molecular fingerprints fall into two categories: hash-based methods like Extended-Connectivity Fingerprints (ECFP) that are fast but lossy and non-invertible, and biologist-guided local-feature fingerprints that require domain expertise and are task-specific. Supervised deep learning fingerprints (e.g., neural fingerprints) can learn representations from data but require large amounts of labeled data, which is expensive to obtain in drug discovery due to the cost of biological experiments.</p>
<p>The authors identify three limitations of existing approaches:</p>
<ol>
<li>Hash-based fingerprints discard information during the hashing process and cannot reconstruct the original molecule</li>
<li>Local-feature fingerprints require expert knowledge and generalize poorly across tasks</li>
<li>Supervised deep learning fingerprints are data-hungry and fail when labeled data is limited</li>
</ol>
<h2 id="self-translation-as-unsupervised-molecular-encoding">Self-Translation as Unsupervised Molecular Encoding</h2>
<p>The key insight is to adapt the <a href="https://en.wikipedia.org/wiki/Seq2seq">sequence-to-sequence</a> learning framework from machine translation (originally English-to-French) to molecular representation learning by setting both the input and output to the same SMILES string. Since the intermediate vector must contain enough information to reconstruct the original SMILES, it serves as a rich, task-agnostic molecular fingerprint.</p>
<p>The architecture consists of two components:</p>
<ul>
<li><strong>Perceiver network</strong>: A multi-layer GRU encoder that reads the SMILES string and compresses it into a fixed-length vector</li>
<li><strong>Interpreter network</strong>: A multi-layer GRU decoder that reconstructs the original SMILES from the fingerprint vector</li>
</ul>
<p>The GRU cell computes a sequence of outputs $(s_1, \ldots, s_T)$ from input sequences $(x_1, \ldots, x_T)$ by iterating:</p>
<p>$$
z_t = \sigma_g(W_z x_t + U_z s_{t-1} + b_z)
$$</p>
<p>$$
r_t = \sigma_r(W_r x_t + U_r s_{t-1} + b_r)
$$</p>
<p>$$
h_t = \tanh(U_h x_t + W_h(s_{t-1} \circ r_t))
$$</p>
<p>$$
s_t = (1 - z_t) \circ h_{t-1} + z_t \circ s_{t-1}
$$</p>
<p>where $z_t$ is the update gate, $r_t$ is the reset gate, $\circ$ denotes element-wise multiplication, and $W$, $U$, $b$ are trainable parameters.</p>
<p>Several adaptations to the original seq2seq framework make this work for molecular data:</p>
<ol>
<li><strong>GRU instead of LSTM</strong>: GRU provides comparable performance with faster training, which is important given the large training data pool</li>
<li><strong>Attention mechanism</strong>: Establishes a stronger connection between the perceiver and interpreter networks via soft alignment, addressing the challenge of passing information through hidden memory for long sequences (SMILES can be up to 250 characters)</li>
<li><strong>Dropout layers</strong>: Added to input and output gates (but not hidden memory transfer) following the approach of Zaremba et al. to combat overfitting when training on large datasets</li>
<li><strong>Fingerprint extraction layer</strong>: A fixed-unit fully connected layer combined with a GRU cell state concatenation layer is inserted between encoder and decoder to explicitly output the fingerprint vector</li>
<li><strong>Reverse target sequence</strong>: Following Sutskever et al., the target sequence is reversed to improve SGD optimization</li>
<li><strong>Bucket training</strong>: Sequences are distributed into buckets by length and padded to enable GPU parallelization</li>
</ol>
<h2 id="classification-experiments-on-logp-and-pm2-datasets">Classification Experiments on LogP and PM2 Datasets</h2>
<h3 id="training-setup">Training Setup</h3>
<p>The unsupervised training used 334,092 valid SMILES representations from combined LogP and PM2-full datasets obtained from the National Center for Advancing Translational Sciences (NCATS) at NIH. Three model variants were trained with fingerprint dimensions of 512, 768, and 1024, differing in the number of GRU layers (2, 3, and 4 respectively) while keeping the latent dimension at 256. Each model was trained for 24 hours on a workstation with an Intel i7-6700K CPU, 16 GB RAM, and an NVIDIA GTX 1080 GPU.</p>
<h3 id="reconstruction-performance">Reconstruction Performance</h3>
<p>The models were evaluated on their ability to reconstruct SMILES strings from their fingerprints:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>GRU Layers</th>
          <th>Latent Dim</th>
          <th>Perplexity</th>
          <th>Exact Match Accuracy</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>seq2seq-512</td>
          <td>2</td>
          <td>256</td>
          <td>1.00897</td>
          <td>94.24%</td>
      </tr>
      <tr>
          <td>seq2seq-768</td>
          <td>3</td>
          <td>256</td>
          <td>1.00949</td>
          <td>92.92%</td>
      </tr>
      <tr>
          <td>seq2seq-1024</td>
          <td>4</td>
          <td>256</td>
          <td>1.01472</td>
          <td>90.26%</td>
      </tr>
  </tbody>
</table>
<p>Deeper models showed lower reconstruction accuracy, possibly because larger fingerprint spaces introduce more null spaces and require longer training to converge.</p>
<h3 id="classification-results">Classification Results</h3>
<p>Two labeled datasets were used for downstream classification:</p>
<ul>
<li><strong>LogP</strong>: 10,850 samples with <a href="https://en.wikipedia.org/wiki/Partition_coefficient">water-octanol partition coefficient</a> values, binarized at a threshold of 1.88</li>
<li><strong>PM2-10k</strong>: 10,000 samples with binary promiscuity class labels</li>
</ul>
<p>The seq2seq fingerprints were evaluated with three ensemble classifiers (<a href="https://en.wikipedia.org/wiki/AdaBoost">AdaBoost</a>, <a href="https://en.wikipedia.org/wiki/Gradient_boosting">GradientBoost</a>, RandomForest) against circular fingerprints (ECFP) and neural fingerprints. Results are 100-run averages of 5-fold cross-validation accuracy.</p>
<p><strong>LogP classification accuracy:</strong></p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Mean Accuracy</th>
          <th>Std Dev</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Circular FP (ECFP)</td>
          <td>0.3674</td>
          <td>0.0074</td>
      </tr>
      <tr>
          <td>Neural FP</td>
          <td>0.6080</td>
          <td>0.0135</td>
      </tr>
      <tr>
          <td>Seq2seq-1024 + GradientBoost</td>
          <td><strong>0.7664</strong></td>
          <td>0.0043</td>
      </tr>
      <tr>
          <td>Seq2seq-1024 + AdaBoost</td>
          <td>0.7342</td>
          <td>0.0042</td>
      </tr>
      <tr>
          <td>Seq2seq-512 + GradientBoost</td>
          <td>0.7350</td>
          <td>0.0060</td>
      </tr>
  </tbody>
</table>
<p><strong>PM2-10k classification accuracy:</strong></p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Mean Accuracy</th>
          <th>Std Dev</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Circular FP (ECFP)</td>
          <td>0.3938</td>
          <td>0.0114</td>
      </tr>
      <tr>
          <td>Neural FP</td>
          <td>0.5227</td>
          <td>0.0112</td>
      </tr>
      <tr>
          <td>Seq2seq-1024 + GradientBoost</td>
          <td><strong>0.6206</strong></td>
          <td>0.0198</td>
      </tr>
      <tr>
          <td>Seq2seq-1024 + AdaBoost</td>
          <td>0.6036</td>
          <td>0.0147</td>
      </tr>
      <tr>
          <td>Seq2seq-512 + GradientBoost</td>
          <td>0.5741</td>
          <td>0.0086</td>
      </tr>
  </tbody>
</table>
<p>The seq2seq fingerprint outperformed both baselines across all configurations. Despite the seq2seq-1024 model having lower reconstruction accuracy, it provided the best classification performance, suggesting that the longer fingerprint captures more discriminative information for downstream tasks even if the reconstruction is less exact.</p>
<h2 id="unsupervised-transfer-learning-for-molecular-properties">Unsupervised Transfer Learning for Molecular Properties</h2>
<p>The results demonstrate that unsupervised pretraining on large unlabeled molecular datasets can produce fingerprints that transfer well to supervised property prediction with limited labels. The key advantages confirmed by the experiments are:</p>
<ol>
<li><strong>Label-free training</strong>: The unsupervised approach uses essentially unlimited SMILES data, avoiding the expensive label collection process</li>
<li><strong>Task-agnostic representations</strong>: The same fingerprints work across different classification tasks (solubility and promiscuity) without retraining</li>
<li><strong>Invertibility</strong>: The fingerprints contain enough information to reconstruct the original SMILES (up to 94.24% exact match), unlike hash-based methods</li>
</ol>
<p><strong>Limitations</strong> acknowledged by the authors include:</p>
<ul>
<li>Long training times (24 hours per model variant), motivating future work on distributed training</li>
<li>The relationship between fingerprint dimensionality and downstream performance is non-monotonic (768-dim underperforms 512-dim on some tasks), suggesting sensitivity to hyperparameter choices</li>
<li>Only classification tasks were evaluated; regression performance was not assessed</li>
<li>The comparison baselines are limited to ECFP and neural fingerprints from 2015</li>
</ul>
<p><strong>Future directions</strong> proposed include distributed training strategies, hyperparameter optimization methods, and semi-supervised extensions that incorporate label information into the fingerprint training.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Unsupervised training</td>
          <td>LogP + PM2-full (combined)</td>
          <td>334,092 SMILES</td>
          <td>Obtained from NCATS at NIH</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>LogP</td>
          <td>10,850 samples</td>
          <td>Binary labels at LogP threshold 1.88</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>PM2-10k</td>
          <td>10,000 samples</td>
          <td>Binary promiscuity labels</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Encoder-decoder: Multi-layer GRU with attention mechanism and dropout</li>
<li>Fingerprint dimensions: 512, 768, 1024 (with 2, 3, 4 GRU layers respectively)</li>
<li>Latent dimension: 256 for all variants</li>
<li>Downstream classifiers: AdaBoost, GradientBoost, RandomForest</li>
<li>Evaluation: 5-fold cross-validation, 100-run averages</li>
<li>Baselines: ECFP via RDKit, Neural Fingerprint from HIPS/neural-fingerprint</li>
</ul>
<h3 id="models">Models</h3>
<p>Three model variants trained for 24 hours each. The paper states code would become publicly available after acceptance, but no public repository has been confirmed.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best Value</th>
          <th>Task</th>
          <th>Configuration</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Classification accuracy</td>
          <td>0.7664</td>
          <td>LogP</td>
          <td>seq2seq-1024 + GradientBoost</td>
      </tr>
      <tr>
          <td>Classification accuracy</td>
          <td>0.6206</td>
          <td>PM2-10k</td>
          <td>seq2seq-1024 + GradientBoost</td>
      </tr>
      <tr>
          <td>Exact match reconstruction</td>
          <td>94.24%</td>
          <td>SMILES recovery</td>
          <td>seq2seq-512</td>
      </tr>
      <tr>
          <td>Perplexity</td>
          <td>1.00897</td>
          <td>SMILES recovery</td>
          <td>seq2seq-512</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Training: Intel i7-6700K @ 4.00 GHz, 16 GB RAM, NVIDIA GTX 1080 GPU</li>
<li>Hyperparameter search and classifier training: TACC Lonestar 5 cluster</li>
<li>Training time: 24 hours per model variant</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/HIPS/neural-fingerprint">Neural Fingerprint (baseline)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Baseline comparison code</td>
      </tr>
  </tbody>
</table>
<p>The authors indicated the seq2seq fingerprint code would be released after acceptance, but no public repository has been found as of this writing. The datasets were sourced from NCATS/NIH.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Xu, Z., Wang, S., Zhu, F., &amp; Huang, J. (2017). Seq2seq Fingerprint: An Unsupervised Deep Molecular Embedding for Drug Discovery. <em>Proceedings of the 8th ACM International Conference on Bioinformatics, Computational Biology, and Health Informatics (ACM-BCB &lsquo;17)</em>, 285-294. <a href="https://doi.org/10.1145/3107411.3107424">https://doi.org/10.1145/3107411.3107424</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{xu2017seq2seq,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Seq2seq Fingerprint: An Unsupervised Deep Molecular Embedding for Drug Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Xu, Zheng and Wang, Sheng and Zhu, Feiyun and Huang, Junzhou}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 8th ACM International Conference on Bioinformatics, Computational Biology, and Health Informatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{285--294}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2017}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{ACM}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1145/3107411.3107424}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>REINVENT 4: Open-Source Generative Molecule Design</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/reinvent4-generative-molecule-design/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/reinvent4-generative-molecule-design/</guid><description>REINVENT 4 is an open-source generative AI framework combining RNNs and transformers with reinforcement and curriculum learning for de novo molecular design.</description><content:encoded><![CDATA[<h2 id="an-open-source-reference-implementation-for-generative-molecular-design">An Open-Source Reference Implementation for Generative Molecular Design</h2>
<p>REINVENT 4 is a <strong>Resource</strong> paper presenting a production-grade, open-source software framework for AI-driven generative molecular design. The primary contribution is the unified codebase that integrates four distinct molecule generators (de novo, scaffold decoration, linker design, molecular optimization) within three machine learning optimization algorithms (transfer learning, reinforcement learning, <a href="/notes/chemistry/molecular-design/generation/rl-tuned/curriculum-learning-molecular-design/">curriculum learning</a>). The software is released under the Apache 2.0 license and represents the fourth major version of the REINVENT platform, which has been in continuous production use at AstraZeneca for drug discovery.</p>
<h2 id="bridging-the-gap-between-research-prototypes-and-production-molecular-design">Bridging the Gap Between Research Prototypes and Production Molecular Design</h2>
<p>The motivation for REINVENT 4 stems from several gaps in the generative molecular design landscape. While numerous AI model architectures have been developed for molecular generation (<a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">VAEs</a>, GANs, RNNs, transformers, flow models, diffusion models), most exist as research prototypes released alongside individual publications rather than as maintained, integrated software. The authors argue that the scientific community needs reference implementations of common generative molecular design algorithms in the public domain to:</p>
<ol>
<li>Enable nuanced debate about the application of AI in drug discovery</li>
<li>Serve as educational tools for practitioners entering the field</li>
<li>Increase transparency around AI-driven molecular design</li>
<li>Provide a foundation for future innovation</li>
</ol>
<p>REINVENT 4 consolidates previously separate codebases (<a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a> v1, v2, LibInvent, LinkInvent, Mol2Mol) into a single repository with a consistent interface, addressing the fragmentation that characterized earlier releases.</p>
<h2 id="unified-framework-for-sequence-based-molecular-generation">Unified Framework for Sequence-Based Molecular Generation</h2>
<p>The core design of REINVENT 4 centers on sequence-based neural network models that generate <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings in an autoregressive manner. All generators model the probability of producing a token sequence, with two formulations.</p>
<p>For unconditional agents (de novo generation), the joint probability of a sequence $T$ with tokens $t_1, t_2, \ldots, t_\ell$ is:</p>
<p>$$
\mathbf{P}(T) = \prod_{i=1}^{\ell} \mathbf{P}(t_i \mid t_{i-1}, t_{i-2}, \ldots, t_1)
$$</p>
<p>For conditional agents (scaffold decoration, linker design, molecular optimization), the joint probability given an input sequence $S$ is:</p>
<p>$$
\mathbf{P}(T \mid S) = \prod_{i=1}^{\ell} \mathbf{P}(t_i \mid t_{i-1}, t_{i-2}, \ldots, t_1, S)
$$</p>
<p>The negative log-likelihood for unconditional agents is:</p>
<p>$$
NLL(T) = -\log \mathbf{P}(T) = -\sum_{i=1}^{\ell} \log \mathbf{P}(t_i \mid t_{i-1}, t_{i-2}, \ldots, t_1)
$$</p>
<h3 id="reinforcement-learning-with-dap">Reinforcement Learning with DAP</h3>
<p>The key optimization mechanism is reinforcement learning via the &ldquo;Difference between Augmented and Posterior&rdquo; (DAP) strategy. For each generated sequence $T$, the augmented likelihood is defined as:</p>
<p>$$
\log \mathbf{P}_{\text{aug}}(T) = \log \mathbf{P}_{\text{prior}}(T) + \sigma \mathbf{S}(T)
$$</p>
<p>where $\mathbf{S}(T) \in [0, 1]$ is the scalar score and $\sigma \geq 0$ controls the balance between reward and regularization. The DAP loss is:</p>
<p>$$
\mathcal{L}(T) = \left(\log \mathbf{P}_{\text{aug}}(T) - \log \mathbf{P}_{\text{agent}}(T)\right)^2
$$</p>
<p>The presence of the prior likelihood in the augmented likelihood constrains how far the agent can deviate from chemically plausible space, functioning similarly to proximal policy gradient methods. The loss is lower-bounded by:</p>
<p>$$
\mathcal{L}(T) \geq \max\left(0, \log \mathbf{P}_{\text{prior}}(T) + \sigma \mathbf{S}(T)\right)^2
$$</p>
<h3 id="four-molecule-generators">Four Molecule Generators</h3>
<p>REINVENT 4 supports four generator types:</p>
<table>
  <thead>
      <tr>
          <th>Generator</th>
          <th>Architecture</th>
          <th>Input</th>
          <th>Task</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Reinvent</td>
          <td>RNN</td>
          <td>None</td>
          <td>De novo design from scratch</td>
      </tr>
      <tr>
          <td>LibInvent</td>
          <td>RNN</td>
          <td>Scaffold SMILES</td>
          <td>R-group replacement, library design</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/generation/rl-tuned/link-invent-generative-linker-design/">LinkInvent</a></td>
          <td>RNN</td>
          <td>Two warhead fragments</td>
          <td>Linker design, scaffold hopping</td>
      </tr>
      <tr>
          <td>Mol2Mol</td>
          <td>Transformer</td>
          <td>Input molecule</td>
          <td>Molecular optimization within similarity bounds</td>
      </tr>
  </tbody>
</table>
<p>All generators are fully integrated with all three optimization algorithms (TL, RL, CL). The Mol2Mol transformer was trained on over 200 billion molecular pairs from PubChem with <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> $\geq 0.50$, using ranking loss to directly link negative log-likelihood to molecular similarity.</p>
<h3 id="staged-learning-curriculum-learning">Staged Learning (Curriculum Learning)</h3>
<p>A key new feature is staged learning, which implements curriculum learning as multi-stage RL. Each stage can define a different scoring profile, allowing users to gradually phase in computationally expensive scoring functions. For example, cheap drug-likeness filters can run first, followed by docking in later stages. Stages terminate when a maximum score threshold is exceeded or a step limit is reached.</p>
<h3 id="scoring-subsystem">Scoring Subsystem</h3>
<p>The scoring subsystem implements a plugin architecture supporting over 25 scoring components, including:</p>
<ul>
<li>Physicochemical descriptors from RDKit (QED, SLogP, TPSA, molecular weight, etc.)</li>
<li>Molecular docking via DockStream (<a href="https://en.wikipedia.org/wiki/AutoDock">AutoDock Vina</a>, rDock, Hybrid, Glide, GOLD)</li>
<li>QSAR models via Qptuna and ChemProp (D-MPNN)</li>
<li>Shape similarity via ROCS</li>
<li>Synthesizability estimation via SA score</li>
<li>Matched molecular pairs via mmpdb</li>
<li>Generic REST and external process interfaces</li>
</ul>
<p>Scores are aggregated via weighted arithmetic or geometric mean. A transform system (sigmoid, step functions, value maps) normalizes individual component scores to $[0, 1]$.</p>
<h2 id="pdk1-inhibitor-case-study">PDK1 Inhibitor Case Study</h2>
<p>The paper demonstrates REINVENT 4 through a structure-based drug design exercise targeting <a href="https://en.wikipedia.org/wiki/PDPK1">Phosphoinositide-dependent kinase-1 (PDK1)</a> inhibitors. The experimental setup uses PDB crystal structure 2XCH with DockStream and Glide for docking, defining hits as molecules with docking score $\leq -8$ kcal/mol and QED $\geq 0.7$.</p>
<p><strong>Baseline RL from prior</strong>: 50 epochs of staged learning with batch size 128 produced 119 hits from 6,400 generated molecules (1.9% hit rate), spread across 103 generic Bemis-Murcko scaffolds.</p>
<p><strong>Transfer learning + RL</strong>: After 10 epochs of TL on 315 congeneric pyridinone PDK1 actives from PubChem Assay AID1798002, the same 50-epoch RL run produced 222 hits (3.5% hit rate) across 176 unique generic scaffolds, nearly doubling productivity.</p>
<p>Both approaches generated top-scoring molecules (docking score of -10.1 kcal/mol each) with plausible binding poses reproducing key protein-ligand interactions seen in the native crystal structure, including hinge interactions with ALA 162 and contacts with LYS 111.</p>
<p>The paper also demonstrates the agent&rsquo;s plasticity through a molecular weight switching experiment: after 500 epochs driving generation toward 1500 Da molecules, switching the reward to favor molecules $\leq 500$ Da resulted in rapid adaptation within ~50 epochs, showing that the RL agent can recover from extreme biases.</p>
<h2 id="practical-software-for-ai-driven-drug-discovery">Practical Software for AI-Driven Drug Discovery</h2>
<p>REINVENT 4 represents a mature, well-documented framework that consolidates years of incremental development into a single codebase. Key practical features include TOML/JSON configuration, TensorBoard visualization, multinomial sampling and beam search decoding, diversity filters for scaffold-level novelty, experience replay (inception), and a plugin mechanism for extending the scoring subsystem.</p>
<p>The authors acknowledge that this is one approach among many and that there is no single solution that uniformly outperforms others. REINVENT has demonstrated strong sample efficiency in benchmarks and produced realistic 3D docking poses, but the paper does not claim universal superiority. The focus is on providing a well-engineered, transparent reference implementation rather than advancing a novel algorithm.</p>
<p>Limitations include that only the Mol2Mol prior supports stereochemistry, the training data biases constrain the explorable chemical space, and the SMILES-based representation inherits the known fragility of string-based molecular encodings.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Prior training (Reinvent)</td>
          <td>ChEMBL 25</td>
          <td>~1.7M molecules</td>
          <td>Drug-like compounds</td>
      </tr>
      <tr>
          <td>Prior training (LibInvent)</td>
          <td>ChEMBL 27</td>
          <td>~1.9M molecules</td>
          <td>Scaffold-decoration pairs</td>
      </tr>
      <tr>
          <td>Prior training (LinkInvent)</td>
          <td>ChEMBL 27</td>
          <td>~1.9M molecules</td>
          <td>Fragment-linker pairs</td>
      </tr>
      <tr>
          <td>Prior training (Mol2Mol)</td>
          <td>ChEMBL 28 / PubChem</td>
          <td>~200B pairs</td>
          <td>Tanimoto similarity $\geq 0.50$</td>
      </tr>
      <tr>
          <td>Case study TL</td>
          <td>PubChem AID1798002</td>
          <td>315 compounds</td>
          <td>Congeneric PDK1 actives</td>
      </tr>
      <tr>
          <td>Case study docking</td>
          <td>PDB 2XCH</td>
          <td>1 structure</td>
          <td>PDK1 crystal structure</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Optimization</strong>: DAP (recommended), plus three deprecated alternatives (REINFORCE, A2C, MAULI)</li>
<li><strong>Decoding</strong>: Multinomial sampling (default, temperature $K = 1$) and beam search</li>
<li><strong>Diversity filter</strong>: Murcko scaffold, topological scaffold, scaffold similarity, same-SMILES penalty</li>
<li><strong>Experience replay</strong>: Inception memory with configurable size and sampling rate</li>
<li><strong>Gradient descent</strong>: Adam optimizer</li>
</ul>
<h3 id="models">Models</h3>
<p>All pre-trained priors are distributed with the repository. RNN-based generators (Reinvent, LibInvent, LinkInvent) and transformer-based generator (Mol2Mol) with multiple similarity-conditioned variants.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Condition</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Hit rate (RL)</td>
          <td>1.9%</td>
          <td>50 epochs, batch 128</td>
          <td>PDK1 case study</td>
      </tr>
      <tr>
          <td>Hit rate (TL+RL)</td>
          <td>3.5%</td>
          <td>10 TL + 50 RL epochs</td>
          <td>PDK1 case study</td>
      </tr>
      <tr>
          <td>Scaffold diversity (RL)</td>
          <td>103 scaffolds</td>
          <td>From 119 hits</td>
          <td>Generic Bemis-Murcko</td>
      </tr>
      <tr>
          <td>Scaffold diversity (TL+RL)</td>
          <td>176 scaffolds</td>
          <td>From 222 hits</td>
          <td>Generic Bemis-Murcko</td>
      </tr>
      <tr>
          <td>Best docking score</td>
          <td>-10.1 kcal/mol</td>
          <td>Both methods</td>
          <td>Glide SP</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify hardware requirements. REINVENT 4 supports both GPU and CPU execution. Python 3.10+ is required, with PyTorch 1.x (2.0 also compatible) and RDKit 2022.9+.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/MolecularAI/REINVENT4">REINVENT4</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Full framework with pre-trained priors</td>
      </tr>
      <tr>
          <td><a href="https://github.com/MolecularAI/DockStream">DockStream</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Docking wrapper for scoring</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Loeffler, H. H., He, J., Tibo, A., Janet, J. P., Voronov, A., Mervin, L. H., &amp; Engkvist, O. (2024). Reinvent 4: Modern AI-driven generative molecule design. <em>Journal of Cheminformatics</em>, 16, 20. <a href="https://doi.org/10.1186/s13321-024-00812-5">https://doi.org/10.1186/s13321-024-00812-5</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{loeffler2024reinvent,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Reinvent 4: Modern AI-driven generative molecule design}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Loeffler, Hannes H. and He, Jiazhen and Tibo, Alessandro and Janet, Jon Paul and Voronov, Alexey and Mervin, Lewis H. and Engkvist, Ola}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{16}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{20}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-024-00812-5}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>NLP Models That Automate Programming for Chemistry</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/nlp-models-transform-chemistry/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/nlp-models-transform-chemistry/</guid><description>A perspective on how code-generating LLMs like OpenAI Codex and GPT-3 will reshape computational chemistry research workflows and education.</description><content:encoded><![CDATA[<h2 id="a-perspective-on-code-generating-llms-for-chemistry">A Perspective on Code-Generating LLMs for Chemistry</h2>
<p>This is a <strong>Position</strong> paper that argues large language models (LLMs) capable of generating code from natural language prompts, specifically OpenAI&rsquo;s Codex and GPT-3, are poised to transform both chemistry research and chemistry education. Published in the inaugural volume of Digital Discovery (RSC), the paper combines a brief history of NLP developments with concrete demonstrations of code generation for computational chemistry tasks, then offers a forward-looking perspective on challenges and opportunities.</p>
<h2 id="bridging-the-gap-between-natural-language-and-scientific-software">Bridging the Gap Between Natural Language and Scientific Software</h2>
<p>The authors identify a core friction in modern computational chemistry: while the number of available software packages has grown dramatically, researchers spend a large fraction of their time learning interfaces to these packages rather than doing science. Tasks like searching documentation, following tutorials, and trial-and-error experimentation with APIs consume effort that could be directed at research itself.</p>
<p>At the same time, programming assignments in chemistry courses serve dual pedagogical purposes (reinforcing physical intuition and teaching marketable skills), but are constrained by students&rsquo; median programming experience. The emergence of code-generating NLP models opens the possibility of reducing both barriers simultaneously.</p>
<h2 id="code-generation-as-a-chemistry-interface">Code Generation as a Chemistry Interface</h2>
<p>The paper&rsquo;s core thesis is that NLP models trained on code can serve as a natural language interface to the entire ecosystem of scientific computing tools. The authors demonstrate this with several concrete examples using OpenAI Codex:</p>
<ol>
<li>
<p><strong>Quantum chemistry</strong>: Prompting Codex to &ldquo;compute the dissociation curve of H2 using pyscf&rdquo; produced correct, runnable code that selected <a href="https://en.wikipedia.org/wiki/Hartree%E2%80%93Fock_method">Hartree-Fock</a> with <a href="https://en.wikipedia.org/wiki/STO-nG_basis_sets">STO-3G</a>. A follow-up prompt requesting &ldquo;the most accurate method&rdquo; caused it to switch to <a href="https://en.wikipedia.org/wiki/Coupled_cluster">CCSD</a> in a large basis set.</p>
</li>
<li>
<p><strong>Chemical entity recognition</strong>: Using GPT-3 with only three training examples, the authors demonstrated extraction of chemical entity names from published text, a task that previously required thousands of labeled examples.</p>
</li>
<li>
<p><strong>Molecular visualization</strong>: Drawing caffeine from its <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> string, generating Gaussian input files from SMILES, implementing random walks, and downloading and analyzing <a href="https://en.wikipedia.org/wiki/Protein_Data_Bank">PDB structures</a> with MDTraj.</p>
</li>
<li>
<p><strong>Voice-controlled molecular dynamics</strong>: The authors previously built MARVIS, a voice-controlled <a href="/notes/chemistry/molecular-simulation/">molecular dynamics</a> analysis tool that uses GPT-3 to convert natural language into <a href="https://en.wikipedia.org/wiki/Visual_Molecular_Dynamics">VMD</a> commands. Only about a dozen examples were needed to teach GPT-3 to render proteins, change representations, and select atoms.</p>
</li>
</ol>
<p>An important caveat: the authors emphasize that all chemistry &ldquo;knowledge&rdquo; (including the SMILES string for caffeine) is entirely contained in the model&rsquo;s learned floating-point weights. The model has no access to databases or curated lists of chemical concepts.</p>
<h2 id="demonstrations-and-practical-evaluation">Demonstrations and Practical Evaluation</h2>
<p>Rather than a formal experimental evaluation with benchmarks and metrics, this perspective paper relies on qualitative demonstrations. The key examples, with full details provided in the ESI, include:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Input</th>
          <th>Result</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>H2 dissociation curve</td>
          <td>Natural language prompt</td>
          <td>Correct PySCF code (HF/STO-3G)</td>
      </tr>
      <tr>
          <td>Upgrade method accuracy</td>
          <td>Follow-up prompt</td>
          <td>Switched to CCSD with large basis</td>
      </tr>
      <tr>
          <td>Chemical NER</td>
          <td>3 examples + new text</td>
          <td>Extracted compound names (with some gaps)</td>
      </tr>
      <tr>
          <td>Molecule drawing</td>
          <td>&ldquo;Load caffeine from SMILES, draw it&rdquo;</td>
          <td>Correct RDKit rendering</td>
      </tr>
      <tr>
          <td>Gaussian input file</td>
          <td>Function with docstring</td>
          <td>Complete file writer with B3LYP/6-31G(d)</td>
      </tr>
      <tr>
          <td>PDB analysis</td>
          <td>Natural language description</td>
          <td>Downloaded structure and computed <a href="https://en.wikipedia.org/wiki/Radius_of_gyration">radius of gyration</a></td>
      </tr>
  </tbody>
</table>
<p>The authors note that Codex generates correct code at about a 30% rate on a single attempt for standard problems, improving to above 50% when multiple solutions are tried. Mistakes tend to occur when complex algorithms are requested with little specificity, and the code rarely has syntax errors but may fail in obvious ways (missing imports, wrong data types).</p>
<h2 id="challenges-access-correctness-and-bias">Challenges: Access, Correctness, and Bias</h2>
<p>The paper identifies three ongoing challenges:</p>
<p><strong>Access and price.</strong> Advanced models from OpenAI were, at the time of writing, limited to early testers. Per-query costs (1-3 cents for GPT-3) would become prohibitive at the scale needed for parsing academic literature or supporting medium-sized courses. The authors advocate for open-source models and equitable deployment by researchers with computational resources.</p>
<p><strong>Correctness.</strong> Code generation does not guarantee correctness. The authors raise a subtle point: Codex may produce code that executes successfully but does not follow best scientific practice for a particular computational task. Over-reliance on AI-generated code without verification could erode trust in scientific software. However, they argue that strategies for assessing code correctness apply equally to human-written and AI-generated code.</p>
<p><strong>Fairness and bias.</strong> The authors flag several concerns: AI-generated code trained on its own outputs could narrow the range of packages, methods, or programming languages used in chemistry. They observed Codex&rsquo;s preference for Python and for specific popular libraries (e.g., defaulting to <a href="https://en.wikipedia.org/wiki/PSI_(computational_chemistry)">Psi4</a> for single-point energy calculations). GPT-3 has also been shown to reflect racism, sexism, and other biases present in its training data.</p>
<h2 id="implications-for-research-and-education">Implications for Research and Education</h2>
<p>The authors conclude with an optimistic but measured outlook:</p>
<ul>
<li><strong>For research</strong>: NLP code generation will increase accessibility of software tools and expand what a single research group can accomplish. Better tools have historically not reduced the need for scientists but expanded the complexity of problems that can be tackled.</li>
<li><strong>For programming skills</strong>: Using Codex will make chemists better programmers, not worse. The process of crafting prompts, mentally checking outputs, testing on sample inputs, and iterating develops algorithmic thinking. The authors report discovering chemistry software libraries they would not have found otherwise through iterative prompt creation.</li>
<li><strong>For education</strong>: Instructors should rethink programming assignments. The authors suggest moving toward more difficult compound assignments, treating code exercises as laboratory explorations of scientific concepts rather than syntax drills, and aligning coursework with the tools students will have access to in their careers.</li>
<li><strong>For accessibility</strong>: NLP models can reduce barriers for non-native English speakers (though accuracy with non-English prompts was not fully explored) and for users who have difficulty with keyboard-and-mouse interfaces (via voice control).</li>
</ul>
<p>The paper acknowledges that these capabilities were, in early 2022, just beginning, with Codex being the first capable code-generation model. Already at the time of writing, models surpassing GPT-3 in language tasks had appeared, and models matching GPT-3 with 1/20th the parameters had been demonstrated.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>This is a perspective paper with qualitative demonstrations rather than a reproducible experimental study. The authors provide all prompts and multiple responses in the ESI.</p>
<h3 id="data">Data</h3>
<p>All prompts and code outputs are provided in the Electronic Supplementary Information (ESI) available from the RSC.</p>
<h3 id="algorithms">Algorithms</h3>
<p>The paper does not introduce new algorithms. It evaluates existing models (GPT-3, Codex) on chemistry-related code generation tasks.</p>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Provider</th>
          <th>Access</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GPT-3</td>
          <td>OpenAI</td>
          <td>API access (commercial)</td>
      </tr>
      <tr>
          <td>Codex</td>
          <td>OpenAI</td>
          <td>Early tester program (2021)</td>
      </tr>
      <tr>
          <td>GPT-Neo</td>
          <td>EleutherAI</td>
          <td>Open source</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<p>No formal metrics are reported for the chemistry demonstrations. The authors cite the Codex paper&rsquo;s reported ~30% pass rate on single attempts and &gt;50% with multiple attempts on standard programming problems.</p>
<h3 id="hardware">Hardware</h3>
<p>No hardware requirements are specified for the demonstrations (API-based inference).</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/whitead/marvis">MARVIS</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Voice-controlled MD analysis using GPT-3</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Hocky, G. M., &amp; White, A. D. (2022). Natural language processing models that automate programming will transform chemistry research and teaching. <em>Digital Discovery</em>, 1(2), 79-83. <a href="https://doi.org/10.1039/d1dd00009h">https://doi.org/10.1039/d1dd00009h</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{hocky2022natural,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Natural language processing models that automate programming will transform chemistry research and teaching}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Hocky, Glen M. and White, Andrew D.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Digital Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{79--83}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/d1dd00009h}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Neural Machine Translation of Chemical Nomenclature</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/nmt-chemical-nomenclature-en-zh/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/nmt-chemical-nomenclature-en-zh/</guid><description>Xu et al. apply CNN and LSTM seq2seq models to translate chemical nomenclature between English and Chinese, outperforming rule-based tools.</description><content:encoded><![CDATA[<h2 id="a-method-for-neural-translation-of-chemical-names">A Method for Neural Translation of Chemical Names</h2>
<p>This is a <strong>Method</strong> paper that introduces deep learning approaches for translating chemical nomenclature between English and Chinese. The primary contribution is demonstrating that character-level sequence-to-sequence neural networks (both CNN-based and LSTM-based) can serve as viable alternatives to hand-crafted rule-based translation systems for chemical names. The work compares two neural architectures against an existing rule-based tool on bilingual chemical name datasets.</p>
<h2 id="bridging-the-english-chinese-chemical-nomenclature-gap">Bridging the English-Chinese Chemical Nomenclature Gap</h2>
<p>English and Chinese are the two most widely used languages for chemical nomenclature worldwide. Translation between them is important for chemical data processing, especially for converting Chinese chemical names extracted via named entity recognition into English names that existing name-to-structure tools can parse. Rule-based translation between these languages faces considerable challenges:</p>
<ol>
<li>Chinese chemical names lack word boundaries (no spaces), making segmentation difficult.</li>
<li>Word order is often reversed between English and Chinese chemical names (e.g., &ldquo;ethyl acetate&rdquo; maps to characters meaning &ldquo;acetate-ethyl&rdquo; in Chinese).</li>
<li>The same English morpheme can map to different Chinese characters depending on chemical context (e.g., &ldquo;ethyl&rdquo; translates differently in &ldquo;ethyl acetate&rdquo; vs. &ldquo;ethyl alcohol&rdquo;).</li>
<li>Trivial names, especially for natural products, follow irregular translation patterns or are transliterations.</li>
</ol>
<p>Building comprehensive rule sets requires a formally trained chemist fluent in both languages, making rule-based approaches expensive and fragile.</p>
<h2 id="character-level-sequence-to-sequence-translation">Character-Level Sequence-to-Sequence Translation</h2>
<p>The core idea is to treat chemical name translation as a character-level machine translation task, applying encoder-decoder architectures with attention mechanisms. Two architectures are proposed:</p>
<p><strong>CNN-based architecture</strong>: Three 1D convolutional layers encode the input character sequence. A decoder with three 1D convolutional layers processes the target sequence offset by one timestep, combined with attention mechanism layers that connect encoder and decoder outputs. Two additional 1D convolutional layers produce the final decoded output sequence.</p>
<p><strong>LSTM-based architecture</strong>: An LSTM encoder converts the input sequence into two state vectors. An LSTM decoder is trained with teacher forcing, using the encoder&rsquo;s state vectors as its initial state, and generating the target sequence offset by one timestep.</p>
<p>Both models operate at the character level. Input chemical name strings are transformed into embedding vectors, with the vocabulary size equal to the number of unique characters in the respective language (100 unique characters for English names, 2,056 unique characters for Chinese names).</p>
<h2 id="experimental-setup-and-comparison-with-rule-based-tool">Experimental Setup and Comparison with Rule-Based Tool</h2>
<h3 id="datasets">Datasets</h3>
<p>The authors built two directional datasets from a manually curated corpus of scientific literature maintained at their institution:</p>
<ul>
<li><strong>En2Ch (English to Chinese)</strong>: 30,394 name pairs after deduplication</li>
<li><strong>Ch2En (Chinese to English)</strong>: 37,207 name pairs after deduplication</li>
</ul>
<p>The datasets cover systematic compound names through trivial names. For names with multiple valid translations, the most commonly used translation was selected. Each dataset was split 80/20 for training and validation.</p>
<h3 id="model-configuration">Model Configuration</h3>
<p>Both neural network models used the following hyperparameters:</p>
<ul>
<li>Batch size: 64</li>
<li>Epochs: 100</li>
<li>Latent dimensionality: 256 (encoding and decoding space)</li>
<li>Implementation: Python 3.7 with Keras 2.3 and TensorFlow backend</li>
</ul>
<h3 id="evaluation-metrics">Evaluation Metrics</h3>
<p>The models were evaluated on five metrics across both translation directions:</p>
<ul>
<li><strong>Success Rate</strong>: Percentage of inputs that produced any output</li>
<li><strong>String Matching Accuracy</strong>: Exact match with the single target name</li>
<li><strong>Data Matching Accuracy</strong>: Exact match allowing any valid translation from the corpus</li>
<li><strong>Manual Spot Check</strong>: Blind evaluation of 100 random samples per approach</li>
<li><strong>Running Time</strong>: Wall-clock time on the same hardware</li>
</ul>
<h3 id="baseline">Baseline</h3>
<p>The rule-based comparison system operates in three steps: disassemble the input name into word fragments, translate each fragment, and reassemble into the target language. This tool had been deployed as an online service with over one million uses at the time of publication.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="main-results">Main Results</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>CNN</th>
          <th>LSTM</th>
          <th>Rule-based</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Success Rate En2Ch</td>
          <td>100%</td>
          <td>100%</td>
          <td>75.97%</td>
      </tr>
      <tr>
          <td>Success Rate Ch2En</td>
          <td>100%</td>
          <td>100%</td>
          <td>59.90%</td>
      </tr>
      <tr>
          <td>String Match En2Ch</td>
          <td>82.92%</td>
          <td>89.64%</td>
          <td>39.81%</td>
      </tr>
      <tr>
          <td>String Match Ch2En</td>
          <td>78.11%</td>
          <td>55.44%</td>
          <td>43.77%</td>
      </tr>
      <tr>
          <td>Data Match En2Ch</td>
          <td>84.44%</td>
          <td>90.82%</td>
          <td>45.15%</td>
      </tr>
      <tr>
          <td>Data Match Ch2En</td>
          <td>80.22%</td>
          <td>57.40%</td>
          <td>44.91%</td>
      </tr>
      <tr>
          <td>Manual Check En2Ch</td>
          <td>90.00%</td>
          <td>89.00%</td>
          <td>80.00%</td>
      </tr>
      <tr>
          <td>Manual Check Ch2En</td>
          <td>82.00%</td>
          <td>61.00%</td>
          <td>78.00%</td>
      </tr>
      <tr>
          <td>Time En2Ch (s)</td>
          <td>1423</td>
          <td>190</td>
          <td>288</td>
      </tr>
      <tr>
          <td>Time Ch2En (s)</td>
          <td>1876</td>
          <td>303</td>
          <td>322</td>
      </tr>
  </tbody>
</table>
<p>Both neural approaches achieved 100% success rate (always producing output), while the rule-based tool failed on 24% and 40% of inputs for En2Ch and Ch2En respectively. The rule-based tool&rsquo;s failures were concentrated on Chinese names lacking word boundaries and on trivial names of natural products.</p>
<p>For English-to-Chinese translation, LSTM performed best at 89.64% string matching accuracy (90.82% data matching), followed by CNN at 82.92%. For Chinese-to-English, CNN substantially outperformed LSTM (78.11% vs. 55.44% string matching), suggesting that LSTM had difficulty with long-term dependencies in Chinese character sequences. The authors observed that many LSTM errors appeared at the ends of chemical names.</p>
<h3 id="analysis-by-name-type">Analysis by Name Type</h3>
<p>The CNN-based approach outperformed LSTM on CAS names (80% vs. 52% in manual checks) and was more robust for longer names. The rule-based tool showed consistent performance regardless of name length, suggesting it was more suited to regular systematic names but struggled with the diversity of real-world chemical nomenclature.</p>
<h3 id="limitations">Limitations</h3>
<ul>
<li>Performance depends heavily on training data quality and quantity.</li>
<li>Neither neural approach was validated on an external test set outside the institution&rsquo;s corpus.</li>
<li>The CNN model was considerably slower (5-6x) than the other two approaches.</li>
<li>No comparison against modern transformer-based NMT architectures (the study predates widespread adoption of transformers for this task).</li>
<li>The dataset is relatively small by modern NMT standards (30-37K pairs).</li>
<li>The authors noted that some neural translations were actually better than the target labels, suggesting the evaluation metrics understate true performance.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The authors suggest that combining CNN and LSTM architectures could yield further improvements, and that the approach has practical applications in scientific publishing (Chinese journals requiring English abstracts) and chemical database interoperability.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Validation (En2Ch)</td>
          <td>Curated bilingual corpus</td>
          <td>30,394 pairs</td>
          <td>80/20 split, from SIOC chemical data system</td>
      </tr>
      <tr>
          <td>Training/Validation (Ch2En)</td>
          <td>Curated bilingual corpus</td>
          <td>37,207 pairs</td>
          <td>80/20 split, from SIOC chemical data system</td>
      </tr>
      <tr>
          <td>Testing (En2Ch)</td>
          <td>Held-out validation split</td>
          <td>6,079 records</td>
          <td>Same source</td>
      </tr>
      <tr>
          <td>Testing (Ch2En)</td>
          <td>Held-out validation split</td>
          <td>7,441 records</td>
          <td>Same source</td>
      </tr>
  </tbody>
</table>
<p>Training data, Python code for both models, and result data are provided as supplementary files with the paper.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Character-level CNN encoder-decoder with attention (3+3+2 conv layers)</li>
<li>Character-level LSTM encoder-decoder with teacher forcing</li>
<li>Batch size: 64, epochs: 100, latent dim: 256</li>
</ul>
<h3 id="models">Models</h3>
<p>Both models implemented in Python 3.7 with Keras 2.3 / TensorFlow. No pre-trained weights are released separately, but the training code is provided as supplementary material.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best Value (En2Ch)</th>
          <th>Best Value (Ch2En)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Success Rate</td>
          <td>100% (both DL)</td>
          <td>100% (both DL)</td>
          <td>Rule-based: 75.97% / 59.90%</td>
      </tr>
      <tr>
          <td>String Matching</td>
          <td>89.64% (LSTM)</td>
          <td>78.11% (CNN)</td>
          <td>Best neural model per direction</td>
      </tr>
      <tr>
          <td>Data Matching</td>
          <td>90.82% (LSTM)</td>
          <td>80.22% (CNN)</td>
          <td>Allows multiple valid translations</td>
      </tr>
      <tr>
          <td>Manual Spot Check</td>
          <td>90.00% (CNN)</td>
          <td>82.00% (CNN)</td>
          <td>Blind evaluation of 100 samples</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper. Running times reported but hardware details not provided.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://doi.org/10.1186/s13321-020-00457-0">Supplementary files</a></td>
          <td>Code + Data</td>
          <td>CC-BY 4.0</td>
          <td>Training data, CNN/LSTM code, results (Additional files 1-6)</td>
      </tr>
      <tr>
          <td><a href="https://www.organchem.csdb.cn/translate">SIOC Translation Tool</a></td>
          <td>Other</td>
          <td>Not specified</td>
          <td>Rule-based baseline tool, online service</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Xu, T., Chen, W., Zhou, J., Dai, J., Li, Y., &amp; Zhao, Y. (2020). Neural machine translation of chemical nomenclature between English and Chinese. <em>Journal of Cheminformatics</em>, 12, 50. <a href="https://doi.org/10.1186/s13321-020-00457-0">https://doi.org/10.1186/s13321-020-00457-0</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{xu2020neural,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Neural machine translation of chemical nomenclature between English and Chinese}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Xu, Tingjun and Chen, Weiming and Zhou, Junhong and Dai, Jingfang and Li, Yingyong and Zhao, Yingli}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{50}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-020-00457-0}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolPMoFiT: Inductive Transfer Learning for QSAR</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/molpmofit-transfer-learning-qsar/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/molpmofit-transfer-learning-qsar/</guid><description>MolPMoFiT adapts ULMFiT for QSAR by pre-training an LSTM language model on 1M ChEMBL SMILES and fine-tuning on small molecular property datasets.</description><content:encoded><![CDATA[<h2 id="transfer-learning-meets-molecular-property-prediction">Transfer Learning Meets Molecular Property Prediction</h2>
<p>This is a <strong>Method</strong> paper that introduces MolPMoFiT (Molecular Prediction Model Fine-Tuning), a transfer learning approach for <a href="https://en.wikipedia.org/wiki/Quantitative_structure%E2%80%93activity_relationship">QSPR/QSAR</a> modeling. The primary contribution is adapting the ULMFiT framework from NLP to molecular property prediction by treating <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES strings</a> as a chemical language. A general-purpose molecular structure prediction model (MSPM) is pre-trained on one million <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a> molecules via self-supervised next-token prediction, then fine-tuned for specific QSAR endpoints. The approach achieves competitive or superior results to graph neural networks and descriptor-based methods across four benchmark datasets, with particular benefits for small datasets.</p>
<h2 id="the-small-data-problem-in-qsar-modeling">The Small Data Problem in QSAR Modeling</h2>
<p>Deep learning models for molecular property prediction typically require large labeled training sets to learn useful structural representations. While methods like graph convolutional neural networks and SMILES-based models have achieved strong results on well-studied endpoints, they must be trained from scratch for each new task. This presents a challenge for small chemical datasets with limited labeled data, which remain common in drug discovery for specialized endpoints like <a href="https://en.wikipedia.org/wiki/Allosteric_regulation">allosteric inhibition</a>, renal clearance, and inhibitor residence times.</p>
<p>Transfer learning had already shown transformative impact in computer vision (ImageNet pre-training) and NLP (ELMo, BERT, ULMFiT). In chemistry, prior transfer learning efforts included ChemNet (supervised pre-training on computed descriptors), <a href="/notes/chemistry/molecular-representations/encoders/mol2vec-unsupervised-chemical-intuition/">Mol2vec</a> (unsupervised substructure embeddings), and pre-trained graph neural networks. However, a systematic application of the ULMFiT self-supervised pre-training pipeline to SMILES-based molecular models had not been explored. MolPMoFiT fills this gap by treating the vast corpus of unlabeled molecular structures as a self-supervised training signal, analogous to how language models learn from unlabeled text.</p>
<h2 id="core-innovation-ulmfit-adapted-for-smiles">Core Innovation: ULMFiT Adapted for SMILES</h2>
<p>MolPMoFiT adapts ULMFiT&rsquo;s three-stage transfer learning pipeline to molecular property prediction:</p>
<p><strong>Stage 1: General-Domain MSPM Pre-training.</strong> A molecular structure prediction model is trained on one million curated ChEMBL molecules to predict the next token in a SMILES string. This is purely self-supervised: the SMILES string provides its own labels. The model learns general chemical syntax and structural patterns.</p>
<p><strong>Stage 2: Task-Specific MSPM Fine-tuning (Optional).</strong> The general MSPM is further fine-tuned on the unlabeled SMILES of the target task dataset. This adapts the language model to the specific chemical distribution of interest (e.g., HIV inhibitors vs. general bioactive molecules). Discriminative fine-tuning adjusts learning rates per layer:</p>
<p>$$\eta^{layer-1} = \eta^{layer} / 2.6$$</p>
<p>where higher layers (containing more task-specific features) receive higher learning rates.</p>
<p><strong>Stage 3: QSAR/QSPR Model Fine-tuning.</strong> The embedding and encoder weights from the pre-trained MSPM are transferred to a new model with a task-specific classifier head. Fine-tuning uses three key techniques from ULMFiT:</p>
<ul>
<li><strong>Discriminative fine-tuning</strong>: Different learning rates per layer group</li>
<li><strong>Gradual unfreezing</strong>: Layers are unfrozen sequentially (classifier first, then progressively deeper LSTM layers)</li>
<li><strong>One cycle policy</strong>: Learning rate scheduling following Smith&rsquo;s approach</li>
</ul>
<p>The model architecture is AWD-LSTM (ASGD Weight-Dropped LSTM) with an embedding dimension of 400, three LSTM layers with 1152 hidden units, and dropouts applied at every layer (embedding, input, weights, hidden). The QSAR classifier concatenates max pooling, mean pooling, and the last hidden state $h_T$ from the final LSTM layer, feeding this into two feedforward layers.</p>
<p><strong>SMILES Augmentation.</strong> Since multiple valid SMILES can represent the same molecule through different atom orderings, the authors use <a href="/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/">SMILES enumeration</a> as data augmentation. For regression tasks, Gaussian noise ($\sigma_{noise}$) is added to labels of augmented SMILES to simulate experimental error. Test-time augmentation (TTA) averages predictions across the canonical SMILES and four randomized SMILES.</p>
<h2 id="benchmarks-across-four-qsar-datasets">Benchmarks Across Four QSAR Datasets</h2>
<h3 id="datasets">Datasets</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Size</th>
          <th>Task</th>
          <th>Metric</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://en.wikipedia.org/wiki/Lipophilicity">Lipophilicity</a></td>
          <td>4,200</td>
          <td>Regression (logD)</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>642</td>
          <td>Regression (<a href="https://en.wikipedia.org/wiki/Solvation">solvation energy</a>)</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>41,127</td>
          <td>Classification (replication inhibition)</td>
          <td>AUROC</td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td>2,039</td>
          <td>Classification (<a href="https://en.wikipedia.org/wiki/Blood%E2%80%93brain_barrier">blood-brain barrier</a>)</td>
          <td>AUROC</td>
      </tr>
  </tbody>
</table>
<p>All datasets use the same 10 random 80:10:10 splits from <a href="/notes/chemistry/molecular-design/property-prediction/systematic-study-molecular-property-prediction/">Yang et al. (2019)</a> for fair comparison. Both random and scaffold splits were evaluated, with scaffold splits representing a more realistic test of generalization to novel chemical scaffolds.</p>
<h3 id="baselines">Baselines</h3>
<p>Models were compared against results reported by Yang et al. (2019): directed message passing neural network (D-MPNN), D-MPNN with RDKit features, random forest on Morgan fingerprints, feed-forward networks on Morgan fingerprints, and feed-forward networks on <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a> descriptors.</p>
<h3 id="hyperparameters">Hyperparameters</h3>
<p>The same set of fine-tuning hyperparameters was used across all four tasks (tuned on the HIV dataset):</p>
<table>
  <thead>
      <tr>
          <th>Layer Group</th>
          <th>Base Learning Rate</th>
          <th>Epochs</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Linear head only</td>
          <td>3e-2</td>
          <td>4</td>
      </tr>
      <tr>
          <td>+ Final LSTM layer</td>
          <td>5e-3</td>
          <td>4</td>
      </tr>
      <tr>
          <td>+ Final two LSTM layers</td>
          <td>5e-4</td>
          <td>4</td>
      </tr>
      <tr>
          <td>Full model</td>
          <td>5e-5</td>
          <td>6</td>
      </tr>
  </tbody>
</table>
<p>Data augmentation settings were task-specific: lipophilicity training SMILES augmented 25x ($\sigma_{noise} = 0.3$); FreeSolv augmented 50x ($\sigma_{noise} = 0.5$); HIV active class augmented 60x and inactive 2x; BBBP positive class 10x and negative 30x.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="benchmark-results">Benchmark Results</h3>
<p><strong>Lipophilicity (random split):</strong> MolPMoFiT achieved RMSE of $0.565 \pm 0.037$ with TTA and $0.625 \pm 0.032$ without, outperforming D-MPNN and other baselines.</p>
<p><strong>FreeSolv (random split):</strong> RMSE of $1.197 \pm 0.127$ with TTA. The small dataset size (642 compounds) led to high variance across splits.</p>
<p><strong>BBBP (random split):</strong> AUROC of $0.950 \pm 0.020$, outperforming all comparison models. Task-specific MSPM fine-tuning showed no clear benefit over the general MSPM.</p>
<p><strong>HIV (random split):</strong> General MolPMoFiT achieved AUROC of $0.828 \pm 0.029$ with TTA. Task-specific fine-tuning yielded a slightly higher $0.834 \pm 0.025$ with TTA.</p>
<p>Scaffold splits consistently produced lower performance than random splits across all datasets, as expected for out-of-distribution generalization.</p>
<h3 id="transfer-learning-impact">Transfer Learning Impact</h3>
<p>Across all four datasets and varying training set sizes, MolPMoFiT consistently outperformed models trained from scratch with the same architecture. The improvement was most pronounced at smaller training set sizes, confirming the utility of pre-trained representations for low-data regimes.</p>
<h3 id="smiles-augmentation-analysis">SMILES Augmentation Analysis</h3>
<p>Training data augmentation provided significant improvements across all tasks. For classification (HIV, BBBP), augmentation improved performance regardless of whether class re-balancing was applied. For regression (lipophilicity, FreeSolv), both SMILES augmentation and label noise were beneficial, with optimal noise levels varying by dataset.</p>
<h3 id="limitations">Limitations</h3>
<p>The authors note a fundamental limitation: the model learns mappings from individual SMILES strings to properties rather than from molecular structures to properties. SMILES augmentation acts as a regularization technique to mitigate this, making the model more robust to different SMILES representations of the same molecule. The task-specific MSPM fine-tuning stage did not consistently improve results, requiring further investigation. All hyperparameters were tuned on one dataset (HIV) and applied uniformly, which may not be optimal for all endpoints.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ChEMBL (curated)</td>
          <td>1M molecules</td>
          <td>Filtered: no mixtures, max 50 heavy atoms, standardized with MolVS, canonized with RDKit</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Lipophilicity</td>
          <td>4,200</td>
          <td><a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> benchmark</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>FreeSolv</td>
          <td>642</td>
          <td>MoleculeNet benchmark</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>HIV</td>
          <td>41,127</td>
          <td>MoleculeNet benchmark</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>BBBP</td>
          <td>2,039</td>
          <td>MoleculeNet benchmark</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>AWD-LSTM architecture with embedding dim 400, three LSTM layers (1152 hidden units), dropouts at all layers</li>
<li>ULMFiT fine-tuning: discriminative learning rates ($\eta^{layer-1} = \eta^{layer}/2.6$), gradual unfreezing, one cycle policy</li>
<li>SMILES character-level tokenization with special handling for two-character tokens (Cl, Br) and bracket-enclosed tokens</li>
<li>SMILES enumeration for data augmentation with optional Gaussian label noise for regression</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>General-domain MSPM pre-trained on 1M ChEMBL molecules (10 epochs)</li>
<li>Task-specific MSPMs fine-tuned per dataset (optional stage)</li>
<li>QSAR models fine-tuned with transferred embeddings and encoder</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Split</th>
          <th>Metric</th>
          <th>MolPMoFiT (TTA)</th>
          <th>Best Baseline</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Lipophilicity</td>
          <td>Random</td>
          <td>RMSE</td>
          <td>$0.565 \pm 0.037$</td>
          <td>D-MPNN</td>
      </tr>
      <tr>
          <td>Lipophilicity</td>
          <td>Scaffold</td>
          <td>RMSE</td>
          <td>$0.635 \pm 0.031$</td>
          <td>D-MPNN</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>Random</td>
          <td>RMSE</td>
          <td>$1.197 \pm 0.127$</td>
          <td>D-MPNN</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>Scaffold</td>
          <td>RMSE</td>
          <td>$2.082 \pm 0.460$</td>
          <td>D-MPNN</td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td>Random</td>
          <td>AUROC</td>
          <td>$0.950 \pm 0.020$</td>
          <td>D-MPNN</td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td>Scaffold</td>
          <td>AUROC</td>
          <td>$0.931 \pm 0.025$</td>
          <td>D-MPNN</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>Random</td>
          <td>AUROC</td>
          <td>$0.828 \pm 0.029$</td>
          <td>D-MPNN</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>Scaffold</td>
          <td>AUROC</td>
          <td>$0.816 \pm 0.022$</td>
          <td>D-MPNN</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>NVIDIA Quadro P4000 GPU (single GPU)</li>
<li>General-domain MSPM pre-training: approximately 1 day</li>
<li>Pre-training needs to be done only once; fine-tuning is fast per task</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/XinhaoLi74/MolPMoFiT">MolPMoFiT</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>PyTorch + fastai v1 implementation with curated datasets</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Li, X., &amp; Fourches, D. (2020). Inductive transfer learning for molecular activity prediction: Next-Gen QSAR Models with MolPMoFiT. <em>Journal of Cheminformatics</em>, 12, 27. <a href="https://doi.org/10.1186/s13321-020-00430-x">https://doi.org/10.1186/s13321-020-00430-x</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{li2020molpmofit,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Inductive transfer learning for molecular activity prediction: Next-Gen QSAR Models with MolPMoFiT}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Li, Xinhao and Fourches, Denis}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{27}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-020-00430-x}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolBERT: Auxiliary Tasks for Molecular BERT Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/</guid><description>MolBERT applies BERT to SMILES with domain-relevant auxiliary tasks like physicochemical property prediction, improving virtual screening and QSAR.</description><content:encoded><![CDATA[<h2 id="bert-based-molecular-representations-with-auxiliary-pre-training-tasks">BERT-Based Molecular Representations with Auxiliary Pre-Training Tasks</h2>
<p>This is a <strong>Method</strong> paper that introduces MolBERT, a bidirectional Transformer (BERT) architecture applied to SMILES-based molecular representations for drug discovery. The primary contribution is a systematic study of how different domain-relevant self-supervised pre-training tasks affect the quality of learned molecular embeddings, paired with a model that achieves state-of-the-art performance on <a href="https://en.wikipedia.org/wiki/Virtual_screening">virtual screening</a> and <a href="https://en.wikipedia.org/wiki/Quantitative_structure%E2%80%93activity_relationship">quantitative structure-activity relationship (QSAR)</a> benchmarks.</p>
<h2 id="why-domain-relevant-pre-training-matters-for-molecular-language-models">Why Domain-Relevant Pre-Training Matters for Molecular Language Models</h2>
<p>Molecular representations are foundational for predictive, generative, and analytical tasks in drug discovery. Language models applied to text-based molecular representations like SMILES have demonstrated strong performance across property prediction, reaction prediction, and molecular generation. However, several open questions remained at the time of this work:</p>
<ol>
<li><strong>Task selection for pre-training</strong>: Prior work explored masked token prediction, input translation, and property concatenation, but there was no systematic comparison of how different self-supervised tasks affect downstream performance.</li>
<li><strong>SMILES ambiguity</strong>: The same molecule can be encoded as many different SMILES strings depending on how the molecular graph is traversed. Canonicalization algorithms address this but introduce their own artifacts that may distract the model.</li>
<li><strong>Domain knowledge integration</strong>: Standard NLP pre-training objectives (e.g., masked language modeling) do not explicitly encode chemical knowledge. It was unclear whether incorporating chemistry-specific supervision during pre-training could improve representation quality.</li>
</ol>
<p>MolBERT addresses these gaps by evaluating three pre-training tasks, including a novel physicochemical property prediction objective, and measuring their individual and combined effects on downstream drug discovery benchmarks.</p>
<h2 id="three-auxiliary-tasks-for-chemistry-aware-pre-training">Three Auxiliary Tasks for Chemistry-Aware Pre-Training</h2>
<p>MolBERT uses the BERT-Base architecture (12 attention heads, 12 layers, 768-dimensional hidden states, approximately 85M parameters) and explores three self-supervised pre-training tasks:</p>
<p><strong>Masked Language Modeling (MaskedLM)</strong>: The standard BERT objective where 15% of input tokens are masked and the model predicts their identity. The loss is cross-entropy between predicted and true tokens.</p>
<p><strong>SMILES Equivalence (SMILES-Eq)</strong>: A binary classification task where the model receives two SMILES strings and predicts whether they represent the same molecule. The second string is either a random permutation of the first (same molecule, different traversal) or a randomly sampled molecule. This is optimized with cross-entropy loss.</p>
<p><strong>Physicochemical Property Prediction (PhysChemPred)</strong>: Using <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a>, a set of 200 real-valued molecular descriptors are computed for each molecule. The model predicts these normalized descriptors from the SMILES input using mean squared error:</p>
<p>$$\mathcal{L}_{\text{PhysChemPred}} = \frac{1}{D} \sum_{d=1}^{D} (y_d - \hat{y}_d)^2$$</p>
<p>where $D = 200$ is the number of descriptors, $y_d$ is the true normalized descriptor value, and $\hat{y}_d$ is the model&rsquo;s prediction.</p>
<p>The final training loss is the arithmetic mean of all active task losses:</p>
<p>$$\mathcal{L}_{\text{total}} = \frac{1}{|\mathcal{T}|} \sum_{t \in \mathcal{T}} \mathcal{L}_t$$</p>
<p>where $\mathcal{T}$ is the set of active pre-training tasks.</p>
<p>Additionally, MolBERT supports SMILES permutation augmentation during training, where each input molecule is represented by a randomly sampled non-canonical SMILES string rather than the canonical form. The model uses a fixed vocabulary of 42 tokens, a sequence length of 128, and relative positional embeddings (from Transformer-XL) to support arbitrary-length SMILES at inference time.</p>
<h2 id="ablation-study-and-benchmark-evaluation">Ablation Study and Benchmark Evaluation</h2>
<h3 id="pre-training-setup">Pre-Training Setup</h3>
<p>All models were pre-trained on the <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol benchmark dataset</a>, consisting of approximately 1.6M compounds curated from ChEMBL, using an 80%/5% train/validation split. Training used the Adam optimizer with a learning rate of $3 \times 10^{-5}$ for 20 epochs (ablation) or 100 epochs (final model).</p>
<h3 id="ablation-impact-of-task-combinations-on-virtual-screening">Ablation: Impact of Task Combinations on Virtual Screening</h3>
<p>The ablation study evaluated all seven possible task combinations on the RDKit virtual screening benchmark (69 datasets, 5 query molecules per target). Results measured by AUROC and BEDROC20 (an early enrichment metric with $\alpha = 20$):</p>
<table>
  <thead>
      <tr>
          <th style="text-align: center">MaskedLM</th>
          <th style="text-align: center">PhysChemPred</th>
          <th style="text-align: center">SMILES-Eq</th>
          <th style="text-align: center">AUROC (w/ perm)</th>
          <th style="text-align: center">BEDROC20 (w/ perm)</th>
          <th style="text-align: center">AUROC (w/o perm)</th>
          <th style="text-align: center">BEDROC20 (w/o perm)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">0.685 +/- 0.069</td>
          <td style="text-align: center">0.246 +/- 0.041</td>
          <td style="text-align: center">0.707 +/- 0.059</td>
          <td style="text-align: center">0.280 +/- 0.042</td>
      </tr>
      <tr>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">No</td>
          <td style="text-align: center">0.738 +/- 0.060</td>
          <td style="text-align: center">0.323 +/- 0.071</td>
          <td style="text-align: center">0.740 +/- 0.066</td>
          <td style="text-align: center">0.322 +/- 0.065</td>
      </tr>
      <tr>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">No</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">0.483 +/- 0.092</td>
          <td style="text-align: center">0.092 +/- 0.069</td>
          <td style="text-align: center">0.493 +/- 0.068</td>
          <td style="text-align: center">0.108 +/- 0.070</td>
      </tr>
      <tr>
          <td style="text-align: center">No</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">0.476 +/- 0.077</td>
          <td style="text-align: center">0.064 +/- 0.034</td>
          <td style="text-align: center">0.514 +/- 0.165</td>
          <td style="text-align: center">0.084 +/- 0.014</td>
      </tr>
      <tr>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">No</td>
          <td style="text-align: center">No</td>
          <td style="text-align: center">0.696 +/- 0.058</td>
          <td style="text-align: center">0.283 +/- 0.077</td>
          <td style="text-align: center">0.676 +/- 0.060</td>
          <td style="text-align: center">0.250 +/- 0.073</td>
      </tr>
      <tr>
          <td style="text-align: center">No</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">No</td>
          <td style="text-align: center">0.719 +/- 0.057</td>
          <td style="text-align: center">0.293 +/- 0.071</td>
          <td style="text-align: center">0.716 +/- 0.061</td>
          <td style="text-align: center">0.290 +/- 0.076</td>
      </tr>
      <tr>
          <td style="text-align: center">No</td>
          <td style="text-align: center">No</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">0.129 +/- 0.067</td>
          <td style="text-align: center">0.005 +/- 0.037</td>
          <td style="text-align: center">0.508 +/- 0.068</td>
          <td style="text-align: center">0.048 +/- 0.035</td>
      </tr>
  </tbody>
</table>
<p>Key findings from the ablation:</p>
<ul>
<li>PhysChemPred had the highest individual impact (average BEDROC20 of 0.292 alone vs. 0.266 for MaskedLM alone).</li>
<li>Combining MaskedLM + PhysChemPred achieved the best performance (BEDROC20 of 0.323), though the additive gain from MaskedLM was modest (+0.031).</li>
<li>The SMILES-Eq task consistently decreased performance when added to other task combinations.</li>
</ul>
<p>A further sub-ablation on PhysChemPred descriptor groups showed that surface descriptors alone (49 of 200 descriptors) achieved nearly the same performance as the full set, suggesting molecular surface properties provide particularly informative supervision.</p>
<h3 id="virtual-screening-results">Virtual Screening Results</h3>
<p>Using the best task combination (MaskedLM + PhysChemPred) trained for 100 epochs:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>AUROC</th>
          <th>BEDROC20</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MolBERT (100 epochs)</td>
          <td>0.743 +/- 0.062</td>
          <td>0.344 +/- 0.062</td>
      </tr>
      <tr>
          <td>CDDD</td>
          <td>0.725 +/- 0.057</td>
          <td>0.310 +/- 0.080</td>
      </tr>
      <tr>
          <td>RDKit descriptors</td>
          <td>0.633 +/- 0.027</td>
          <td>0.217 +/- 0.000</td>
      </tr>
      <tr>
          <td>ECFC4</td>
          <td>0.603 +/- 0.056</td>
          <td>0.170 +/- 0.079</td>
      </tr>
  </tbody>
</table>
<p>MolBERT outperformed all baselines including <a href="/notes/chemistry/molecular-representations/encoders/cddd-translation-molecular-descriptors/">CDDD</a> (the prior state of the art), RDKit calculated descriptors, and extended-connectivity fingerprints (ECFC4).</p>
<h3 id="qsar-results">QSAR Results</h3>
<p>On <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> regression tasks (RMSE, lower is better):</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th style="text-align: center">RDKit (norm)</th>
          <th style="text-align: center">ECFC4</th>
          <th style="text-align: center">CDDD</th>
          <th style="text-align: center">MolBERT</th>
          <th style="text-align: center">MolBERT (finetune)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESOL</td>
          <td style="text-align: center">0.687 +/- 0.08</td>
          <td style="text-align: center">0.902 +/- 0.06</td>
          <td style="text-align: center">0.567 +/- 0.06</td>
          <td style="text-align: center">0.552 +/- 0.07</td>
          <td style="text-align: center"><strong>0.531 +/- 0.04</strong></td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td style="text-align: center">1.671 +/- 0.45</td>
          <td style="text-align: center">2.876 +/- 0.38</td>
          <td style="text-align: center">1.456 +/- 0.43</td>
          <td style="text-align: center">1.523 +/- 0.66</td>
          <td style="text-align: center"><strong>0.948 +/- 0.33</strong></td>
      </tr>
      <tr>
          <td>Lipophilicity</td>
          <td style="text-align: center">0.738 +/- 0.04</td>
          <td style="text-align: center">0.770 +/- 0.03</td>
          <td style="text-align: center">0.669 +/- 0.02</td>
          <td style="text-align: center">0.602 +/- 0.01</td>
          <td style="text-align: center"><strong>0.561 +/- 0.03</strong></td>
      </tr>
  </tbody>
</table>
<p>On MoleculeNet classification tasks (AUROC, higher is better):</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th style="text-align: center">RDKit (norm)</th>
          <th style="text-align: center">ECFC4</th>
          <th style="text-align: center">CDDD</th>
          <th style="text-align: center">MolBERT</th>
          <th style="text-align: center">MolBERT (finetune)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>BACE</td>
          <td style="text-align: center">0.831</td>
          <td style="text-align: center">0.845</td>
          <td style="text-align: center">0.833</td>
          <td style="text-align: center">0.849</td>
          <td style="text-align: center"><strong>0.866</strong></td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td style="text-align: center">0.696</td>
          <td style="text-align: center">0.678</td>
          <td style="text-align: center">0.761</td>
          <td style="text-align: center">0.750</td>
          <td style="text-align: center"><strong>0.762</strong></td>
      </tr>
      <tr>
          <td>HIV</td>
          <td style="text-align: center">0.708</td>
          <td style="text-align: center">0.714</td>
          <td style="text-align: center">0.753</td>
          <td style="text-align: center">0.747</td>
          <td style="text-align: center"><strong>0.783</strong></td>
      </tr>
  </tbody>
</table>
<p>Fine-tuned MolBERT achieved the best performance on all six QSAR datasets. When used as a fixed feature extractor with an SVM, MolBERT embeddings outperformed other representations on three of six tasks.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="key-findings">Key Findings</h3>
<ol>
<li><strong>Pre-training task selection matters significantly.</strong> The choice of auxiliary tasks during pre-training has a large effect on downstream performance. PhysChemPred provides the strongest individual signal.</li>
<li><strong>Domain-relevant auxiliary tasks improve representation quality.</strong> Predicting physicochemical properties during pre-training encodes chemical knowledge directly into the embeddings, outperforming purely linguistic objectives.</li>
<li><strong>The SMILES equivalence task hurts performance.</strong> Despite being chemically motivated, the SMILES-Eq task consistently degraded results, suggesting it may introduce conflicting learning signals.</li>
<li><strong>PhysChemPred organizes the embedding space.</strong> Analysis of pairwise cosine similarities showed that models trained with PhysChemPred assign high similarity to permutations of the same molecule and low similarity to different molecules, creating a more semantically meaningful representation space.</li>
</ol>
<h3 id="limitations">Limitations</h3>
<ul>
<li>The paper evaluates only SMILES-based representations, inheriting all limitations of string-based molecular encodings (inability to capture 3D structure, sensitivity to tokenization).</li>
<li>The virtual screening evaluation uses a fixed number of query molecules ($n = 5$), which may not reflect realistic screening scenarios.</li>
<li>Cross-validation splits from ChemBench were used for QSAR evaluation rather than scaffold splits, which may overestimate performance on structurally novel compounds.</li>
<li>The model&rsquo;s 128-token sequence length limit may truncate larger molecules, though relative positional embeddings partially address this at inference time.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The authors propose extending MolBERT to learn representations for other biological entities such as proteins, and developing more advanced pre-training strategies.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>GuacaMol (ChEMBL)</td>
          <td>~1.6M compounds</td>
          <td>80% train / 5% validation split</td>
      </tr>
      <tr>
          <td>Virtual Screening</td>
          <td>RDKit benchmark v1.2</td>
          <td>69 target datasets</td>
          <td>Filtered subset with active/decoy compounds</td>
      </tr>
      <tr>
          <td>QSAR (Regression)</td>
          <td>ESOL, FreeSolv, Lipophilicity</td>
          <td>Varies</td>
          <td>From MoleculeNet, ChemBench splits</td>
      </tr>
      <tr>
          <td>QSAR (Classification)</td>
          <td>BACE, BBBP, HIV</td>
          <td>Varies</td>
          <td>From MoleculeNet, ChemBench splits</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Architecture: BERT-Base (12 heads, 12 layers, 768-dim hidden, ~85M params)</li>
<li>Optimizer: Adam, learning rate $3 \times 10^{-5}$</li>
<li>Vocabulary: 42 tokens, sequence length 128</li>
<li>Masking: 15% of tokenized input</li>
<li>Positional encoding: relative positional embeddings (Transformer-XL)</li>
<li>Fine-tuning SVM: $C = 5.0$, RBF kernel (from Winter et al.)</li>
<li>Fine-tuning head: single linear layer on pooled output</li>
<li>Embeddings: pooled output (or average sequence output when only MaskedLM is used)</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>BERT-Base with ~85M parameters</li>
<li>Pre-trained weights available at <a href="https://github.com/BenevolentAI/MolBERT">BenevolentAI/MolBERT</a></li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>AUROC</td>
          <td>Virtual Screening, Classification QSAR</td>
          <td>Standard area under ROC curve</td>
      </tr>
      <tr>
          <td>BEDROC20</td>
          <td>Virtual Screening</td>
          <td>Early enrichment metric, $\alpha = 20$</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>Regression QSAR</td>
          <td>Root mean squared error</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>2 GPUs, 16 CPUs</li>
<li>Pre-training time: ~40 hours (20 epochs)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/BenevolentAI/MolBERT">BenevolentAI/MolBERT</a></td>
          <td>Code + Model</td>
          <td>MIT</td>
          <td>Official implementation with pre-trained weights</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Fabian, B., Edlich, T., Gaspar, H., Segler, M., Meyers, J., Fiscato, M., &amp; Ahmed, M. (2020). Molecular representation learning with language models and domain-relevant auxiliary tasks. <em>arXiv preprint arXiv:2011.13230</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{fabian2020molecular,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Molecular representation learning with language models and domain-relevant auxiliary tasks}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Fabian, Benedek and Edlich, Thomas and Gaspar, H{\&#39;e}l{\&#39;e}na and Segler, Marwin and Meyers, Joshua and Fiscato, Marco and Ahmed, Mohamed}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2011.13230}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MaCBench: Multimodal Chemistry and Materials Benchmark</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/macbench-multimodal-chemistry-benchmark/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/macbench-multimodal-chemistry-benchmark/</guid><description>MaCBench benchmarks vision language models on chemistry and materials science tasks, revealing failures in spatial reasoning and cross-modal integration.</description><content:encoded><![CDATA[<h2 id="a-benchmark-for-multimodal-scientific-reasoning">A Benchmark for Multimodal Scientific Reasoning</h2>
<p>MaCBench is a <strong>Resource</strong> contribution that provides a comprehensive benchmark for evaluating vision language models (VLLMs) on real-world chemistry and materials science tasks. Rather than testing general-purpose visual reasoning or text-only scientific knowledge, MaCBench specifically targets the interplay between visual and textual modalities across the scientific workflow. The benchmark contains 779 multiple-choice questions and 374 numeric-answer questions organized into 11 topics across three pillars: data extraction, experimental execution, and data interpretation. Through systematic ablation studies, the authors identify fundamental limitations in spatial reasoning, cross-modal synthesis, and multi-step inference that current VLLMs exhibit.</p>
<h2 id="why-multimodal-evaluation-matters-for-chemistry">Why Multimodal Evaluation Matters for Chemistry</h2>
<p>Scientific research inherently requires integrating multiple information modalities: reading plots, interpreting spectra, evaluating laboratory setups, and connecting visual observations with domain knowledge. While text-only benchmarks like <a href="/notes/chemistry/llm-applications/chembench-llm-chemistry-evaluation/">ChemBench</a> have evaluated LLM capabilities in chemistry, and general multimodal benchmarks have tested visual reasoning, no prior work had systematically assessed how VLLMs handle the specific multimodal demands of the chemistry and materials science workflow.</p>
<p>Existing evaluations treated either the scientific reasoning dimension or the multimodal dimension in isolation. This left a critical gap: can VLLMs reliably assist with tasks that require both visual perception and scientific reasoning simultaneously? For example, identifying laboratory equipment is a perception task, but evaluating whether a laboratory setup is safe requires integrating visual understanding with domain-specific knowledge about hazards.</p>
<p>The authors designed MaCBench to fill this gap by constructing tasks that mirror actual scientific workflows and by including ablation studies that isolate specific failure modes.</p>
<h2 id="benchmark-design-three-pillars-of-scientific-work">Benchmark Design: Three Pillars of Scientific Work</h2>
<p>The benchmark is structured around three pillars reflecting the scientific process:</p>
<p><strong>Data Extraction</strong> covers parsing scientific literature, including extracting values from tables and plots, interpreting chemical structure diagrams, and identifying reaction components. Tasks range from simple value extraction to complex spatial reasoning about molecular relationships (e.g., identifying isomeric relationships between compounds).</p>
<p><strong>Experimental Execution</strong> evaluates understanding of laboratory operations and crystallographic analysis. This includes equipment identification, safety assessment of laboratory setups, and interpretation of crystal structure renderings (<a href="https://en.wikipedia.org/wiki/Space_group">space group</a> assignment, atomic species counting, density calculations).</p>
<p><strong>Data Interpretation</strong> tests analysis of experimental outputs: spectral analysis (<a href="https://en.wikipedia.org/wiki/X-ray_diffraction">XRD</a>, <a href="https://en.wikipedia.org/wiki/Nuclear_magnetic_resonance_spectroscopy">NMR</a>, <a href="https://en.wikipedia.org/wiki/Mass_spectrometry">mass spectrometry</a>), electronic structure interpretation, adsorption isotherm analysis, and <a href="https://en.wikipedia.org/wiki/Atomic_force_microscopy">AFM</a> image interpretation.</p>
<p>Each task uses a single prompt template containing multiple questions. All questions pair images with text-based prompts. The dataset was curated manually, with questions reviewed by multiple scientists before inclusion. A BigBench canary string is embedded in each file to prevent data contamination during future model training.</p>
<h2 id="evaluation-of-frontier-vllms-and-ablation-studies">Evaluation of Frontier VLLMs and Ablation Studies</h2>
<p>The authors evaluated four frontier VLLMs: Claude 3.5 Sonnet, GPT-4o, Gemini 1.5 Pro, and Llama 3.2 90B Vision. Performance is reported relative to random baselines to account for the varying number of answer choices across MCQ tasks:</p>
<p>$$
\text{acc}_{\text{rel}} = \text{acc} - \text{acc}_{\text{baseline}}
$$</p>
<p>Each benchmark run was repeated five times to capture variability, with standard deviations reported as error bars.</p>
<h3 id="overall-performance-landscape">Overall Performance Landscape</h3>
<p>Claude 3.5 Sonnet was the leading model across all three task families, though no model dominated across all individual tasks. Key findings:</p>
<ul>
<li><strong>Equipment identification</strong>: average accuracy of 0.77 (strong perception performance)</li>
<li><strong>Hand-drawn molecule to <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> matching</strong>: average accuracy of 0.80</li>
<li><strong>Table composition extraction</strong>: average accuracy of 0.53 (Llama 3.2 indistinguishable from random guessing)</li>
<li><strong>Isomer relationship identification</strong>: average accuracy of 0.24 (barely above the 0.14 baseline)</li>
<li><strong>Laboratory safety assessment</strong>: average accuracy of 0.46</li>
<li><strong>AFM image interpretation</strong>: average accuracy of 0.24</li>
<li><strong>NMR and mass spectrometry analysis</strong>: average accuracy of 0.35</li>
</ul>
<h3 id="ablation-studies-four-dimensions-of-failure">Ablation Studies: Four Dimensions of Failure</h3>
<p>The authors designed ablations isolating four specific dimensions:</p>
<p><strong>1. Modality (Image vs. Text):</strong> When identical information was presented as text instead of images, performance improved consistently across all tasks. For XRD peak identification, models showed a roughly 35% performance increase when peaks were provided as text rather than displayed visually. Even crystal structure volume calculations differed by four percentage points between visual and textual input of unit cell parameters.</p>
<p><strong>2. Multi-Step Reasoning:</strong> Performance degraded consistently as tasks required more reasoning steps. For XRD analysis, identifying the highest peak achieved 0.74 average accuracy, while ranking relative peak intensities dropped to 0.28. Isotherm analysis showed the same pattern: finding the maximum value was easier than ordering multiple values.</p>
<p><strong>3. Scientific Terminology:</strong> Removing domain-specific terminology (e.g., using <a href="https://en.wikipedia.org/wiki/IUPAC_nomenclature_of_organic_chemistry">IUPAC names</a> instead of SMILES notation) improved performance on several tasks, suggesting models are sensitive to specific vocabularies rather than understanding underlying concepts. Gemini 1.5 Pro showed particular sensitivity to exact prompt wording, with large performance variations from minor changes like replacing &ldquo;image&rdquo; with &ldquo;diagram&rdquo; or &ldquo;plot.&rdquo;</p>
<p><strong>4. Guidance:</strong> Adding step-by-step instructions improved performance for most models on spectral analysis and XRD pattern matching, with the notable exception of Claude 3.5 Sonnet, whose performance did not improve with guidance.</p>
<h3 id="internet-frequency-correlation">Internet Frequency Correlation</h3>
<p>The authors measured the correlation between model performance and the number of Google search results for various crystal structures (as a proxy for training data frequency). For all tested cases, structures with correct model responses had higher Internet presence. This effect held even for pure perception tasks like counting atomic species, suggesting models may rely on memorized patterns rather than genuine visual reasoning.</p>
<h2 id="limitations-of-current-vllms-for-scientific-assistance">Limitations of Current VLLMs for Scientific Assistance</h2>
<p>The results reveal three fundamental limitations of current VLLMs:</p>
<p><strong>Spatial reasoning failure:</strong> Models perform well on perception tasks (identifying equipment, matching hand-drawn molecules) but fail when spatial understanding is required (<a href="https://en.wikipedia.org/wiki/Stereochemistry">stereochemistry</a> assignment at 0.24 accuracy, space group identification at 0.45). This limitation undermines one of the most intuitive potential use cases of vision models.</p>
<p><strong>Incomplete cross-modal integration:</strong> The consistent performance gap between text and image presentations of identical information demonstrates that current models have not developed robust strategies for visual information processing. The models process text and images through fundamentally different pathways, with text consistently yielding better results.</p>
<p><strong>Multi-step reasoning brittleness:</strong> The systematic degradation across reasoning steps indicates that chaining logical operations, a core requirement for scientific reasoning, remains a fundamental weakness.</p>
<p>The authors note that compared to text-only benchmarks (e.g., ChemBench), multimodal systems show much higher performance variability across tasks, suggesting greater fragility. They propose that advances in synthetic training data generation (particularly for spatial reasoning) and modality transformation training tasks could help address these limitations. They also acknowledge that future workflows with machine-actionable data formats may reduce the need for some multimodal parsing capabilities.</p>
<p>The benchmark does not encompass the full scope of scientific reasoning, and the evaluated models are not exhaustive of all available architectures. The authors call for continued research across wider task and model sets, along with interpretability studies to distinguish genuine reasoning from pattern matching.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Evaluation</td>
          <td>MaCBench</td>
          <td>779 MCQs + 374 numeric questions</td>
          <td>11 topics across 3 pillars</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>MaCBench-Ablations</td>
          <td>Subset with ablation variants</td>
          <td>Modality, terminology, guidance, step complexity</td>
      </tr>
  </tbody>
</table>
<p>Both datasets are available on HuggingFace. Questions are stored in extended BigBench format with base-64-encoded images and BigBench canary strings.</p>
<h3 id="algorithms">Algorithms</h3>
<p>The evaluation pipeline builds on the ChemBench framework (v0.3.0). Answer extraction uses regex-based parsing backed by an LLM extractor (Claude 3.5 Sonnet) for fallback cases. Refusal detection combines LLM Guard regex patterns with a fine-tuned DistilRoBERTa model, with up to five retries for refused responses.</p>
<p><strong>Scoring:</strong></p>
<ul>
<li>MCQs: correct if <a href="https://en.wikipedia.org/wiki/Hamming_distance">Hamming loss</a> is zero (exact match)</li>
<li>Numeric: correct if mean absolute error falls within specified tolerance (default 1%, up to 5% for specific tasks)</li>
<li>Random baseline: random option selection for MCQs; mean of all target values in a topic for numeric questions</li>
</ul>
<h3 id="models">Models</h3>
<p>Four frontier VLLMs evaluated:</p>
<ul>
<li>Claude 3.5 Sonnet (Anthropic)</li>
<li>GPT-4o (OpenAI)</li>
<li>Gemini 1.5 Pro (Google)</li>
<li>Llama 3.2 90B Vision (Meta)</li>
</ul>
<p>Default quality/resolution settings were used for each provider.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best Model</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Equipment identification</td>
          <td>Average</td>
          <td>0.77</td>
          <td>varies</td>
          <td>Near-ceiling perception</td>
      </tr>
      <tr>
          <td>Hand-drawn molecule matching</td>
          <td>Average</td>
          <td>0.80</td>
          <td>~0.20</td>
          <td>4x above baseline</td>
      </tr>
      <tr>
          <td>Isomer relationship</td>
          <td>Average</td>
          <td>0.24</td>
          <td>0.14</td>
          <td>Near random</td>
      </tr>
      <tr>
          <td>Laboratory safety</td>
          <td>Average</td>
          <td>0.46</td>
          <td>varies</td>
          <td>Below practical utility</td>
      </tr>
      <tr>
          <td>AFM interpretation</td>
          <td>Average</td>
          <td>0.24</td>
          <td>varies</td>
          <td>Near random</td>
      </tr>
      <tr>
          <td>Henry constant comparison</td>
          <td>Average</td>
          <td>0.83</td>
          <td>varies</td>
          <td>Strongest interpretation task</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify hardware requirements. All evaluations were run through commercial API endpoints.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/lamalab-org/macbench">MaCBench Repository</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Benchmark data and evaluation card</td>
      </tr>
      <tr>
          <td><a href="https://github.com/lamalab-org/chembench">ChemBench Framework</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Evaluation pipeline (v0.3.0)</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/datasets/kjappelbaum/MaCBench">MaCBench Dataset</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>1,153 questions with images</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/datasets/kjappelbaum/MaCBench-Ablations">MaCBench-Ablations</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>Ablation task variants</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.14935487">ChemBench v0.3.0 (Zenodo)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Archived release</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation:</strong> Alampara, N., Schilling-Wilhelmi, M., Ríos-García, M., Mandal, I., Khetarpal, P., Grover, H. S., Krishnan, N. M. A., &amp; Jablonka, K. M. (2025). Probing the limitations of multimodal language models for chemistry and materials research. <em>Nature Computational Science</em>, 5(10), 952-961. <a href="https://doi.org/10.1038/s43588-025-00836-3">https://doi.org/10.1038/s43588-025-00836-3</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{alampara2025macbench,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Probing the limitations of multimodal language models for chemistry and materials research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Alampara, Nawaf and Schilling-Wilhelmi, Mara and R{\&#39;\i}os-Garc{\&#39;\i}a, Marti{\~n}o and Mandal, Indrajeet and Khetarpal, Pranav and Grover, Hargun Singh and Krishnan, N. M. Anoop and Jablonka, Kevin Maik}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Computational Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{10}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{952--961}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s43588-025-00836-3}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>LMs Generate 3D Molecules from XYZ, CIF, PDB Files</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/3d-chemical-language-models-xyz-cif-pdb/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/3d-chemical-language-models-xyz-cif-pdb/</guid><description>Transformer language models trained on XYZ, CIF, and PDB sequences generate valid 3D molecules, crystals, and protein binding sites.</description><content:encoded><![CDATA[<h2 id="language-models-as-3d-chemical-structure-generators">Language Models as 3D Chemical Structure Generators</h2>
<p>This is a <strong>Method</strong> paper that demonstrates transformer-based language models can generate molecules, crystalline materials, and protein binding sites directly in three dimensions by training on sequences derived from standard chemical file formats (XYZ, CIF, PDB). The key contribution is showing that unmodified autoregressive language models, using only next-token prediction, achieve performance comparable to domain-specific 3D generative models that incorporate SE(3) equivariance and other geometric inductive biases.</p>
<h2 id="beyond-graphs-and-strings-the-need-for-3d-chemical-generation">Beyond Graphs and Strings: The Need for 3D Chemical Generation</h2>
<p>Molecular design with deep learning has largely relied on two representation paradigms: molecular graphs (processed with graph neural networks) and linearized string representations like <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> and <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> (processed with sequence models). Both approaches have proven effective for drug-like organic molecules, but they share a fundamental limitation: they cannot represent structures whose identity depends on 3D spatial arrangement.</p>
<p>Crystalline materials, for example, have periodic lattice structures that cannot be reduced to simple graphs. Protein binding sites are defined by the 3D arrangement of hundreds of atoms across multiple residues. For tasks like catalysis design or structure-based drug discovery, the geometric positions of atoms are essential information that graphs and strings discard entirely.</p>
<p>Existing 3D generative models address this gap but typically require specialized architectures with SE(3) equivariance to handle rotational and translational symmetries. This work asks whether the general-purpose sequence modeling capability of transformers is sufficient to learn 3D chemical structure distributions without any domain-specific architectural modifications.</p>
<h2 id="direct-tokenization-of-chemical-file-formats">Direct Tokenization of Chemical File Formats</h2>
<p>The core insight is straightforward: any 3D molecule, crystal, or biomolecule is already stored as text in standard file formats (<a href="https://en.wikipedia.org/wiki/XYZ_file_format">XYZ</a>, <a href="https://en.wikipedia.org/wiki/Crystallographic_Information_File">CIF</a>, <a href="https://en.wikipedia.org/wiki/Protein_Data_Bank_(file_format)">PDB</a>). These files encode atom types and their Cartesian coordinates as sequences of characters and numbers. Rather than designing specialized architectures for point cloud generation, the authors simply tokenize these files and train a standard GPT-style transformer to predict the next token.</p>
<p>A molecule with $n$ atoms is represented as:</p>
<p>$$
\mathcal{M} = (e_1, x_1, y_1, z_1, \dots, e_n, x_n, y_n, z_n)
$$</p>
<p>where $e_i$ is the element type and $(x_i, y_i, z_i)$ are Cartesian coordinates. Crystals additionally include lattice parameters:</p>
<p>$$
\mathcal{C} = (\ell_a, \ell_b, \ell_c, \alpha, \beta, \gamma, e_1, x_1, y_1, z_1, \dots, e_n, x_n, y_n, z_n)
$$</p>
<p>Protein binding sites use residue-atom indicators (e.g., HIS-C, CYS-N) instead of bare element symbols:</p>
<p>$$
\mathcal{P} = (a_1, x_1, y_1, z_1, \dots, a_n, x_n, y_n, z_n)
$$</p>
<p>The language model learns the joint distribution via autoregressive factorization:</p>
<p>$$
p(x) = \prod_{i=1}^{n} p(t_i \mid t_{i-1}, \dots, t_1)
$$</p>
<p>Two tokenization strategies are explored:</p>
<ol>
<li><strong>Character-level (LM-CH)</strong>: Every character in the file is a token, including digits, minus signs, spaces, and newlines. This produces long sequences but uses a small vocabulary (~30 tokens).</li>
<li><strong>Atom+coordinate-level (LM-AC)</strong>: Each atom placement requires exactly 4 tokens: one element/residue token and three coordinate tokens (e.g., &lsquo;-1.98&rsquo;). The vocabulary is larger (~100-10K tokens) but sequences are shorter.</li>
</ol>
<p>Numerical precision is controlled by rounding coordinates to 1, 2, or 3 decimal places. Since the model lacks rotation and translation invariance, random rotation augmentation during training improves performance.</p>
<h2 id="experiments-across-molecules-crystals-and-protein-binding-sites">Experiments Across Molecules, Crystals, and Protein Binding Sites</h2>
<h3 id="molecular-generation-zinc">Molecular Generation (ZINC)</h3>
<p>The model is evaluated on 250K commercially available molecules from the ZINC dataset, with an average of 23 heavy atoms. XYZ files are generated using RDKit&rsquo;s conformer tools. Coordinates use 2 decimal places of precision. The authors generate 10K molecules and evaluate both 3D geometry quality and standard generative metrics.</p>
<p>For 3D geometry assessment, root mean squared deviation (RMSD) between language model-generated conformers and RDKit-generated conformers shows most molecules fall between 1.0 and 2.0 RMSD, with a heavy tail extending to 4.0.</p>
<p>Standard metrics include validity, uniqueness, novelty, and earth mover&rsquo;s distance (WA) for molecular property distributions (QED, SA score, molecular weight).</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>3D</th>
          <th>Valid (%)</th>
          <th>Unique (%)</th>
          <th>Novel (%)</th>
          <th>WA MW</th>
          <th>WA SA</th>
          <th>WA QED</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Train</td>
          <td>No</td>
          <td>100.0</td>
          <td>100.0</td>
          <td>100.0</td>
          <td>0.816</td>
          <td>0.013</td>
          <td>0.002</td>
      </tr>
      <tr>
          <td>SM-LM</td>
          <td>No</td>
          <td>98.35</td>
          <td>100.0</td>
          <td>100.0</td>
          <td>3.640</td>
          <td>0.049</td>
          <td>0.005</td>
      </tr>
      <tr>
          <td>SF-LM</td>
          <td>No</td>
          <td>100.0</td>
          <td>100.0</td>
          <td>100.0</td>
          <td>3.772</td>
          <td>0.085</td>
          <td>0.006</td>
      </tr>
      <tr>
          <td>JTVAE</td>
          <td>No</td>
          <td>100.0</td>
          <td>98.56</td>
          <td>100.0</td>
          <td>22.63</td>
          <td>0.126</td>
          <td>0.023</td>
      </tr>
      <tr>
          <td>ENF</td>
          <td>Yes</td>
          <td>1.05</td>
          <td>96.37</td>
          <td>99.72</td>
          <td>168.5</td>
          <td>1.886</td>
          <td>0.160</td>
      </tr>
      <tr>
          <td>G-SchNet</td>
          <td>Yes</td>
          <td>1.20</td>
          <td>55.96</td>
          <td>98.33</td>
          <td>152.7</td>
          <td>1.126</td>
          <td>0.185</td>
      </tr>
      <tr>
          <td>EDM</td>
          <td>Yes</td>
          <td>77.51</td>
          <td>96.40</td>
          <td>95.30</td>
          <td>101.2</td>
          <td>0.939</td>
          <td>0.093</td>
      </tr>
      <tr>
          <td>LM-CH</td>
          <td>Yes</td>
          <td>90.13</td>
          <td>100.0</td>
          <td>100.0</td>
          <td>3.912</td>
          <td>2.608</td>
          <td>0.077</td>
      </tr>
      <tr>
          <td>LM-AC</td>
          <td>Yes</td>
          <td>98.51</td>
          <td>100.0</td>
          <td>100.0</td>
          <td>1.811</td>
          <td>0.026</td>
          <td>0.004</td>
      </tr>
  </tbody>
</table>
<p>The atom+coordinate tokenization model (LM-AC) achieves 98.51% validity with 100% uniqueness and novelty. Its WA scores for molecular weight (1.811) and QED (0.004) are substantially better than all other 3D generative baselines and competitive with SMILES/SELFIES language models. The character-level model (LM-CH) at 90.13% validity performs comparably to graph-based models but falls short of the string-based language models.</p>
<h3 id="crystal-generation-perov-5-and-mp-20">Crystal Generation (Perov-5 and MP-20)</h3>
<p>Crystal generation uses CIF-derived sequences with 3 decimal places of precision. Two datasets are used: Perov-5 (18,928 <a href="https://en.wikipedia.org/wiki/Perovskite_(structure)">perovskite</a> materials, 5 atoms per unit cell, 56 elements) and MP-20 (45,231 diverse materials, 1-20 atoms per unit cell, 89 elements).</p>
<p>Evaluation metrics include structural validity (minimum interatomic distance &gt; 0.5 angstrom), compositional validity (charge neutrality via SMACT), coverage (recall and precision between generated and test sets), and earth mover&rsquo;s distance for density and number of unique elements.</p>
<table>
  <thead>
      <tr>
          <th>Data</th>
          <th>Model</th>
          <th>Struc. Valid (%)</th>
          <th>Comp. Valid (%)</th>
          <th>COV-R (%)</th>
          <th>COV-P (%)</th>
          <th>WA density</th>
          <th>WA elements</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Perov-5</td>
          <td>CDVAE</td>
          <td>100.0</td>
          <td>98.59</td>
          <td>99.45</td>
          <td>98.46</td>
          <td>0.126</td>
          <td>0.063</td>
      </tr>
      <tr>
          <td>Perov-5</td>
          <td>LM-CH</td>
          <td>100.0</td>
          <td>98.51</td>
          <td>99.60</td>
          <td>99.42</td>
          <td>0.071</td>
          <td>0.036</td>
      </tr>
      <tr>
          <td>Perov-5</td>
          <td>LM-AC</td>
          <td>100.0</td>
          <td>98.79</td>
          <td>98.78</td>
          <td>99.36</td>
          <td>0.089</td>
          <td>0.028</td>
      </tr>
      <tr>
          <td>MP-20</td>
          <td>CDVAE</td>
          <td>100.0</td>
          <td>86.70</td>
          <td>99.15</td>
          <td>99.49</td>
          <td>0.688</td>
          <td>1.432</td>
      </tr>
      <tr>
          <td>MP-20</td>
          <td>LM-CH</td>
          <td>84.81</td>
          <td>83.55</td>
          <td>99.25</td>
          <td>97.89</td>
          <td>0.864</td>
          <td>0.132</td>
      </tr>
      <tr>
          <td>MP-20</td>
          <td>LM-AC</td>
          <td>95.81</td>
          <td>88.87</td>
          <td>99.60</td>
          <td>98.55</td>
          <td>0.696</td>
          <td>0.092</td>
      </tr>
  </tbody>
</table>
<p>On Perov-5, both language models outperform CDVAE across most metrics. On the more diverse MP-20 dataset, LM-AC achieves the best scores on 3 of 6 metrics and remains competitive on the others. LM-CH struggles more with structural validity on MP-20 (84.81%).</p>
<h3 id="protein-binding-site-generation-pdb">Protein Binding Site Generation (PDB)</h3>
<p>The most challenging task involves generating protein binding sites (~200-250 atoms each) from PDB-derived sequences. The dataset contains approximately 180K protein-ligand pairs. Residue-atom tokenization is used (e.g., CYS-C, CYS-N), with 2 decimal places of precision.</p>
<p>Validity is assessed per-residue using xyz2mol, with an additional check for inter-residue atomic overlap (atoms from different residues closer than the minimum bond distance). Approximately 99% of generated pockets pass the residue validity check, while about 5% fail the overlap check. Of generated pockets, 89.8% have unique residue orderings, and 83.6% have novel orderings not seen in training, indicating the model is generating novel binding site structures rather than memorizing.</p>
<h2 id="competitive-3d-generation-without-geometric-inductive-biases">Competitive 3D Generation Without Geometric Inductive Biases</h2>
<p>The central finding is that standard transformer language models, without any equivariance or geometric inductive biases, can generate valid 3D chemical structures across three substantially different domains. The atom+coordinate tokenization (LM-AC) consistently outperforms character-level tokenization (LM-CH), likely because it produces shorter sequences and reduces the number of sequential decisions needed per atom placement.</p>
<p>Several limitations are worth noting. The model generates atoms using absolute Cartesian coordinates, which means it must learn rotation and translation invariance purely from data augmentation rather than having it built into the architecture. The authors acknowledge this becomes increasingly difficult as structure size grows. The vocabulary size also scales with coordinate precision and structure complexity, which could become prohibitive for very large systems.</p>
<p>The paper does not include computational cost comparisons with baseline models, making it difficult to assess the practical tradeoff between the simplicity of the language modeling approach and the efficiency of specialized architectures. The authors also note that further validation through computational simulation and experiment is needed to confirm the physical plausibility of generated structures.</p>
<p>Future directions identified include inverse design of molecules and materials conditioned on target properties, extension to more complex structures (metal-organic frameworks), and exploration of alternative tokenization strategies to handle larger systems.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Eval</td>
          <td>ZINC</td>
          <td>250K molecules</td>
          <td>~23 heavy atoms avg; XYZ files via RDKit conformer generation</td>
      </tr>
      <tr>
          <td>Training/Eval</td>
          <td>Perov-5</td>
          <td>18,928 perovskites</td>
          <td>5 atoms/unit cell, 56 elements</td>
      </tr>
      <tr>
          <td>Training/Eval</td>
          <td>MP-20</td>
          <td>45,231 materials</td>
          <td>1-20 atoms/unit cell, 89 elements</td>
      </tr>
      <tr>
          <td>Training/Eval</td>
          <td>Protein binding sites</td>
          <td>~180K protein-ligand pairs</td>
          <td>Processed to 200-250 atoms per pocket</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Architecture</strong>: GPT-style transformer with ~1M to 100M parameters</li>
<li><strong>Layers</strong>: 12</li>
<li><strong>Embedding size</strong>: 128 to 1024</li>
<li><strong>Attention heads</strong>: 4 to 12</li>
<li><strong>Batch size</strong>: 4 to 32 structures</li>
<li><strong>Learning rate</strong>: $10^{-4}$ to $10^{-5}$, decayed to $9 \times 10^{-6}$</li>
<li><strong>Data augmentation</strong>: Random rotation of training structures at each epoch</li>
<li><strong>Numerical precision</strong>: 2 decimal places (molecules, proteins), 3 decimal places (crystals)</li>
</ul>
<h3 id="models">Models</h3>
<p>No pre-trained model weights are publicly available. The paper mentions &ldquo;Example code can be found at&rdquo; but the URL appears to be missing from the published version.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Domain</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity</td>
          <td>Molecules</td>
          <td>xyz2mol produces valid RDKit Mol object</td>
      </tr>
      <tr>
          <td>Validity</td>
          <td>Crystals</td>
          <td>Structural (min distance &gt; 0.5 angstrom) and compositional (charge neutral)</td>
      </tr>
      <tr>
          <td>Uniqueness</td>
          <td>All</td>
          <td>Fraction of distinct generated structures</td>
      </tr>
      <tr>
          <td>Novelty</td>
          <td>All</td>
          <td>Fraction not in training set</td>
      </tr>
      <tr>
          <td>Earth mover&rsquo;s distance</td>
          <td>All</td>
          <td>Distribution match for domain-specific properties</td>
      </tr>
      <tr>
          <td>RMSD</td>
          <td>Molecules</td>
          <td>Deviation from RDKit conformer geometries</td>
      </tr>
      <tr>
          <td>Coverage</td>
          <td>Crystals</td>
          <td>Recall and precision between generated and test sets</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Models were trained using the Canada Computing Systems (Compute Canada). Specific GPU types, counts, and training times are not reported.</p>
<h3 id="artifacts">Artifacts</h3>
<p>No public code repository, model weights, or datasets specific to this work were found. The ZINC, Perov-5, and MP-20 datasets used for evaluation are publicly available from their original sources.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Flam-Shepherd, D. &amp; Aspuru-Guzik, A. (2023). Language models can generate molecules, materials, and protein binding sites directly in three dimensions as XYZ, CIF, and PDB files. <em>arXiv preprint arXiv:2305.05708</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{flamshepherd2023language,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Language models can generate molecules, materials, and protein binding sites directly in three dimensions as {XYZ}, {CIF}, and {PDB} files}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Flam-Shepherd, Daniel and Aspuru-Guzik, Al{\&#39;a}n}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2305.05708}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>LLM4Mol: ChatGPT Captions as Molecular Representations</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/llm4mol-captions-as-representations/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/llm4mol-captions-as-representations/</guid><description>LLM4Mol uses ChatGPT to generate text explanations for SMILES strings and fine-tunes RoBERTa on these captions for molecular property prediction.</description><content:encoded><![CDATA[<h2 id="llm-generated-text-as-molecular-representations">LLM-Generated Text as Molecular Representations</h2>
<p>This is a <strong>Method</strong> paper that proposes using large language models (specifically ChatGPT) to generate natural language explanations for molecules represented as SMILES strings, and then using those explanations as input representations for downstream molecular property prediction. The approach is called <strong>Captions as new Representations (CaR)</strong>. The authors also evaluate ChatGPT directly on zero-shot and few-shot molecular classification to gauge in-context learning ability on chemical data.</p>
<h2 id="bridging-molecular-data-and-natural-language-understanding">Bridging Molecular Data and Natural Language Understanding</h2>
<p>Molecular property prediction is central to <a href="https://en.wikipedia.org/wiki/Virtual_screening">virtual screening</a>, drug discovery, and materials design. Molecules are typically represented either as graphs (processed by GNNs) or as <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES strings</a> (processed by NLP-based methods). While both paradigms have shown success, they do not directly use the broad world knowledge embedded in large language models.</p>
<p>LLMs such as ChatGPT demonstrate strong capabilities in text understanding and can generate informative descriptions when given SMILES strings, including functional groups, chemical properties, and potential pharmaceutical applications. The question motivating this work is whether LLM-generated textual descriptions can serve as better molecular representations than raw SMILES or graph encodings for property prediction tasks.</p>
<p>Prior work had not systematically explored two directions: (1) whether LLMs can perform molecular classification via in-context learning, and (2) whether LLM-generated captions can serve as transferable representations for small downstream models.</p>
<h2 id="captions-as-representations-car">Captions as Representations (CaR)</h2>
<p>The core contribution is the CaR framework, which operates in two stages:</p>
<ol>
<li>
<p><strong>Caption generation</strong>: Given a molecule&rsquo;s SMILES string, ChatGPT is prompted to produce a detailed textual explanation covering functional groups, chemical properties, and potential applications.</p>
</li>
<li>
<p><strong>Fine-tuning a small LM</strong>: The generated text explanations replace the original SMILES as input to a pre-trained language model (e.g., RoBERTa). This small LM is then fine-tuned on downstream classification or regression tasks.</p>
</li>
</ol>
<p>The insight is that ChatGPT&rsquo;s world knowledge can enrich the molecular representation with semantically meaningful features that raw SMILES lack. For example, on the PTC (Predictive Toxicology Challenge) dataset, the authors performed keyword searches for terms like &ldquo;toxicity&rdquo;, &ldquo;cancer&rdquo;, and &ldquo;harmful&rdquo; in the ChatGPT-generated explanations and found that these keywords appeared predominantly in entries labeled as toxic, indicating that the generated captions carry predictive signal.</p>
<p>The authors also explore <strong>in-context molecular classification</strong>, where ChatGPT is directly prompted with zero or few examples to classify molecules. This serves as a preliminary evaluation of LLM reasoning capabilities on molecular data.</p>
<h2 id="experimental-setup-and-benchmarks">Experimental Setup and Benchmarks</h2>
<h3 id="datasets">Datasets</h3>
<p>The evaluation spans 9 datasets across classification and regression:</p>
<ul>
<li><strong>Classification (TUDataset)</strong>: MUTAG, PTC, AIDS</li>
<li><strong>Classification (<a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>)</strong>: SIDER, ClinTox, BACE, BBBP</li>
<li><strong>Regression (MoleculeNet)</strong>: ESOL, <a href="https://en.wikipedia.org/wiki/Lipophilicity">Lipophilicity</a></li>
</ul>
<h3 id="baselines">Baselines</h3>
<p>Baselines include GNN-based methods (GCN, GIN, ChebyNet, D-MPNN, GraphMVP, InfoGraph, G-Motif, Mole-BERT) and SMILES-based methods (ECFP4-MLP, <a href="/notes/chemistry/molecular-representations/encoders/smiles-transformer/">SMILES-Transformer</a>, MolR, <a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a>, MolKD).</p>
<h3 id="splitting-strategies">Splitting Strategies</h3>
<ul>
<li><strong>Random splitting</strong>: 8/1/1 train/validate/test with 10-fold cross-validation</li>
<li><strong>Scaffold splitting</strong>: 5 random seeds, reported as mean and standard deviation</li>
</ul>
<h3 id="key-results-random-splitting">Key Results: Random Splitting</h3>
<p>Under random splitting, CaR-RoBERTa achieves the best results on almost all datasets:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>MUTAG (ACC)</th>
          <th>PTC (ACC)</th>
          <th>AIDS (ACC)</th>
          <th>SIDER (AUC)</th>
          <th>ClinTox (AUC)</th>
          <th>ESOL (RMSE)</th>
          <th>Lipo (RMSE)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GCN</td>
          <td>90.00</td>
          <td>62.57</td>
          <td>78.68</td>
          <td>64.24</td>
          <td>91.88</td>
          <td>0.77</td>
          <td>0.80</td>
      </tr>
      <tr>
          <td>GIN</td>
          <td>89.47</td>
          <td>58.29</td>
          <td>78.01</td>
          <td>66.19</td>
          <td>92.08</td>
          <td>0.67</td>
          <td>0.79</td>
      </tr>
      <tr>
          <td>ECFP4-MLP</td>
          <td>96.84</td>
          <td>85.71</td>
          <td>94.64</td>
          <td>90.19</td>
          <td>95.81</td>
          <td>0.60</td>
          <td>0.60</td>
      </tr>
      <tr>
          <td>CaR-RoBERTa</td>
          <td>91.05</td>
          <td>93.14</td>
          <td>94.37</td>
          <td>88.81</td>
          <td>99.80</td>
          <td>0.45</td>
          <td>0.47</td>
      </tr>
  </tbody>
</table>
<p>CaR-RoBERTa improves over the best GNN by up to 53% on PTC and reduces RMSE by 35-37% on regression tasks. However, ECFP4-MLP outperforms CaR on MUTAG (96.84 vs. 91.05).</p>
<h3 id="key-results-scaffold-splitting">Key Results: Scaffold Splitting</h3>
<p>Under the more challenging scaffold splitting:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>SIDER (AUC)</th>
          <th>ClinTox (AUC)</th>
          <th>BACE (AUC)</th>
          <th>BBBP (AUC)</th>
          <th>ESOL (RMSE)</th>
          <th>Lipo (RMSE)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GraphMVP-C</td>
          <td>63.90</td>
          <td>77.50</td>
          <td>81.20</td>
          <td>72.40</td>
          <td>1.03</td>
          <td>0.68</td>
      </tr>
      <tr>
          <td>Mole-BERT</td>
          <td>62.80</td>
          <td>78.90</td>
          <td>80.80</td>
          <td>71.90</td>
          <td>1.02</td>
          <td>0.68</td>
      </tr>
      <tr>
          <td>MolKD</td>
          <td>61.30</td>
          <td>83.80</td>
          <td>80.10</td>
          <td>74.80</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>CaR-RoBERTa</td>
          <td>58.06</td>
          <td>84.16</td>
          <td>80.73</td>
          <td>81.99</td>
          <td>0.96</td>
          <td>1.02</td>
      </tr>
  </tbody>
</table>
<p>Results are more mixed under scaffold splitting. CaR achieves the best performance on ClinTox (+30% over GNNs) and BBBP (+15%), but underperforms on SIDER and Lipophilicity.</p>
<h3 id="few-shot-classification-with-chatgpt">Few-Shot Classification with ChatGPT</h3>
<p>Direct few-shot classification with ChatGPT shows mixed results. On MUTAG, ChatGPT underperforms classical methods across all shot counts. On PTC, ChatGPT outperforms GNNs in the few-shot regime. Performance improves with increasing number of shots, but results are inconsistent across different prompts.</p>
<h3 id="replacing-the-small-lm">Replacing the Small LM</h3>
<p>The authors test CaR with different downstream models: RoBERTa, DeBERTa, and an adaptive language model for molecules. Pre-trained models all perform similarly, and all outperform a DeBERTa trained from scratch, validating that CaR&rsquo;s effectiveness comes from the caption quality rather than the specific choice of downstream model.</p>
<h2 id="findings-limitations-and-future-directions">Findings, Limitations, and Future Directions</h2>
<h3 id="key-findings">Key Findings</h3>
<ol>
<li>ChatGPT-generated text explanations serve as effective molecular representations, outperforming GNNs and SMILES-based methods on most benchmarks under random splitting.</li>
<li>ChatGPT has some capacity for few-shot molecular classification, but performance is inconsistent and prompt-sensitive.</li>
<li>The CaR approach is model-agnostic: different pre-trained small LMs achieve similar results when fine-tuned on the generated captions.</li>
<li>Under scaffold splitting, CaR shows strong results on some datasets (ClinTox, BBBP) but underperforms on others (SIDER, Lipophilicity).</li>
</ol>
<h3 id="limitations-acknowledged-by-the-authors">Limitations Acknowledged by the Authors</h3>
<ul>
<li><strong>Single LLM</strong>: Only ChatGPT was used. Other LLMs (GPT-4, domain-specific models like MolReGPT) were not evaluated.</li>
<li><strong>No graph structure integration</strong>: CaR treats molecular prediction purely as an NLP task and does not incorporate structural graph information, which is known to be important for molecular properties.</li>
<li><strong>Limited to small molecules</strong>: The approach works only for molecules representable as SMILES. Proteins, antibodies, and other large biomolecules with 3D structure are not addressed.</li>
</ul>
<h3 id="additional-considerations">Additional Considerations</h3>
<p>The random splitting results are notably strong, but random splits tend to overestimate performance compared to scaffold splits, which test generalization to structurally novel molecules. The high variance on some scaffold-split results (e.g., ClinTox with 17.63 standard deviation) suggests instability. The reliance on a proprietary API (ChatGPT) also limits reproducibility and introduces cost constraints for large-scale applications.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Classification</td>
          <td>MUTAG (TUDataset)</td>
          <td>188 molecules</td>
          <td>Mutagenicity prediction</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>PTC (TUDataset)</td>
          <td>344 molecules</td>
          <td>Predictive toxicology</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>AIDS (TUDataset)</td>
          <td>2,000 molecules</td>
          <td>HIV activity</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>SIDER (MoleculeNet)</td>
          <td>1,427 molecules</td>
          <td>Side effect prediction</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>ClinTox (MoleculeNet)</td>
          <td>1,478 molecules</td>
          <td>Clinical trial toxicity</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BACE (MoleculeNet)</td>
          <td>1,513 molecules</td>
          <td><a href="https://en.wikipedia.org/wiki/Beta-secretase_1">Beta-secretase</a> inhibition</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BBBP (MoleculeNet)</td>
          <td>2,039 molecules</td>
          <td>Blood-brain barrier penetration</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>ESOL (MoleculeNet)</td>
          <td>1,128 molecules</td>
          <td>Aqueous solubility</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>Lipophilicity (MoleculeNet)</td>
          <td>4,200 molecules</td>
          <td>Lipophilicity</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>ChatGPT (GPT-3.5) generates textual explanations for SMILES strings</li>
<li>RoBERTa is fine-tuned on generated captions using HuggingFace Transformers with default parameters</li>
<li>10-fold cross-validation for random split; 5 random seeds for scaffold split</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>ChatGPT (GPT-3.5) for caption generation</li>
<li>RoBERTa-base for downstream fine-tuning (default HuggingFace parameters)</li>
<li>DeBERTa and adaptive-lm-molecules tested as alternatives</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>Classification: accuracy (ACC) and ROC-AUC</li>
<li>Regression: RMSE</li>
<li>Mean and standard deviation reported across folds/seeds</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/ChnQ/LLM4Mol">LLM4Mol</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Qian, C., Tang, H., Yang, Z., Liang, H., &amp; Liu, Y. (2023). Can Large Language Models Empower Molecular Property Prediction? <em>arXiv preprint arXiv:2307.07443</em>. <a href="https://arxiv.org/abs/2307.07443">https://arxiv.org/abs/2307.07443</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{qian2023can,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Can Large Language Models Empower Molecular Property Prediction?}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Qian, Chen and Tang, Huayi and Yang, Zhirui and Liang, Hong and Liu, Yong}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2307.07443}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.48550/arxiv.2307.07443}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>LLM-Prop: Predicting Crystal Properties from Text</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/llm-prop-crystal-property-prediction/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/llm-prop-crystal-property-prediction/</guid><description>LLM-Prop fine-tunes the T5 encoder on crystal text descriptions to predict band gap, volume, and other properties, outperforming GNN baselines.</description><content:encoded><![CDATA[<h2 id="text-based-crystal-property-prediction-with-llms">Text-Based Crystal Property Prediction with LLMs</h2>
<p>LLM-Prop is a <strong>Method</strong> paper that proposes using the encoder portion of <a href="https://en.wikipedia.org/wiki/T5_(language_model)">T5</a> (a general-purpose language model) fine-tuned on crystal text descriptions to predict physical and electronic properties of crystalline materials. The primary contribution is demonstrating that text-based representations of crystals, generated by Robocrystallographer, can serve as effective inputs for <a href="/notes/chemistry/molecular-design/property-prediction/">property prediction</a>, outperforming graph neural network (GNN) baselines on several tasks despite using a non-domain-specific pre-trained model with fewer parameters.</p>
<h2 id="why-text-instead-of-crystal-graphs">Why Text Instead of Crystal Graphs?</h2>
<p>Graph neural networks have been the dominant approach for crystal property prediction. Models like CGCNN, MEGNet, and ALIGNN represent crystals as graphs where atoms are nodes and bonds are edges. However, GNNs face several fundamental challenges for crystals:</p>
<ol>
<li><strong>Periodicity encoding</strong>: Crystals have repetitive unit cell arrangements that are distinct from standard molecular graphs, and GNNs struggle to encode this periodicity efficiently.</li>
<li><strong>Information incorporation</strong>: Critical structural information like bond angles, <a href="https://en.wikipedia.org/wiki/Space_group">space group</a> symmetry, and <a href="https://en.wikipedia.org/wiki/Wyckoff_positions">Wyckoff sites</a> is difficult to incorporate into graph representations.</li>
<li><strong>Expressiveness</strong>: Graphs may lack the expressiveness needed to convey complex crystal information relevant to property prediction.</li>
</ol>
<p>Meanwhile, textual descriptions of crystals (generated by tools like Robocrystallographer) naturally encode space group information, bond geometries, coordination environments, and symmetry details in human-readable form. Despite this richness, text-based approaches for crystal property prediction had been largely unexplored.</p>
<h2 id="core-innovation-t5-encoder-with-careful-fine-tuning">Core Innovation: T5 Encoder with Careful Fine-Tuning</h2>
<p>The key insight of LLM-Prop is to take a pre-trained encoder-decoder model (<a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5</a>-small) and discard the decoder entirely, using only the encoder with a linear prediction head. This design has several advantages:</p>
<ul>
<li>Cutting the network in half (from ~60M to ~37M parameters) allows processing of longer input sequences</li>
<li>Longer sequences mean more crystal information can be included</li>
<li>The encoder-only approach avoids T5&rsquo;s known weakness at regression in text-to-text format</li>
</ul>
<p>The framework applies several preprocessing strategies to the crystal text descriptions:</p>
<ol>
<li><strong>Stopword removal</strong>: Standard English stopwords are removed, except digits and symbols carrying chemical information</li>
<li><strong>Numerical token replacement</strong>: Bond distances are replaced with a <code>[NUM]</code> token and bond angles with <code>[ANG]</code>, reducing sequence length while preserving structural cues</li>
<li><strong>[CLS] token prepending</strong>: A classification token is added at the start, and its learned embedding is used as input to the prediction layer</li>
<li><strong>Label scaling</strong>: For regression tasks, targets are normalized using z-score, min-max, or log normalization</li>
</ol>
<p>The normalization schemes are defined as:</p>
<p>$$
\hat{Y}_{i}(\text{z-score}) = \frac{Y_{i} - \mu}{\sigma}
$$</p>
<p>$$
\hat{Y}_{i}(\text{min-max}) = \frac{Y_{i} - Y_{\min}}{Y_{\max} - Y_{\min}}
$$</p>
<p>$$
\hat{Y}_{i}(\text{log-norm}) = \log(Y_{i} + 1)
$$</p>
<p>The tokenizer is also retrained on the crystal text corpus with a vocabulary size of 32k, and the special tokens <code>[NUM]</code>, <code>[ANG]</code>, and <code>[CLS]</code> are added to the vocabulary.</p>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<h3 id="dataset-textedge">Dataset: TextEdge</h3>
<p>The authors collected data from the <a href="https://en.wikipedia.org/wiki/Materials_Project">Materials Project</a> database (as of November 2022), yielding 144,931 crystal structure-description pairs split into 125,098 training, 9,945 validation, and 9,888 test samples. Crystal text descriptions were generated using Robocrystallographer. The dataset covers six prediction tasks:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Type</th>
          <th>Metric</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Band gap (eV)</td>
          <td>Regression</td>
          <td>MAE (lower is better)</td>
      </tr>
      <tr>
          <td>Unit cell volume (A^3/cell)</td>
          <td>Regression</td>
          <td>MAE (lower is better)</td>
      </tr>
      <tr>
          <td>Formation energy per atom (eV/atom)</td>
          <td>Regression</td>
          <td>MAE (lower is better)</td>
      </tr>
      <tr>
          <td>Energy per atom (eV/atom)</td>
          <td>Regression</td>
          <td>MAE (lower is better)</td>
      </tr>
      <tr>
          <td>Energy above hull (eV/atom)</td>
          <td>Regression</td>
          <td>MAE (lower is better)</td>
      </tr>
      <tr>
          <td>Is-gap-direct</td>
          <td>Classification</td>
          <td>AUC (higher is better)</td>
      </tr>
  </tbody>
</table>
<h3 id="baselines">Baselines</h3>
<p>Seven baselines were compared:</p>
<ul>
<li><strong>GNN-based</strong>: CGCNN, MEGNet, ALIGNN, DeeperGATGNN</li>
<li><strong>Classic ML</strong>: XGBoost, Random Forest (on Robocrystallographer features)</li>
<li><strong>Text-based</strong>: MatBERT (domain-specific pre-trained BERT, ~110M parameters)</li>
</ul>
<p>All models were trained and evaluated on the same dataset splits for fair comparison. GNN models were retrained on the new data rather than using results from older, smaller Materials Project versions.</p>
<h3 id="main-results-llm-prop-vs-gnn-baselines">Main Results: LLM-Prop vs. GNN Baselines</h3>
<p>When using crystal text descriptions as input, LLM-Prop achieved:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Band gap (eV)</th>
          <th>Volume (A^3/cell)</th>
          <th>FEPA (eV/atom)</th>
          <th>EPA (eV/atom)</th>
          <th>Ehull (eV/atom)</th>
          <th>Is-gap-direct (AUC)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CGCNN</td>
          <td>0.293</td>
          <td>188.834</td>
          <td>0.046</td>
          <td>0.082</td>
          <td>0.040</td>
          <td>0.830</td>
      </tr>
      <tr>
          <td>MEGNet</td>
          <td>0.304</td>
          <td>297.948</td>
          <td>0.077</td>
          <td>0.056</td>
          <td>0.051</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>ALIGNN</td>
          <td>0.250</td>
          <td>129.580</td>
          <td>0.027</td>
          <td>0.059</td>
          <td>0.028</td>
          <td>0.678</td>
      </tr>
      <tr>
          <td>DeeperGATGNN</td>
          <td>0.291</td>
          <td>111.857</td>
          <td>0.081</td>
          <td>0.116</td>
          <td>0.045</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>LLM-Prop (Descr.)</td>
          <td><strong>0.231</strong></td>
          <td><strong>39.252</strong></td>
          <td>0.056</td>
          <td>0.067</td>
          <td>0.047</td>
          <td><strong>0.857</strong></td>
      </tr>
  </tbody>
</table>
<p>LLM-Prop outperformed the best GNN baseline (ALIGNN) by approximately 8% on <a href="https://en.wikipedia.org/wiki/Band_gap">band gap</a> prediction, 65% on volume prediction, and 3% on band gap classification (Is-gap-direct). For formation energy per atom, energy per atom, and energy above hull, ALIGNN retained an advantage.</p>
<h3 id="llm-prop-vs-matbert">LLM-Prop vs. MatBERT</h3>
<p>LLM-Prop also outperformed MatBERT (a domain-specific pre-trained BERT) across all tasks despite having roughly 3x fewer parameters. The table below shows the best result for each model across the three input preprocessing strategies (w/ Numbers, w/o Numbers, w/ [NUM]&amp;[ANG]):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Band gap (eV)</th>
          <th>Volume (A^3/cell)</th>
          <th>FEPA (eV/atom)</th>
          <th>EPA (eV/atom)</th>
          <th>Ehull (eV/atom)</th>
          <th>Is-gap-direct (AUC)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MatBERT (best)</td>
          <td>0.258</td>
          <td>54.969</td>
          <td>0.071</td>
          <td>0.098</td>
          <td>0.050</td>
          <td>0.722</td>
      </tr>
      <tr>
          <td>LLM-Prop (best)</td>
          <td><strong>0.231</strong></td>
          <td><strong>39.138</strong></td>
          <td><strong>0.056</strong></td>
          <td><strong>0.067</strong></td>
          <td><strong>0.047</strong></td>
          <td><strong>0.857</strong></td>
      </tr>
  </tbody>
</table>
<p>Note: LLM-Prop&rsquo;s best band gap (0.231) comes from the &ldquo;w/o Numbers&rdquo; configuration, while the best volume (39.138) comes from &ldquo;w/ Numbers&rdquo;. The best Is-gap-direct AUC (0.857) uses the &ldquo;[NUM]&amp;[ANG]&rdquo; configuration.</p>
<h3 id="ablation-studies">Ablation Studies</h3>
<p>The contribution of each preprocessing strategy was evaluated:</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>Band gap</th>
          <th>Volume</th>
          <th>Is-gap-direct (AUC)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>LLM-Prop (baseline)</td>
          <td>0.256</td>
          <td>69.352</td>
          <td>0.796</td>
      </tr>
      <tr>
          <td>+ modified tokenizer</td>
          <td>0.247</td>
          <td>78.632</td>
          <td>0.785</td>
      </tr>
      <tr>
          <td>+ label scaling</td>
          <td>0.242</td>
          <td>44.515</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>+ [CLS] token</td>
          <td>0.231</td>
          <td>39.520</td>
          <td>0.842</td>
      </tr>
      <tr>
          <td>+ [NUM] token</td>
          <td>0.251</td>
          <td>86.090</td>
          <td>0.793</td>
      </tr>
      <tr>
          <td>+ [ANG] token</td>
          <td>0.242</td>
          <td>64.965</td>
          <td>0.810</td>
      </tr>
      <tr>
          <td>- stopwords</td>
          <td>0.252</td>
          <td>56.593</td>
          <td>0.779</td>
      </tr>
      <tr>
          <td>LLM-Prop+all (no space group)</td>
          <td>0.235</td>
          <td>97.457</td>
          <td>0.705</td>
      </tr>
      <tr>
          <td>LLM-Prop+all</td>
          <td><strong>0.229</strong></td>
          <td>42.259</td>
          <td><strong>0.857</strong></td>
      </tr>
  </tbody>
</table>
<p>The [CLS] token provided the single largest improvement across all tasks. Label scaling was critical for volume prediction (reducing MAE from 69.352 to 44.515). Removing space group information from descriptions degraded volume prediction dramatically (from 42.259 to 97.457), confirming that space group symmetry is a key factor.</p>
<h3 id="data-efficiency-and-transfer-learning">Data Efficiency and Transfer Learning</h3>
<p>LLM-Prop achieved SOTA results on band gap and volume prediction with only about 90k training samples (35k fewer than baselines). For volume prediction specifically, LLM-Prop outperformed all GNN baselines with just 30k training samples.</p>
<p>Transfer learning experiments showed that LLM-Prop transferred well between band gap and volume prediction tasks:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Volume-to-Band gap (Test)</th>
          <th>Band gap-to-Volume (Test)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CGCNN-transfer</td>
          <td>0.295</td>
          <td>182.997</td>
      </tr>
      <tr>
          <td>ALIGNN-transfer</td>
          <td>0.322</td>
          <td>136.164</td>
      </tr>
      <tr>
          <td>MatBERT-transfer</td>
          <td>0.266</td>
          <td>54.289</td>
      </tr>
      <tr>
          <td>LLM-Prop-transfer</td>
          <td><strong>0.244</strong></td>
          <td><strong>50.753</strong></td>
      </tr>
  </tbody>
</table>
<h2 id="key-findings-limitations-and-future-directions">Key Findings, Limitations, and Future Directions</h2>
<p><strong>Key findings</strong>:</p>
<ul>
<li>Text descriptions of crystals carry rich structural information (space groups, Wyckoff sites, coordination geometries) that is difficult to encode in graphs but naturally expressed in text</li>
<li>A carefully fine-tuned general-purpose LLM encoder can outperform domain-specific pre-trained models, challenging the assumption that in-domain pre-training is always necessary</li>
<li>Removing numerical information (bond distances and angles) from descriptions often improves performance, because current LLMs treat numbers as regular tokens without understanding their quantitative meaning</li>
<li>Longer input sequences correlate with better performance, with 888 tokens as the default maximum on the hardware used</li>
</ul>
<p><strong>Limitations acknowledged by the authors</strong>:</p>
<ul>
<li>The origin of LLM-Prop&rsquo;s performance advantage over GNNs is not fully understood. It remains unclear whether the boost comes from additional structured information in text or from the different data modality itself</li>
<li>LLM-Prop cannot perform zero-shot predictions since T5 was not pre-trained on materials science data</li>
<li>The approach depends on Robocrystallographer to generate text descriptions, adding a preprocessing dependency</li>
<li>Current LLMs&rsquo; inability to reason about numerical values limits the use of quantitative information in descriptions</li>
</ul>
<p><strong>Future directions</strong> suggested by the authors include investigating techniques to use <a href="/notes/chemistry/molecular-design/generation/autoregressive/3d-chemical-language-models-xyz-cif-pdb/">CIF files</a> directly as LLM inputs, developing new GNN architectures that incorporate space group and Wyckoff site information, and further exploring which information in crystal descriptions contributes most to each property prediction task.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Eval</td>
          <td>TextEdge</td>
          <td>144,931 crystals</td>
          <td>From Materials Project (Nov 2022), text generated by Robocrystallographer</td>
      </tr>
      <tr>
          <td>Training split</td>
          <td>TextEdge</td>
          <td>125,098</td>
          <td>Random split</td>
      </tr>
      <tr>
          <td>Validation split</td>
          <td>TextEdge</td>
          <td>9,945</td>
          <td>Random split</td>
      </tr>
      <tr>
          <td>Test split</td>
          <td>TextEdge</td>
          <td>9,888</td>
          <td>Random split</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Optimizer</strong>: Adam with one-cycle learning rate scheduler</li>
<li><strong>Learning rate</strong>: 1e-3 for LLM-Prop, 5e-5 for MatBERT</li>
<li><strong>Dropout</strong>: 0.2 for LLM-Prop, 0.5 for MatBERT</li>
<li><strong>Batch size</strong>: 64 (888 tokens) or 16 (2000 tokens) for LLM-Prop</li>
<li><strong>Epochs</strong>: 200-300 depending on task</li>
<li><strong>Loss</strong>: MAE for regression, BCE for classification</li>
<li><strong>Evaluation</strong>: MAE for regression, AUC for classification</li>
<li><strong>Each model run 5 times on test set</strong>, averaged MAE reported</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Base model</strong>: T5-small encoder (~60M parameters total, ~37M after discarding decoder and adding prediction head)</li>
<li><strong>Vocabulary size</strong>: 32k (retrained tokenizer)</li>
<li><strong>Max input tokens</strong>: 888 (default) or 2000</li>
<li><strong>Special tokens</strong>: [CLS], [NUM], [ANG]</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/vertaix/LLM-Prop">LLM-Prop</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation</td>
      </tr>
      <tr>
          <td><a href="https://drive.google.com/drive/folders/1YCDBzwjwNRIc1FRkB662G3Y5AOWaokUG">TextEdge + Checkpoints</a></td>
          <td>Dataset + Model</td>
          <td>Not specified</td>
          <td>Benchmark dataset and trained model checkpoints</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPUs</strong>: NVIDIA RTX A6000</li>
<li><strong>Training time</strong>: ~40 minutes per epoch for LLM-Prop</li>
<li><strong>Inference</strong>: ~1 minute for 10,000 materials on one GPU</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rubungo, A. N., Arnold, C. B., Rand, B. P., &amp; Dieng, A. B. (2025). LLM-Prop: predicting the properties of crystalline materials using large language models. <em>npj Computational Materials</em>, 11, 186. <a href="https://doi.org/10.1038/s41524-025-01536-2">https://doi.org/10.1038/s41524-025-01536-2</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rubungo2025llmprop,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{LLM-Prop: predicting the properties of crystalline materials using large language models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Rubungo, Andre Niyongabo and Arnold, Craig B. and Rand, Barry P. and Dieng, Adji Bousso}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{npj Computational Materials}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{11}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{186}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41524-025-01536-2}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Inverse Molecular Design with ML Generative Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/inverse-molecular-design-ml-review/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/inverse-molecular-design-ml-review/</guid><description>Review of inverse molecular design approaches including VAEs, GANs, and RL for navigating chemical space and generating novel molecules with desired properties.</description><content:encoded><![CDATA[<h2 id="a-foundational-systematization-of-inverse-molecular-design">A Foundational Systematization of Inverse Molecular Design</h2>
<p>This paper is a <strong>Systematization</strong> of the nascent field of inverse molecular design using machine learning generative models. Published in <em>Science</em> in 2018, it organizes and contextualizes the rapidly emerging body of work on using deep generative models (variational autoencoders, generative adversarial networks, and reinforcement learning) to navigate chemical space and propose novel molecules with targeted properties. Rather than introducing a new method, the paper synthesizes the conceptual framework connecting molecular representations, generative architectures, and inverse design objectives, establishing a reference point for the field at a critical early stage.</p>
<h2 id="the-challenge-of-navigating-chemical-space">The Challenge of Navigating Chemical Space</h2>
<p>The core problem is the sheer scale of chemical space. For pharmacologically relevant small molecules alone, the number of possible structures is estimated at $10^{60}$. Traditional approaches to materials discovery rely on trial and error or high-throughput virtual screening (HTVS), both of which are fundamentally limited by the need to enumerate and evaluate candidates from a predefined library.</p>
<p>The conventional materials discovery pipeline, from concept to commercial product, historically takes 15 to 20 years, involving iterative cycles of simulation, synthesis, device integration, and characterization. Inverse design offers a conceptual alternative: start from a desired functionality and search for molecular structures that satisfy it. This inverts the standard paradigm where a molecule is proposed first and its properties are computed or measured afterward.</p>
<p>The key distinction the authors draw is between discriminative and generative models. A discriminative model learns $p(y|x)$, the conditional probability of properties $y$ given a molecule $x$. A <a href="/notes/machine-learning/generative-models/">generative model</a> instead learns the joint distribution $p(x,y)$, which can be conditioned to yield either the direct design problem $p(y|x)$ or the inverse design problem $p(x|y)$.</p>
<h2 id="three-pillars-vaes-gans-and-reinforcement-learning">Three Pillars: VAEs, GANs, and Reinforcement Learning</h2>
<p>The review organizes inverse molecular design approaches around three generative paradigms and the molecular representations they operate on.</p>
<h3 id="molecular-representations">Molecular Representations</h3>
<p>The paper surveys representations across three broad categories:</p>
<ul>
<li><strong>Discrete (text-based)</strong>: <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings encode molecular structure as 1D text following a grammar syntax. Their adoption has been driven by the availability of NLP deep learning tools.</li>
<li><strong>Continuous (vectors/tensors)</strong>: <a href="/posts/molecular-descriptor-coulomb-matrix/">Coulomb matrices</a>, bag of bonds, fingerprints, symmetry functions, and electronic density representations. These expose different physical symmetries (permutational, rotational, reflectional, translational invariance).</li>
<li><strong>Weighted graphs</strong>: Molecules as undirected graphs where atoms are nodes and bonds are edges, with vectorized features on edges and nodes (bonding type, aromaticity, charge, distance).</li>
</ul>
<p>An ideal representation for inverse design should be invertible, meaning it supports mapping back to a synthesizable molecular structure. SMILES strings and molecular graphs are invertible, while many continuous representations require lookup tables or auxiliary methods.</p>
<h3 id="variational-autoencoders-vaes">Variational Autoencoders (VAEs)</h3>
<p><a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">VAEs</a> encode molecules into a continuous latent space and decode latent vectors back to molecular representations. The key insight is that by constraining the encoder to produce latent vectors following a Gaussian distribution, the model gains the ability to <a href="/posts/modern-variational-autoencoder-in-pytorch/">interpolate between molecules and sample novel structures</a>. The latent space encodes a geometry: nearby points decode to similar molecules, and gradient-based optimization over this continuous space enables direct property optimization.</p>
<p>The VAE loss function combines a reconstruction term with a KL divergence regularizer:</p>
<p>$$\mathcal{L} = \mathbb{E}_{q(z|x)}[\log p(x|z)] - D_{KL}(q(z|x) | p(z))$$</p>
<p>where $q(z|x)$ is the encoder (approximate posterior), $p(x|z)$ is the decoder, and $p(z)$ is the prior (typically Gaussian).</p>
<p>Semi-supervised variants jointly train on molecules and properties, reorganizing latent space so molecules with similar properties cluster together. <a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">Gomez-Bombarelli et al.</a> demonstrated local and global optimization across generated distributions using Bayesian optimization over latent space.</p>
<p>The review traces the evolution from character-level SMILES VAEs to <a href="/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/">grammar-aware and syntax-directed variants</a> that improve the generation of syntactically valid structures.</p>
<h3 id="generative-adversarial-networks-gans">Generative Adversarial Networks (GANs)</h3>
<p><a href="/posts/what-is-a-gan/">GANs</a> pit a generator against a discriminator in an adversarial training framework. The generator learns to produce synthetic molecules from noise, while the discriminator learns to distinguish synthetic from real molecules. Training convergence for GANs is challenging, suffering from mode collapse and generator-discriminator imbalance.</p>
<p>For molecular applications, dealing with discrete SMILES data introduces nondifferentiability, addressed through workarounds like SeqGAN&rsquo;s policy gradient approach and boundary-seeking GANs.</p>
<h3 id="reinforcement-learning-rl">Reinforcement Learning (RL)</h3>
<p>RL treats molecule generation as a sequential decision process where an agent (the generator) takes actions (adding characters to a SMILES string) to maximize a reward (desired molecular properties). Since rewards can only be assigned after sequence completion, Monte Carlo Tree Search (MCTS) is used to simulate possible completions and weight paths based on their success.</p>
<p>Applications include generation of drug-like molecules and <a href="https://en.wikipedia.org/wiki/Retrosynthesis">retrosynthesis</a> planning. Notable examples cited include RL for optimizing putative <a href="https://en.wikipedia.org/wiki/Janus_kinase_2">JAK2</a> inhibitors and molecules active against <a href="https://en.wikipedia.org/wiki/Dopamine_receptor_D2">dopamine receptor type 2</a>.</p>
<h3 id="hybrid-approaches">Hybrid Approaches</h3>
<p>The review highlights that these paradigms are not exclusive. Examples include druGAN (adversarial autoencoder) and <a href="/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/">ORGANIC</a> (combined GAN and RL), which leverage strengths of multiple frameworks.</p>
<h2 id="survey-of-applications-and-design-paradigms">Survey of Applications and Design Paradigms</h2>
<p>Being a review paper, this work does not present new experiments but surveys existing applications across domains:</p>
<p><strong>Drug Discovery</strong>: Most generative model applications at the time of writing targeted pharmaceutical properties, including solubility, melting temperature, synthesizability, and target activity. Popova et al. optimized for JAK2 inhibitors, and Olivecrona et al. targeted dopamine receptor type 2.</p>
<p><strong>Materials Science</strong>: HTVS had been applied to organic photovoltaics (screening by frontier orbital energies and conversion efficiency), organic redox flow batteries (redox potential and solubility), organic LEDs (singlet-triplet gap), and inorganic materials via the Materials Project.</p>
<p><strong>Chemical Space Exploration</strong>: Evolution strategies had been applied to map chemical space, with structured search procedures incorporating genotype representations and mutation operations. Bayesian sampling with sequential Monte Carlo and gradient-based optimization of properties with respect to molecular systems represented alternative inverse design strategies.</p>
<p><strong>Graph-Based Generation</strong>: The paper notes the emerging extension of VAEs to molecular graphs (junction tree VAE) and message passing networks for incremental graph construction, though the graph isomorphism approximation problem remained a practical challenge.</p>
<h2 id="future-directions-and-open-challenges">Future Directions and Open Challenges</h2>
<p>The authors identify several open directions for the field:</p>
<p><strong>Closed-Loop Discovery</strong>: The ultimate goal is to concurrently propose, create, and characterize new materials with simultaneous data flow between components. At the time of writing, very few examples of successful closed-loop approaches existed.</p>
<p><strong>Active Learning</strong>: Combining inverse design with Bayesian optimization enables models that adapt as they explore chemical space, expanding in regions of high uncertainty and discovering molecular regions with desirable properties as a function of composition.</p>
<p><strong>Representation Learning</strong>: No single molecular representation works optimally for all properties. Graph and hierarchical representations were identified as areas needing further study. Representations that encode relevant physics tend to generalize better.</p>
<p><strong>Improved Architectures</strong>: Memory-augmented sequence generation models, Riemannian optimization methods exploiting latent space geometry, multi-level VAEs for structured latent spaces, and inverse RL for learning reward functions were highlighted as promising research directions.</p>
<p><strong>Integration into Education</strong>: The authors advocate for integrating ML into curricula across chemical, biochemical, medicinal, and materials sciences.</p>
<h3 id="limitations">Limitations</h3>
<p>As a review paper from 2018, this work captures the field at an early stage. Several limitations are worth noting:</p>
<ul>
<li>The survey is dominated by SMILES-based approaches, reflecting the state of the field at the time. Graph-based and 3D-aware generative models were just emerging.</li>
<li>Quantitative benchmarking of generative models was not yet standardized. The review does not provide systematic comparisons across methods.</li>
<li>The synthesis feasibility of generated molecules receives limited attention. The gap between computationally generated candidates and experimentally realizable molecules was (and remains) a significant challenge.</li>
<li>Transformer-based architectures, which would come to dominate chemical language modeling, are not discussed, as the Transformer had only been published the year prior.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>As a review/perspective paper, this work does not introduce new models, datasets, or experiments. The reproducibility assessment applies to the cited primary works rather than the review itself.</p>
<h3 id="key-cited-methods-and-their-resources">Key Cited Methods and Their Resources</h3>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Authors</th>
          <th>Type</th>
          <th>Availability</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">Automatic Chemical Design (VAE)</a></td>
          <td>Gomez-Bombarelli et al.</td>
          <td>Code + Data</td>
          <td>Published in ACS Central Science</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/">Grammar VAE</a></td>
          <td>Kusner et al.</td>
          <td>Code</td>
          <td>arXiv:1703.01925</td>
      </tr>
      <tr>
          <td>Junction Tree VAE</td>
          <td>Jin et al.</td>
          <td>Code</td>
          <td>arXiv:1802.04364</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/">ORGANIC</a></td>
          <td>Sanchez-Lengeling et al.</td>
          <td>Code</td>
          <td>ChemRxiv preprint</td>
      </tr>
      <tr>
          <td>SeqGAN</td>
          <td>Yu et al.</td>
          <td>Code</td>
          <td>AAAI 2017</td>
      </tr>
      <tr>
          <td>Neural Message Passing</td>
          <td>Gilmer et al.</td>
          <td>Code</td>
          <td>arXiv:1704.01212</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Sánchez-Lengeling, B., &amp; Aspuru-Guzik, A. (2018). Inverse molecular design using machine learning: Generative models for matter engineering. <em>Science</em>, 361(6400), 360-365. <a href="https://doi.org/10.1126/science.aat2663">https://doi.org/10.1126/science.aat2663</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{sanchez-lengeling2018inverse,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Inverse molecular design using machine learning: Generative models for matter engineering}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{S{\&#39;a}nchez-Lengeling, Benjamin and Aspuru-Guzik, Al{\&#39;a}n}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{361}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{6400}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{360--365}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Association for the Advancement of Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1126/science.aat2663}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Group SELFIES: Fragment-Based Molecular Strings</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/group-selfies-fragment-molecular-representation/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/group-selfies-fragment-molecular-representation/</guid><description>Group SELFIES extends SELFIES with fragment-based group tokens for chemically robust molecular string representations that improve distribution learning.</description><content:encoded><![CDATA[<h2 id="a-fragment-aware-extension-of-selfies">A Fragment-Aware Extension of SELFIES</h2>
<p>This is a <strong>Method</strong> paper that introduces Group SELFIES, a molecular string representation extending SELFIES by incorporating group tokens that represent functional groups or entire substructures. The primary contribution is a representation that maintains the 100% chemical validity guarantee of SELFIES while enabling fragment-level molecular encoding. Group SELFIES is shorter, more human-readable, and produces better distribution learning compared to both SMILES and standard SELFIES.</p>
<h2 id="from-atoms-to-fragments-in-molecular-strings">From Atoms to Fragments in Molecular Strings</h2>
<p>Molecular string representations underpin nearly all string-based molecular generation, from chemical language models and VAEs to genetic algorithms. SMILES, the dominant representation, suffers from validity issues: generated strings frequently contain syntax errors or violate valency constraints. SELFIES solved this by guaranteeing that every string decodes to a valid molecule, but both SMILES and SELFIES operate at the atomic level. Human chemists, by contrast, think about molecules in terms of functional groups and substructures.</p>
<p>Fragment-based generative models exploit this inductive bias by constructing custom representations amenable to fragment-based molecular design. However, these approaches are typically graph-based, losing the desirable properties of string representations: easy manipulation and direct input into established language models. Historical string representations like Wiswesser Line Notation (WLN), Hayward Notation, and SYBYL Line Notation (SLN) did use non-atomic tokens, but none provided chemical robustness guarantees.</p>
<p>The gap is clear: no existing string representation combines the chemical robustness of SELFIES with the fragment-level abstraction that captures meaningful chemical motifs.</p>
<h2 id="group-tokens-with-chemical-robustness-guarantees">Group Tokens with Chemical Robustness Guarantees</h2>
<p>The core innovation is the introduction of <strong>group tokens</strong> into the SELFIES framework. Each group token represents a predefined molecular fragment (such as a benzene ring, carboxyl group, or any user-specified substructure) and is treated as a single unit during encoding and decoding.</p>
<h3 id="group-definition">Group Definition</h3>
<p>Each group is defined as a set of atoms and bonds with labeled <strong>attachment points</strong> that specify how the group participates in bonding. Each attachment point has a specified maximum valency, allowing the decoder to continue tracking available valency during string construction. Group tokens take the form <code>[:S&lt;group-name&gt;]</code>, where <code>S</code> is the starting attachment index.</p>
<h3 id="encoding">Encoding</h3>
<p>To encode a molecule, the encoder first recognizes and replaces substructure matches from the group set. By default, the encoder processes larger groups first, but users can override this with priority values. The encoder then traverses the molecular graph similarly to standard SELFIES encoding, inserting tokens that track attachment indices for entering and exiting groups.</p>
<h3 id="decoding">Decoding</h3>
<p>When the decoder encounters a group token, it looks up the corresponding group in the group set dictionary, places all atoms of the group, and connects the main chain to the starting attachment point. Navigation between attachment points is handled by reading subsequent tokens as relative indices. If an attachment point is occupied, the next available one is used. If all attachment points are exhausted, the group is immediately popped from the stack.</p>
<h3 id="chemical-robustness">Chemical Robustness</h3>
<p>The key property preserved from SELFIES is that <strong>any arbitrary Group SELFIES string decodes to a molecule with valid valency</strong>. This is achieved by maintaining the same two SELFIES decoder features within the group framework:</p>
<ol>
<li>Token overloading: every token can be interpreted as a number when needed (for branch lengths, ring targets, or attachment indices).</li>
<li>Valency tracking: if adding a bond would exceed available valency, the decoder adjusts the bond order or skips the bond.</li>
</ol>
<p>The authors verified robustness by encoding and decoding 25 million molecules from the eMolecules database.</p>
<h3 id="chirality-handling">Chirality Handling</h3>
<p>Group SELFIES handles chirality differently from SMILES and SELFIES. Rather than using <code>@</code>-notation for tetrahedral chirality, all chiral centers must be specified as groups. An &ldquo;essential set&rdquo; of 23 groups covers all relevant chiral centers in the eMolecules database. This approach also supports extended chirality (axial, helical, planar) by abstracting the entire chiral substructure into a group token.</p>
<h3 id="fragment-selection">Fragment Selection</h3>
<p>The group set is a user-defined dictionary that maps group names to molecular fragments. Users can specify groups manually using SMILES-like syntax, extract them from fragment libraries, or use fragmentation algorithms such as matched molecular pair analysis. The authors tested several approaches, including a naive method that cleaves side chains from rings and methods based on cheminformatics fragmentation tools. A useful group set typically contains fragments that appear in many molecules and replace many atoms, with similar fragments merged to reduce redundancy.</p>
<h2 id="experiments-on-compactness-generation-and-distribution-learning">Experiments on Compactness, Generation, and Distribution Learning</h2>
<h3 id="compactness-section-41">Compactness (Section 4.1)</h3>
<p>Using 53 groups (30 extracted from ZINC-250k plus 23 from the essential set), Group SELFIES strings are shorter than their SMILES and SELFIES equivalents. Despite Group SELFIES having a larger alphabet, the compressed file size of the ZINC-250k dataset is smallest for Group SELFIES, indicating lower information-theoretic complexity.</p>
<h3 id="random-molecular-generation-section-42">Random Molecular Generation (Section 4.2)</h3>
<p>To isolate the effect of the representation from the generative model, the authors use a primitive generative model: sample a random string length from the dataset, draw tokens uniformly from a bag of all tokens, and concatenate. From 100,000 ZINC-250k molecules:</p>
<ul>
<li>Randomly sampled Group SELFIES strings produce molecules whose SAScore and QED distributions more closely overlap with the original ZINC dataset than molecules from randomly sampled SELFIES strings.</li>
<li>The Wasserstein distances to the ZINC distribution are consistently lower for Group SELFIES.</li>
<li>On a nonfullerene acceptor (NFA) dataset, Group SELFIES preserves aromatic rings while SELFIES rarely does.</li>
</ul>
<h3 id="distribution-learning-with-vaes-section-43">Distribution Learning with VAEs (Section 4.3)</h3>
<p>Using the MOSES benchmarking framework, VAEs were trained for 125 epochs on both Group SELFIES and SELFIES representations. The Group SELFIES VAE used 300 groups extracted from the MOSES training set. Results from 100,000 generated molecules:</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Group-VAE-125</th>
          <th>SELFIES-VAE-125</th>
          <th>Train (Reference)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Valid</td>
          <td>1.0 (0)</td>
          <td>1.0 (0)</td>
          <td>1.0</td>
      </tr>
      <tr>
          <td>Unique@1k</td>
          <td>1.0 (0)</td>
          <td>0.9996 (5)</td>
          <td>1.0</td>
      </tr>
      <tr>
          <td>Unique@10k</td>
          <td>0.9985 (4)</td>
          <td>0.9986 (4)</td>
          <td>1.0</td>
      </tr>
      <tr>
          <td>FCD (Test)</td>
          <td>0.1787 (29)</td>
          <td>0.6351 (43)</td>
          <td>0.008</td>
      </tr>
      <tr>
          <td>FCD (TestSF)</td>
          <td>0.734 (109)</td>
          <td>1.3136 (128)</td>
          <td>0.4755</td>
      </tr>
      <tr>
          <td>SNN (Test)</td>
          <td>0.6051 (4)</td>
          <td>0.6014 (3)</td>
          <td>0.6419</td>
      </tr>
      <tr>
          <td>Frag (Test)</td>
          <td>0.9995 (0)</td>
          <td>0.9989 (0)</td>
          <td>1.0</td>
      </tr>
      <tr>
          <td>Scaf (Test)</td>
          <td>0.9649 (21)</td>
          <td>0.9588 (15)</td>
          <td>0.9907</td>
      </tr>
      <tr>
          <td>IntDiv</td>
          <td>0.8587 (1)</td>
          <td>0.8579 (1)</td>
          <td>0.8567</td>
      </tr>
      <tr>
          <td>Novelty</td>
          <td>0.9623 (7)</td>
          <td>0.96 (4)</td>
          <td>1.0</td>
      </tr>
  </tbody>
</table>
<p>The most notable improvement is in Frechet ChemNet Distance (FCD), where Group SELFIES achieves 0.1787 versus 0.6351 for SELFIES on the test set. FCD measures the difference between penultimate-layer activations of ChemNet, encoding a mixture of biological and chemical properties relevant to drug-likeness. Most other metrics are comparable, with Group SELFIES matching or slightly outperforming SELFIES across the board.</p>
<h2 id="advantages-limitations-and-future-directions">Advantages, Limitations, and Future Directions</h2>
<h3 id="key-findings">Key Findings</h3>
<p>Group SELFIES provides three main advantages over standard SELFIES:</p>
<ol>
<li><strong>Substructure control</strong>: Important scaffolds, chiral centers, and charged groups can be preserved during molecular optimization.</li>
<li><strong>Compactness</strong>: Group tokens represent multiple atoms, yielding shorter strings with lower information-theoretic complexity.</li>
<li><strong>Improved distribution learning</strong>: The FCD metric shows substantial improvement, indicating generated molecules better capture biological and chemical properties of the training set.</li>
</ol>
<p>Both SELFIES and Group SELFIES achieve 100% validity, eliminating the validity issues associated with SMILES-based generation.</p>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge several limitations:</p>
<ul>
<li><strong>Computational speed</strong>: Encoding and decoding is slower than SELFIES due to RDKit overhead, particularly for the encoder which performs substructure matching for every group in the set.</li>
<li><strong>No group overlap</strong>: Groups cannot overlap in the current formulation, which limits expressiveness for polycyclic compounds.</li>
<li><strong>Group set design</strong>: Choosing an effective group set remains an open design choice that may require domain expertise or fragmentation algorithm tuning.</li>
<li><strong>Limited generative model evaluation</strong>: The paper focuses on random sampling and VAEs; evaluation with more sophisticated models (GANs, reinforcement learning, genetic algorithms) is left to future work.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The authors propose several extensions: flexible scaffold tokens that preserve topology while allowing atom-type variation, representations based on cellular complexes or hypergraphs to handle overlapping groups, and integration with genetic algorithms like JANUS for molecular optimization.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Compactness / Generation</td>
          <td>ZINC-250k</td>
          <td>250,000 molecules</td>
          <td>Random subset of 10,000 for fragment extraction; 100,000 for generation</td>
      </tr>
      <tr>
          <td>Distribution Learning</td>
          <td>MOSES benchmark</td>
          <td>~1.9M molecules</td>
          <td>Standard train/test split from MOSES framework</td>
      </tr>
      <tr>
          <td>Robustness Verification</td>
          <td>eMolecules</td>
          <td>25M molecules</td>
          <td>Full database encode-decode round trip</td>
      </tr>
      <tr>
          <td>NFA Generation</td>
          <td>NFA dataset</td>
          <td>Not specified</td>
          <td>Nonfullerene acceptors from Lopez et al. (2017)</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Fragmentation</strong>: Naive ring-sidechain cleavage, matched molecular pair analysis, and diversity-based selection of 300 groups for VAE experiments.</li>
<li><strong>Essential set</strong>: 23 chiral groups covering all relevant chiral centers in eMolecules.</li>
<li><strong>Random generation</strong>: Bag-of-tokens sampling with length matched to dataset distribution.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>VAE</strong>: Trained for 125 epochs on MOSES dataset using both SELFIES and Group SELFIES tokenizations.</li>
<li>Architecture details follow the MOSES benchmark VAE configuration.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FCD</td>
          <td>Frechet ChemNet Distance (penultimate layer activations)</td>
      </tr>
      <tr>
          <td>SNN</td>
          <td>Average Tanimoto similarity to nearest neighbor in reference set</td>
      </tr>
      <tr>
          <td>Frag</td>
          <td>Cosine similarity of BRICS fragment distributions</td>
      </tr>
      <tr>
          <td>Scaf</td>
          <td>Cosine similarity of Bemis-Murcko scaffold distributions</td>
      </tr>
      <tr>
          <td>IntDiv</td>
          <td>Internal diversity via Tanimoto similarity</td>
      </tr>
      <tr>
          <td>Validity</td>
          <td>Percentage passing RDKit parsing</td>
      </tr>
      <tr>
          <td>Uniqueness</td>
          <td>Percentage of non-duplicate generated molecules</td>
      </tr>
      <tr>
          <td>Novelty</td>
          <td>Fraction of generated molecules not in training set</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Robustness verification performed on the Niagara supercomputer (SciNet HPC Consortium).</li>
<li>VAE training hardware not specified.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/aspuru-guzik-group/group-selfies">group-selfies</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Open-source Python implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Cheng, A. H., Cai, A., Miret, S., Malkomes, G., Phielipp, M., &amp; Aspuru-Guzik, A. (2023). Group SELFIES: A robust fragment-based molecular string representation. <em>Digital Discovery</em>, 2(3), 748-758. <a href="https://doi.org/10.1039/D3DD00012E">https://doi.org/10.1039/D3DD00012E</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{cheng2023group,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Group SELFIES: A Robust Fragment-Based Molecular String Representation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Cheng, Austin H. and Cai, Andy and Miret, Santiago and Malkomes, Gustavo and Phielipp, Mariano and Aspuru-Guzik, Al{\&#39;a}n}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Digital Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{3}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{748--758}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/D3DD00012E}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Foundation Models in Chemistry: A 2025 Perspective</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/foundation-models-chemistry-perspective/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/foundation-models-chemistry-perspective/</guid><description>Perspective reviewing foundation models for chemistry across property prediction, MLIPs, inverse design, and multi-domain applications.</description><content:encoded><![CDATA[<h2 id="a-systematization-of-foundation-models-for-chemistry">A Systematization of Foundation Models for Chemistry</h2>
<p>This is a <strong>Systematization</strong> paper. It organizes the rapidly growing landscape of foundation models in chemistry into a coherent taxonomy. The paper distinguishes between &ldquo;small&rdquo; foundation models (pretrained for a single application domain) and &ldquo;big&rdquo; foundation models (adaptable across multiple domains such as property prediction and inverse design). It covers models based on graph neural networks (GNNs) and language models, reviews pretraining strategies (self-supervised, multimodal, supervised), and maps approximately 40 models across four application domains.</p>
<h2 id="why-a-foundation-model-perspective-for-chemistry">Why a Foundation Model Perspective for Chemistry?</h2>
<p>Foundation models have transformed NLP and computer vision through large-scale pretraining and transfer learning. In chemistry, however, several persistent challenges motivate the adoption of this paradigm:</p>
<ol>
<li><strong>Data scarcity</strong>: Chemical datasets are often small and expensive to generate (requiring experiments or quantum mechanical calculations), unlike the large annotated datasets available in NLP/CV.</li>
<li><strong>Poor generalization</strong>: ML models in chemistry frequently need to extrapolate to out-of-domain compounds (e.g., novel drug candidates, unseen crystal structures), where conventional models struggle.</li>
<li><strong>Limited transferability</strong>: Traditional ML interatomic potentials (MLIPs) are trained on system-specific datasets and cannot be easily transferred across different chemical systems.</li>
</ol>
<p>Foundation models address these by learning general representations from large unlabeled datasets, which can then be adapted to specific downstream tasks via finetuning. The paper argues that summarizing this fast-moving field is timely, given the diversity of approaches emerging across molecular property prediction, MLIPs, inverse design, and multi-domain applications.</p>
<h2 id="small-vs-big-foundation-models-a-two-tier-taxonomy">Small vs. Big Foundation Models: A Two-Tier Taxonomy</h2>
<p>The paper&rsquo;s central organizing framework distinguishes two scopes of foundation model:</p>
<p><strong>Small foundation models</strong> are pretrained models adapted to various tasks within a single application domain. Examples include:</p>
<ul>
<li>A model pretrained on large molecular databases that predicts multiple molecular properties (band gap, formation energy, etc.)</li>
<li>A universal MLIP that can simulate diverse chemical systems</li>
<li>A pretrained generative model adapted for inverse design of different target properties</li>
</ul>
<p><strong>Big foundation models</strong> span multiple application domains, handling both property prediction and inverse design within a single framework. These typically use multimodal learning (combining SMILES/graphs with text) or build on large language models.</p>
<h3 id="architectures">Architectures</h3>
<p>The paper reviews two primary architecture families:</p>
<p><strong>Graph Neural Networks (GNNs)</strong> represent molecules and crystals as graphs $G = (V, E)$ with nodes (atoms) and edges (bonds). Node features are updated through message passing:</p>
<p>$$
m_{i}^{t+1} = \sum_{j \in N(i)} M_{t}(v_{i}^{t}, v_{j}^{t}, e_{ij}^{t})
$$</p>
<p>$$
v_{i}^{t+1} = U_{t}(v_{i}^{t}, m_{i}^{t+1})
$$</p>
<p>After $T$ message-passing steps, a readout function produces a graph-level feature:</p>
<p>$$
g = R({v_{i}^{T} \mid i \in G})
$$</p>
<p>Recent equivariant GNNs (e.g., NequIP, MACE, EquformerV2) use vectorial features that respect geometric symmetries, improving expressivity for tasks sensitive to 3D structure.</p>
<p><strong>Language Models</strong> operate on string representations of molecules (<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>, <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>) or crystal structures. Autoregressive models like GPT maximize:</p>
<p>$$
\prod_{t=1}^{T} P(y_{t} \mid x_{1}, x_{2}, \ldots, x_{t-1})
$$</p>
<p>Transformers use self-attention:</p>
<p>$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^{T}}{\sqrt{d_{k}}}\right)V
$$</p>
<h3 id="pretraining-strategies">Pretraining Strategies</h3>
<p>The paper categorizes pretraining methods into three self-supervised learning (SSL) approaches plus supervised and multimodal strategies:</p>
<table>
  <thead>
      <tr>
          <th>Strategy</th>
          <th>Mechanism</th>
          <th>Example Models</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Contrastive learning</td>
          <td>Maximize similarity between positive pairs, minimize for negatives</td>
          <td>GraphCL, MolCLR, GraphMVP, CrysGNN</td>
      </tr>
      <tr>
          <td>Predictive learning</td>
          <td>Predict self-generated labels (node context, functional groups, space group)</td>
          <td>GROVER, Hu et al., CrysGNN</td>
      </tr>
      <tr>
          <td>Generative learning</td>
          <td>Reconstruct masked nodes/edges or entire molecules/SMILES</td>
          <td><a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a>, <a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a>, <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MoLFormer</a></td>
      </tr>
      <tr>
          <td>Supervised pretraining</td>
          <td>Train on energy, forces, stress from DFT databases</td>
          <td>M3GNet, CHGNet, MACE-MP-0, MatterSim</td>
      </tr>
      <tr>
          <td>Multimodal learning</td>
          <td>Learn joint representations across SMILES/graph + text modalities</td>
          <td>KV-PLM, <a href="/notes/chemistry/molecular-representations/multimodal/momu-molecular-multimodal-foundation/">MoMu</a>, MoleculeSTM, <a href="/notes/chemistry/molecular-representations/multimodal/spmm-bidirectional-structure-property/">SPMM</a></td>
      </tr>
  </tbody>
</table>
<p>A common finding across studies is that combining local and global information (e.g., via contrastive learning between node-level and graph-level views, or supervised learning on both forces and total energy) produces more transferable representations.</p>
<h2 id="survey-of-models-across-four-domains">Survey of Models Across Four Domains</h2>
<h3 id="property-prediction">Property Prediction</h3>
<p>The paper reviews 13 models for molecular and materials property prediction. Key findings:</p>
<ul>
<li><strong>Contrastive learning approaches</strong> (GraphCL, MolCLR, GraphMVP) achieve strong results by defining positive pairs through augmentation, 2D/3D structure views, or crystal system membership.</li>
<li><strong>Language model approaches</strong> (<a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a>, <a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a>, <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MoLFormer</a>) show that transformers trained on SMILES via masked language modeling can compete with GNN-based approaches.</li>
<li><a href="/notes/chemistry/molecular-representations/encoders/molformer/">MoLFormer</a>, pretrained on 1.1 billion SMILES from PubChem and ZINC, outperformed many baselines including GNNs on <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> and <a href="/notes/chemistry/datasets/qm9/">QM9</a> benchmarks. Its attention maps captured molecular structural features directly from SMILES strings.</li>
<li>For crystalline materials, CrysGNN combined contrastive, predictive, and generative learning, demonstrating improvements even on small experimental datasets.</li>
</ul>
<h3 id="machine-learning-interatomic-potentials-mlips">Machine Learning Interatomic Potentials (MLIPs)</h3>
<p>The paper surveys 10 universal MLIPs, all using supervised learning on DFT-calculated energies, forces, and stresses:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Architecture</th>
          <th>Training Data Size</th>
          <th>Key Capability</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>M3GNet</td>
          <td>GNN</td>
          <td>187K (MP)</td>
          <td>First universal MLIP</td>
      </tr>
      <tr>
          <td>CHGNet</td>
          <td>GNN</td>
          <td>1.58M (MPtrj)</td>
          <td>Predicts magnetic moments</td>
      </tr>
      <tr>
          <td>MACE-MP-0</td>
          <td>MACE</td>
          <td>1.58M (MPtrj)</td>
          <td>35 diverse applications</td>
      </tr>
      <tr>
          <td>GNoME potential</td>
          <td>NequIP</td>
          <td>89M</td>
          <td>Zero-shot comparable to trained MLIPs</td>
      </tr>
      <tr>
          <td>MatterSim</td>
          <td>M3GNet/Graphormer</td>
          <td>17M</td>
          <td>SOTA on Matbench Discovery</td>
      </tr>
      <tr>
          <td>eqV2</td>
          <td>EquformerV2</td>
          <td>118M (OMat24)</td>
          <td>Structural relaxation</td>
      </tr>
  </tbody>
</table>
<p>The GNoME potential, trained on approximately 89 million data points, achieved zero-shot performance comparable to state-of-the-art MLIPs trained from scratch. MatterSim, trained on over 17 million entries across wide temperature (0-5000K) and pressure (0-1000 GPa) ranges, achieved state-of-the-art on Matbench Discovery and accurately computed thermodynamic and lattice dynamic properties.</p>
<h3 id="inverse-design">Inverse Design</h3>
<p>Few pretrained generative models for inverse design exist. The paper highlights three:</p>
<ul>
<li><strong>MatterGen</strong> (Microsoft): Diffusion model pretrained on Alexandria/MP databases (607K structures), finetuned for conditional generation on band gap, elastic modulus, spacegroup, and composition. Generated S.U.N. (stable, unique, novel) materials at rates more than 2x the previous state of the art.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/autoregressive/gp-molformer/">GP-MoLFormer</a></strong> (IBM): MoLFormer pretrained on 1.1B SMILES, finetuned via pair-tuning for property-guided molecular optimization.</li>
<li><strong>CrystalLLM</strong>: Finetuned LLaMA-2 70B for crystal generation with target spacegroup and composition using string representations and prompting.</li>
</ul>
<h3 id="multi-domain-models">Multi-Domain Models</h3>
<p>The paper covers two multi-domain categories:</p>
<p><strong>Property prediction + MLIP</strong>: Denoising pretraining learns virtual forces that guide noisy configurations back to equilibrium, connecting to force prediction. Joint multi-domain pretraining (JMP) from Meta FAIR achieved state-of-the-art on 34 of 40 tasks spanning molecules, crystals, and MOFs by training simultaneously on diverse energy/force databases.</p>
<p><strong>Property prediction + inverse design</strong>: Multimodal models (KV-PLM, <a href="/notes/chemistry/molecular-representations/multimodal/momu-molecular-multimodal-foundation/">MoMu</a>, MoleculeSTM, <a href="/notes/chemistry/molecular-representations/multimodal/molfm-multimodal-molecular-foundation/">MolFM</a>, <a href="/notes/chemistry/molecular-representations/multimodal/spmm-bidirectional-structure-property/">SPMM</a>) learn joint representations from molecular structures and text, enabling text-based inverse design and property prediction in a single framework. LLM-based models (<a href="/notes/chemistry/llm-applications/chemdfm-x/">ChemDFM</a>, <a href="/notes/chemistry/molecular-representations/multimodal/nach0-multimodal-chemical-language-model/">nach0</a>, <a href="/notes/chemistry/llm-applications/fine-tuning-gpt3-molecular-properties/">finetuned GPT-3</a>) can interact with humans and handle diverse chemistry tasks through instruction tuning.</p>
<h2 id="trends-and-future-directions">Trends and Future Directions</h2>
<h3 id="scope-expansion">Scope Expansion</h3>
<p>The authors identify three axes for expanding foundation model scope:</p>
<ol>
<li><strong>Material types</strong>: Most models target molecules or a single material class. Foundation models that span molecules, crystals, surfaces, and MOFs could exploit shared chemistry across materials.</li>
<li><strong>Modalities</strong>: Beyond SMILES, graphs, and text, additional modalities (images, spectral data like XRD patterns) remain underexplored.</li>
<li><strong>Downstream tasks</strong>: Extending to new chemistry and tasks through emergent capabilities, analogous to the capabilities observed in LLMs at scale.</li>
</ol>
<h3 id="performance-and-scaling">Performance and Scaling</h3>
<p>Key scaling challenges include:</p>
<ul>
<li><strong>Data quality vs. quantity</strong>: Noisy DFT labels (e.g., HOMO-LUMO gaps with high uncertainty from different functionals/basis sets) can limit scalability and out-of-distribution performance.</li>
<li><strong>GNN scalability</strong>: While transformers scale to hundreds of billions of parameters, GNNs have rarely been explored above one million parameters due to oversmoothing and the curse of dimensionality. Recent work by Sypetkowski et al. demonstrated scaling GNNs to 3 billion parameters with consistent improvements.</li>
<li><strong>Database integration</strong>: Combining datasets from different DFT codes requires proper alignment (e.g., total energy alignment methods).</li>
</ul>
<h3 id="efficiency">Efficiency</h3>
<p>For MLIPs, efficiency is critical since MD simulations require millions of inference steps. Approaches include:</p>
<ul>
<li>Knowledge distillation from expensive teacher models to lighter student models</li>
<li>Model compression techniques (quantization, pruning) adapted for GNNs</li>
<li>Investigating whether strict equivariance is always necessary</li>
</ul>
<h3 id="interpretability">Interpretability</h3>
<p>Foundation models can generate hallucinations or mode-collapsed outputs. The authors highlight recent interpretability advances (feature extraction from Claude 3, knowledge localization and editing in transformers) as promising directions for more reliable chemical applications.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<p><strong>Key findings</strong>:</p>
<ul>
<li>Combining local and global information in pretraining consistently improves downstream performance across all domains reviewed.</li>
<li>Self-supervised pretraining enables effective transfer learning even in low-data regimes, a critical advantage for chemistry.</li>
<li>Universal MLIPs have reached the point where zero-shot performance can be comparable to system-specific trained models.</li>
<li>Multimodal learning is the most promising approach for big foundation models capable of spanning property prediction and inverse design.</li>
</ul>
<p><strong>Limitations acknowledged by the authors</strong>:</p>
<ul>
<li>The precise definition of &ldquo;foundation model&rdquo; in chemistry is not established and varies by scope.</li>
<li>Most surveyed models focus on molecules, with crystalline materials less explored.</li>
<li>Benchmarks for low-data regimes and out-of-distribution performance are insufficient.</li>
<li>The paper focuses on three domains (property prediction, MLIPs, inverse design) and does not cover retrosynthesis, reaction prediction, or other chemical tasks in depth.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>This is a perspective/review paper. No new data or models are introduced. The paper surveys existing models and their training datasets, summarized in Table 1 of the paper.</p>
<h3 id="algorithms">Algorithms</h3>
<p>Not applicable (review paper). The paper describes pretraining strategies (contrastive, predictive, generative, supervised, multimodal) at a conceptual level with references to the original works.</p>
<h3 id="models">Models</h3>
<p>Not applicable (review paper). The paper catalogs approximately 40 foundation models across four domains. See Table 1 in the paper for the complete listing.</p>
<h3 id="evaluation">Evaluation</h3>
<p>Not applicable (review paper). The paper references benchmark results from the original studies (MoleculeNet, QM9, Matbench, Matbench Discovery, JARVIS-DFT) but does not perform independent evaluation.</p>
<h3 id="hardware">Hardware</h3>
<p>Not applicable (review paper).</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Choi, J., Nam, G., Choi, J., &amp; Jung, Y. (2025). A Perspective on Foundation Models in Chemistry. <em>JACS Au</em>, 5(4), 1499-1518. <a href="https://doi.org/10.1021/jacsau.4c01160">https://doi.org/10.1021/jacsau.4c01160</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{choi2025perspective,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{A Perspective on Foundation Models in Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Choi, Junyoung and Nam, Gunwook and Choi, Jaesik and Jung, Yousung}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{JACS Au}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1499--1518}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/jacsau.4c01160}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Fine-Tuning GPT-3 for Molecular Property Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/fine-tuning-gpt3-molecular-properties/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/fine-tuning-gpt3-molecular-properties/</guid><description>Evaluating fine-tuned GPT-3 ada models for HOMO/LUMO classification of organic semiconductors from SMILES, with ablation and robustness analysis.</description><content:encoded><![CDATA[<h2 id="gpt-3-as-a-molecular-property-classifier">GPT-3 as a Molecular Property Classifier</h2>
<p>This is an <strong>Empirical</strong> paper that evaluates the effectiveness of fine-tuning OpenAI&rsquo;s GPT-3 language model (specifically the &ldquo;ada&rdquo; base model) for predicting electronic and functional properties of organic molecules. Rather than proposing a new architecture, the work systematically tests whether a general-purpose LLM can learn chemically meaningful patterns from <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings when fine-tuned on classification tasks. The primary contribution is the empirical characterization of GPT-3&rsquo;s performance, robustness, and limitations for molecular property prediction, including extensive ablation studies.</p>
<h2 id="why-fine-tune-a-general-purpose-llm-for-chemistry">Why Fine-Tune a General-Purpose LLM for Chemistry?</h2>
<p>Machine learning for molecular property prediction typically relies on specialized representations: molecular graphs processed by graph neural networks (GNNs), engineered molecular descriptors, or domain-specific chemical language models trained from scratch on SMILES or <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>. These approaches require varying levels of domain expertise to design the inputs and architecture.</p>
<p>GPT-3, pre-trained on vast amounts of general text, already has an internal representation of language structure. SMILES notation, as a text-based molecular representation, can be treated as a &ldquo;language&rdquo; with its own syntax. The authors hypothesize that GPT-3&rsquo;s language understanding capabilities, combined with the human-readable nature of SMILES, may enable the model to recognize significant patterns within chemical structures and capture structure-property dependencies. The key question is whether fine-tuning alone is sufficient, or whether specialized architectures provide fundamental advantages.</p>
<p>Prior work by <a href="/notes/chemistry/llm-applications/leveraging-llms-predictive-chemistry/">Jablonka et al.</a> showed that fine-tuned GPT-3 could perform surprisingly well on low-data chemistry tasks, sometimes surpassing dedicated models. This paper extends that investigation with a focus on electronic properties (<a href="https://en.wikipedia.org/wiki/HOMO_and_LUMO">HOMO and LUMO</a> energies) of <a href="https://en.wikipedia.org/wiki/Organic_semiconductor">organic semiconductors</a>, with deeper analysis of robustness and failure modes.</p>
<h2 id="smiles-to-classification-via-prompt-completion-fine-tuning">SMILES-to-Classification via Prompt-Completion Fine-Tuning</h2>
<p>The core approach is straightforward. Each training example is a prompt-completion pair in JSONL format:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-json" data-lang="json"><span style="display:flex;"><span>{<span style="color:#f92672">&#34;prompt&#34;</span>: <span style="color:#e6db74">&#34;SMILES_string&#34;</span>, <span style="color:#f92672">&#34;completion&#34;</span>: <span style="color:#e6db74">&#34;class_label&#34;</span>}
</span></span></code></pre></div><p>The SMILES string serves as the prompt, and the fine-tuned model learns to complete it with a class label (0/1 for binary, 0/1/2 for ternary, 0/1/2/3 for quaternary classification). Class thresholds are determined by equally segmenting the property value range. The authors use GPT-3&rsquo;s default tokenizer, which breaks SMILES strings into subword tokens that do not correspond to chemically meaningful units (e.g., &ldquo;c1ccccc1&rdquo; for benzene gets tokenized into arbitrary fragments).</p>
<p>This design choice has important implications. The model must learn chemical semantics from token patterns that are not aligned with atoms or bonds. The authors note this as a limitation and hypothesize that a chemistry-aware tokenizer could improve performance.</p>
<h2 id="experimental-setup-and-baseline-comparisons">Experimental Setup and Baseline Comparisons</h2>
<h3 id="datasets">Datasets</h3>
<p>The primary dataset is a collection of 48,182 organic semiconductor (OSC) molecules extracted from the <a href="https://en.wikipedia.org/wiki/Cambridge_Structural_Database">Cambridge Structural Database</a> (CSD). Each molecule has a SMILES representation and quantum-chemically computed electronic properties (HOMO and LUMO energies). A secondary dataset of 572 aromatic molecular photocatalysts (AMPs) with experimentally measured <a href="https://en.wikipedia.org/wiki/Hydrogen_evolution_reaction">hydrogen evolution rates</a> (HER) provides an additional test case.</p>
<h3 id="baselines">Baselines</h3>
<p>Three baselines are compared:</p>
<ol>
<li><strong>Directed message-passing neural network (D-MPNN)</strong> via Chemprop, using default molecular graph representations</li>
<li><strong>RDKit molecular descriptors + SVM</strong>, using the top 20 descriptors selected by SelectKBest</li>
<li><strong>Prior ML results</strong> from the original AMP dataset paper (using engineered domain-specific features)</li>
</ol>
<h3 id="main-results">Main Results</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Task</th>
          <th>Classes</th>
          <th>GPT-3 Accuracy</th>
          <th>GNN Accuracy</th>
          <th>Descriptors Accuracy</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>OSCs (48,182)</td>
          <td>HOMO</td>
          <td>3</td>
          <td>0.92</td>
          <td>0.94</td>
          <td>0.87</td>
      </tr>
      <tr>
          <td>OSCs (48,182)</td>
          <td>HOMO</td>
          <td>4</td>
          <td>0.68</td>
          <td>0.75</td>
          <td>0.47</td>
      </tr>
      <tr>
          <td>OSCs (48,182)</td>
          <td>HOMO</td>
          <td>5</td>
          <td>0.60</td>
          <td>0.68</td>
          <td>0.40</td>
      </tr>
      <tr>
          <td>OSCs (48,182)</td>
          <td>LUMO</td>
          <td>3</td>
          <td>0.94</td>
          <td>0.94</td>
          <td>0.91</td>
      </tr>
      <tr>
          <td>AMPs (572)</td>
          <td>HER</td>
          <td>2</td>
          <td>0.88</td>
          <td>0.86</td>
          <td>0.87</td>
      </tr>
  </tbody>
</table>
<p>For ternary classification, GPT-3 performs on par with GNNs (0.92 vs. 0.94 for HOMO; 0.94 vs. 0.94 for LUMO). Performance degrades more steeply than GNNs as the number of classes increases: at 5-class HOMO, GPT-3 achieves only 0.60 vs. GNN&rsquo;s 0.68. On the small AMP dataset (572 molecules), GPT-3 slightly outperforms the GNN (0.88 vs. 0.86).</p>
<h3 id="learning-curves">Learning Curves</h3>
<p>The data efficiency analysis reveals that GPT-3 needs at least 20% of the OSC dataset (approximately 9,600 molecules) to reach accuracy above 0.9. Below 1,000 training points, accuracy drops below 0.6. GNNs outperform GPT-3 in this low-data regime, which the authors attribute to (1) the molecular graph being chemically more expressive than SMILES for these tasks, and (2) fine-tuning requiring sufficient data to capture relevant SMILES patterns.</p>
<h3 id="ablation-study-1-single-atom-removal">Ablation Study 1: Single-Atom Removal</h3>
<p>The authors tested robustness by removing individual non-hydrogen, non-carbon atoms from SMILES strings and replacing them with a <code>&lt;missing&gt;</code> token. Out of 45,763 ablation tests on 7,714 correctly predicted molecules, 95.2% retained the same classification. This suggests the model captures redundant structural information rather than relying on any single atom.</p>
<h3 id="ablation-study-2-single-group-removal">Ablation Study 2: Single-Group Removal</h3>
<p>Fifteen chemical groups (nitrile, nitro, enamine, ketone, etc.) were individually ablated. The fine-tuned model attributed the most importance to acetylene (81% agreement for HOMO), enamine (85%), nitro (86%), and ketone (87%) groups, as these altered HOMO predictions in more than 10% of tests. Interestingly, groups that participate in electronic pi-conjugation tended to be more &ldquo;important&rdquo; to the model&rsquo;s HOMO predictions.</p>
<p>When ablated atoms were replaced with random elements instead of the <code>&lt;missing&gt;</code> token, the model failed in 80% of cases for a representative molecule. This suggests the model may &ldquo;fill in&rdquo; the missing information when seeing the <code>&lt;missing&gt;</code> token but gets confused by incorrect atomic identities.</p>
<h3 id="predicting-unknown-molecular-families">Predicting Unknown Molecular Families</h3>
<p>The authors held out entire families of <a href="https://en.wikipedia.org/wiki/Polycyclic_aromatic_hydrocarbon">polycyclic aromatic hydrocarbons</a> (naphthalene, anthracene, tetracene, pyrene, perylene), quinones, and imides during training, then tested predictions on these unseen families. Results for the first five PAH families:</p>
<table>
  <thead>
      <tr>
          <th>Fragment Family</th>
          <th>Molecules</th>
          <th>GPT-3 HOMO</th>
          <th>GNN HOMO</th>
          <th>GPT-3 LUMO</th>
          <th>GNN LUMO</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Naphthalene</td>
          <td>475</td>
          <td>0.94</td>
          <td>0.95</td>
          <td>0.88</td>
          <td>0.91</td>
      </tr>
      <tr>
          <td>Anthracene</td>
          <td>577</td>
          <td>0.99</td>
          <td>1.00</td>
          <td>0.93</td>
          <td>0.97</td>
      </tr>
      <tr>
          <td>Tetracene</td>
          <td>72</td>
          <td>0.96</td>
          <td>1.00</td>
          <td>0.90</td>
          <td>0.99</td>
      </tr>
      <tr>
          <td>Pyrene</td>
          <td>237</td>
          <td>0.98</td>
          <td>1.00</td>
          <td>0.97</td>
          <td>0.99</td>
      </tr>
      <tr>
          <td>Perylene</td>
          <td>41</td>
          <td>0.98</td>
          <td>1.00</td>
          <td>0.98</td>
          <td>0.95</td>
      </tr>
  </tbody>
</table>
<p>GPT-3 generalizes well to unknown PAH families, though GNNs have a slight edge on HOMO prediction. Performance degrades somewhat for quinones and imides.</p>
<h3 id="canonical-vs-non-canonical-smiles">Canonical vs. Non-Canonical SMILES</h3>
<p>A model fine-tuned only on canonical SMILES performed poorly on non-canonical variants: only 1,622 of 8,578 molecules achieved consistent predictions across all 11 SMILES variants (1 canonical + 10 non-canonical). Augmenting the training data with 5 non-canonical SMILES per molecule dramatically improved consistency to 7,243 of 8,578 molecules and nearly eliminated erroneous (non-class-label) responses. This finding highlights that GPT-3&rsquo;s pattern matching is highly sensitive to surface-level string representation and benefits substantially from <a href="/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/">SMILES enumeration</a> <a href="/notes/chemistry/molecular-design/property-prediction/maxsmi-smiles-augmentation-property-prediction/">data augmentation</a>.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<p>The main findings are:</p>
<ol>
<li>Fine-tuned GPT-3 (ada) achieves competitive accuracy with GNNs for coarse-grained (ternary) HOMO/LUMO classification, but performance drops more steeply with finer granularity.</li>
<li>The model shows robustness to single-atom and single-group ablation, suggesting it captures chemically redundant patterns.</li>
<li>Generalization to held-out molecular families is strong, though GNNs maintain a slight advantage.</li>
<li>SMILES augmentation with non-canonical variants is essential for consistent predictions.</li>
</ol>
<p>The authors acknowledge several limitations:</p>
<ul>
<li><strong>Black-box nature</strong>: GPT-3 provides no physical insight or interpretability, unlike GNN models where molecular graph features can be augmented with domain knowledge.</li>
<li><strong>Tokenization</strong>: The generic tokenizer does not respect chemical structure. A chemistry-aware tokenizer could improve data efficiency and accuracy.</li>
<li><strong>SELFIES underperformance</strong>: Initial tests with SELFIES did not improve over SMILES, likely because generic tokenization stripped away the extra chemical information SELFIES encodes.</li>
<li><strong>Cost</strong>: Fine-tuning via OpenAI&rsquo;s API cost approximately $500 for the experiments, and the model is closed-source, preventing systematic interpretation of learned representations.</li>
<li><strong>Classification only</strong>: The approach performs coarse-grained classification rather than regression, limiting utility for applications requiring precise numerical predictions.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Evaluation</td>
          <td>OSC molecules from CSD</td>
          <td>48,182</td>
          <td>SMILES + DFT-computed HOMO/LUMO energies</td>
      </tr>
      <tr>
          <td>Training/Evaluation</td>
          <td>Aromatic molecular photocatalysts (AMPs)</td>
          <td>572</td>
          <td>Experimental hydrogen evolution rates</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Fine-tuning uses OpenAI&rsquo;s GPT-3 &ldquo;ada&rdquo; base model via the API</li>
<li>Prompt-completion pairs in JSONL format</li>
<li>Default GPT-3 tokenizer</li>
<li>80/20 train/test split for OSC; stratified 10-fold CV for AMPs</li>
<li>Non-canonical SMILES generated using RDKit (10 per molecule for testing, 5 per molecule for augmented training)</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>GPT-3 &ldquo;ada&rdquo; (fine-tuned, closed-source, accessed via OpenAI API)</li>
<li>Chemprop D-MPNN baseline (open-source)</li>
<li>RDKit descriptors + scikit-learn SVM baseline</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best GPT-3 Value</th>
          <th>Best GNN Value</th>
          <th>Task</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Accuracy</td>
          <td>0.92</td>
          <td>0.94</td>
          <td>3-class HOMO (OSCs)</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td>0.94</td>
          <td>0.94</td>
          <td>3-class LUMO (OSCs)</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td>0.88</td>
          <td>0.86</td>
          <td>2-class HER (AMPs)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify local hardware requirements. All GPT-3 fine-tuning was conducted via OpenAI&rsquo;s cloud API at a total cost of approximately $500.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/XieZikai/Chem-GPT-Finetune">Chem-GPT-Finetune</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Python code and datasets for fine-tuning and evaluation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Xie, Z., Evangelopoulos, X., Omar, O. H., Troisi, A., Cooper, A. I., &amp; Chen, L. (2024). Fine-tuning GPT-3 for machine learning electronic and functional properties of organic molecules. <em>Chemical Science</em>, 15(2), 500-510.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{xie2024finetuning,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Fine-tuning {GPT-3} for machine learning electronic and functional properties of organic molecules}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Xie, Zikai and Evangelopoulos, Xenophon and Omar, {\&#34;O}mer H. and Troisi, Alessandro and Cooper, Andrew I. and Chen, Linjiang}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Chemical Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{500--510}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/D3SC04610A}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>CogMol: Controlled Molecule Generation for COVID-19</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/cogmol-target-specific-drug-design/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/cogmol-target-specific-drug-design/</guid><description>CogMol combines a SMILES VAE with controlled latent space sampling to generate drug-like molecules with target specificity for novel viral proteins.</description><content:encoded><![CDATA[<h2 id="a-controlled-generation-framework-for-target-specific-drug-design">A Controlled Generation Framework for Target-Specific Drug Design</h2>
<p>This is a <strong>Method</strong> paper that introduces CogMol (Controlled Generation of Molecules), an end-to-end framework for de novo drug design. The primary contribution is a pipeline that combines a SMILES-based Variational Autoencoder (VAE) with multi-attribute controlled latent space sampling (CLaSS) to generate novel drug-like molecules with high binding affinity to specified protein targets, off-target selectivity, and favorable drug-likeness properties. The framework operates on protein sequence embeddings, allowing it to generalize to unseen target proteins without model retraining.</p>
<h2 id="multi-constraint-drug-design-for-novel-viral-targets">Multi-Constraint Drug Design for Novel Viral Targets</h2>
<p>Traditional drug discovery costs 2-3 billion USD and takes over a decade with less than 10% success rate. Generating drug molecules requires satisfying multiple competing objectives simultaneously: target binding affinity, off-target selectivity, synthetic accessibility, drug-likeness, and low toxicity. Prior generative approaches using reinforcement learning or Bayesian optimization are computationally expensive and typically require fine-tuning on target-specific ligand libraries, making them unable to generalize to unseen protein targets.</p>
<p>The emergence of SARS-CoV-2 in 2020 created an urgent need for antiviral drug candidates targeting novel viral proteins. Because no binding affinity data existed for these new targets, and the viral proteins were not closely related to proteins in existing databases like BindingDB, existing target-specific generative frameworks could not be directly applied. CogMol addresses this by using pre-trained protein sequence embeddings from UniRep (trained on 24 million UniRef50 sequences) rather than learning protein representations from the limited BindingDB training set.</p>
<h2 id="controlled-latent-space-sampling-with-pre-trained-protein-embeddings">Controlled Latent Space Sampling with Pre-trained Protein Embeddings</h2>
<p>CogMol&rsquo;s core innovation is a three-component architecture that enables multi-constraint molecule generation for unseen targets:</p>
<p><strong>1. SMILES VAE with adaptive pre-training.</strong> A Variational Autoencoder is first trained unsupervised on the MOSES/ZINC dataset (1.6M molecules), then jointly fine-tuned with QED and SA property predictors on BindingDB molecules. The standard VAE objective is:</p>
<p>$$\mathcal{L}_{\text{VAE}}(\theta, \phi) = \mathbb{E}_{p(x)} \left\{ \mathbb{E}_{q_\phi(z|x)} [\log p_\theta(x|z)] - D_{\text{KL}}(q_\phi(z|x) | p(z)) \right\}$$</p>
<p>where $q_\phi(z|x) = \mathcal{N}(z; \mu(x), \Sigma(x))$ specifies a diagonal Gaussian encoder distribution.</p>
<p><strong>2. Protein-molecule binding affinity predictor.</strong> A regression model takes pre-trained UniRep protein sequence embeddings and molecule latent embeddings $z$ as input and predicts pIC50 binding affinity ($= -\log(\text{IC50})$). Because UniRep embeddings capture sequence, structural, and functional relationships from a large unsupervised corpus, the predictor can estimate binding affinity for novel target sequences not present in the training data.</p>
<p><strong>3. CLaSS controlled sampling.</strong> The Conditional Latent attribute Space Sampling scheme generates molecules satisfying multiple constraints (affinity, QED, selectivity) through rejection sampling in the VAE latent space:</p>
<p>$$p(\mathbf{x} | \mathbf{a}) = \mathbb{E}_{\mathbf{z}} [p(\mathbf{z} | \mathbf{a}) , p(\mathbf{x} | \mathbf{z})] \approx \mathbb{E}_{\mathbf{z}} [\hat{p}_\xi(\mathbf{z} | \mathbf{a}) , p_\theta(\mathbf{x} | \mathbf{z})]$$</p>
<p>where $\mathbf{a} = [a_1, a_2, \ldots, a_n]$ is a set of independent attribute constraints. The conditional density $\hat{p}_\xi(\mathbf{z} | \mathbf{a})$ is approximated using a Gaussian mixture model $Q_\xi(\mathbf{z})$ and per-attribute classifiers $q_\xi(a_i | \mathbf{z})$, with Bayes&rsquo; rule and conditional independence assumptions. The acceptance probability equals the product of all attribute predictor scores, enabling efficient multi-constraint sampling without surrogate model or policy learning.</p>
<p><strong>Selectivity modeling.</strong> Off-target selectivity for a molecule $m$ against target $T$ is defined as:</p>
<p>$$\text{Sel}_{T,m} = \text{BA}(T, m) - \frac{1}{k} \sum_{i=1}^{k} \text{BA}(T_i, m)$$</p>
<p>where $\text{BA}(T, m)$ is binding affinity to the target and $T_i$ are $k$ randomly selected off-targets. This selectivity score is incorporated as a control attribute during CLaSS sampling.</p>
<h2 id="experimental-setup-covid-19-targets-and-in-silico-screening">Experimental Setup: COVID-19 Targets and In Silico Screening</h2>
<p><strong>Target proteins.</strong> CogMol was applied to three SARS-CoV-2 targets not present in BindingDB: NSP9 Replicase dimer, Main Protease (Mpro), and the Receptor-Binding Domain (RBD) of the spike protein. A cancer target (human HDAC1) with low ligand coverage in the training data was also evaluated.</p>
<p><strong>Training data.</strong> The SMILES VAE was trained on the MOSES benchmark (1.6M molecules from ZINC). The binding affinity predictor used curated IC50 data from BindingDB as reported in DeepAffinity, with all protein classes included in training.</p>
<p><strong>CLaSS controlled generation.</strong> Molecules were generated with simultaneous constraints on binding affinity (&gt; 0.5 normalized), QED (&gt; 0.8 normalized), and selectivity (&gt; 0.5 normalized). Approximately 1000 molecules per target were selected for downstream evaluation.</p>
<p><strong>In silico screening pipeline.</strong> Generated molecules underwent:</p>
<ul>
<li>Toxicity prediction via a multi-task deep neural network (MT-DNN) on 12 Tox21 in vitro endpoints and ClinTox clinical trial failure</li>
<li>Binding affinity rescoring with a higher-accuracy SMILES-level predictor</li>
<li>Blind docking (5 independent runs per molecule) using AutoDock Vina against target protein structures</li>
<li>Synthetic feasibility assessment using a retrosynthetic algorithm based on the Molecular Transformer trained on patent reaction data</li>
</ul>
<p><strong>Baselines.</strong> VAE performance was benchmarked against models from the MOSES platform. CLaSS-accepted molecules were compared against randomly sampled molecules from the latent space. Generated molecules were compared against FDA-approved drugs for toxicity and synthesizability.</p>
<h3 id="key-results">Key Results</h3>
<p><strong>CLaSS enrichment (Table 1).</strong> CLaSS consistently produced higher fractions of molecules meeting all criteria compared to random sampling. For the triple constraint (affinity &gt; 0.5, QED &gt; 0.8, selectivity &gt; 0.5), the enrichment was substantial: 6.9% vs. 0.7% for NSP9, 9.0% vs. 0.9% for RBD, and 10.4% vs. 1.1% for Mpro.</p>
<table>
  <thead>
      <tr>
          <th>Target</th>
          <th>CLaSS (Aff+QED+Sel)</th>
          <th>Random (Aff+QED+Sel)</th>
          <th>Enrichment</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>NSP9</td>
          <td>6.9%</td>
          <td>0.7%</td>
          <td>~10x</td>
      </tr>
      <tr>
          <td>RBD</td>
          <td>9.0%</td>
          <td>0.9%</td>
          <td>~10x</td>
      </tr>
      <tr>
          <td>Mpro</td>
          <td>10.4%</td>
          <td>1.1%</td>
          <td>~9.5x</td>
      </tr>
  </tbody>
</table>
<p><strong>Docking results (Table 3).</strong> 87-95% of high-affinity generated molecules showed docking binding free energy (BFE) below -6 kcal/mol, with minimum BFEs reaching -8.6 to -9.5 kcal/mol depending on the target.</p>
<p><strong>Novelty.</strong> The likelihood of generating an exact duplicate of a training molecule was 2% or less. Against the full PubChem database (~103M molecules), exact matches ranged from 3.7% to 9.5%. Generated molecules also showed novel chemical scaffolds as confirmed by high Frechet ChemNet Distance.</p>
<p><strong>Synthesizability.</strong> Generated molecules for COVID-19 targets showed 85-90% synthetic feasibility using retrosynthetic analysis, exceeding the ~78% rate of FDA-approved drugs.</p>
<p><strong>Toxicity.</strong> Approximately 70% of generated parent molecules and ~80% of predicted metabolites were toxic in 0-1 endpoints out of 13, comparable to FDA-approved drugs.</p>
<h2 id="generated-molecules-show-favorable-binding-and-drug-like-properties">Generated Molecules Show Favorable Binding and Drug-Like Properties</h2>
<p>CogMol demonstrates that controlled latent space sampling with pre-trained protein embeddings can generate novel, drug-like molecules for unseen viral targets. The key findings are:</p>
<ol>
<li>CLaSS provides roughly 10x enrichment over random latent space sampling for molecules satisfying all three constraints (affinity, QED, selectivity).</li>
<li>Generated molecules bind favorably to druggable pockets in target protein 3D structures, even though the generation model uses only 1D sequence information.</li>
<li>Some generated SMILES matched existing PubChem molecules with known biological activity, suggesting the model identifies chemically relevant regions of molecular space.</li>
<li>The framework generalizes across targets of varying novelty, with Mpro (more similar to training proteins) yielding easier generation than NSP9 or RBD.</li>
</ol>
<p><strong>Limitations.</strong> The authors note that no wet-lab validation was performed on generated candidates. There may be divergence between ML-predicted properties and experimental measurements. The binding affinity predictor&rsquo;s accuracy is bounded by the quality and coverage of BindingDB training data. Selectivity modeling uses a random sample of off-targets rather than a pharmacologically curated panel.</p>
<p><strong>Future directions.</strong> The authors propose incorporating additional contexts beyond target protein (e.g., metabolic pathways), adding more pharmacologically relevant controls, and weighting objectives by relative importance.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>VAE pre-training</td>
          <td>MOSES/ZINC</td>
          <td>1.6M train, 176K test</td>
          <td>Publicly available benchmark</td>
      </tr>
      <tr>
          <td>VAE adaptive training</td>
          <td>BindingDB (DeepAffinity split)</td>
          <td>~27K protein-ligand pairs</td>
          <td>Curated IC50 data</td>
      </tr>
      <tr>
          <td>Protein embeddings</td>
          <td>UniRef50 via UniRep</td>
          <td>24M sequences</td>
          <td>Pre-trained, publicly available</td>
      </tr>
      <tr>
          <td>Toxicity prediction</td>
          <td>Tox21 + ClinTox</td>
          <td>12 in vitro + clinical endpoints</td>
          <td>Public benchmark datasets</td>
      </tr>
      <tr>
          <td>Docking validation</td>
          <td>PDB structures</td>
          <td>3 SARS-CoV-2 targets</td>
          <td>Public crystal structures</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>VAE architecture: SMILES encoder-decoder with diagonal Gaussian latent space, jointly trained with QED and SA regressors</li>
<li>CLaSS: rejection sampling from Gaussian mixture model of latent space with per-attribute classifiers</li>
<li>Binding affinity: regression on concatenated UniRep protein embeddings and VAE molecule embeddings</li>
<li>Selectivity: excess binding affinity over average of $k$ random off-targets</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>SMILES VAE with adaptive pre-training (ZINC then BindingDB)</li>
<li>Multi-task toxicity classifier (MT-DNN) for Tox21 and ClinTox endpoints</li>
<li>Binding affinity predictor (latent-level for generation, SMILES-level for screening)</li>
<li>Retrosynthetic predictor based on Molecular Transformer</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity</td>
          <td>90%</td>
          <td>-</td>
          <td>Generated SMILES</td>
      </tr>
      <tr>
          <td>Uniqueness</td>
          <td>99%</td>
          <td>-</td>
          <td>Among valid molecules</td>
      </tr>
      <tr>
          <td>Filter pass</td>
          <td>95%</td>
          <td>-</td>
          <td>Relevant chemical filters</td>
      </tr>
      <tr>
          <td>Docking BFE &lt; -6 kcal/mol</td>
          <td>87-95%</td>
          <td>-</td>
          <td>Varies by target</td>
      </tr>
      <tr>
          <td>Synthetic feasibility</td>
          <td>85-90%</td>
          <td>78% (FDA drugs)</td>
          <td>COVID-19 targets</td>
      </tr>
      <tr>
          <td>Low toxicity (0-1 endpoints)</td>
          <td>~70% parent, ~80% metabolite</td>
          <td>Comparable to FDA drugs</td>
          <td>MT-DNN prediction</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify GPU types or training times. The work was funded internally by IBM Research.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/IBM/CogMol">CogMol (GitHub)</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official implementation</td>
      </tr>
      <tr>
          <td><a href="https://github.com/IBM/CogMol">~3500 generated molecules</a></td>
          <td>Dataset</td>
          <td>Open license</td>
          <td>For three SARS-CoV-2 targets</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chenthamarakshan, V., Das, P., Hoffman, S. C., Strobelt, H., Padhi, I., Lim, K. W., Hoover, B., Manica, M., Born, J., Laino, T., &amp; Mojsilovic, A. (2020). CogMol: Target-Specific and Selective Drug Design for COVID-19 Using Deep Generative Models. <em>Advances in Neural Information Processing Systems</em>, 33, 4320-4332.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{chenthamarakshan2020cogmol,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{CogMol: Target-Specific and Selective Drug Design for COVID-19 Using Deep Generative Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Chenthamarakshan, Vijil and Das, Payel and Hoffman, Samuel C. and Strobelt, Hendrik and Padhi, Inkit and Lim, Kar Wai and Hoover, Benjamin and Manica, Matteo and Born, Jannis and Laino, Teodoro and Mojsilovi{\&#39;c}, Aleksandra}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{33}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{4320--4332}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemLLMBench: Benchmarking LLMs on Chemistry Tasks</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/chemllmbench-eight-chemistry-tasks/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/chemllmbench-eight-chemistry-tasks/</guid><description>ChemLLMBench evaluates five LLMs across eight chemistry tasks covering understanding, reasoning, and explaining, finding GPT-4 leads but struggles with SMILES.</description><content:encoded><![CDATA[<h2 id="a-benchmark-resource-for-llm-chemistry-evaluation">A Benchmark Resource for LLM Chemistry Evaluation</h2>
<p>This is a <strong>Resource</strong> paper that introduces ChemLLMBench, a comprehensive benchmark for evaluating large language models on practical chemistry tasks. The primary contribution is the systematic design of eight chemistry tasks organized around three fundamental capabilities (understanding, reasoning, and explaining) along with a standardized evaluation framework that includes prompt templates, in-context learning strategies, and comparison against domain-specific baselines. The benchmark provides the first broad-scope assessment of general-purpose LLMs on chemistry problems, establishing baseline performance levels across multiple models and task types.</p>
<h2 id="why-benchmark-llms-for-chemistry">Why Benchmark LLMs for Chemistry?</h2>
<p>At the time of this work, large language models had demonstrated broad reasoning capabilities across many domains, but their application to practical chemistry tasks remained underexplored. Prior studies (e.g., Nascimento and Pimentel, 2023; Jablonka et al., 2023; White et al., 2023) had examined LLMs on specific chemistry case studies, but no comprehensive or systematic evaluation existed. Two challenges motivated this benchmark:</p>
<ol>
<li>Chemistry encompasses diverse task types that require different capabilities. Some tasks can be formulated as problems that LLMs can address (classification, text generation), while others demand deep understanding of molecular representations that LLMs may lack.</li>
<li>Reliable evaluation requires careful standardization of prompts, demonstration examples, and evaluation procedures. The stochastic nature of LLM outputs and the cost of API calls further constrain experimental design.</li>
</ol>
<p>The authors, a joint team of AI researchers and chemists at Notre Dame (including the NSF Center for Computer Assisted Synthesis, C-CAS), designed this benchmark to clarify where LLMs are useful for chemistry practitioners and where they fall short.</p>
<h2 id="eight-tasks-across-three-chemistry-capabilities">Eight Tasks Across Three Chemistry Capabilities</h2>
<p>The benchmark organizes eight tasks into three capability categories:</p>
<p><strong>Understanding</strong> tasks test whether LLMs can interpret molecular representations:</p>
<ul>
<li><strong>Name prediction</strong>: Translation between <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>, <a href="https://en.wikipedia.org/wiki/IUPAC_nomenclature_of_organic_chemistry">IUPAC names</a>, and molecular formulas (four subtasks)</li>
<li><strong>Property prediction</strong>: Binary classification on five <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> datasets (BBBP, HIV, BACE, Tox21, ClinTox)</li>
</ul>
<p><strong>Reasoning</strong> tasks require knowledge of chemical reactions and transformations:</p>
<ul>
<li><strong>Yield prediction</strong>: Binary classification of high/low yield on <a href="https://en.wikipedia.org/wiki/Buchwald%E2%80%93Hartwig_amination">Buchwald-Hartwig</a> and <a href="https://en.wikipedia.org/wiki/Suzuki_reaction">Suzuki-Miyaura</a> HTE datasets</li>
<li><strong>Reaction prediction</strong>: Generating product SMILES from reactants/reagents (USPTO-Mixed)</li>
<li><strong>Reagents selection</strong>: Ranking candidate reactants, solvents, or ligands (Suzuki HTE dataset)</li>
<li><strong><a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">Retrosynthesis</a></strong>: Predicting reactant SMILES from a target product (USPTO-50k)</li>
</ul>
<p><strong>Explaining</strong> tasks leverage LLMs&rsquo; natural language capabilities:</p>
<ul>
<li><strong>Text-based molecule design</strong>: Generating SMILES from a textual molecular description (ChEBI-20)</li>
<li><strong>Molecule captioning</strong>: Generating textual descriptions of molecules from SMILES (ChEBI-20)</li>
</ul>
<p>Each task uses 100 test instances randomly sampled from established datasets, with evaluations repeated five times to account for LLM output variability.</p>
<h2 id="evaluation-framework-and-in-context-learning-design">Evaluation Framework and In-Context Learning Design</h2>
<h3 id="models-evaluated">Models evaluated</h3>
<p>Five LLMs were tested: GPT-4, GPT-3.5 (ChatGPT), Davinci-003, Llama2-13B-chat, and <a href="/notes/chemistry/llm-applications/galactica-large-language-model-for-science/">Galactica</a>-30B.</p>
<h3 id="prompt-design">Prompt design</h3>
<p>The authors developed a standardized zero-shot prompt template instructing the LLM to act as &ldquo;an expert chemist&rdquo; with task-specific input/output descriptions. For in-context learning (ICL), they designed a four-part template: {General Template}{Task-Specific Template}{ICL}{Question}. The task-specific template includes input explanations, output explanations, and output restrictions to reduce hallucinations.</p>
<h3 id="icl-strategies">ICL strategies</h3>
<p>Two retrieval strategies were explored for selecting demonstration examples:</p>
<ul>
<li><strong>Random</strong>: Randomly selecting k examples from the candidate pool</li>
<li><strong>Scaffold</strong>: Finding the top-k most similar examples using <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> on Morgan fingerprints (for SMILES inputs) or sequence matching (for text inputs)</li>
</ul>
<p>The number of examples k was varied per task (typically k in {4, 5, 8, 10, 20}). A validation set of 30 instances was used to select the best five configurations, which were then applied to the test set.</p>
<h3 id="results-summary">Results summary</h3>
<p>The authors classify LLM performance into three categories:</p>
<table>
  <thead>
      <tr>
          <th>Category</th>
          <th>Tasks</th>
          <th>Key Observation</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Not Competitive (NC)</td>
          <td>Name prediction, Reaction prediction, Retrosynthesis</td>
          <td>LLMs lack deep understanding of SMILES strings; 70% lower accuracy than <a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">Chemformer</a> on reaction prediction</td>
      </tr>
      <tr>
          <td>Competitive (C)</td>
          <td>Yield prediction, Reagents selection</td>
          <td>Classification/ranking formulations are more tractable; GPT-4 reaches 80% accuracy on Buchwald-Hartwig yield prediction vs. 96.5% for UAGNN</td>
      </tr>
      <tr>
          <td>Selectively Competitive (SC)</td>
          <td>Property prediction, Molecule design, Molecule captioning</td>
          <td>Performance depends heavily on prompt design; GPT-4 outperforms RF/XGBoost on HIV and ClinTox when property label semantics are included in prompts</td>
      </tr>
  </tbody>
</table>
<p>GPT-4 ranked first on 6 of 8 tasks by average performance, with an overall average rank of 1.25 across all tasks.</p>
<h3 id="key-findings-on-icl">Key findings on ICL</h3>
<p>Three consistent observations emerged across tasks:</p>
<ol>
<li>ICL prompting outperforms zero-shot prompting on all tasks</li>
<li>Scaffold-based retrieval of similar examples generally outperforms random sampling</li>
<li>Using more ICL examples (larger k) typically improves performance</li>
</ol>
<h3 id="smiles-vs-selfies-comparison">SMILES vs. SELFIES comparison</h3>
<p>The authors tested <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> representations as an alternative to SMILES on four tasks. SMILES outperformed SELFIES on all tasks, likely because LLM pretraining data contains more SMILES-related content. However, SELFIES produced fewer invalid molecular strings, consistent with its design guarantee of chemical validity.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="performance-patterns">Performance patterns</h3>
<p>The benchmark reveals a clear performance hierarchy: GPT-4 outperforms all others, followed by Davinci-003 and GPT-3.5 (roughly comparable), with Llama2-13B-chat and Galactica-30B trailing well behind. The ranking is consistent across most tasks.</p>
<p>LLMs perform best when chemistry tasks can be cast as classification or ranking problems rather than generation tasks requiring precise SMILES output. Text-related tasks (molecule captioning, property prediction with label semantics) also play to LLM strengths.</p>
<h3 id="fundamental-limitation-smiles-understanding">Fundamental limitation: SMILES understanding</h3>
<p>The paper identifies a core limitation: LLMs treat SMILES strings as character sequences via <a href="https://en.wikipedia.org/wiki/Byte-pair_encoding">byte-pair encoding</a> tokenization, which fragments molecular structure information. Specific issues include:</p>
<ul>
<li>Inability to infer implicit hydrogen atoms</li>
<li>Failure to recognize equivalent SMILES representations of the same molecule</li>
<li>Tokenization that breaks SMILES into subwords not aligned with chemical substructures</li>
<li>Generation of chemically invalid SMILES (up to 27.8% invalid for Llama2-13B-chat on reaction prediction)</li>
</ul>
<h3 id="hallucination-in-chemistry">Hallucination in chemistry</h3>
<p>Two types of hallucinations were identified:</p>
<ol>
<li><strong>Input hallucinations</strong>: Misinterpreting SMILES input (e.g., failing to count atoms or recognize functional groups)</li>
<li><strong>Output hallucinations</strong>: Generating chemically unreasonable molecules when SMILES output is required</li>
</ol>
<h3 id="evaluation-metric-limitations">Evaluation metric limitations</h3>
<p>The authors note that standard NLP metrics (BLEU, ROUGE) do not fully capture chemical correctness. For molecule design, exact match is a more meaningful metric than BLEU, yet GPT-4 achieves only 17.4% exact match despite a BLEU score of 0.816. This highlights the need for chemistry-specific evaluation metrics.</p>
<h3 id="future-directions">Future directions</h3>
<p>The authors suggest several promising directions: advanced prompting techniques (chain-of-thought, decomposed prompting), coupling LLMs with chemistry-specific tools (e.g., RDKit), and developing chemistry-aware ICL methods for higher-quality demonstration examples.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Understanding</td>
          <td>PubChem</td>
          <td>630 molecules</td>
          <td>Name prediction (500 ICL, 100 test)</td>
      </tr>
      <tr>
          <td>Understanding</td>
          <td>BBBP, HIV, BACE, Tox21, ClinTox (MoleculeNet)</td>
          <td>2,053-41,127 ICL candidates</td>
          <td>Property prediction, MIT license</td>
      </tr>
      <tr>
          <td>Reasoning</td>
          <td>Buchwald-Hartwig, Suzuki-Miyaura (HTE)</td>
          <td>3,957 / 5,650</td>
          <td>Yield prediction, MIT license</td>
      </tr>
      <tr>
          <td>Reasoning</td>
          <td>USPTO-Mixed</td>
          <td>409,035 ICL candidates</td>
          <td>Reaction prediction, MIT license</td>
      </tr>
      <tr>
          <td>Reasoning</td>
          <td>Suzuki HTE</td>
          <td>5,760</td>
          <td>Reagents selection, MIT license</td>
      </tr>
      <tr>
          <td>Reasoning</td>
          <td>USPTO-50k</td>
          <td>40,029 ICL candidates</td>
          <td>Retrosynthesis, MIT license</td>
      </tr>
      <tr>
          <td>Explaining</td>
          <td>ChEBI-20</td>
          <td>26,407 ICL candidates</td>
          <td>Molecule design and captioning, CC BY 4.0</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Zero-shot and few-shot ICL prompting with standardized templates</li>
<li>Scaffold-based retrieval using Tanimoto similarity on 2048-bit Morgan fingerprints (radius=2)</li>
<li>Text similarity via Python&rsquo;s difflib.SequenceMatcher</li>
<li>Grid search over k and retrieval strategies on a 30-instance validation set</li>
<li>Five repeated evaluations per task configuration to account for LLM stochasticity</li>
</ul>
<h3 id="models">Models</h3>
<p>Five LLMs evaluated: GPT-4, GPT-3.5-turbo, text-davinci-003, Llama2-13B-chat, and Galactica-30B. Baselines include Chemformer (reaction prediction, retrosynthesis), UAGNN (yield prediction), MolT5-Large (molecule design, captioning), <a href="/notes/chemistry/molecular-representations/name-translation/stout/">STOUT</a> (name prediction), and RF/XGBoost from MoleculeNet (property prediction).</p>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>Accuracy and F1 score for classification tasks (property prediction, yield prediction)</li>
<li>Top-1 accuracy and invalid SMILES rate for generation tasks (reaction prediction, retrosynthesis)</li>
<li>BLEU, exact match, <a href="https://en.wikipedia.org/wiki/Levenshtein_distance">Levenshtein distance</a>, validity, fingerprint Tanimoto similarity (MACCS, RDK, Morgan), and <a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">FCD</a> for molecule design</li>
<li>BLEU-2, BLEU-4, ROUGE-1/2/L, and METEOR for molecule captioning</li>
<li>All evaluations repeated 5 times; mean and standard deviation reported</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper. Evaluation was conducted via API calls for GPT models; local inference details for Llama and Galactica are not provided.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/ChemFoundationModels/ChemLLMBench">ChemLLMBench</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Official benchmark code and prompts (Jupyter notebooks)</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Guo, T., Guo, K., Nan, B., Liang, Z., Guo, Z., Chawla, N. V., Wiest, O., &amp; Zhang, X. (2023). What can Large Language Models do in chemistry? A comprehensive benchmark on eight tasks. <em>Advances in Neural Information Processing Systems 36 (NeurIPS 2023)</em>, 59662-59688.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{guo2023chemllmbench,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{What can Large Language Models do in chemistry? A comprehensive benchmark on eight tasks}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Guo, Taicheng and Guo, Kehan and Nan, Bozhao and Liang, Zhenwen and Guo, Zhichun and Chawla, Nitesh V. and Wiest, Olaf and Zhang, Xiangliang}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems 36 (NeurIPS 2023)}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{59662--59688}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Chemical Language Models for De Novo Drug Design Review</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/clms-de-novo-drug-design-review/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/clms-de-novo-drug-design-review/</guid><description>Review of chemical language models for de novo drug design covering string representations, architectures, training strategies, and experimental validation.</description><content:encoded><![CDATA[<h2 id="a-systematization-of-chemical-language-models-for-drug-design">A Systematization of Chemical Language Models for Drug Design</h2>
<p>This paper is a <strong>Systematization</strong> (minireview) that surveys the landscape of chemical language models (CLMs) for de novo drug design. It organizes the field along three axes: molecular string representations, deep learning architectures, and generation strategies (distribution learning, goal-directed, and conditional). The review also highlights experimental validations, current gaps, and future opportunities.</p>
<h2 id="why-chemical-language-models-matter-for-drug-design">Why Chemical Language Models Matter for Drug Design</h2>
<p>De novo drug design faces an enormous combinatorial challenge: the &ldquo;chemical universe&rdquo; is estimated to contain up to $10^{60}$ drug-like small molecules. Exhaustive enumeration is infeasible, and traditional design algorithms rely on hand-crafted assembly rules. Chemical language models address this by borrowing natural language processing techniques to learn the &ldquo;chemical language,&rdquo; generating molecules as string representations (<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>, <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>, DeepSMILES) that satisfy both syntactic validity (chemically valid structures) and semantic correctness (desired pharmacological properties).</p>
<p>CLMs have gained traction because string representations are readily available for most molecular databases, generation is computationally cheap (one molecule per forward pass through a sequence model), and the same architecture can be applied to diverse tasks (property prediction, de novo generation, reaction prediction). At the time of this review, CLMs had produced experimentally validated bioactive molecules in several prospective studies, establishing them as practical tools for drug discovery.</p>
<h2 id="molecular-string-representations-smiles-deepsmiles-and-selfies">Molecular String Representations: SMILES, DeepSMILES, and SELFIES</h2>
<p>The review covers three main string representations used as input/output for CLMs:</p>
<p><strong>SMILES</strong> (Simplified Molecular Input Line Entry Systems) converts hydrogen-depleted molecular graphs into strings where atoms are denoted by atomic symbols, bonds and branching by punctuation, and ring openings/closures by numbers. SMILES are non-univocal (multiple valid strings per molecule), and canonicalization algorithms are needed for unique representations. Multiple studies show that using randomized (non-canonical) SMILES for data augmentation improves CLM performance, with diminishing returns beyond 10- to 20-fold augmentation.</p>
<p><strong><a href="/notes/chemistry/molecular-representations/notations/deepsmiles-adaptation-for-ml/">DeepSMILES</a></strong> modifies SMILES to improve machine-readability by replacing the paired ring-opening/closure digits with a count-based system and using closing parentheses only (no opening ones). This reduces the frequency of syntactically invalid strings but does not eliminate them entirely.</p>
<p><strong><a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a></strong> (Self-Referencing Embedded Strings) use a formal grammar that guarantees 100% syntactic validity of decoded molecules. Every SELFIES string maps to a valid molecular graph. However, SELFIES can produce chemically unrealistic molecules (e.g., highly strained ring systems), and the mapping between string edits and molecular changes is less intuitive than for SMILES.</p>
<p>The review notes a key tradeoff: SMILES offer a richer, more interpretable language with well-studied augmentation strategies, while SELFIES guarantee validity at the cost of chemical realism and edit interpretability.</p>
<h2 id="clm-architectures-and-training-strategies">CLM Architectures and Training Strategies</h2>
<h3 id="architectures">Architectures</h3>
<p>The review describes the main architectures used in CLMs:</p>
<p><strong>Recurrent Neural Networks (RNNs)</strong>, particularly LSTMs and GRUs, dominated early CLM work. These models process SMILES character-by-character and generate new strings autoregressively via next-token prediction. RNNs are computationally efficient and well-suited to the sequential nature of molecular strings.</p>
<p><strong><a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">Variational Autoencoders (VAEs)</a></strong> encode molecules into a continuous latent space and decode them back into strings. This enables smooth interpolation between molecules and latent-space optimization, but generated strings may be syntactically invalid.</p>
<p><strong><a href="/posts/what-is-a-gan/">Generative Adversarial Networks (GANs)</a></strong> have been adapted for molecular string generation (e.g., <a href="/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/">ORGAN</a>), though they face training instability and mode collapse challenges that limit their adoption.</p>
<p><strong>Transformers</strong> have emerged as an increasingly popular alternative, offering parallelized training and the ability to capture long-range dependencies in molecular strings. The review notes the growing relevance of Transformer-based CLMs, particularly for large-scale pretraining.</p>
<h3 id="generation-strategies">Generation Strategies</h3>
<p>The review organizes CLM generation into three categories:</p>
<ol>
<li>
<p><strong>Distribution learning</strong>: The model learns to reproduce the statistical distribution of a training set of molecules. No explicit scoring function is used during generation. The generated molecules are evaluated post-hoc by comparing their property distributions to the training set. This approach is end-to-end but provides no direct indication of individual molecule quality.</p>
</li>
<li>
<p><strong>Goal-directed generation</strong>: A pretrained CLM is steered toward molecules optimizing a specified scoring function (e.g., predicted bioactivity, physicochemical properties). Common approaches include reinforcement learning (<a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a> and variants), hill-climbing, and Bayesian optimization. Scoring functions provide direct quality signals but can introduce biases, shortcuts, and limited structural diversity.</p>
</li>
<li>
<p><strong>Conditional generation</strong>: An intermediate approach that learns a joint semantic space between molecular structures and desired properties. The desired property profile serves as an input &ldquo;prompt&rdquo; for generation (e.g., a protein target, gene expression signature, or 3D shape). This bypasses the need for external scoring functions but has seen limited experimental application.</p>
</li>
</ol>
<h3 id="transfer-learning-and-chemical-space-exploration">Transfer Learning and Chemical Space Exploration</h3>
<p>Transfer learning is the dominant paradigm for CLM-driven chemical space exploration. A large-scale pretraining step (on $10^5$ to $10^6$ molecules via next-character prediction) is followed by fine-tuning on a smaller set of molecules with desired properties (often 10 to $10^2$ molecules). Key findings from the literature:</p>
<ul>
<li>The minimum training set size depends on target molecule complexity and heterogeneity.</li>
<li>SMILES augmentation is most beneficial with small training sets (fewer than 10,000 molecules) and plateaus for large, structurally complex datasets.</li>
<li>Fine-tuning with as few as 10 to 100 molecules has produced experimentally validated bioactive designs.</li>
<li>Hyperparameter tuning has relatively little effect on overall CLM performance.</li>
</ul>
<h2 id="evaluating-clm-designs-and-experimental-validation">Evaluating CLM Designs and Experimental Validation</h2>
<p>The review identifies evaluation as a critical gap. CLMs are often benchmarked on &ldquo;toy&rdquo; properties such as calculated logP, molecular weight, or QED (quantitative estimate of drug-likeness). These metrics capture the ability to satisfy predefined criteria but fail to reflect real-world drug discovery complexity and may lead to trivial solutions.</p>
<p>Existing benchmarks (<a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a>, <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a>) enable comparability across independently developed approaches but do not fully address the quality of generated compounds. The review emphasizes that experimental validation is the ultimate test. At the time of writing, only a few prospective applications had been published:</p>
<ul>
<li>Dual modulator of <a href="https://en.wikipedia.org/wiki/Retinoid_X_receptor">retinoid X</a> and <a href="https://en.wikipedia.org/wiki/Peroxisome_proliferator-activated_receptor">PPAR</a> receptors (EC50 ranging from 0.06 to 2.3 uM)</li>
<li>Inhibitor of <a href="https://en.wikipedia.org/wiki/Pim_kinase">Pim1 kinase</a> and <a href="https://en.wikipedia.org/wiki/Cyclin-dependent_kinase_4">CDK4</a> (manually modified from generated design)</li>
<li>Natural-product-inspired <a href="https://en.wikipedia.org/wiki/RAR-related_orphan_receptor_gamma">RORgamma</a> agonist (EC50 = 0.68 uM)</li>
<li>Molecules designed via combined generative AI and on-chip synthesis</li>
</ul>
<p>The scarcity of experimental validations reflects the interdisciplinary expertise required and the time/cost of chemical synthesis.</p>
<h2 id="gaps-limitations-and-future-directions">Gaps, Limitations, and Future Directions</h2>
<p>The review identifies several key gaps and opportunities:</p>
<p><strong>Scoring function limitations</strong>: Current scoring functions struggle with activity cliffs and non-additive structure-activity relationships. Conditional generation methods may help overcome these limitations by learning direct structure-property mappings.</p>
<p><strong>Structure-based design</strong>: Generating molecules that match electrostatic and shape features of protein binding pockets holds promise for addressing unexplored targets. However, prospective applications have been limited, potentially due to bias in existing protein-ligand affinity datasets.</p>
<p><strong>Synthesizability</strong>: Improving the ability of CLMs to propose synthesizable molecules is expected to increase practical relevance. Automated synthesis platforms may help but could also limit accessible chemical space.</p>
<p><strong>Few-shot learning</strong>: Large-scale pretrained CLMs combined with few-shot learning approaches are expected to boost prospective applications.</p>
<p><strong>Extensions beyond small molecules</strong>: Extending chemical languages to more complex molecular entities (proteins with non-natural amino acids, crystals, supramolecular chemistry) is an open frontier.</p>
<p><strong>Failure modes</strong>: Several studies have documented failure modes in goal-directed generation, including model shortcuts (exploiting scoring function artifacts), limited structural diversity, and generation of chemically unrealistic molecules.</p>
<p><strong>Interdisciplinary collaboration</strong>: The review emphasizes that bridging deep learning, cheminformatics, and medicinal chemistry expertise is essential for translating CLM designs into real-world drug candidates.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>This is a review paper and does not present novel experimental data. The paper surveys results from the literature.</p>
<h3 id="algorithms">Algorithms</h3>
<p>No novel algorithms are introduced. The review categorizes existing approaches (RNNs, VAEs, GANs, Transformers) and generation strategies (distribution learning, goal-directed, conditional).</p>
<h3 id="models">Models</h3>
<p>No new models are presented. The paper references existing implementations including REINVENT, ORGAN, and various RNN-based and Transformer-based CLMs.</p>
<h3 id="evaluation">Evaluation</h3>
<p>The review discusses existing benchmarks:</p>
<ul>
<li><strong><a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a></strong>: Benchmarking suite for de novo molecular design</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a></strong>: Benchmarking platform for molecular generation models</li>
<li><strong>QED</strong>: Quantitative estimate of drug-likeness</li>
<li>Various physicochemical property metrics (logP, molecular weight)</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Not applicable (review paper).</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Grisoni, F. (2023). Chemical language models for de novo drug design: Challenges and opportunities. <em>Current Opinion in Structural Biology</em>, 79, 102527. <a href="https://doi.org/10.1016/j.sbi.2023.102527">https://doi.org/10.1016/j.sbi.2023.102527</a></p>
<p><strong>Publication</strong>: Current Opinion in Structural Biology, Volume 79, April 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{grisoni2023chemical,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Chemical language models for de novo drug design: Challenges and opportunities}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Grisoni, Francesca}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Current Opinion in Structural Biology}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{79}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{102527}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Elsevier}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1016/j.sbi.2023.102527}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>CDDD: Learning Descriptors by Translating SMILES</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/cddd-translation-molecular-descriptors/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/cddd-translation-molecular-descriptors/</guid><description>CDDD learns continuous molecular descriptors by translating between SMILES and InChI representations, outperforming fingerprints in virtual screening.</description><content:encoded><![CDATA[<h2 id="a-translation-based-method-for-learned-molecular-descriptors">A Translation-Based Method for Learned Molecular Descriptors</h2>
<p>This is a <strong>Method</strong> paper that introduces Continuous and Data-Driven Descriptors (CDDD), a neural machine translation approach for learning fixed-size, continuous molecular representations. Rather than training an autoencoder to reconstruct <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings, Winter et al. train an encoder-decoder model to translate between semantically equivalent but syntactically different molecular representations (e.g., randomized SMILES to canonical SMILES, or <a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a> to canonical SMILES). The bottleneck latent vector serves as a general-purpose molecular descriptor. Pretrained on approximately 72 million compounds from <a href="/notes/chemistry/datasets/zinc-22/">ZINC15</a> and PubChem, CDDD produces 512-dimensional descriptors that achieve competitive QSAR performance and significantly outperform all tested molecular fingerprints in ligand-based virtual screening.</p>
<h2 id="why-translation-instead-of-reconstruction">Why Translation Instead of Reconstruction?</h2>
<p>Molecular descriptors are central to cheminformatics. Traditional approaches rely on human-engineered fingerprints like ECFPs, which encode structural features as fixed-length bit vectors. While effective, these representations are constrained by predefined feature extraction rules.</p>
<p>Recent work applied deep neural networks directly to molecular graphs or SMILES strings to learn task-specific representations. However, these end-to-end approaches must learn features from scratch for each new dataset, making them prone to overfitting on the small bioactivity datasets typical in drug discovery.</p>
<p>Unsupervised approaches based on autoencoders (notably <a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">Gomez-Bombarelli et al.&rsquo;s VAE</a> and <a href="/notes/chemistry/molecular-representations/encoders/seq2seq-fingerprint-molecular-embedding/">Xu et al.&rsquo;s seq2seq model</a>) offered a path toward general-purpose learned descriptors. These models reconstruct SMILES strings through an information bottleneck, forcing the latent space to capture molecular information. The concern with reconstruction, however, is that the model may focus on syntactic patterns of the string representation rather than the underlying chemical semantics. A model that memorizes SMILES syntax shortcuts can achieve low reconstruction error without truly encoding chemical meaning.</p>
<p>Winter et al. address this by drawing on the analogy to neural machine translation: a translator must understand the meaning of a sentence to produce a correct translation in another language. By training the model to translate between different molecular representations (which share chemical semantics but differ in syntax), the latent space is forced to capture the chemical information common to both representations, rather than representation-specific syntactic artifacts.</p>
<h2 id="translation-as-semantic-compression">Translation as Semantic Compression</h2>
<p>The core insight is that translating between two syntactically different but semantically equivalent representations forces the encoder to capture only the chemical meaning shared by both. The model architecture follows the standard encoder-decoder framework from neural machine translation.</p>
<p>The encoder reads a source molecular string (e.g., a randomized SMILES or InChI) and compresses it into a fixed-size latent vector. The decoder takes this latent vector and generates the target molecular string (canonical SMILES). The model is trained to minimize character-level cross-entropy between the decoder output and the target sequence.</p>
<p>Four translation tasks were evaluated:</p>
<ol>
<li><strong>Randomized SMILES to canonical SMILES</strong> (best performing)</li>
<li><strong>InChI to canonical SMILES</strong></li>
<li><strong>Canonical SMILES to canonical SMILES</strong> (autoencoding baseline)</li>
<li><strong>Canonical SMILES to InChI</strong> (failed to learn)</li>
</ol>
<p>The final model uses an RNN encoder with 3 stacked GRU layers (512, 1024, and 2048 units). The concatenated cell states pass through a fully connected layer with tanh activation to produce a 512-dimensional latent vector. The decoder mirrors this architecture, initializing its GRU states from the latent vector via separate fully connected layers. Teacher forcing is used during training, and left-to-right beam search is used at inference.</p>
<p>An auxiliary property prediction network takes the latent vector as input and predicts nine molecular properties (logP, partial charges, valence electrons, H-bond donors/acceptors, Balaban&rsquo;s J, <a href="https://en.wikipedia.org/wiki/Molar_refractivity">molar refractivity</a>, TPSA). This multi-task signal encourages the latent space to encode physically meaningful information. The full training objective combines the translation cross-entropy loss with the property prediction mean squared error:</p>
<p>$$\mathcal{L} = \mathcal{L}_{\text{translation}} + \mathcal{L}_{\text{properties}}$$</p>
<p>To ensure invariance to input SMILES representation at inference time, the model uses randomized SMILES as input half the time and canonical SMILES the other half during training. Input dropout (15% at the character level) and Gaussian noise (standard deviation 0.05) are applied for regularization.</p>
<h2 id="qsar-benchmarks-virtual-screening-and-latent-space-exploration">QSAR Benchmarks, Virtual Screening, and Latent Space Exploration</h2>
<h3 id="pretraining">Pretraining</h3>
<p>The model was pretrained on approximately 72 million compounds from ZINC15 and PubChem (merged, deduplicated, filtered for organic molecules with MW 12-600, &gt;3 heavy atoms, logP between -7 and 5). All evaluation compounds were removed from the pretraining set.</p>
<h3 id="qsar-experiments">QSAR Experiments</h3>
<p>Ten QSAR datasets were used, spanning classification (<a href="https://en.wikipedia.org/wiki/Ames_test">Ames mutagenicity</a>, <a href="https://en.wikipedia.org/wiki/KCNH2">hERG inhibition</a>, <a href="https://en.wikipedia.org/wiki/Blood%E2%80%93brain_barrier">BBB penetration</a>, BACE inhibition, bee toxicity) and regression (EGFR inhibition, <a href="https://en.wikipedia.org/wiki/Plasmodium_falciparum">Plasmodium falciparum</a> inhibition, lipophilicity, aqueous solubility, melting point). Two datasets (Ames and lipophilicity) served as validation for architecture selection; the remaining eight were held out for final evaluation.</p>
<p>CDDD descriptors with an SVM were benchmarked against:</p>
<ul>
<li>Nine circular fingerprint variants (Morgan fingerprints, radius 1-3, folded to 512/1024/2048 bits) with RF, SVM, and GB</li>
<li>Graph convolution models (<a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">DeepChem</a>)</li>
</ul>
<p>Both random-split and cluster-split (K-means on MACCS fingerprints, K=5) cross-validation were performed.</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Split</th>
          <th>CDDD + SVM</th>
          <th>Best Fingerprint</th>
          <th>Graph Conv</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Ames (ROC-AUC)</td>
          <td>Random</td>
          <td>0.89</td>
          <td>0.89 (ecfc2, RF)</td>
          <td>0.88</td>
      </tr>
      <tr>
          <td>hERG (ROC-AUC)</td>
          <td>Random</td>
          <td>0.86</td>
          <td>0.85 (ecfc4, RF)</td>
          <td>0.86</td>
      </tr>
      <tr>
          <td>BBBP (ROC-AUC)</td>
          <td>Random</td>
          <td>0.93</td>
          <td>0.93 (ecfc2, RF)</td>
          <td>0.92</td>
      </tr>
      <tr>
          <td>BACE (ROC-AUC)</td>
          <td>Random</td>
          <td>0.90</td>
          <td>0.91 (ecfc2, RF)</td>
          <td>0.91</td>
      </tr>
      <tr>
          <td>Bee toxicity (ROC-AUC)</td>
          <td>Random</td>
          <td>0.92</td>
          <td>0.91 (ecfc6, RF)</td>
          <td>0.89</td>
      </tr>
      <tr>
          <td>Lipophilicity ($r^2$)</td>
          <td>Random</td>
          <td>0.72</td>
          <td>0.69 (ecfc2, SVM)</td>
          <td>0.73</td>
      </tr>
      <tr>
          <td>ESOL ($r^2$)</td>
          <td>Random</td>
          <td>0.92</td>
          <td>0.58 (ecfc6, SVM)</td>
          <td>0.86</td>
      </tr>
      <tr>
          <td>Melting point ($r^2$)</td>
          <td>Random</td>
          <td>0.42</td>
          <td>0.38 (ecfc2, SVM)</td>
          <td>0.39</td>
      </tr>
  </tbody>
</table>
<p>CDDD descriptors showed competitive or better performance across all tasks. Notably, CDDD achieved substantially higher $r^2$ on aqueous solubility (0.92 vs. 0.58 for the best fingerprint). The authors emphasize that CDDD&rsquo;s feature extraction was fixed based on two validation tasks, while baseline methods selected the best fingerprint/model combination per task, making the comparison conservative for CDDD.</p>
<h3 id="virtual-screening">Virtual Screening</h3>
<p>Ligand-based virtual screening experiments followed the Riniker et al. benchmarking protocol on 40 DUD targets and 17 MUV targets. Five active compounds were randomly selected per target, and remaining compounds were ranked by similarity (cosine similarity for CDDD, <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto</a> for fingerprints). This process was repeated 50 times per target.</p>
<table>
  <thead>
      <tr>
          <th>Database</th>
          <th>CDDD (ROC-AUC)</th>
          <th>Second Best</th>
          <th>p-value (Wilcoxon)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>DUD</td>
          <td>0.949</td>
          <td>0.899 (laval)</td>
          <td>$5 \times 10^{-38}$</td>
      </tr>
      <tr>
          <td>MUV</td>
          <td>0.679</td>
          <td>0.677 (ap)</td>
          <td>0.04</td>
      </tr>
  </tbody>
</table>
<p>CDDD significantly outperformed all 14 baseline fingerprints on both databases. The DUD improvement was particularly large (+5.0 ROC-AUC points over the next best). On MUV, which is designed to be harder, the advantage was smaller but still statistically significant. Importantly, while the best baseline fingerprint varied between DUD and MUV (laval vs. ap), CDDD ranked first on both, demonstrating consistent performance.</p>
<h3 id="latent-space-exploration">Latent Space Exploration</h3>
<p>The continuous, reversible nature of CDDD enables chemical space navigation. Shifting a molecule&rsquo;s embedding along the first principal component of the pretraining data correlates with molecular size (Spearman $r = 0.947$, $p = 0.00048$), while the second principal component correlates with polarity/logP ($r = -0.916$, $p = 0.00015$).</p>
<p>When shifting 1000 compounds along 100 random directions, the model maintained high valid SMILES generation rates (&gt;97% for the top beam search output, &gt;99% when considering the top 3 outputs). Euclidean distance in the descriptor space correlated smoothly with Tanimoto distance in fingerprint space, confirming that the latent space supports meaningful interpolation.</p>
<h2 id="consistent-learned-descriptors-for-chemistry">Consistent Learned Descriptors for Chemistry</h2>
<p>CDDD demonstrated that translation between molecular representations produces more informative latent spaces than autoencoder reconstruction. The key findings are:</p>
<ol>
<li><strong>Translation outperforms reconstruction</strong>: Models trained on translating between different representations consistently produced better downstream descriptors than autoencoding models, despite autoencoding being an easier task.</li>
<li><strong>Auxiliary property prediction helps</strong>: The additional classification task for molecular properties improved descriptor quality, particularly for physicochemical endpoints correlated with the predicted properties.</li>
<li><strong>Consistent performance</strong>: Unlike baseline methods where the best fingerprint varies by task, CDDD showed consistent performance across all QSAR and VS experiments.</li>
<li><strong>Smooth latent space</strong>: The continuous descriptor space supports meaningful interpolation and chemical space exploration with high valid SMILES rates.</li>
</ol>
<p>The authors acknowledge several limitations. The InChI-to-SMILES translation worked but produced inferior descriptors compared to SMILES-to-SMILES, and SMILES-to-InChI translation failed entirely, likely due to InChI&rsquo;s complex syntax (counting, arithmetic). The approach was only tested with string-based representations; translation between conceptually different representations (e.g., 3D structures) remains future work. The QSAR evaluation, while extensive, used relatively standard datasets, and the method&rsquo;s advantage over graph convolution models was modest on tasks where end-to-end learning had sufficient data.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>ZINC15 + PubChem (merged)</td>
          <td>~72M compounds</td>
          <td>Filtered: organic, MW 12-600, &gt;3 heavy atoms, logP -7 to 5</td>
      </tr>
      <tr>
          <td>Validation</td>
          <td>Ames mutagenicity</td>
          <td>6,130</td>
          <td>Classification</td>
      </tr>
      <tr>
          <td>Validation</td>
          <td>Lipophilicity</td>
          <td>3,817</td>
          <td>Regression</td>
      </tr>
      <tr>
          <td>Test</td>
          <td>hERG, BBBP, BACE, bee toxicity</td>
          <td>188-3,440</td>
          <td>Classification</td>
      </tr>
      <tr>
          <td>Test</td>
          <td>EGFR, Plasmodium, ESOL, melting point</td>
          <td>184-4,451</td>
          <td>Regression</td>
      </tr>
      <tr>
          <td>VS</td>
          <td>DUD</td>
          <td>40 targets</td>
          <td>Ligand-based virtual screening</td>
      </tr>
      <tr>
          <td>VS</td>
          <td>MUV</td>
          <td>17 targets</td>
          <td>Maximum unbiased validation</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Encoder: 3 stacked GRU layers (512, 1024, 2048 units) with tanh bottleneck to 512-dim latent space</li>
<li>Decoder: Matching 3 stacked GRU layers, initialized from latent space</li>
<li>Auxiliary classifier: 3 FC layers (512, 128, 9) predicting molecular properties</li>
<li>Optimizer: Adam, initial LR $5 \times 10^{-4}$, decayed by 0.9 every 50,000 steps</li>
<li>Batch size: 64 with bucketing by sequence length</li>
<li>Input regularization: 15% character dropout + Gaussian noise (std 0.05)</li>
<li>Beam search for decoding at inference</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/jrwnter/cddd">CDDD (GitHub)</a></td>
          <td>Code + Model</td>
          <td>MIT</td>
          <td>Pretrained model and extraction code</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>QSAR: 5-fold random CV and 5-fold cluster CV (K-means on MACCS, K=5)</li>
<li>Classification metric: ROC-AUC</li>
<li>Regression metric: $r^2$</li>
<li>VS: ROC-AUC averaged over 50 random active set selections per target</li>
<li>Statistical test: <a href="https://en.wikipedia.org/wiki/Wilcoxon_signed-rank_test">Wilcoxon signed-rank test</a> for VS comparisons</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Framework: TensorFlow 1.4.1</li>
<li>Fingerprint extraction on GPU is comparable in speed to RDKit on CPU</li>
<li>SVM training on 512-dim CDDD descriptors takes seconds (vs. minutes for 2048-dim fingerprints)</li>
<li>Graph convolution training: ~30 minutes per task on GPU</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Winter, R., Montanari, F., Noe, F., &amp; Clevert, D.-A. (2019). Learning continuous and data-driven molecular descriptors by translating equivalent chemical representations. <em>Chemical Science</em>, 10(6), 1692-1701. <a href="https://doi.org/10.1039/C8SC04175J">https://doi.org/10.1039/C8SC04175J</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{winter2019learning,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Learning continuous and data-driven molecular descriptors by translating equivalent chemical representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Winter, Robin and Montanari, Floriane and No{\&#39;e}, Frank and Clevert, Djork-Arn{\&#39;e}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Chemical Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{10}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{6}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1692--1701}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2019}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/C8SC04175J}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Avoiding Failure Modes in Goal-Directed Generation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/avoiding-failure-modes-goal-directed-generation/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/avoiding-failure-modes-goal-directed-generation/</guid><description>Langevin et al. show that apparent failure modes in goal-directed molecular generation stem from QSAR model disagreement, not algorithmic flaws.</description><content:encoded><![CDATA[<h2 id="reinterpreting-goal-directed-generation-failures-as-qsar-model-issues">Reinterpreting Goal-Directed Generation Failures as QSAR Model Issues</h2>
<p>This is an <strong>Empirical</strong> study that challenges a widely cited finding about failure modes in goal-directed molecular generation. <a href="/notes/chemistry/molecular-design/generation/evaluation/failure-modes-molecule-generation/">Renz et al. (2019)</a> had shown that when molecules are optimized against a machine learning scoring function, control models trained on the same data distribution assign much lower scores to the generated molecules. This was interpreted as evidence that generation algorithms exploit model-specific biases. Langevin et al. demonstrate that this divergence is already present in the original data distribution and is attributable to disagreement among the QSAR classifiers, not to flaws in the generation algorithms themselves.</p>
<h2 id="why-qsar-model-agreement-matters-for-molecular-generation">Why QSAR Model Agreement Matters for Molecular Generation</h2>
<p>Goal-directed generation uses a scoring function (typically a <a href="https://en.wikipedia.org/wiki/Quantitative_structure%E2%80%93activity_relationship">QSAR</a> model) to guide the design of molecules that maximize predicted activity. In the experimental framework from Renz et al., three Random Forest classifiers are trained: an optimization model $C_{opt}$ on Split 1, a model control $C_{mc}$ on Split 1 with a different random seed, and a data control $C_{dc}$ on Split 2. Each returns a confidence score ($S_{opt}$, $S_{mc}$, $S_{dc}$). The expectation is that molecules with high $S_{opt}$ should also score highly under $S_{mc}$ and $S_{dc}$, since all three models are trained on the same data distribution for the same target.</p>
<p>Renz et al. observed that during optimization, $S_{mc}$ and $S_{dc}$ diverge from $S_{opt}$, reaching substantially lower values. This was interpreted as goal-directed generation exploiting biases unique to the optimization model. The recommendation was to halt generation when control scores stop increasing, requiring a held-out dataset for a control model, which may not be feasible in low-data regimes.</p>
<p>The key insight of Langevin et al. is that nobody had checked whether this score disagreement existed before generation even began. If the classifiers already disagree on high-scoring molecules in the original dataset, the divergence during generation is expected behavior, not evidence of algorithmic failure.</p>
<h2 id="pre-existing-classifier-disagreement-explains-the-divergence">Pre-Existing Classifier Disagreement Explains the Divergence</h2>
<p>The core contribution is showing that the gap between optimization and control scores is a property of the QSAR models, not of the generation algorithms.</p>
<p>The authors introduce a held-out test set (10% of the data, used for neither training split) and augment it via Topliss tree enumeration to produce structural analogs for smoother statistical estimates. On this held-out set, they compute the Mean Average Difference (MAD) between $S_{opt}$ and control scores as a function of $S_{opt}$:</p>
<p>$$
\text{MAD}(x) = \frac{1}{|\{i : S_{opt}(x_i) \geq x\}|} \sum_{S_{opt}(x_i) \geq x} |S_{opt}(x_i) - S_{dc}(x_i)|
$$</p>
<p>On the three original datasets (DRD2, EGFR, JAK2), the MAD between $S_{opt}$ and $S_{dc}$ grows substantially with $S_{opt}$, reaching approximately 0.3 for the highest-scoring molecules. For EGFR, even the top molecules (with $S_{opt}$ between 0.5 and 0.6) have $S_{dc}$ below 0.2. This disagreement exists entirely within the original data distribution, before any generative algorithm is applied.</p>
<p>The authors formalize this with tolerance intervals. At each generation time step $t$, the distribution of optimization scores is $P_t[S_{opt}(x)]$. From the held-out set, the conditional distributions $P[S_{dc}(x) | S_{opt}(x)]$ and $P[S_{mc}(x) | S_{opt}(x)]$ are estimated empirically. The expected control scores at time $t$ are then:</p>
<p>$$
\mathbb{E}[S_{dc}] = \int P[S_{dc}(x) | S_{opt}(x)] \cdot P_t[S_{opt}(x)] , dS_{opt}
$$</p>
<p>By sampling from these distributions, the authors construct 95% tolerance intervals for the expected control scores at each time step. The observed trajectories of $S_{mc}$ and $S_{dc}$ during generation fall within these intervals, demonstrating that the divergence is fully explained by pre-existing classifier disagreement.</p>
<h2 id="experimental-setup-original-reproduction-and-corrected-experiments">Experimental Setup: Original Reproduction and Corrected Experiments</h2>
<h3 id="reproduction-of-renz-et-al">Reproduction of Renz et al.</h3>
<p>The original experimental framework uses three datasets from ChEMBL: <a href="https://en.wikipedia.org/wiki/Dopamine_receptor_D2">DRD2</a> (842 molecules, 59 actives), <a href="https://en.wikipedia.org/wiki/Epidermal_growth_factor_receptor">EGFR</a> (842 molecules, 40 actives), and <a href="https://en.wikipedia.org/wiki/Janus_kinase_2">JAK2</a> (667 molecules, 140 actives). These are small, noisy, and chemically diverse. Three goal-directed generation algorithms are tested:</p>
<table>
  <thead>
      <tr>
          <th>Algorithm</th>
          <th>Type</th>
          <th>Mechanism</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Graph GA</td>
          <td>Genetic algorithm on molecular graphs</td>
          <td>Mutation and crossover of molecular graphs</td>
      </tr>
      <tr>
          <td>SMILES-LSTM</td>
          <td>Recurrent neural network</td>
          <td>Hill-climbing fine-tuning on best molecules</td>
      </tr>
      <tr>
          <td>MSO</td>
          <td>Particle swarm in CDDD latent space</td>
          <td>Multiple swarm optimization</td>
      </tr>
  </tbody>
</table>
<p>All algorithms are run for 151 epochs with 10 runs each. The reproduction confirms the findings of Renz et al.: $S_{mc}$ and $S_{dc}$ diverge from $S_{opt}$ during optimization.</p>
<h3 id="tolerance-interval-analysis">Tolerance interval analysis</h3>
<p>The held-out set is augmented using Topliss tree enumeration on phenyl rings, providing structural analogs that are reasonable from a medicinal chemistry perspective. The optimization score range is divided into 25 equal bins, and for each molecule at each time step, 10 samples from the conditional control score distribution are drawn to construct empirical tolerance intervals.</p>
<h3 id="corrected-experiments-with-adequate-models">Corrected experiments with adequate models</h3>
<p>To test whether generation algorithms actually exploit biases when the classifiers agree, the authors construct two tasks where optimization and control models correlate well:</p>
<ol>
<li><strong>ALDH1 dataset</strong>: 464 molecules from LIT-PCBA, split using similarity-based pairing to maximize intra-pair chemical similarity. This ensures both splits sample similar chemistry.</li>
<li><strong>Modified JAK2</strong>: The same JAK2 dataset but with Random Forest hyperparameters adjusted (200 trees instead of 100, minimum 3 samples per leaf instead of 1) to reduce overfitting to spurious correlations.</li>
</ol>
<p>In both cases, $S_{opt}$, $S_{mc}$, and $S_{dc}$ agree well on the held-out test set. The starting population for generation is set to the held-out test set (rather than random ChEMBL molecules) to avoid building in a distribution shift.</p>
<h2 id="findings-no-algorithmic-failure-when-models-agree">Findings: No Algorithmic Failure When Models Agree</h2>
<p>On the corrected experimental setups (ALDH1 and modified JAK2), there is no major divergence between optimization and control scores during generation. The three algorithms produce molecules that score similarly under all three classifiers.</p>
<p>Key findings:</p>
<ol>
<li>
<p><strong>Pre-existing disagreement explains divergence</strong>: On all three original datasets, the divergence between $S_{opt}$ and control scores during generation falls within the tolerance intervals predicted from the initial data distribution alone. The generation algorithms are not exploiting model-specific biases beyond what already exists in the data.</p>
</li>
<li>
<p><strong>Split similarity bias is also pre-existing</strong>: Renz et al. observed that generated molecules are more similar to Split 1 (used to train $C_{opt}$) than Split 2. The authors show this bias is already present in the top-5 percentile of the held-out set: on EGFR and DRD2, high-scoring molecules are inherently more similar to Split 1.</p>
</li>
<li>
<p><strong>Appropriate model design resolves the issue</strong>: When Random Forest hyperparameters are chosen to avoid overfitting (more trees, higher minimum samples per leaf), or when data splits are constructed to be chemically balanced, the classifiers agree and the generation algorithms behave as expected.</p>
</li>
<li>
<p><strong>Quality problems remain independent</strong>: Even when optimization and control scores align, the generated molecules can still be poor drug candidates (unreactive, unsynthesizable, containing unusual fragments). The score divergence issue and the chemical quality issue are separate problems.</p>
</li>
</ol>
<h3 id="limitations-acknowledged-by-the-authors">Limitations acknowledged by the authors</h3>
<ul>
<li>The study focuses on Random Forest classifiers with ECFP fingerprints. The behavior of other model types (e.g., graph neural networks) and descriptor types is not fully explored, though supplementary results show similar patterns with physico-chemical descriptors and Atom-Pair fingerprints.</li>
<li>The corrected ALDH1 task uses a relatively small dataset (464 molecules) with careful split construction. Scaling this approach to larger, more heterogeneous datasets is not demonstrated.</li>
<li>The authors note that their results do not prove generation algorithms never exploit biases; they show that the specific evidence from Renz et al. can be explained without invoking algorithmic failure.</li>
<li>The problem of low-quality generated molecules (poor synthesizability, unusual fragments) remains unresolved and is acknowledged as an open question.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Original tasks</td>
          <td>DRD2, EGFR, JAK2</td>
          <td>842, 842, 667 molecules</td>
          <td>Extracted from ChEMBL; small with few actives</td>
      </tr>
      <tr>
          <td>New task</td>
          <td>ALDH1</td>
          <td>464 molecules (173 with purine substructure)</td>
          <td>Extracted from LIT-PCBA; similarity-based split</td>
      </tr>
      <tr>
          <td>Augmentation</td>
          <td>Topliss tree analogs</td>
          <td>~10x augmentation of held-out set</td>
          <td>Structural analogs via phenyl ring enumeration</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Three goal-directed generation algorithms from the original Renz et al. study:</p>
<ul>
<li><strong>Graph GA</strong>: Genetic algorithm on molecular graphs (Jensen, 2019)</li>
<li><strong>SMILES-LSTM</strong>: Hill-climbing on LSTM-generated SMILES (Segler et al., 2018)</li>
<li><strong>MSO</strong>: Multi-Swarm Optimization in CDDD latent space (Winter et al., 2019)</li>
</ul>
<p>All run for 151 epochs, 10 runs each.</p>
<h3 id="models">Models</h3>
<p>Random Forest classifiers (scikit-learn) with:</p>
<ul>
<li>ECFP fingerprints (radius 2, 1024 bits, RDKit)</li>
<li>Default parameters for original tasks</li>
<li>Modified parameters for JAK2 correction: 200 trees, min 3 samples per leaf</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Purpose</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Mean Average Difference (MAD)</td>
          <td>Measures disagreement between optimization and control scores</td>
          <td>Computed as function of $S_{opt}$ on held-out set</td>
      </tr>
      <tr>
          <td>95% tolerance intervals</td>
          <td>Expected range of control scores given optimization scores</td>
          <td>Empirical, constructed from held-out set</td>
      </tr>
      <tr>
          <td>Tanimoto similarity</td>
          <td>Split bias assessment</td>
          <td>Morgan fingerprints, radius 2, 1024 bits</td>
      </tr>
      <tr>
          <td>ROC-AUC</td>
          <td>Classifier predictive performance</td>
          <td>Used to verify models have comparable accuracy</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Sanofi-Public/IDD-papers-avoiding_failure_modes">Code and datasets</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Fork of Renz et al. codebase with modifications</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Langevin, M., Vuilleumier, R., &amp; Bianciotto, M. (2022). Explaining and avoiding failure modes in goal-directed generation of small molecules. <em>Journal of Cheminformatics</em>, 14, 20. <a href="https://doi.org/10.1186/s13321-022-00601-y">https://doi.org/10.1186/s13321-022-00601-y</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{langevin2022explaining,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Explaining and avoiding failure modes in goal-directed generation of small molecules}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Langevin, Maxime and Vuilleumier, Rodolphe and Bianciotto, Marc}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{14}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{20}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-022-00601-y}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Atom-in-SMILES: Better Tokens for Chemical Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/atom-in-smiles-tokenization/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/atom-in-smiles-tokenization/</guid><description>Atom-in-SMILES replaces generic SMILES tokens with environment-aware atomic tokens, reducing token degeneration and improving chemical translation accuracy.</description><content:encoded><![CDATA[<h2 id="a-new-tokenization-method-for-chemical-language-models">A New Tokenization Method for Chemical Language Models</h2>
<p>This is a <strong>Method</strong> paper that introduces Atom-in-SMILES (AIS), a tokenization scheme for SMILES strings that replaces generic atomic tokens with environment-aware tokens encoding each atom&rsquo;s local chemical neighborhood. The primary contribution is demonstrating that tokenization quality has a significant impact on chemical language model outcomes across multiple tasks: SMILES canonicalization, <a href="/notes/chemistry/molecular-design/reaction-prediction/">single-step retrosynthesis</a>, and <a href="/notes/chemistry/molecular-design/property-prediction/">molecular property prediction</a>.</p>
<h2 id="why-standard-smiles-tokenization-falls-short">Why Standard SMILES Tokenization Falls Short</h2>
<p>Standard atom-wise SMILES tokenization treats all atoms of the same element identically. Every carbon is tokenized as &ldquo;C&rdquo; regardless of whether it is part of an aromatic ring, a carbonyl group, or a methyl chain. This creates a highly degenerate token space where chemically distinct atoms share the same representation.</p>
<p>The authors draw an analogy between natural language and chemical language. A typical SMILES sequence is about three times longer than a natural language sentence, yet the token vocabulary is roughly 1000 times smaller. This mismatch leads to extreme token repetition: the same tokens (C, c, N, O) appear many times within a single sequence. In natural language processing, token degeneration (where models repeatedly predict the same token) is a known failure mode of autoregressive decoders. The repetitive nature of SMILES tokens exacerbates this problem in chemical language models.</p>
<p>SMILES also lacks a one-to-one correspondence between tokens and chemical meaning. Two molecules that differ in only one atom substitution (e.g., swapping a carbon for a nitrogen in a ring) produce identical token sets under atom-wise tokenization, making it harder for models to distinguish structurally similar molecules.</p>
<h2 id="core-innovation-encoding-atom-environments-into-tokens">Core Innovation: Encoding Atom Environments into Tokens</h2>
<p>The key insight is to replace each atomic token with a richer token that encodes the atom&rsquo;s local chemical environment, inspired by the <a href="https://en.wikipedia.org/wiki/Atoms_in_molecules">atoms-in-molecules (AIM)</a> concept from quantum chemistry. For a given SMILES string, the AIS mapping function $f$ operates on the token space:</p>
<p>$$
f(X) = \begin{cases} AE|_{X_{\text{central}}} &amp; \text{if } X \text{ is an atom} \\ X &amp; \text{otherwise} \end{cases}
$$</p>
<p>where $AE|_{X_{\text{central}}}$ denotes the atomic environment centered on atom $X$. Non-atomic tokens (brackets, bond symbols, ring closures) pass through unchanged.</p>
<p>Each AIS token is formatted as <code>[Sym;Ring;Neighbors]</code> where:</p>
<ul>
<li><strong>Sym</strong> is the atomic symbol with chirality, aromaticity (lowercase for aromatic), hydrogen count, and formal charge</li>
<li><strong>Ring</strong> indicates whether the atom is in a ring (<code>R</code>) or not (<code>!R</code>)</li>
<li><strong>Neighbors</strong> lists the neighboring atoms interacting with the central atom</li>
</ul>
<p>This mapping is bijective: SMILES strings can be fully recovered from AIS strings via an inverse projection. The algorithm iterates over atoms in a molecule, computes their local environments using RDKit, and produces environment-aware token variants.</p>
<p>As a concrete example, in glycine the two carbons and two oxygens are indistinguishable under atom-wise tokenization. Under AIS, each receives a unique token reflecting its bonding environment (e.g., the carboxyl carbon is distinguished from the alpha carbon).</p>
<p>The AIS tokenization also exhibits a fingerprint-like property. Because each token encodes local structural information, the set of AIS tokens for a molecule functions similarly to circular fingerprints like ECFP2. The authors show that pairwise <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarities</a> computed from AIS token sets have resolution comparable to ECFP2 and HashAP fingerprints, and better resolution than MACCS, Avalon, and RDKit fingerprints.</p>
<p>Token repetition can be quantified as:</p>
<p>$$
\text{rep-}l = \sum_{t=1}^{|s|} \mathbb{1}[s_t \in s_{t-w-1:t-1}]
$$</p>
<p>where $s$ is the predicted sequence, $|s|$ is the token count, and $w$ is the window size. AIS tokens exhibit consistently lower normalized repetition rates compared to SMILES, <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>, and <a href="/notes/chemistry/molecular-representations/notations/deepsmiles-adaptation-for-ml/">DeepSMILES</a> across diverse molecular datasets (drugs, natural products, steroids, lipids, metal complexes, octane isomers).</p>
<h2 id="experimental-evaluation-across-three-chemical-tasks">Experimental Evaluation Across Three Chemical Tasks</h2>
<h3 id="input-output-equivalent-mapping-smiles-canonicalization">Input-Output Equivalent Mapping (SMILES Canonicalization)</h3>
<p>The first task tests whether a model can translate non-canonical SMILES enumerations into canonical form. The authors constructed deliberately challenging datasets from <a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a> subsets with cumulative structural constraints (no cyclic heteroatom-heteroatom bonds, stable functional groups only, fragment-like, scaffold-like, etc.), generating training sets of 1M molecules augmented with 150K molecules from the most restrictive subset at 10x, 30x, and 50x augmentation levels.</p>
<table>
  <thead>
      <tr>
          <th>GDB-13 Subset</th>
          <th>Atom-wise (x10)</th>
          <th>Atom-wise (x50)</th>
          <th>AIS (x10)</th>
          <th>AIS (x50)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ab</td>
          <td>34.2%</td>
          <td>33.2%</td>
          <td>37.3%</td>
          <td>34.1%</td>
      </tr>
      <tr>
          <td>abc</td>
          <td>31.0%</td>
          <td>29.6%</td>
          <td>33.7%</td>
          <td>30.4%</td>
      </tr>
      <tr>
          <td>abcde</td>
          <td>48.7%</td>
          <td>45.5%</td>
          <td>53.6%</td>
          <td>47.0%</td>
      </tr>
      <tr>
          <td>abcdef</td>
          <td>41.8%</td>
          <td>39.1%</td>
          <td>52.5%</td>
          <td>46.9%</td>
      </tr>
      <tr>
          <td>abcdefg</td>
          <td>50.9%</td>
          <td>50.0%</td>
          <td>59.9%</td>
          <td>56.8%</td>
      </tr>
  </tbody>
</table>
<p>AIS outperformed atom-wise tokenization on all subsets and augmentation levels. The performance gap grew larger for more restrictive (more similar) subsets, reaching up to 10.7% on the abcdef subset. This demonstrates that AIS is particularly effective when molecules are structurally similar and harder to distinguish.</p>
<h3 id="single-step-retrosynthesis">Single-Step Retrosynthesis</h3>
<p>The second task uses the USPTO-50K benchmark for single-step <a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">retrosynthetic prediction</a> via a template-free transformer encoder-decoder model. The model was trained for 200,000 steps with Adam optimizer, negative log-likelihood loss, and cyclic learning rate scheduling.</p>
<table>
  <thead>
      <tr>
          <th>Tokenization</th>
          <th>rep-|P - rep-|GT &gt;= 2</th>
          <th>String Exact (%)</th>
          <th>Tc Exact (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Atom-wise baseline</td>
          <td>&ndash;</td>
          <td>42.00</td>
          <td>&ndash;</td>
      </tr>
      <tr>
          <td>Atom-wise (reproduced)</td>
          <td>801</td>
          <td>42.05</td>
          <td>44.72</td>
      </tr>
      <tr>
          <td>SmilesPE</td>
          <td>821</td>
          <td>19.82</td>
          <td>22.74</td>
      </tr>
      <tr>
          <td>SELFIES</td>
          <td>886</td>
          <td>28.82</td>
          <td>30.76</td>
      </tr>
      <tr>
          <td>DeepSMILES</td>
          <td>902</td>
          <td>38.63</td>
          <td>41.20</td>
      </tr>
      <tr>
          <td><strong>Atom-in-SMILES</strong></td>
          <td><strong>727</strong></td>
          <td><strong>46.32</strong></td>
          <td><strong>47.62</strong></td>
      </tr>
  </tbody>
</table>
<p>AIS achieved 46.32% string exact accuracy (4.3% above the atom-wise baseline) and 47.62% Tanimoto exact accuracy (2.9% above baseline). AIS also had the fewest degenerate token repetitions (727 vs. 801 for atom-wise), representing approximately a 10% reduction. DeepSMILES had the highest repetition count (902) despite reasonable overall accuracy. SELFIES and <a href="/notes/chemistry/molecular-representations/notations/smiles-pair-encoding/">SmilesPE</a> both performed substantially worse than the atom-wise baseline on this task.</p>
<p>The authors identified six common token repetition patterns in retrosynthetic predictions: long head repetitions, long tail repetitions, repetitive rings, repetitive chains, and halogen repetitions on both aliphatic and aromatic carbons.</p>
<h3 id="molecular-property-prediction">Molecular Property Prediction</h3>
<p>The third task evaluates tokenization schemes on <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> benchmarks using Random Forest models with 5-fold cross-validation. AIS tokens were converted to fingerprint-like feature vectors.</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>SMILES</th>
          <th>DeepSMILES</th>
          <th>SELFIES</th>
          <th>SmilesPE</th>
          <th>AIS</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Regression (RMSE, lower is better)</strong></td>
          <td></td>
          <td></td>
          <td></td>
          <td></td>
          <td></td>
      </tr>
      <tr>
          <td>ESOL</td>
          <td>0.628</td>
          <td>0.631</td>
          <td>0.675</td>
          <td>0.689</td>
          <td><strong>0.553</strong></td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>0.545</td>
          <td>0.544</td>
          <td>0.564</td>
          <td>0.761</td>
          <td><strong>0.441</strong></td>
      </tr>
      <tr>
          <td>Lipophilicity</td>
          <td>0.924</td>
          <td>0.895</td>
          <td>0.938</td>
          <td>0.800</td>
          <td><strong>0.683</strong></td>
      </tr>
      <tr>
          <td><strong>Classification (ROC-AUC, higher is better)</strong></td>
          <td></td>
          <td></td>
          <td></td>
          <td></td>
          <td></td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td>0.758</td>
          <td>0.777</td>
          <td>0.799</td>
          <td>0.847</td>
          <td><strong>0.885</strong></td>
      </tr>
      <tr>
          <td>BACE</td>
          <td>0.740</td>
          <td>0.774</td>
          <td>0.746</td>
          <td>0.837</td>
          <td><strong>0.835</strong></td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>0.649</td>
          <td>0.648</td>
          <td>0.653</td>
          <td>0.739</td>
          <td><strong>0.729</strong></td>
      </tr>
  </tbody>
</table>
<p>AIS achieved the best performance on all three regression datasets and two of three classification datasets. On ESOL, the RMSE improvement over standard SMILES was 12%. On lipophilicity, the improvement was 26%.</p>
<h2 id="key-findings-better-tokens-yield-better-chemical-models">Key Findings: Better Tokens Yield Better Chemical Models</h2>
<p>The main findings of this work are:</p>
<ol>
<li>
<p><strong>Tokenization significantly impacts chemical language model quality.</strong> The choice of tokenization scheme can change prediction accuracy by over 10 percentage points on equivalent mapping tasks.</p>
</li>
<li>
<p><strong>AIS reduces token degeneration by approximately 10%</strong> compared to atom-wise SMILES tokenization, with consistently lower normalized repetition rates across diverse molecular datasets.</p>
</li>
<li>
<p><strong>AIS outperforms all compared tokenization schemes</strong> (atom-wise SMILES, SmilesPE, SELFIES, DeepSMILES) on canonicalization, retrosynthesis, and property prediction.</p>
</li>
<li>
<p><strong>The fingerprint-like nature of AIS tokens</strong> enables direct use as molecular features for property prediction and provides resolution comparable to established circular fingerprints.</p>
</li>
<li>
<p><strong>The mapping is invertible</strong>, so AIS strings can always be converted back to valid SMILES. This is a practical advantage over approaches that may lose structural information.</p>
</li>
</ol>
<p><strong>Limitations</strong>: AIS cannot distinguish environmentally identical substructures or atoms related by a molecular symmetry plane, since it only considers nearest-neighbor environments. Performance on long-chain molecules (e.g., lipids) is similar across all tokenization schemes, suggesting that local environment encoding is less informative for repetitive linear structures.</p>
<p><strong>Future directions</strong>: The authors suggest AIS has potential for broader adoption in molecular generative models, chemical translation, and property prediction tasks across the cheminformatics community.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Canonicalization training</td>
          <td>GDB-13 subsets</td>
          <td>1M + 150K augmented</td>
          <td>Cumulative structural constraints a-h</td>
      </tr>
      <tr>
          <td>Canonicalization testing</td>
          <td>GDB-13 disjoint test sets</td>
          <td>20K per subset</td>
          <td>Various restriction levels</td>
      </tr>
      <tr>
          <td>Retrosynthesis</td>
          <td>USPTO-50K</td>
          <td>~50K reactions</td>
          <td>Sequences &gt; 150 tokens removed</td>
      </tr>
      <tr>
          <td>Property prediction</td>
          <td>MoleculeNet (ESOL, FreeSolv, Lipophilicity, BBBP, BACE, HIV)</td>
          <td>Varies</td>
          <td>Standard benchmark splits</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Transformer encoder-decoder architecture for canonicalization and retrosynthesis tasks</li>
<li>200,000 training steps with Adam optimizer, negative log-likelihood loss, cyclic learning rate scheduler</li>
<li>Random Forest with 5-fold cross-validation for property prediction</li>
<li>AIS tokenization implemented via RDKit for atom environment extraction</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>String exact match (%)</td>
          <td>Canonicalization, Retrosynthesis</td>
          <td>Exact SMILES match</td>
      </tr>
      <tr>
          <td>Tanimoto exactness (Tc)</td>
          <td>Retrosynthesis</td>
          <td>Morgan FP radius 3, 2048 bits</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>Regression property prediction</td>
          <td>ESOL, FreeSolv, Lipophilicity</td>
      </tr>
      <tr>
          <td>ROC-AUC</td>
          <td>Classification property prediction</td>
          <td>BBBP, BACE, HIV</td>
      </tr>
      <tr>
          <td>rep-l</td>
          <td>Token degeneration</td>
          <td>Single-token repetition count</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not explicitly specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/snu-lcbc/atom-in-SMILES">atom-in-SMILES</a></td>
          <td>Code</td>
          <td>CC-BY-NC-SA-4.0</td>
          <td>AIS tokenization implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ucak, U. V., Ashyrmamatov, I., &amp; Lee, J. (2023). Improving the quality of chemical language model outcomes with atom-in-SMILES tokenization. <em>Journal of Cheminformatics</em>, 15, 55. <a href="https://doi.org/10.1186/s13321-023-00725-9">https://doi.org/10.1186/s13321-023-00725-9</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ucak2023improving,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Improving the quality of chemical language model outcomes with atom-in-SMILES tokenization}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ucak, Umit V. and Ashyrmamatov, Islambek and Lee, Juyong}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{55}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-023-00725-9}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SPECTRA: Evaluating Generalizability of Molecular AI</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/spectra-evaluating-generalizability-molecular-ai/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/spectra-evaluating-generalizability-molecular-ai/</guid><description>SPECTRA evaluates ML model generalizability on molecular datasets by plotting performance across a spectrum of train-test overlap levels.</description><content:encoded><![CDATA[<h2 id="a-spectral-framework-for-evaluating-molecular-ml-generalizability">A Spectral Framework for Evaluating Molecular ML Generalizability</h2>
<p>This is a <strong>Method</strong> paper that introduces SPECTRA (SPECtral framework for model evaluaTion on moleculaR dAtasets), a systematic approach for evaluating how well machine learning models generalize on molecular sequencing data. The primary contribution is a framework that generates train-test splits with controlled, decreasing levels of overlap, producing a spectral performance curve (SPC) and a single summary metric, the area under the spectral performance curve (AUSPC), for comparing model generalizability across tasks and architectures.</p>
<h2 id="why-existing-molecular-benchmarks-overestimate-generalizability">Why Existing Molecular Benchmarks Overestimate Generalizability</h2>
<p>Deep learning has achieved high performance on molecular sequencing benchmarks, but a persistent gap exists between benchmark performance and real-world deployment. The authors identify the root cause: existing evaluation approaches use either metadata-based (MB) splits or similarity-based (SB) splits, both of which provide an incomplete picture of generalizability.</p>
<p>MB splits partition data by metadata properties (e.g., temporal splits, random splits) without controlling sequence similarity between train and test sets. This means high train-test similarity can inflate performance metrics. SB splits control similarity at a single threshold, but the model&rsquo;s behavior at other similarity levels remains unknown.</p>
<p>For example, the TAPE benchmark&rsquo;s remote homology family split has 97% cross-split overlap, while the superfamily split has 71%. Model accuracy drops by 50% between these two points, yet the full curve of performance degradation is never characterized. This gap between evaluated and real-world overlap levels leads to overoptimistic deployment expectations, as demonstrated by the case of <a href="https://en.wikipedia.org/wiki/Rifampicin">rifampicin</a> resistance prediction in <em>M. tuberculosis</em>, where commercial genotypic assays later proved unreliable in specific geographic regions.</p>
<h2 id="the-spectra-framework-spectral-properties-graphs-and-performance-curves">The SPECTRA Framework: Spectral Properties, Graphs, and Performance Curves</h2>
<p>SPECTRA takes three inputs: a molecular sequencing dataset, a machine learning model, and a spectral property definition. A spectral property (SP) is a molecular sequence property expected to influence model generalizability for a specific task. For sequence-to-sequence datasets, the spectral property is typically sequence identity (proportion of aligned positions &gt; 0.3). For mutational scan datasets, it is defined by sample barcodes (string representations of mutations present in each sample).</p>
<h3 id="spectral-property-graph-construction">Spectral Property Graph Construction</h3>
<p>SPECTRA constructs a spectral property graph (SPG) where nodes represent samples and edges connect samples that share the spectral property. The goal is to generate train-test splits with controlled levels of cross-split overlap by finding approximate <a href="https://en.wikipedia.org/wiki/Maximal_independent_set">maximal independent sets</a> of this graph.</p>
<p>Finding the exact maximal independent set is NP-Hard, so SPECTRA uses a greedy randomized algorithm parameterized by a spectral parameter $\mathbf{SP} \in [0, 1]$:</p>
<ol>
<li>Randomly order SPG vertices</li>
<li>Select the first vertex and delete each neighbor with probability equal to $\mathbf{SP}$</li>
<li>Continue until no vertices remain</li>
</ol>
<p>When $\mathbf{SP} = 0$, this produces a random split (maximum cross-split overlap). When $\mathbf{SP} = 1$, it approximates the maximal independent set (minimum cross-split overlap). For each spectral parameter value (incremented by 0.05 from 0 to 1), three splits with different random seeds are generated.</p>
<h3 id="the-spectral-performance-curve-and-auspc">The Spectral Performance Curve and AUSPC</h3>
<p>The model is trained and evaluated on each split. Plotting test performance against the spectral parameter produces the spectral performance curve (SPC). The area under this curve, the AUSPC, serves as a single summary metric for model generalizability that captures behavior across the full spectrum of train-test overlap.</p>
<h3 id="handling-mutational-scan-datasets">Handling Mutational Scan Datasets</h3>
<p>For mutational scan datasets where sample barcodes map to multiple samples, SPECTRA introduces two modifications: (1) weighting nodes in the SPG by the number of samples they represent, and (2) running a subset sum algorithm to ensure 80/20 train-test splits by sample count.</p>
<h2 id="evaluation-across-18-datasets-and-19-models">Evaluation Across 18 Datasets and 19 Models</h2>
<p>The authors apply SPECTRA to 18 molecular sequencing datasets spanning three benchmarks (TAPE, PEER, ProteinGym) plus PDBBind, evaluating 19 models including CNNs, LSTMs, GNNs (GearNet), LLMs (ESM2), diffusion models (DiffDock), variational autoencoders (EVE), and logistic regression.</p>
<h3 id="benchmark-datasets">Benchmark Datasets</h3>
<p>The core evaluation covers five primary tasks:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Dataset</th>
          <th>Type</th>
          <th>Metric</th>
          <th>Samples</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Rifampicin resistance (RIF)</td>
          <td>TB clinical isolates</td>
          <td>MSD</td>
          <td>AUROC</td>
          <td>17,474</td>
      </tr>
      <tr>
          <td><a href="https://en.wikipedia.org/wiki/Isoniazid">Isoniazid</a> resistance (INH)</td>
          <td>TB clinical isolates</td>
          <td>MSD</td>
          <td>AUROC</td>
          <td>26,574</td>
      </tr>
      <tr>
          <td><a href="https://en.wikipedia.org/wiki/Pyrazinamide">Pyrazinamide</a> resistance (PZA)</td>
          <td>TB clinical isolates</td>
          <td>MSD</td>
          <td>AUROC</td>
          <td>12,146</td>
      </tr>
      <tr>
          <td>Fluorescence prediction</td>
          <td><a href="https://en.wikipedia.org/wiki/Green_fluorescent_protein">GFP</a> variants</td>
          <td>MSD</td>
          <td>Spearman&rsquo;s $\rho$</td>
          <td>54,024</td>
      </tr>
      <tr>
          <td>Vaccine escape</td>
          <td>SARS-CoV-2 RBD</td>
          <td>MSD</td>
          <td>Spearman&rsquo;s $\rho$</td>
          <td>438,046</td>
      </tr>
  </tbody>
</table>
<p>Additional benchmarks include remote homology detection, secondary structure prediction, subcellular localization, and protein-ligand binding (PDBBind, Astex diverse set, Posebusters).</p>
<h3 id="models-evaluated">Models Evaluated</h3>
<p>Eight models were evaluated in depth across the five primary tasks: logistic regression, CNN, ESM2 (pretrained), ESM2-Finetuned, GearNet, GearNet-Finetuned, EVE, and SeqDesign. Additional models (LSTM, ResNet, DeepSF, Transformer, HHblits, Equibind, DiffDock, TankBind, Transception, MSA Transformer, ESM1v, Progen2) were evaluated on specific benchmark tasks.</p>
<h3 id="existing-splits-as-points-on-the-spc">Existing Splits as Points on the SPC</h3>
<p>SPECTRA reveals that existing benchmark splits correspond to specific points on the spectral performance curve. For instance:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Benchmark Split</th>
          <th>Cross-Split Overlap</th>
          <th>Spectral Parameter</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Remote homology</td>
          <td>TAPE family</td>
          <td>97%</td>
          <td>0.025</td>
      </tr>
      <tr>
          <td>Remote homology</td>
          <td>TAPE superfamily</td>
          <td>71%</td>
          <td>0.475</td>
      </tr>
      <tr>
          <td>Secondary structure</td>
          <td>CASP12</td>
          <td>48%</td>
          <td>0.5</td>
      </tr>
      <tr>
          <td>Protein-ligand binding</td>
          <td>Equibind temporal</td>
          <td>76%</td>
          <td>0.55</td>
      </tr>
      <tr>
          <td>Protein-ligand binding</td>
          <td>LPPDBind similarity</td>
          <td>91%</td>
          <td>0.275</td>
      </tr>
      <tr>
          <td>Protein-ligand binding</td>
          <td>Posebusters</td>
          <td>70%</td>
          <td>0.575</td>
      </tr>
  </tbody>
</table>
<h2 id="performance-degradation-and-foundation-model-insights">Performance Degradation and Foundation Model Insights</h2>
<h3 id="universal-performance-decline">Universal Performance Decline</h3>
<p>All evaluated models demonstrate decreased performance as cross-split overlap decreases. Logistic regression drops from AUROC &gt; 0.9 to 0.5 for rifampicin resistance. ESM2-Finetuned decreases from Spearman&rsquo;s $\rho &gt; 0.9$ to less than 0.4 for GFP fluorescence prediction.</p>
<p>No single model achieves the highest AUSPC across all tasks. CNN maintains AUSPC &gt; 0.6 across all tasks but is surpassed by ESM2-Finetuned and ESM2 on rifampicin resistance. Some models retain reasonable performance even at $\mathbf{SP} = 1$ (minimal overlap): ESM2, ESM2-Finetuned, and CNN maintain AUROC &gt; 0.7 for RIF and PZA at this extreme.</p>
<h3 id="uncovering-hidden-spectral-properties">Uncovering Hidden Spectral Properties</h3>
<p>SPECTRA can detect unconsidered spectral properties through high variance in model performance at fixed spectral parameters. For rifampicin resistance, the CNN shows high variance at $\mathbf{SP} = 0.9$, $0.95$, and $1.0$ (standard deviations of 0.09, 0.10, and 0.08 respectively).</p>
<p>The authors trace this to the rifampicin resistance determining region (RRDR), a 26-amino-acid region of the rpoB gene. They define diff-RRDR as:</p>
<p>$$
\text{diff-RRDR} = \left(\max\left(\text{position}_{\text{train}}\right) - \max\left(\text{position}_{\text{test}}\right)\right) + \left(\min\left(\text{position}_{\text{train}}\right) - \min\left(\text{position}_{\text{test}}\right)\right)
$$</p>
<p>diff-RRDR correlates with CNN performance variance (Spearman&rsquo;s $\rho = -0.51$, p-value $= 1.79 \times 10^{-5}$) but not with ESM2 performance. The authors attribute this to ESM2&rsquo;s larger context window (512 positions vs. CNN&rsquo;s 12), making it more invariant to positional shifts in resistance-determining mutations.</p>
<h3 id="foundation-model-generalizability">Foundation Model Generalizability</h3>
<p>For protein foundation models, SPECTRA reveals that AUSPC correlates with the similarity between task-specific datasets and the pretraining dataset. ESM2&rsquo;s AUSPC varies from 0.91 (RIF) to 0.26 (SARS-CoV-2). The correlation between UniRef50 overlap and AUSPC is strong (Spearman&rsquo;s $\rho = 0.9$, p-value $= 1.4 \times 10^{-27}$).</p>
<p>This finding holds across multiple foundation models (Transception, MSA Transformer, ESM1v, Progen2) evaluated on five ProteinGym datasets (Spearman&rsquo;s $\rho = 0.9$, p-value $= 0.04$). Fine-tuning improves AUSPC for tasks with low pretraining overlap (PZA, SARS-CoV-2, GFP).</p>
<h3 id="computational-cost">Computational Cost</h3>
<p>Generating SPECTRA splits ranges from 5 minutes (amyloid beta aggregation) to 9 hours (PDBBind). Generating spectral performance curves ranges from 1 hour (logistic regression) to 5 days (ESM2-Finetuned). The authors recommend releasing SPECTRA splits alongside new benchmarks to amortize this cost.</p>
<h2 id="limitations-and-future-directions">Limitations and Future Directions</h2>
<p>The authors acknowledge several limitations:</p>
<ul>
<li><strong>Spectral property selection is pivotal</strong>: The choice of spectral property must be biologically informed and task-specific. Standardized definitions across the community are needed.</li>
<li><strong>Computational cost</strong>: Running SPECTRA is expensive, especially for large models. The authors mitigate this with multi-core CPU parallelization and multi-GPU training.</li>
<li><strong>Not a model ranking tool</strong>: SPECTRA is designed for understanding generalizability patterns, not for ranking models. Proper ranking requires averaging AUSPCs across many tasks in a standardized benchmark.</li>
<li><strong>Spectral parameter vs. cross-split overlap</strong>: The minimal achievable cross-split overlap varies across tasks, so SPECTRA plots performance against the spectral parameter rather than overlap directly. This means the AUSPC reflects relative impact on performance per unit decrease in overlap.</li>
</ul>
<p>The authors envision SPECTRA as a foundation for next-generation molecular benchmarks that explicitly characterize generalizability across the full spectrum of distribution shift, applicable beyond molecular data to small molecule therapeutics, inverse protein folding, and patient-level clinical datasets.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>All data used in this study is publicly available.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Evaluation</td>
          <td>TB RIF resistance</td>
          <td>17,474 isolates</td>
          <td>From Green et al. (2022)</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>TB INH resistance</td>
          <td>26,574 isolates</td>
          <td>From Green et al. (2022)</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>TB PZA resistance</td>
          <td>12,146 isolates</td>
          <td>From Green et al. (2022)</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>GFP fluorescence</td>
          <td>54,024 samples</td>
          <td>From Sarkisyan et al. (2016)</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>SARS-CoV-2 escape</td>
          <td>438,046 samples</td>
          <td>From Greaney et al. (2021)</td>
      </tr>
      <tr>
          <td>Benchmark</td>
          <td>TAPE (remote homology, secondary structure)</td>
          <td>Various</td>
          <td>From Rao et al. (2019)</td>
      </tr>
      <tr>
          <td>Benchmark</td>
          <td>PEER (subcellular localization)</td>
          <td>13,949 samples</td>
          <td>From Xu et al. (2022)</td>
      </tr>
      <tr>
          <td>Benchmark</td>
          <td>ProteinGym (amyloid, RRM)</td>
          <td>Various</td>
          <td>From Notin et al. (2022)</td>
      </tr>
      <tr>
          <td>Benchmark</td>
          <td>PDBBind (protein-ligand binding)</td>
          <td>14,993-16,742 complexes</td>
          <td>From Wang et al. (2005)</td>
      </tr>
  </tbody>
</table>
<p>Data is also available on <a href="https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/W5UUNN">Harvard Dataverse</a>.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Spectral property comparison uses Biopython pairwise alignment (match=1, mismatch=-2, gap=-2.5) with a 0.3 similarity threshold for sequence-to-sequence datasets</li>
<li>Greedy randomized maximal independent set approximation for split generation</li>
<li>Spectral parameter incremented in 0.05 steps from 0 to 1</li>
<li>Three random seeds per spectral parameter value</li>
<li>80/20 train-test split ratio enforced via subset sum for mutational scan datasets</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>ESM2: 650M parameter version from Lin et al. (2023)</li>
<li>ESM2-Finetuned: First 30 layers frozen, masked language head replaced with linear prediction layer</li>
<li>GearNet and GearNet-Finetuned: Protein structures generated via ESMFold</li>
<li>CNN: Architecture from Green et al. (2022), one-hot encoded sequences</li>
<li>Logistic regression: One-hot encoded mutational barcodes</li>
<li>EVE and SeqDesign: MSAs constructed via Jackhmmer against UniRep100</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>AUROC</td>
          <td>TB resistance (RIF, INH, PZA)</td>
          <td>Binary classification</td>
      </tr>
      <tr>
          <td>Spearman&rsquo;s $\rho$</td>
          <td>GFP fluorescence, SARS-CoV-2 escape</td>
          <td>Regression tasks</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td>Remote homology, secondary structure, subcellular localization</td>
          <td>Per-label/class accuracy</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>Protein-ligand binding</td>
          <td>Predicted vs. actual complex</td>
      </tr>
      <tr>
          <td>AUSPC</td>
          <td>All tasks</td>
          <td>Area under spectral performance curve</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Most models: 1x Tesla A10 GPU</li>
<li>ESM2-Finetuned: 4x Tesla A100 GPUs on Azure cluster</li>
<li>Hyperparameter optimization: Weights &amp; Biases random search over learning rate</li>
<li>All code in PyTorch</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/mims-harvard/SPECTRA">SPECTRA Code</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Framework implementation and reproduction scripts</td>
      </tr>
      <tr>
          <td><a href="https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/W5UUNN">Harvard Dataverse</a></td>
          <td>Dataset</td>
          <td>CC0 1.0</td>
          <td>All datasets and generated splits</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ektefaie, Y., Shen, A., Bykova, D., Marin, M. G., Zitnik, M., &amp; Farhat, M. (2024). Evaluating generalizability of artificial intelligence models for molecular datasets. <em>Nature Machine Intelligence</em>, 6(12), 1512-1524. <a href="https://doi.org/10.1038/s42256-024-00931-6">https://doi.org/10.1038/s42256-024-00931-6</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ektefaie2024evaluating,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Evaluating generalizability of artificial intelligence models for molecular datasets}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ektefaie, Yasha and Shen, Andrew and Bykova, Daria and Marin, Maximillian G. and Zitnik, Marinka and Farhat, Maha}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Machine Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{6}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1512--1524}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s42256-024-00931-6}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolGenBench: Benchmarking Molecular Generative Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/molgenbench-molecular-generative-models/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/molgenbench-molecular-generative-models/</guid><description>MolGenBench benchmarks 17 molecular generative models across 120 protein targets using novel metrics for target awareness, hit rates, and lead optimization.</description><content:encoded><![CDATA[<h2 id="a-comprehensive-benchmark-for-structure-based-molecular-generation">A Comprehensive Benchmark for Structure-Based Molecular Generation</h2>
<p>MolGenBench is a <strong>Resource</strong> paper that provides a large-scale, application-oriented benchmark for evaluating molecular generative models in the context of structure-based drug design (SBDD). The primary contribution is a dataset of 220,005 experimentally validated active molecules across 120 protein targets, organized into 5,433 chemical series, along with a suite of novel evaluation metrics. The benchmark addresses both <a href="https://en.wikipedia.org/wiki/De_novo_drug_design">de novo molecular design</a> and hit-to-lead (H2L) optimization, a critical drug discovery stage that existing benchmarks largely ignore.</p>
<h2 id="gaps-in-existing-molecular-generation-benchmarks">Gaps in Existing Molecular Generation Benchmarks</h2>
<p>Despite rapid progress in deep generative models for drug discovery, the evaluation landscape has not kept pace. The authors identify four categories of limitations in existing benchmarks:</p>
<ol>
<li>
<p><strong>Dataset construction</strong>: Existing benchmarks use overly stringent activity cutoffs and too few protein targets. The widely used CrossDocked2020 dataset contains very few reference ligands per target, making it difficult to evaluate whether a model can rediscover the full distribution of active compounds.</p>
</li>
<li>
<p><strong>Model selection</strong>: Prior benchmark studies evaluate a narrow range of architectures and do not systematically examine the effects of training data composition, prior knowledge integration, or architectural paradigm.</p>
</li>
<li>
<p><strong>Evaluation scenarios</strong>: Existing benchmarks focus exclusively on de novo generation. Hit-to-lead optimization, where a hit compound is refined through R-group modifications, remains unstandardized.</p>
</li>
<li>
<p><strong>Evaluation metrics</strong>: Standard metrics (QED, Vina score, SA score) correlate strongly with atom count and fail to assess target-specific generation capacity. The AddCarbon model illustrates this: simply adding random carbon atoms to training molecules achieves near-perfect scores on standard metrics while producing nonsensical chemistry.</p>
</li>
</ol>
<h2 id="novel-metrics-for-evaluating-molecular-generation">Novel Metrics for Evaluating Molecular Generation</h2>
<p>MolGenBench introduces three key metrics designed to capture aspects of model performance that existing metrics miss.</p>
<h3 id="target-aware-score-tascore">Target-Aware Score (TAScore)</h3>
<p>The TAScore measures whether a model generates target-specific molecules rather than generic structures. It compares the ratio of active molecule or scaffold recovery on a specific target to the background recovery across all targets:</p>
<p>$$
\text{TAScore}_{\text{label}, i} = \frac{S_{i} / S_{\text{all}}}{R_{i} / R_{\text{all}}}; \quad \text{label} \in \{\text{SMILES}, \text{scaffold}\}
$$</p>
<p>For target $i$: $R_{\text{all}}$ is the total number of distinct molecules generated across all 120 targets; $R_{i}$ is the subset matching known actives for target $i$ (without conditioning on target $i$); $S_{\text{all}}$ is the total generated when conditioned on target $i$; and $S_{i}$ is the subset matching known actives for target $i$. A TAScore above 1 indicates the model uses target-specific information effectively.</p>
<h3 id="hit-rate">Hit Rate</h3>
<p>The hit rate quantifies the efficiency of active compound discovery:</p>
<p>$$
\text{HitRate}_{\text{label}} = \frac{\mathcal{M}_{\text{active}}}{\mathcal{M}_{\text{sampled}}}; \quad \text{label} \in \{\text{SMILES}, \text{scaffold}\}
$$</p>
<p>where $\mathcal{M}_{\text{active}}$ is the number of unique active molecules or scaffolds found, and $\mathcal{M}_{\text{sampled}}$ is the total number of generated molecules.</p>
<h3 id="mean-normalized-affinity-mna-score">Mean Normalized Affinity (MNA) Score</h3>
<p>For H2L optimization, the MNA Score measures whether models generate compounds with improved potency relative to the known activity range within each chemical series:</p>
<p>$$
\text{NA}_{g} = \frac{\text{Affinity}_{g}^{\text{series}} - \text{Affinity}_{\min}^{\text{series}}}{\text{Affinity}_{\max}^{\text{series}} - \text{Affinity}_{\min}^{\text{series}}}
$$</p>
<p>$$
\text{MNAScore} = \frac{1}{G} \sum_{g}^{G} \text{NA}_{g}
$$</p>
<p>This normalizes affinities to [0, 1] within each series, enabling cross-series comparison.</p>
<h2 id="systematic-evaluation-of-17-generative-models-across-two-drug-discovery-scenarios">Systematic Evaluation of 17 Generative Models Across Two Drug Discovery Scenarios</h2>
<h3 id="dataset-construction">Dataset Construction</h3>
<p>The MolGenBench dataset was built from <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL v33</a>. Ligands failing RDKit validation were discarded, along with entries where binding affinity exceeded 10 uM. The 120 protein targets were selected based on minimum thresholds: at least 50 active molecules, at least 50 unique Bemis-Murcko scaffolds, and at least 20 distinct chemical series per target. For H2L optimization, maximum common substructures (MCS) were identified per series, with dual thresholds requiring the MCS to appear in over 80% of molecules and cover more than one-third of each molecule&rsquo;s atoms. The top 5 series per target (ranked by dockable ligands) formed the H2L test set: 600 compound series across 120 targets.</p>
<h3 id="evaluated-models">Evaluated Models</h3>
<p><strong>De novo models (10)</strong>: Pocket2Mol, TargetDiff, FLAG, DecompDiff, SurfGen, PocketFlow, MolCraft, <a href="/notes/chemistry/molecular-design/generation/target-aware/tamgen-target-aware-molecule-generation/">TamGen</a>, DiffSBDD-M (trained on BindingMOAD), DiffSBDD-C (trained on CrossDock). These span autoregressive, diffusion, and Bayesian flow network architectures.</p>
<p><strong>H2L models (7)</strong>: Fragment-based (DiffSBDD-M/C inpainting, Delete, DiffDec) and ligand-based (ShEPhERD, ShapeMol, PGMG). These use pharmacophore, surface, or shape priors.</p>
<p>Models were further stratified by whether test proteins appeared in their CrossDock training set (&ldquo;Proteins in CrossDock&rdquo; vs. &ldquo;Proteins Not in CrossDock&rdquo;), enabling direct measurement of generalization.</p>
<h3 id="evaluation-dimensions">Evaluation Dimensions</h3>
<p>The benchmark evaluates six dimensions:</p>
<table>
  <thead>
      <tr>
          <th>Dimension</th>
          <th>Key Metrics</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Basic molecular properties</td>
          <td>Validity, QED, SA score, uniqueness, diversity, JSD alignment</td>
      </tr>
      <tr>
          <td>Chemical safety</td>
          <td>Industry-standard filter pass rates (Eli Lilly, Novartis, ChEMBL rules)</td>
      </tr>
      <tr>
          <td>Conformational quality</td>
          <td>PoseBusters pass rate, strain energy, steric clash frequency</td>
      </tr>
      <tr>
          <td>Active compound recovery</td>
          <td>Hit rate, hit fraction, active molecule and scaffold recovery counts</td>
      </tr>
      <tr>
          <td>Target awareness</td>
          <td>TAScore at molecule and scaffold levels</td>
      </tr>
      <tr>
          <td>Lead optimization</td>
          <td>MNA Score, number of series with hits</td>
      </tr>
  </tbody>
</table>
<h3 id="key-results-basic-properties-and-chemical-safety">Key Results: Basic Properties and Chemical Safety</h3>
<p>Most models generate drug-like molecules with reasonable QED (0.4-0.6) and SA scores (0.5-0.8). However, two models (FLAG, SurfGen) showed validity below 0.4. TamGen exhibited low uniqueness (~27%), suggesting overreliance on pretrained patterns.</p>
<p>Chemical filter pass rates revealed a more concerning picture: only TamGen and PGMG exceeded 50% of molecules passing all industry-standard filters. Most models fell below 40%, and some (FLAG, SurfGen) below 5%. Nearly 70% of reference active molecules passed the same filters, indicating models frequently generate high-risk compounds.</p>
<h3 id="key-results-conformational-quality">Key Results: Conformational Quality</h3>
<p>MolCraft achieved the highest PoseBusters validity (0.783 PB-valid score among valid molecules). PocketFlow, despite perfect SMILES validity, had fewer than half of its valid molecules pass conformational checks. Most models produced conformations with higher <a href="https://en.wikipedia.org/wiki/Strain_(chemistry)">strain energy</a> than those from <a href="https://en.wikipedia.org/wiki/AutoDock">AutoDock Vina</a>. Some models (MolCraft for de novo, DiffDec for H2L) surpassed Vina in minimizing steric clashes, suggesting advanced architectures can exceed the patterns in their training data.</p>
<h3 id="key-results-active-compound-recovery-and-hit-rates">Key Results: Active Compound Recovery and Hit Rates</h3>
<p>De novo models exhibited very low hit rates. The highest molecular hit rate among de novo models was 0.124% on proteins in CrossDock, dropping to 0.024% on unseen proteins. Scaffold-level hit rates were 10-fold higher, showing that generating pharmacologically plausible scaffolds is considerably easier than generating fully active molecules.</p>
<p>After removing molecules overlapping with the CrossDock training set, TamGen&rsquo;s recovery dropped substantially (from 30.3 to 18.7 targets), confirming significant memorization effects. On proteins not in CrossDock, half of the de novo models failed to recover any active molecules at all.</p>
<p>Fragment-based H2L models substantially outperformed both de novo models and ligand-based H2L approaches. Delete recovered active molecules in 44.3 series (out of 600), and DiffDec in 34.7 series.</p>
<h3 id="key-results-target-awareness">Key Results: Target Awareness</h3>
<p>Most de novo models failed the TAScore evaluation. PocketFlow showed the strongest target awareness at the scaffold level, with only 27% of targets showing TAScore &lt; 1 (indicating no target specificity). At the molecular level, results were even weaker: TamGen achieved TAScore &gt; 1 for only 30.6% of CrossDock-seen targets and just 4 out of 35 unseen targets. Most models generated structurally similar molecules regardless of which target they were conditioned on.</p>
<h3 id="key-results-h2l-optimization-mna-score">Key Results: H2L Optimization (MNA Score)</h3>
<p>DiffDec achieved the highest total active hits (121.7) and the best MNA Score (0.523), followed by Delete (104.7 hits, MNA Score 0.482). Ligand-based models (ShEPhERD, PGMG) recovered fewer hits but showed higher MNA Scores per hit, suggesting pharmacophore-based priors help prioritize more potent molecules when actives are found. The most successful model (Delete) achieved a hit in only 9.6% of series (57/600), indicating substantial room for improvement.</p>
<h2 id="critical-findings-and-limitations-of-current-molecular-generative-models">Critical Findings and Limitations of Current Molecular Generative Models</h2>
<p>The benchmark reveals several consistent limitations:</p>
<ol>
<li>
<p><strong>Low screening efficiency</strong>: De novo models achieve molecular hit rates below 0.13%, far from practical utility. Scaffold recovery is more feasible but still limited.</p>
</li>
<li>
<p><strong>Weak target awareness</strong>: Most SBDD models fail to use protein structural information effectively, generating similar molecules across different targets. This raises concerns about off-target effects.</p>
</li>
<li>
<p><strong>Conformational prediction remains difficult</strong>: Most models produce conformations with higher strain energy than classical docking, and only a small fraction (typically below 23%) of generated poses match redocked conformations within 2 Angstrom RMSD.</p>
</li>
<li>
<p><strong>Generalization gap</strong>: Performance consistently drops on proteins not in the training set, and prior benchmarks that do not stratify by training data exposure overestimate real-world utility.</p>
</li>
<li>
<p><strong>Inference-time scaling does not solve the problem</strong>: Sampling up to 100,000 molecules increased the absolute number of active discoveries but with diminishing efficiency. Without better scoring functions, scaling sampling offers limited practical value.</p>
</li>
<li>
<p><strong>Chemical safety</strong>: Most models produce a majority of molecules that fail industry-standard reactivity and promiscuity filters.</p>
</li>
</ol>
<p>The authors acknowledge that the benchmark&rsquo;s 220,005 active molecules represent a biased subset of bioactive chemical space. A model&rsquo;s failure to rediscover known actives for a given target may reflect sampling limitations rather than generating inactive compounds.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Active compounds</td>
          <td>ChEMBL v33</td>
          <td>220,005 molecules, 120 targets</td>
          <td>Filtered at 10 uM affinity threshold</td>
      </tr>
      <tr>
          <td>H2L series</td>
          <td>ChEMBL v33 + PDB</td>
          <td>5,433 series (600 used for H2L test)</td>
          <td>MCS-based series construction</td>
      </tr>
      <tr>
          <td>Protein structures</td>
          <td><a href="https://en.wikipedia.org/wiki/Protein_Data_Bank">PDB</a></td>
          <td>120 targets</td>
          <td>One PDB entry per target</td>
      </tr>
      <tr>
          <td>Training (most models)</td>
          <td>CrossDocked2020</td>
          <td>Varies</td>
          <td>Standard SBDD training set</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>De novo models sampled 1,000 molecules per target; H2L models sampled 200 per series</li>
<li>All experiments repeated three times with different random seeds</li>
<li>Docking performed with AutoDock Vina using standard parameters</li>
<li>Chemical filters applied via the medchem library</li>
<li>Conformational quality assessed with PoseBusters and PoseCheck</li>
<li>Interaction scores computed via ProLIF with frequency-weighted normalization</li>
</ul>
<h3 id="models">Models</h3>
<p>All 17 models were obtained from their official GitHub repositories and run with default configurations. The benchmark does not introduce new model architectures.</p>
<h3 id="evaluation">Evaluation</h3>
<p>Summary of key metrics across the best-performing models in each category:</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best De Novo</th>
          <th>Value</th>
          <th>Best H2L</th>
          <th>Value</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PB-valid score</td>
          <td>MolCraft</td>
          <td>0.783</td>
          <td>DiffSBDD-M</td>
          <td>0.597</td>
      </tr>
      <tr>
          <td>Molecular hit rate (in CrossDock)</td>
          <td>TamGen</td>
          <td>0.124%</td>
          <td>DiffDec</td>
          <td>Higher than de novo</td>
      </tr>
      <tr>
          <td>Scaffold hit rate (in CrossDock)</td>
          <td>PocketFlow</td>
          <td>&gt;10%</td>
          <td>Delete</td>
          <td>Lower than PocketFlow</td>
      </tr>
      <tr>
          <td>TAScore scaffold (% targets &gt;1)</td>
          <td>PocketFlow</td>
          <td>73%</td>
          <td>N/A</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>MNA Score</td>
          <td>N/A</td>
          <td>N/A</td>
          <td>DiffDec</td>
          <td>0.523</td>
      </tr>
      <tr>
          <td>Filter pass rate</td>
          <td>TamGen</td>
          <td>&gt;50%</td>
          <td>PGMG</td>
          <td>&gt;50%</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Specific hardware requirements are not detailed in the paper. Models were run using their default configurations from official repositories.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/CAODH/MolGenBench">MolGenBench</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Benchmark evaluation framework</td>
      </tr>
      <tr>
          <td><a href="https://zenodo.org/records/17572553">Zenodo dataset</a></td>
          <td>Dataset</td>
          <td>CC-BY-NC-ND 4.0</td>
          <td>Processed data and source data for all results</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Cao, D., Fan, Z., Yu, J., Chen, M., Jiang, X., Sheng, X., Wang, X., Zeng, C., Luo, X., Teng, D., &amp; Zheng, M. (2025). Benchmarking Real-World Applicability of Molecular Generative Models from De novo Design to Lead Optimization with MolGenBench. <em>bioRxiv</em>. <a href="https://doi.org/10.1101/2025.11.03.686215">https://doi.org/10.1101/2025.11.03.686215</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{cao2025molgenbench,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Benchmarking Real-World Applicability of Molecular Generative Models from De novo Design to Lead Optimization with MolGenBench}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Cao, Duanhua and Fan, Zhehuan and Yu, Jie and Chen, Mingan and Jiang, Xinyu and Sheng, Xia and Wang, Xingyou and Zeng, Chuanlong and Luo, Xiaomin and Teng, Dan and Zheng, Mingyue}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{bioRxiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1101/2025.11.03.686215}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MoleculeNet: Benchmarking Molecular Machine Learning</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/</guid><description>MoleculeNet curates 17 datasets across quantum mechanics, physical chemistry, biophysics, and physiology with standardized splits and metrics for molecular ML.</description><content:encoded><![CDATA[<h2 id="a-resource-paper-for-molecular-machine-learning-benchmarking">A Resource Paper for Molecular Machine Learning Benchmarking</h2>
<p>This is a <strong>Resource</strong> paper. MoleculeNet provides a standardized benchmark suite for evaluating molecular machine learning methods. Its primary contribution is the curation of 17 public datasets spanning four categories of molecular properties, together with standardized evaluation metrics, multiple dataset splitting strategies, and open-source implementations of featurization and learning algorithms via the DeepChem library.</p>
<h2 id="why-molecular-ml-needed-a-unified-benchmark">Why Molecular ML Needed a Unified Benchmark</h2>
<p>Prior to MoleculeNet, algorithmic progress in molecular machine learning was difficult to measure. Individual papers benchmarked proposed methods on different datasets with different metrics, making cross-method comparison unreliable. Several factors make molecular ML particularly challenging:</p>
<ol>
<li><strong>Data scarcity</strong>: Molecular datasets are much smaller than those available for computer vision or NLP, since obtaining accurate chemical property measurements requires specialized instruments and expert supervision.</li>
<li><strong>Heterogeneous outputs</strong>: Properties of interest range from quantum mechanical characteristics to macroscopic physiological effects on the human body.</li>
<li><strong>Variable input structures</strong>: Molecules have arbitrary size, variable connectivity, and many possible 3D conformers, all of which must be encoded into fixed-length representations for conventional ML algorithms.</li>
<li><strong>No standard evaluation protocol</strong>: Without prescribed metrics, splits, or data subsets, two methods using the same underlying database (e.g., PubChem) could be entirely incomparable.</li>
</ol>
<p>Existing databases like PubChem, ChEMBL, and the Quantum Machine collections provided raw data but did not define evaluation protocols suitable for machine learning development. MoleculeNet bridges this gap, following the precedent set by ImageNet in computer vision and WordNet in NLP.</p>
<h2 id="core-design-datasets-splits-metrics-and-featurizations">Core Design: Datasets, Splits, Metrics, and Featurizations</h2>
<p>MoleculeNet is organized around four components: curated datasets, splitting methods, evaluation metrics, and molecular featurizations.</p>
<h3 id="datasets-across-four-property-categories">Datasets Across Four Property Categories</h3>
<p>The benchmark includes 17 datasets covering over 700,000 compounds and more than 800 tasks. These are organized into four categories reflecting different levels of molecular properties:</p>
<table>
  <thead>
      <tr>
          <th>Category</th>
          <th>Dataset</th>
          <th>Tasks</th>
          <th>Compounds</th>
          <th>Task Type</th>
          <th>Rec. Split</th>
          <th>Rec. Metric</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Quantum Mechanics</td>
          <td>QM7</td>
          <td>1</td>
          <td>7,165</td>
          <td>Regression</td>
          <td>Stratified</td>
          <td>MAE</td>
      </tr>
      <tr>
          <td></td>
          <td>QM7b</td>
          <td>14</td>
          <td>7,211</td>
          <td>Regression</td>
          <td>Random</td>
          <td>MAE</td>
      </tr>
      <tr>
          <td></td>
          <td>QM8</td>
          <td>12</td>
          <td>21,786</td>
          <td>Regression</td>
          <td>Random</td>
          <td>MAE</td>
      </tr>
      <tr>
          <td></td>
          <td>QM9</td>
          <td>12</td>
          <td>133,885</td>
          <td>Regression</td>
          <td>Random</td>
          <td>MAE</td>
      </tr>
      <tr>
          <td>Physical Chemistry</td>
          <td>ESOL</td>
          <td>1</td>
          <td>1,128</td>
          <td>Regression</td>
          <td>Random</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td></td>
          <td>FreeSolv</td>
          <td>1</td>
          <td>643</td>
          <td>Regression</td>
          <td>Random</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td></td>
          <td>Lipophilicity</td>
          <td>1</td>
          <td>4,200</td>
          <td>Regression</td>
          <td>Random</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td>Biophysics</td>
          <td>PCBA</td>
          <td>128</td>
          <td>439,863</td>
          <td>Classification</td>
          <td>Random</td>
          <td>PRC-AUC</td>
      </tr>
      <tr>
          <td></td>
          <td>MUV</td>
          <td>17</td>
          <td>93,127</td>
          <td>Classification</td>
          <td>Random</td>
          <td>PRC-AUC</td>
      </tr>
      <tr>
          <td></td>
          <td>HIV</td>
          <td>1</td>
          <td>41,913</td>
          <td>Classification</td>
          <td>Scaffold</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td></td>
          <td>PDBbind</td>
          <td>1</td>
          <td>11,908</td>
          <td>Regression</td>
          <td>Time</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td></td>
          <td>BACE</td>
          <td>1</td>
          <td>1,522</td>
          <td>Classification</td>
          <td>Scaffold</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td>Physiology</td>
          <td>BBBP</td>
          <td>1</td>
          <td>2,053</td>
          <td>Classification</td>
          <td>Scaffold</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td></td>
          <td>Tox21</td>
          <td>12</td>
          <td>8,014</td>
          <td>Classification</td>
          <td>Random</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td></td>
          <td>ToxCast</td>
          <td>617</td>
          <td>8,615</td>
          <td>Classification</td>
          <td>Random</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td></td>
          <td>SIDER</td>
          <td>27</td>
          <td>1,427</td>
          <td>Classification</td>
          <td>Random</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td></td>
          <td>ClinTox</td>
          <td>2</td>
          <td>1,491</td>
          <td>Classification</td>
          <td>Random</td>
          <td>ROC-AUC</td>
      </tr>
  </tbody>
</table>
<p><strong>Quantum mechanics</strong> datasets (QM7, QM7b, QM8, <a href="/notes/chemistry/datasets/qm9/">QM9</a>) contain DFT-computed electronic properties for subsets of the <a href="/notes/chemistry/datasets/gdb-17/">GDB</a> database. <strong>Physical chemistry</strong> datasets cover solubility (ESOL), hydration free energy (FreeSolv), and lipophilicity. <strong>Biophysics</strong> datasets include high-throughput screening results (PCBA, MUV), HIV inhibition activity, protein-ligand binding affinity (PDBbind), and BACE-1 inhibition. <strong>Physiology</strong> datasets cover blood-brain barrier penetration (BBBP), toxicity (Tox21, ToxCast), side effects (SIDER), and clinical trial toxicity (ClinTox).</p>
<h3 id="data-splitting-strategies">Data Splitting Strategies</h3>
<p>MoleculeNet implements four splitting methods, all using an 80/10/10 train/validation/test ratio:</p>
<ul>
<li><strong>Random splitting</strong>: Standard random assignment to subsets.</li>
<li><strong>Scaffold splitting</strong>: Separates molecules by their 2D structural frameworks (Bemis-Murcko scaffolds), providing a harder generalization test since structurally different molecules appear in different subsets.</li>
<li><strong>Stratified splitting</strong>: Ensures each subset contains the full range of label values (used for QM7).</li>
<li><strong>Time splitting</strong>: Trains on older data and tests on newer data to mimic real-world development (used for PDBbind).</li>
</ul>
<h3 id="evaluation-metrics">Evaluation Metrics</h3>
<p>Regression tasks use MAE or RMSE depending on the dataset. Classification tasks use either ROC-AUC or PRC-AUC. The choice between ROC-AUC and PRC-AUC depends on class imbalance: PRC-AUC is recommended for datasets with positive rates below 2% (PCBA, MUV), since precision-recall curves better capture performance under extreme imbalance.</p>
<p>The false positive rate and precision are defined as:</p>
<p>$$
\text{FPR} = \frac{\text{false positive}}{\text{false positive} + \text{true negative}}
$$</p>
<p>$$
\text{precision} = \frac{\text{true positive}}{\text{false positive} + \text{true positive}}
$$</p>
<p>When positive samples form a small fraction of the data, false positives influence precision much more than FPR, making PRC-AUC more informative than ROC-AUC.</p>
<h3 id="featurization-methods">Featurization Methods</h3>
<p>MoleculeNet implements six molecular featurization approaches:</p>
<ol>
<li><strong>ECFP (Extended-Connectivity Fingerprints)</strong>: Fixed-length binary fingerprints capturing topological substructures via hashing.</li>
<li><strong><a href="/posts/molecular-descriptor-coulomb-matrix/">Coulomb Matrix</a></strong>: Encodes nuclear charges and 3D coordinates through atomic self-energies and Coulomb repulsion:</li>
</ol>
<p>$$
M_{IJ} = \begin{cases} 0.5 Z_{I}^{2.4} &amp; \text{for } I = J \\ \frac{Z_{I} Z_{J}}{|\mathbf{R}_{I} - \mathbf{R}_{J}|} &amp; \text{for } I \neq J \end{cases}
$$</p>
<ol start="3">
<li><strong>Grid Featurizer</strong>: Designed for PDBbind, incorporating both ligand and protein structural information including salt bridges, hydrogen bonds, and SPLIF fingerprints.</li>
<li><strong>Symmetry Functions</strong>: Preserve rotational and permutation symmetry through radial and angular functions between atom pairs and triplets.</li>
<li><strong>Graph Convolutions</strong>: Compute initial atom feature vectors and neighbor lists from molecular graphs.</li>
<li><strong>Weave</strong>: Similar to graph convolutions but also computes pairwise atom features encoding bond properties, graph distance, and ring information.</li>
</ol>
<h2 id="benchmarked-models-and-experimental-setup">Benchmarked Models and Experimental Setup</h2>
<p>MoleculeNet benchmarks 12 learning algorithms divided into conventional methods and graph-based methods.</p>
<h3 id="conventional-methods">Conventional Methods</h3>
<ul>
<li><strong>Logistic Regression</strong> (classification only)</li>
<li><strong>Kernel SVM</strong> with radial basis function kernel</li>
<li><strong>Kernel Ridge Regression (KRR)</strong></li>
<li><strong>Random Forests</strong></li>
<li><strong>Gradient Boosting</strong> (XGBoost)</li>
<li><strong>Singletask/Multitask Networks</strong>: Fully connected networks with shared layers across tasks</li>
<li><strong>Bypass Networks</strong>: Multitask networks augmented with per-task &ldquo;bypass&rdquo; layers that directly connect inputs to outputs</li>
<li><strong>Influence Relevance Voting (IRV)</strong>: Refined K-nearest neighbor classifiers using Jaccard-Tanimoto similarity:</li>
</ul>
<p>$$
S(\vec{A}, \vec{B}) = \frac{A \cap B}{A \cup B}
$$</p>
<h3 id="graph-based-methods">Graph-Based Methods</h3>
<ul>
<li><strong>Graph Convolutional Models (GC)</strong>: Extend circular fingerprints with learnable convolutions over molecular graphs.</li>
<li><strong>Weave Models</strong>: Update atom features using information from all other atoms and their pairwise features.</li>
<li><strong>Directed Acyclic Graph (DAG) Models</strong>: Define directed bonds toward a central atom and propagate features through the directed graph.</li>
<li><strong>Deep Tensor Neural Networks (DTNN)</strong>: Use nuclear charges and distance matrices directly, updating atom embeddings based on pairwise physical distances.</li>
<li><strong>ANI-1</strong>: Learns transferable potentials using symmetry function features with atom-type-specific neural networks.</li>
<li><strong>Message Passing Neural Networks (MPNN)</strong>: Generalized framework with edge-dependent message functions and set2set readout.</li>
</ul>
<h3 id="experimental-protocol">Experimental Protocol</h3>
<p>Gaussian process hyperparameter optimization was applied to each dataset-model combination, followed by three independent runs with different random seeds. All results are reported as means with standard deviations. Variable training-size experiments were conducted on Tox21, FreeSolv, and QM7 to study data efficiency.</p>
<h2 id="key-findings-across-property-categories">Key Findings Across Property Categories</h2>
<h3 id="biophysics-and-physiology">Biophysics and Physiology</h3>
<p>Graph convolutional and weave models showed strong performance on larger datasets with less overfitting than conventional methods. Graph-based models outperformed multitask networks at 30% training data compared to 90% on Tox21. However, for smaller single-task datasets (under 3,000 samples), kernel SVM and ensemble tree methods were more robust. On highly imbalanced datasets like MUV (0.20% positive rate), graph-based models struggled to control false positives.</p>
<p>Multitask training had a regularizing effect, reducing the gap between train and test scores compared to single-task models. Bypass networks consistently matched or exceeded vanilla multitask networks, confirming that per-task layers add explanatory power.</p>
<h3 id="physical-chemistry">Physical Chemistry</h3>
<p>Graph-based methods (GC, DAG, MPNN, Weave) provided significant improvements over single-task networks for predicting solubility, solvation energy, and lipophilicity. The best models achieved accuracy comparable to ab initio predictions (within 0.5 RMSE for ESOL, within 1.5 kcal/mol for FreeSolv). On FreeSolv, a weave model trained on approximately 200 samples matched the accuracy of alchemical free energy calculations.</p>
<h3 id="quantum-mechanics">Quantum Mechanics</h3>
<p>Models incorporating 3D distance information (DTNN, MPNN, KRR with Coulomb matrix) substantially outperformed models using only topological features. DTNN and MPNN covered the best-performing models on 28 of 39 tasks across QM datasets. The choice of physics-aware featurization proved more important than the choice of learning algorithm for these tasks.</p>
<h3 id="summary-of-best-performances">Summary of Best Performances</h3>
<p>Graph-based models outperformed conventional methods on 11 of 17 datasets. Key results on the test set:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Metric</th>
          <th>Best Conventional</th>
          <th>Best Graph-Based</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>QM7</td>
          <td>MAE</td>
          <td>KRR (CM): 10.22</td>
          <td>DTNN: 8.75</td>
      </tr>
      <tr>
          <td>QM9</td>
          <td>MAE</td>
          <td>Multitask (CM): 4.35</td>
          <td>DTNN: 2.35</td>
      </tr>
      <tr>
          <td>ESOL</td>
          <td>RMSE</td>
          <td>XGBoost: 0.99</td>
          <td>MPNN: 0.58</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>RMSE</td>
          <td>XGBoost: 1.74</td>
          <td>MPNN: 1.15</td>
      </tr>
      <tr>
          <td>PCBA</td>
          <td>PRC-AUC</td>
          <td>Logreg: 0.129</td>
          <td>GC: 0.136</td>
      </tr>
      <tr>
          <td>Tox21</td>
          <td>ROC-AUC</td>
          <td>KernelSVM: 0.822</td>
          <td>GC: 0.829</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>ROC-AUC</td>
          <td>KernelSVM: 0.792</td>
          <td>GC: 0.763</td>
      </tr>
      <tr>
          <td>BACE</td>
          <td>ROC-AUC</td>
          <td>RF: 0.867</td>
          <td>Weave: 0.806</td>
      </tr>
  </tbody>
</table>
<p>Conventional methods (KernelSVM, RF) still won on several smaller or scaffold-split datasets (HIV, BACE, MUV, PDBbind, BBBP, SIDER), highlighting that graph-based models are not universally superior, particularly under data scarcity or challenging splits.</p>
<h2 id="conclusions-and-limitations">Conclusions and Limitations</h2>
<p>MoleculeNet demonstrated that learnable representations broadly offer the best performance for molecular machine learning. However, the authors identify several important caveats:</p>
<ol>
<li><strong>Data scarcity</strong>: Graph-based methods are not robust enough on complex tasks with limited training data.</li>
<li><strong>Class imbalance</strong>: On heavily imbalanced classification datasets, conventional methods such as kernel SVM outperform learnable featurizations with respect to recall of positives.</li>
<li><strong>Task-specific featurizations</strong>: For quantum mechanical and biophysical datasets, incorporating physics-aware features (<a href="/posts/molecular-descriptor-coulomb-matrix/">Coulomb matrix</a>, 3D coordinates) is more important than the choice of learning algorithm.</li>
<li><strong>Data-driven physical chemistry</strong>: On FreeSolv, data-driven methods outperformed ab initio calculations with moderate data, suggesting data-driven approaches will become increasingly important as methods and datasets mature.</li>
</ol>
<p>The authors express hope that MoleculeNet will stimulate algorithmic development similar to how ImageNet catalyzed breakthroughs in computer vision. Future directions include extending coverage to 3D protein structure prediction, DNA topological modeling, and other areas of molecular science.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>All 17 datasets are publicly available and integrated into the DeepChem Python package. Users can load any dataset with a single library call.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>QM benchmark</td>
          <td>QM7/QM7b/QM8/QM9</td>
          <td>7K-134K compounds</td>
          <td>DFT-computed properties from GDB subsets</td>
      </tr>
      <tr>
          <td>Physical chemistry</td>
          <td>ESOL/FreeSolv/Lipophilicity</td>
          <td>643-4,200 compounds</td>
          <td>Experimental measurements</td>
      </tr>
      <tr>
          <td>Biophysics</td>
          <td>PCBA/MUV/HIV/PDBbind/BACE</td>
          <td>1.5K-440K compounds</td>
          <td>Bioassay and binding data</td>
      </tr>
      <tr>
          <td>Physiology</td>
          <td>BBBP/Tox21/ToxCast/SIDER/ClinTox</td>
          <td>1.4K-8.6K compounds</td>
          <td>Toxicity and drug safety data</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>All splitting methods (random, scaffold, stratified, time) and featurizations (ECFP, Coulomb matrix, grid, symmetry functions, graph convolutions, weave) are implemented in DeepChem. Hyperparameters were tuned via Gaussian process optimization. Three random seeds were used per experiment.</p>
<h3 id="models">Models</h3>
<p>All 12 models are implemented in DeepChem, built on Scikit-Learn and TensorFlow. No pretrained weights are provided; models are trained from scratch on each dataset.</p>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics include MAE, RMSE, ROC-AUC, and PRC-AUC as specified per dataset. Multi-task datasets report mean metric values across all tasks.</p>
<h3 id="hardware">Hardware</h3>
<p>The authors used Stanford&rsquo;s Sherlock and Xstream GPU nodes. Specific GPU types and training times per model are provided in Table S1 of the supplementary material.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/deepchem/deepchem">DeepChem</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Open-source library with all datasets, featurizations, and models</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Wu, Z., Ramsundar, B., Feinberg, E. N., Gomes, J., Geniesse, C., Pappu, A. S., Leswing, K., &amp; Pande, V. (2018). MoleculeNet: a benchmark for molecular machine learning. <em>Chemical Science</em>, 9(2), 513-530. <a href="https://doi.org/10.1039/c7sc02664a">https://doi.org/10.1039/c7sc02664a</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{wu2018moleculenet,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{MoleculeNet: a benchmark for molecular machine learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Wu, Zhenqin and Ramsundar, Bharath and Feinberg, Evan N. and Gomes, Joseph and Geniesse, Caleb and Pappu, Aneesh S. and Leswing, Karl and Pande, Vijay}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Chemical Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{9}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{513--530}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/c7sc02664a}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Failure Modes in Molecule Generation &amp; Optimization</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/failure-modes-molecule-generation/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/failure-modes-molecule-generation/</guid><description>Renz et al. show trivial models fool distribution-learning metrics and ML scoring functions introduce exploitable biases in goal-directed molecule generation.</description><content:encoded><![CDATA[<h2 id="an-empirical-critique-of-molecular-generation-evaluation">An Empirical Critique of Molecular Generation Evaluation</h2>
<p>This is an <strong>Empirical</strong> paper that critically examines evaluation practices for molecular generative models. Rather than proposing a new generative method, the paper exposes systematic weaknesses in both distribution-learning metrics and goal-directed optimization scoring functions. The primary contributions are: (1) demonstrating that a trivially simple &ldquo;AddCarbon&rdquo; model can achieve near-perfect scores on widely used distribution-learning benchmarks, and (2) introducing an experimental framework with optimization scores and control scores that reveals model-specific and data-specific biases when ML models serve as scoring functions for goal-directed generation.</p>
<h2 id="evaluation-gaps-in-de-novo-molecular-design">Evaluation Gaps in De Novo Molecular Design</h2>
<p>The rapid growth of deep learning methods for molecular generation (RNN-based SMILES generators, VAEs, GANs, graph neural networks) created a need for standardized evaluation. Benchmarking suites like <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> and <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a> introduced metrics for validity, uniqueness, novelty, KL divergence over molecular properties, and <a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">Frechet ChemNet Distance (FCD)</a>. For goal-directed generation, penalized logP became a common optimization target.</p>
<p>However, these metrics leave significant blind spots. Distribution-learning metrics do not detect whether a model merely copies training molecules with minimal modifications. Goal-directed benchmarks often use scoring functions that fail to capture the full requirements of drug discovery (synthetic feasibility, drug-likeness, absence of reactive substructures). When ML models serve as scoring functions, the problem worsens because generated molecules can exploit artifacts of the learned model rather than exhibiting genuinely desirable properties.</p>
<p>At the time of writing, wet-lab validations of generative models remained scarce, with only a handful of studies (Merk et al., Zhavoronkov et al.) demonstrating in vitro activity for generated compounds. The lack of rigorous evaluation left the field unable to distinguish meaningfully innovative methods from those that simply exploit metric weaknesses.</p>
<h2 id="the-copy-problem-and-control-score-framework">The Copy Problem and Control Score Framework</h2>
<p>The paper introduces two key conceptual contributions.</p>
<h3 id="the-addcarbon-model-for-distribution-learning">The AddCarbon Model for Distribution-Learning</h3>
<p>The AddCarbon model is deliberately trivial: it samples a molecule from the training set, inserts a single carbon atom at a random position in its SMILES string, and returns the result if it produces a valid, novel molecule. This model achieves near-perfect scores across most <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> distribution-learning benchmarks:</p>
<table>
  <thead>
      <tr>
          <th>Benchmark</th>
          <th>RS</th>
          <th>LSTM</th>
          <th>GraphMCTS</th>
          <th>AAE</th>
          <th>ORGAN</th>
          <th>VAE</th>
          <th>AddCarbon</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity</td>
          <td>1.000</td>
          <td>0.959</td>
          <td>1.000</td>
          <td>0.822</td>
          <td>0.379</td>
          <td>0.870</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>Uniqueness</td>
          <td>0.997</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>0.841</td>
          <td>0.999</td>
          <td>0.999</td>
      </tr>
      <tr>
          <td>Novelty</td>
          <td>0.000</td>
          <td>0.912</td>
          <td>0.994</td>
          <td>0.998</td>
          <td>0.687</td>
          <td>0.974</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>KL divergence</td>
          <td>0.998</td>
          <td>0.991</td>
          <td>0.522</td>
          <td>0.886</td>
          <td>0.267</td>
          <td>0.982</td>
          <td>0.982</td>
      </tr>
      <tr>
          <td>FCD</td>
          <td>0.929</td>
          <td>0.913</td>
          <td>0.015</td>
          <td>0.529</td>
          <td>0.000</td>
          <td>0.863</td>
          <td>0.871</td>
      </tr>
  </tbody>
</table>
<p>The AddCarbon model beats all baselines except the LSTM on the FCD metric, despite being practically useless. This exposes what the authors call the &ldquo;copy problem&rdquo;: current metrics check only for exact matches to training molecules, so minimal edits evade novelty detection. The authors argue that likelihood-based evaluation on hold-out test sets, analogous to standard practice in NLP, would provide a more comprehensive metric.</p>
<h3 id="control-scores-for-goal-directed-generation">Control Scores for Goal-Directed Generation</h3>
<p>For goal-directed generation, the authors introduce a three-score experimental design:</p>
<ul>
<li><strong>Optimization Score (OS)</strong>: Output of a classifier trained on data split 1, used to guide the molecular optimizer.</li>
<li><strong>Model Control Score (MCS)</strong>: Output of a second classifier trained on split 1 with a different random seed. Divergence between OS and MCS quantifies model-specific biases.</li>
<li><strong>Data Control Score (DCS)</strong>: Output of a classifier trained on data split 2. Divergence between OS and DCS quantifies data-specific biases.</li>
</ul>
<p>This mirrors the training/test split paradigm in supervised learning. If a generator truly produces molecules with the desired bioactivity, the control scores should track the optimization score. Divergence between them indicates the optimizer is exploiting artifacts of the specific model or training data rather than learning generalizable chemical properties.</p>
<h2 id="experimental-setup-three-targets-three-generators">Experimental Setup: Three Targets, Three Generators</h2>
<h3 id="targets-and-data">Targets and Data</h3>
<p>The authors selected three biological targets from ChEMBL: <a href="https://en.wikipedia.org/wiki/Janus_kinase_2">Janus kinase 2</a> (JAK2), <a href="https://en.wikipedia.org/wiki/Epidermal_growth_factor_receptor">epidermal growth factor receptor</a> (EGFR), and <a href="https://en.wikipedia.org/wiki/Dopamine_receptor_D2">dopamine receptor D2</a> (DRD2). For each target, the data was split into two halves (split 1 and split 2) with balanced active/inactive ratios. Random forest classifiers using binary folded ECFP4 fingerprints (radius 2, size 1024) were trained to produce three scoring functions per target: the OS and MCS on split 1 (different random seeds), and the DCS on split 2.</p>
<h3 id="generators">Generators</h3>
<p>Three molecular generators were evaluated:</p>
<ol>
<li><strong>Graph-based Genetic Algorithm (GA)</strong>: Iteratively applies random mutations and crossovers to a population of molecules, retaining the best in each generation. One of the top performers in GuacaMol.</li>
<li><strong>SMILES-LSTM</strong>: An autoregressive model that generates SMILES character by character, optimized via hill climbing (iteratively sampling, keeping top molecules, fine-tuning). Also a top GuacaMol performer.</li>
<li><strong><a href="https://en.wikipedia.org/wiki/Particle_swarm_optimization">Particle Swarm Optimization</a> (PS)</strong>: Optimizes molecules in the continuous latent space of a SMILES-based sequence-to-sequence model.</li>
</ol>
<p>Each optimizer was run 10 times per target dataset.</p>
<h2 id="score-divergence-and-exploitable-biases">Score Divergence and Exploitable Biases</h2>
<h3 id="optimization-vs-control-score-divergence">Optimization vs. Control Score Divergence</h3>
<p>Across all three targets and all three generators, the OS consistently outpaced both control scores during optimization. The DCS sometimes stagnated or even decreased while the OS continued to climb. This divergence demonstrates that the generators exploit biases in the scoring function rather than discovering genuinely active compounds.</p>
<p>The MCS also diverged from the OS despite being trained on exactly the same data, confirming model-specific biases: the optimization exploits features unique to the particular random forest instance. The larger gap between OS and DCS (compared to OS and MCS) indicates that data-specific biases contribute more to the divergence than model-specific biases.</p>
<h3 id="chemical-space-migration">Chemical Space Migration</h3>
<p>Optimized molecules migrated toward the region of split 1 actives (used to train the OS), as shown by t-SNE embeddings and nearest-neighbor Tanimoto similarity analysis. Optimized molecules had more similar neighbors in split 1 than in split 2, confirming data-specific bias. By the end of optimization, generated molecules occupied different regions of chemical space than known actives when measured by logP and molecular weight, with compounds from the same optimization run forming distinct clusters.</p>
<h3 id="quality-of-generated-molecules">Quality of Generated Molecules</h3>
<p>High-scoring generated molecules frequently contained problematic substructures: reactive dienes, nitrogen-fluorine bonds, long heteroatom chains that are synthetically infeasible, and highly uncommon functional groups. The LSTM optimizer showed a bias toward high molecular weight, low diversity, and high logP values. These molecules would be rejected by medicinal chemists despite their high optimization scores.</p>
<h3 id="key-takeaways">Key Takeaways</h3>
<p>The authors emphasize several practical implications:</p>
<ol>
<li><strong>Early stopping</strong>: Control scores can indicate when further optimization is exploiting biases rather than finding better molecules. Optimization should stop when control scores plateau.</li>
<li><strong>Scoring function iteration</strong>: In practice, generative models are &ldquo;highly adept at exploiting&rdquo; incomplete scoring functions, necessitating several iterations of generation and scoring function refinement.</li>
<li><strong>Synthetic accessibility</strong>: Even high-scoring molecules are useless if they cannot be synthesized. The authors consider this a major challenge for practical adoption.</li>
<li><strong>Likelihood-based evaluation</strong>: For distribution-learning, the authors recommend reporting test-set likelihoods for likelihood-based models, following standard NLP practice.</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Bioactivity data</td>
          <td>ChEMBL (JAK2, EGFR, DRD2)</td>
          <td>See Table S1</td>
          <td>Binary classification tasks, split 50/50</td>
      </tr>
      <tr>
          <td>Distribution-learning</td>
          <td>GuacaMol training set</td>
          <td>Subset of ChEMBL</td>
          <td>Used as starting population for GA and PS</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Scoring function</strong>: Random forest classifier (scikit-learn) on binary ECFP4 fingerprints (size 1024, radius 2, RDKit)</li>
<li><strong>GA</strong>: Graph-based genetic algorithm from Jensen (2019)</li>
<li><strong>LSTM</strong>: SMILES-LSTM with hill climbing, pretrained model from GuacaMol</li>
<li><strong>PS</strong>: Particle swarm optimization in latent space of a sequence-to-sequence model (Winter et al. 2019)</li>
<li>Each optimizer run 10 times per target</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Optimization Score (OS)</td>
          <td>RF classifier on split 1</td>
          <td>Guides optimization</td>
      </tr>
      <tr>
          <td>Model Control Score (MCS)</td>
          <td>RF on split 1, different seed</td>
          <td>Detects model-specific bias</td>
      </tr>
      <tr>
          <td>Data Control Score (DCS)</td>
          <td>RF on split 2</td>
          <td>Detects data-specific bias</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> metrics</td>
          <td>Validity, uniqueness, novelty, KL div, FCD</td>
          <td>For distribution-learning</td>
      </tr>
  </tbody>
</table>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/ml-jku/mgenerators-failure-modes">ml-jku/mgenerators-failure-modes</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Data, code, and results</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{renz2019failure,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{On failure modes in molecule generation and optimization}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Renz, Philipp and Van Rompaey, Dries and Wegner, J{\&#34;o}rg Kurt and Hochreiter, Sepp and Klambauer, G{\&#34;u}nter}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Drug Discovery Today: Technologies}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{32-33}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{55--63}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2019}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Elsevier}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1016/j.ddtec.2020.09.003}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Renz, P., Van Rompaey, D., Wegner, J. K., Hochreiter, S., &amp; Klambauer, G. (2019). On failure modes in molecule generation and optimization. <em>Drug Discovery Today: Technologies</em>, 32-33, 55-63. <a href="https://doi.org/10.1016/j.ddtec.2020.09.003">https://doi.org/10.1016/j.ddtec.2020.09.003</a></p>
<p><strong>Publication</strong>: Drug Discovery Today: Technologies, Volume 32-33, 2019</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/ml-jku/mgenerators-failure-modes">Code and data (GitHub)</a></li>
</ul>
]]></content:encoded></item><item><title>ChemEval: Fine-Grained LLM Evaluation for Chemistry</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/chemeval-multilevel-chemical-evaluation/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/chemeval-multilevel-chemical-evaluation/</guid><description>ChemEval is a hierarchical 62-task benchmark evaluating LLMs across four levels of chemical capability, from basic knowledge to synthesis planning.</description><content:encoded><![CDATA[<h2 id="a-hierarchical-benchmark-for-chemistry-llms">A Hierarchical Benchmark for Chemistry LLMs</h2>
<p>ChemEval is a <strong>Resource</strong> paper that introduces a comprehensive, hierarchical benchmark for evaluating large language models on chemical tasks. The benchmark spans four progressive levels of difficulty (Advanced Knowledge Question Answering, Literature Understanding, Molecular Understanding, and Scientific Knowledge Deduction), encompasses 13 capability dimensions, and contains 62 distinct tasks with 3,160 evaluation instances. It covers both text-only and multimodal settings, making it one of the most extensive chemistry-specific LLM evaluation frameworks to date.</p>
<h2 id="gaps-in-existing-chemistry-benchmarks">Gaps in Existing Chemistry Benchmarks</h2>
<p>Prior benchmarks for chemistry LLMs had several shortcomings:</p>
<ul>
<li><strong>General benchmarks</strong> (MMLU, XieZhi, C-Eval) include some chemistry questions but lack the depth needed for meaningful evaluation of domain expertise.</li>
<li><strong>SciEVAL</strong> covers scientific tasks broadly but treats chemistry superficially with overly simplistic questions.</li>
<li><strong><a href="/notes/chemistry/llm-applications/chemllmbench-eight-chemistry-tasks/">ChemLLMBench</a></strong> (Guo et al., 2023) includes only 8 task categories derived from existing public datasets, offering insufficient breadth.</li>
<li><strong><a href="/notes/chemistry/llm-applications/chembench-llm-chemistry-evaluation/">ChemBench</a></strong> (Mirza et al., 2024) provides 7,000 samples but relies exclusively on multiple-choice questions and lacks open-ended evaluation for tasks like synthesis pathway recommendation.</li>
<li><strong><a href="/notes/chemistry/llm-applications/macbench-multimodal-chemistry-benchmark/">MaCBench</a></strong> (Alampara et al., 2025) introduces multimodal evaluation but remains limited in task diversity.</li>
</ul>
<p>None of these benchmarks address LLMs&rsquo; ability to extract chemical information from text and tables, and none provide a graduated, multi-level assessment of chemical competence from basic knowledge through to advanced scientific reasoning.</p>
<h2 id="a-four-level-hierarchical-evaluation-framework">A Four-Level Hierarchical Evaluation Framework</h2>
<p>ChemEval&rsquo;s core innovation is its hierarchical structure that mirrors how chemical expertise develops, from foundational knowledge through applied scientific reasoning.</p>
<h3 id="level-1-advanced-knowledge-question-answering">Level 1: Advanced Knowledge Question Answering</h3>
<p>This level assesses fundamental chemical knowledge through 15 tasks across two dimensions:</p>
<ul>
<li><strong>Objective Questions (ObjQA)</strong>: multiple choice, fill-in-the-blank, and true/false tasks spanning seven core chemistry disciplines (organic, inorganic, materials, analytical, biochemistry, physical, and polymer chemistry).</li>
<li><strong>Subjective Questions (SubjQA)</strong>: short answer and calculation tasks requiring detailed reasoning and explanation.</li>
</ul>
<h3 id="level-2-literature-understanding">Level 2: Literature Understanding</h3>
<p>This level evaluates the ability to interpret chemical literature through 19 tasks across three dimensions:</p>
<ul>
<li><strong>Information Extraction (InfoE)</strong>: 11 tasks covering named entity recognition, relationship classification, substrate extraction, additive/solvent/temperature/time extraction, product extraction, characterization method extraction, catalysis type extraction, and yield extraction.</li>
<li><strong>Inductive Generation (InducGen)</strong>: abstract generation, research outline generation, topic classification, and reaction type recognition.</li>
<li><strong>Molecular Name Recognition (MNR)</strong>: molecular formula recognition, chemical reaction equation recognition, 2D molecular structure recognition, and synthetic pathway analysis (multimodal tasks).</li>
</ul>
<h3 id="level-3-molecular-understanding">Level 3: Molecular Understanding</h3>
<p>This level tests molecular-level comprehension through 15 tasks across four dimensions:</p>
<ul>
<li><strong>Molecular Name Generation (MNGen)</strong>: generating <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> from text descriptions.</li>
<li><strong>Molecular Name Translation (MNTrans)</strong>: <a href="https://en.wikipedia.org/wiki/IUPAC_nomenclature_of_organic_chemistry">IUPAC</a> to molecular formula, SMILES to molecular formula, IUPAC to SMILES, SMILES to IUPAC, and SMILES/<a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> interconversion.</li>
<li><strong>Molecular Property Prediction (MPP)</strong>: classification (ClinTox, HIV inhibition, polarity) and regression (<a href="https://en.wikipedia.org/wiki/Lipophilicity">lipophilicity</a>, boiling point).</li>
<li><strong>Molecular Description (MolDesc)</strong>: physicochemical property prediction from molecular structures and various spectral inputs (IR, Raman, UV-Vis, diffraction, mass spectrum, <a href="https://en.wikipedia.org/wiki/Nuclear_magnetic_resonance_spectroscopy">NMR</a>).</li>
</ul>
<h3 id="level-4-scientific-knowledge-deduction">Level 4: Scientific Knowledge Deduction</h3>
<p>The most advanced level covers 13 tasks across four dimensions:</p>
<ul>
<li><strong><a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">Retrosynthetic Analysis</a> (ReSyn)</strong>: substrate recommendation, synthetic pathway recommendation, and synthetic difficulty evaluation.</li>
<li><strong>Reaction Condition Recommendation (RCRec)</strong>: ligand, reagent, solvent, catalyst, temperature, and time recommendation.</li>
<li><strong>Reaction Outcome Prediction (ROP)</strong>: product prediction, yield prediction, and reaction rate prediction.</li>
<li><strong>Reaction Mechanism Analysis (RMA)</strong>: intermediate derivation.</li>
</ul>
<h3 id="data-construction">Data Construction</h3>
<p>The benchmark combines open-source datasets (ChemRxnExtractor, Mol-Instructions, ChemLLMBench, SMolInstruct) with domain-expert data curated from approximately 500 university-level chemistry textbooks and 9,000 real-world experimental records. Expert-crafted questions were written from scratch to prevent data leakage. A three-tier quality assurance pipeline (annotation by undergraduate students, review by graduate students, final audit by chemistry faculty) ensures correctness.</p>
<p>The text subset contains 1,960 instances (18 open-source tasks, 24 in-house tasks), while the multimodal subset contains 1,200 instances (12 open-source tasks, 30 in-house tasks).</p>
<h2 id="experimental-setup-and-model-comparison">Experimental Setup and Model Comparison</h2>
<h3 id="models-evaluated">Models Evaluated</h3>
<p>ChemEval evaluates a broad set of models under both zero-shot and 3-shot settings:</p>
<p><strong>General LLMs</strong>: OpenAI-o1, OpenAI-o3-mini, GPT-4o, Claude-3.7-Sonnet (thinking and non-thinking modes), Gemini-2.5-Pro, Grok3, DeepSeek-V3, DeepSeek-R1, Qwen2.5 (7B/14B/32B/72B), LLaMA3.3-8B.</p>
<p><strong>Chemistry-specific LLMs</strong>: <a href="/notes/chemistry/llm-applications/chemdfm-r/">ChemDFM</a>, <a href="/notes/chemistry/llm-applications/llamsmol-instruction-tuning-chemistry/">LlaSMol</a>, <a href="/notes/chemistry/llm-applications/chemllm-chemical-large-language-model/">ChemLLM</a>, ChemSpark.</p>
<p><strong>Multimodal LLMs</strong> (for multimodal tasks): GPT-4o, Claude-3.7-Sonnet, Qwen-VL Max, Phi-Vision-3.5, Gemini-2.5-Pro, GLM-4V.</p>
<h3 id="evaluation-metrics">Evaluation Metrics</h3>
<p>The benchmark employs task-appropriate metrics: F1 score, Accuracy, BLEU, Exact Match, Normalized RMSE, <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> (with valid output ratio), LLM Score (judged by GPT-4o), L2 Score for molecular formula similarity, and Overlap for range prediction.</p>
<h3 id="key-results-zero-shot-text-tasks">Key Results (Zero-Shot Text Tasks)</h3>
<table>
  <thead>
      <tr>
          <th>Level</th>
          <th>Top General LLM</th>
          <th>Score</th>
          <th>Top Chemistry LLM</th>
          <th>Score</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Knowledge QA (MCTask)</td>
          <td>Gemini-2.5-Pro</td>
          <td>87.60%</td>
          <td><a href="/notes/chemistry/llm-applications/chemcrow-augmenting-llms-chemistry-tools/">ChemCrow</a></td>
          <td>58.00%</td>
      </tr>
      <tr>
          <td>Literature (CNER)</td>
          <td>Gemini-2.5-Pro</td>
          <td>68.30 F1</td>
          <td>ChemSpark</td>
          <td>71.44 F1</td>
      </tr>
      <tr>
          <td>Molecular (MolNG)</td>
          <td>Gemini-2.5-Pro</td>
          <td>71.11 Tan.</td>
          <td>ChemSpark</td>
          <td>74.81 Tan.</td>
      </tr>
      <tr>
          <td>Molecular (IUPAC2SMILES)</td>
          <td>Gemini-2.5-Pro</td>
          <td>61.33 Tan.</td>
          <td>ChemSpark</td>
          <td>87.54 Tan.</td>
      </tr>
      <tr>
          <td>Scientific (SubRec)</td>
          <td>OpenAI-o3-mini</td>
          <td>4.67 F1</td>
          <td>ChemSpark</td>
          <td>12.37 F1</td>
      </tr>
      <tr>
          <td>Scientific (CatRec)</td>
          <td>All models</td>
          <td>0.00 F1</td>
          <td>ChemSpark</td>
          <td>0.20 F1</td>
      </tr>
  </tbody>
</table>
<h2 id="key-findings-and-performance-patterns">Key Findings and Performance Patterns</h2>
<h3 id="general-vs-chemistry-specific-llms">General vs. Chemistry-Specific LLMs</h3>
<p>General-purpose LLMs excel at Advanced Knowledge QA and Literature Understanding, benefiting from strong document comprehension and instruction-following abilities. Chemistry-specialized models (particularly ChemSpark) outperform in tasks demanding domain-specific molecular knowledge, such as molecular name translation and reaction condition recommendation. However, specialized models show notably weaker instruction-following capability and suffer from catastrophic forgetting of general language abilities during fine-tuning. For example, ChemLLM scores 0.00 on multiple information extraction tasks where general LLMs achieve 60-95%.</p>
<h3 id="impact-of-few-shot-learning">Impact of Few-Shot Learning</h3>
<p>General LLMs tend to benefit from few-shot prompting, particularly for subjective QA and literature understanding tasks. OpenAI-o1 improved on 9 of 10 evaluated tasks. In contrast, chemistry-specialized models often show performance degradation with few-shot examples, likely due to loss of in-context learning capabilities during task-specific fine-tuning. ChemSpark decreased on 7 of 10 tasks in the 3-shot setting.</p>
<h3 id="impact-of-model-scaling">Impact of Model Scaling</h3>
<p>Experiments with Qwen2.5 at 7B, 14B, 32B, and 72B parameters show that scaling improves performance on knowledge QA and literature understanding tasks. However, molecular understanding and scientific knowledge deduction tasks show minimal improvement, and some tasks (e.g., molecular property classification) even decline at the largest scale. Tasks requiring specialized chemical knowledge, like IUPAC-to-SMILES conversion and catalyst recommendation, remain near zero regardless of model size.</p>
<h3 id="thinking-models">Thinking Models</h3>
<p>Comparing OpenAI-o1 vs. GPT-4o and DeepSeek-R1 vs. DeepSeek-V3, thinking models show comparable overall performance to their non-thinking counterparts. They occasionally excel on specific tasks (e.g., reaction product prediction) but do not consistently outperform across chemical tasks. The authors conclude that the primary bottleneck is insufficient domain-specific knowledge, not reasoning depth.</p>
<h3 id="multimodal-tasks">Multimodal Tasks</h3>
<p>Multimodal LLMs handle basic tasks like molecular formula recognition well (GLM-4V and Qwen-VL Max: 100% accuracy) but struggle with advanced challenges. Synthetic pathway analysis yielded 0% F1 across all models. 2D molecular structure recognition produced Tanimoto scores below 21% for all models tested. The performance gap between basic recognition and advanced chemical reasoning is substantial.</p>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge several limitations:</p>
<ol>
<li><strong>Limited instances per task</strong>: with 62 task types and 3,160 total instances, individual tasks may have as few as 20 samples.</li>
<li><strong>Static, single-turn evaluation</strong>: the benchmark does not assess dynamic interaction, tool use, or agentic workflows.</li>
<li><strong>No chemistry-specific multimodal models tested</strong>: only general-purpose VLMs were evaluated on multimodal tasks.</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Evaluation (text)</td>
          <td>ChemEval text subset</td>
          <td>1,960 instances</td>
          <td>18 open-source + 24 in-house tasks</td>
      </tr>
      <tr>
          <td>Evaluation (multimodal)</td>
          <td>ChemEval multimodal subset</td>
          <td>1,200 instances</td>
          <td>12 open-source + 30 in-house tasks</td>
      </tr>
      <tr>
          <td>Source (open-source)</td>
          <td>ChemRxnExtractor, Mol-Instructions, ChemLLMBench, SMolInstruct</td>
          <td>Various</td>
          <td>Adapted for ChemEval format</td>
      </tr>
      <tr>
          <td>Source (expert)</td>
          <td>~500 textbooks, ~9,000 experimental records</td>
          <td>Various</td>
          <td>Novel questions crafted by domain experts</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Evaluation prompts</strong>: task-specific instructions designed for formatted output, with 0-shot and 3-shot variants.</li>
<li><strong>Decoding</strong>: greedy decoding for all LLM inference.</li>
<li><strong>LLM-as-judge</strong>: GPT-4o used for LLM Score metric on subjective tasks.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Key metrics by task type:</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task Types</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Accuracy</td>
          <td>MCTask, TFTask, MolPC, SubE, etc.</td>
          <td>Standard classification accuracy</td>
      </tr>
      <tr>
          <td>F1 Score</td>
          <td>CNER, CERC, extraction tasks, reaction prediction</td>
          <td>Precision-recall harmonic mean</td>
      </tr>
      <tr>
          <td>BLEU</td>
          <td>SMILES2IUPAC</td>
          <td>N-gram overlap with brevity penalty</td>
      </tr>
      <tr>
          <td>Exact Match</td>
          <td>SMILES2IUPAC</td>
          <td>Strict string match</td>
      </tr>
      <tr>
          <td>Tanimoto Similarity</td>
          <td>Molecular generation/translation tasks</td>
          <td>Fingerprint-based molecular similarity</td>
      </tr>
      <tr>
          <td>NRMSE</td>
          <td>Regression tasks (property, temperature, time)</td>
          <td>Normalized prediction error</td>
      </tr>
      <tr>
          <td>LLM Score</td>
          <td>Subjective QA, abstract generation, pathway rec.</td>
          <td>GPT-4o evaluation (0-100)</td>
      </tr>
      <tr>
          <td>L2 Score</td>
          <td>Molecular formula tasks</td>
          <td>$1 / (1 + \text{L2 distance})$ between formulas</td>
      </tr>
      <tr>
          <td>Overlap</td>
          <td>Rate prediction</td>
          <td>Intersection/union of predicted vs. reference ranges</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Chemistry-specific models run on two NVIDIA A40 48GB GPUs.</li>
<li>General models accessed via official APIs.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/USTC-StarTeam/ChemEval">ChemEval Benchmark</a></td>
          <td>Code + Data</td>
          <td>Other (custom)</td>
          <td>Evaluation framework and task data</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Huang, Y., Zhang, R., He, X., Zhi, X., Wang, H., Chen, N., Liu, Z., Li, X., Xu, F., Liu, D., Liang, H., Li, Y., Cui, J., Xu, Y., Wang, S., Liu, Q., Lian, D., Liu, G., &amp; Chen, E. (2024). ChemEval: A Comprehensive Multi-Level Chemical Evaluation for Large Language Models. arXiv preprint arXiv:2409.13989.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{huang2024chemeval,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{ChemEval: A Comprehensive Multi-Level Chemical Evaluation for Large Language Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Huang, Yuqing and Zhang, Rongyang and He, Xuesong and Zhi, Xuyang and Wang, Hao and Chen, Nuo and Liu, Zongbo and Li, Xin and Xu, Feiyang and Liu, Deguang and Liang, Huadong and Li, Yi and Cui, Jian and Xu, Yin and Wang, Shijin and Liu, Qi and Lian, Defu and Liu, Guiquan and Chen, Enhong}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2409.13989}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.48550/arXiv.2409.13989}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemBench: Evaluating LLM Chemistry Against Experts</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/chembench-llm-chemistry-evaluation/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/chembench-llm-chemistry-evaluation/</guid><description>ChemBench benchmarks LLM chemical knowledge with 2,700+ questions across topics, finding top models outperform expert chemists on average.</description><content:encoded><![CDATA[<h2 id="a-benchmark-resource-for-chemistry-focused-llm-evaluation">A Benchmark Resource for Chemistry-Focused LLM Evaluation</h2>
<p>ChemBench is a <strong>Resource</strong> paper that introduces an automated benchmarking framework for evaluating the chemical knowledge and reasoning abilities of large language models against human expert chemists. The primary contribution is the benchmark corpus itself (2,788 question-answer pairs), the evaluation infrastructure, and the human baseline study that contextualizes model performance. The framework is designed to be extensible and can evaluate any system that returns text, including tool-augmented agents.</p>
<h2 id="why-chemistry-needs-its-own-llm-benchmark">Why Chemistry Needs Its Own LLM Benchmark</h2>
<p>Existing LLM benchmarks provide poor coverage of chemistry. BigBench contains only 2 of 204 tasks classified as chemistry-related, and the LM Eval Harness contains none. Developers of chemical language models often fall back on tabular property-prediction datasets (<a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>, Therapeutic Data Commons, MatBench), which give a narrow view of chemical capabilities. Prior attempts at chemistry-specific benchmarks based on university entrance exams or automatic text mining have not gained wide acceptance because they cannot be used with black-box or tool-augmented systems, do not cover a broad range of topics and skills, or are not validated by domain experts.</p>
<p>At the same time, LLMs are increasingly used in chemistry: for property prediction, reaction optimization, materials generation, information extraction, and even autonomous experiment execution. Some users (students, general public) may rely on LLMs for safety-critical chemical questions without the expertise to evaluate outputs. Understanding where LLMs succeed and fail in chemistry is therefore both a scientific and a safety question.</p>
<h2 id="chembench-framework-design-and-benchmark-corpus">ChemBench: Framework Design and Benchmark Corpus</h2>
<p>ChemBench addresses these gaps with several design choices that distinguish it from prior work.</p>
<p><strong>Diverse question corpus.</strong> The benchmark contains 2,788 question-answer pairs from multiple sources: 1,039 manually generated (from university exams, chemistry olympiads, textbooks, and novel questions) and 1,749 semi-automatically generated (from chemical databases covering <a href="https://en.wikipedia.org/wiki/Globally_Harmonized_System_of_Classification_and_Labelling_of_Chemicals">GHS pictograms</a>, daily allowed intakes, hazard statements, <a href="https://en.wikipedia.org/wiki/Nuclear_magnetic_resonance_spectroscopy">NMR</a> peak counts, electron counts, IUPAC-SMILES conversions, oxidation states, and <a href="https://en.wikipedia.org/wiki/Point_group">point groups</a>). Questions span general, organic, inorganic, physical, analytical, and technical chemistry, among other topics.</p>
<p><strong>Skill-based classification.</strong> Each question is annotated with the skills required to answer it: knowledge, reasoning, calculation, intuition, or combinations thereof. Questions are also classified by difficulty level (basic vs. advanced), enabling fine-grained analysis of model capabilities.</p>
<p><strong>Both MCQ and open-ended formats.</strong> The corpus includes 2,544 multiple-choice and 244 open-ended questions, reflecting the reality that chemistry education and research involve more than multiple-choice testing.</p>
<p><strong>Semantic annotation.</strong> Questions use tagged annotations for molecules (<code>[START_SMILES]...[END_SMILES]</code>), equations, units, and reactions. This allows models with special processing for scientific notation (e.g., <a href="/notes/chemistry/llm-applications/galactica-large-language-model-for-science/">Galactica</a>) to handle these modalities appropriately, while remaining compatible with standard text-completion APIs.</p>
<p><strong>Text-completion evaluation.</strong> ChemBench operates on text completions rather than raw logits, enabling evaluation of tool-augmented and agentic systems (not just bare models). Parsing uses multi-step regex followed by LLM-based extraction as a fallback.</p>
<p><strong>ChemBench-Mini.</strong> A curated 236-question subset balances topic and skill diversity for fast, cost-effective routine evaluations. This subset was also used for the full human baseline study.</p>
<h2 id="evaluation-setup-models-human-experts-and-confidence">Evaluation Setup: Models, Human Experts, and Confidence</h2>
<h3 id="models-evaluated">Models evaluated</h3>
<p>The study evaluated a wide range of leading models, including both open-source and proprietary systems: o1-preview, GPT-4, Claude-3.5 (Sonnet), Llama-3.1-405B-Instruct, and others, as well as the agentic literature-search system PaperQA2. All models used greedy decoding (temperature 0) via API endpoints.</p>
<h3 id="human-baseline">Human baseline</h3>
<p>Nineteen chemistry experts participated through a custom web application (chembench.org). Volunteers included 2 post-postdoc researchers, 13 PhD students (with master&rsquo;s degrees), and 1 bachelor&rsquo;s holder. The analysis excluded anyone with fewer than 2 years of chemistry experience. For a subset of questions, volunteers were allowed to use external tools (web search, ChemDraw) but not LLMs or other people.</p>
<h3 id="confidence-calibration">Confidence calibration</h3>
<p>Selected top-performing models were prompted to estimate their confidence on a 1-5 ordinal scale (verbalized confidence estimates). This approach captures semantic uncertainty and works with models that do not expose logits.</p>
<h2 id="key-results-where-llms-outperform-chemists-and-where-they-fail">Key Results: Where LLMs Outperform Chemists and Where They Fail</h2>
<h3 id="overall-performance">Overall performance</h3>
<p>On ChemBench-Mini, the leading model (o1-preview) outperformed the best human expert by nearly a factor of two in overall accuracy. Many other models also exceeded average human performance. Llama-3.1-405B-Instruct achieved performance close to the leading proprietary models, showing that open-source models can be competitive in chemical settings.</p>
<h3 id="performance-varies-by-topic">Performance varies by topic</h3>
<p>While models scored well on general and technical chemistry, they performed poorly on toxicity/safety and analytical chemistry. Predicting the number of NMR signals was particularly difficult (22% correct for o1-preview). This task requires reasoning about molecular symmetry from a <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> string, which models struggle with compared to humans who can view molecular drawings.</p>
<h3 id="textbook-questions-vs-database-derived-questions">Textbook questions vs. database-derived questions</h3>
<p>Models performed better on textbook-inspired questions than on semi-automatically constructed tasks. For example, models could pass the German Chemical Prohibition Ordinance certification exam (71% for GPT-4, 61% for Claude-3.5 Sonnet) while human experts scored only 3% on the sampled subset. This suggests that good textbook question performance does not transfer to tasks requiring deeper reasoning or knowledge outside the training corpus.</p>
<h3 id="knowledge-intensive-limitations">Knowledge-intensive limitations</h3>
<p>Models struggled with knowledge-intensive questions that required looking up facts in specialized databases (<a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a>, Gestis). PaperQA2, which augments LLMs with literature search, could not compensate because the required knowledge lives in specialized databases rather than papers.</p>
<h3 id="chemical-preference-judgment">Chemical preference judgment</h3>
<p>When asked to judge chemical preference (choosing between two molecules in an early <a href="https://en.wikipedia.org/wiki/Virtual_screening">virtual screening</a> setting, following the Choung et al. dataset), model performance was often indistinguishable from random guessing, even for models that excelled at other ChemBench tasks. Human chemists showed reasonable inter-rater agreement on the same questions.</p>
<h3 id="confidence-calibration-is-poor">Confidence calibration is poor</h3>
<p>For most models, verbalized confidence estimates did not correlate meaningfully with actual correctness. GPT-4 reported confidence of 1.0 for a correctly answered safety question but 4.0 for six incorrectly answered ones. Claude-3.5 Sonnet showed slightly better calibration on average but still produced misleading estimates in specific topic areas (e.g., GHS pictogram labeling: average confidence of 2.0 for correct answers vs. 1.83 for incorrect ones).</p>
<h3 id="scaling-and-molecular-complexity">Scaling and molecular complexity</h3>
<p>Model performance correlated with model size, consistent with observations in other domains. However, performance did not correlate with molecular complexity indicators, suggesting that models may rely on training data proximity rather than genuine structural reasoning.</p>
<h2 id="implications-for-chemistry-and-llm-development">Implications for Chemistry and LLM Development</h2>
<p>The authors draw several conclusions from the ChemBench evaluation.</p>
<p><strong>Chemistry education needs rethinking.</strong> Since LLMs already outperform average human chemists on many textbook-style questions, the value of rote memorization and problem-solving in chemistry curricula is diminishing. Critical reasoning and evaluation of model outputs become more important skills.</p>
<p><strong>Breadth vs. depth matters.</strong> Model performance varies widely across topics and question types, even within a single topic. Aggregate scores can mask significant weaknesses in safety-critical areas.</p>
<p><strong>Better human-model interaction is needed.</strong> Poor confidence calibration means users cannot trust models&rsquo; self-reported uncertainty. Developing better uncertainty estimation for chemical LLMs is an important direction.</p>
<p><strong>Room for improvement through specialized data.</strong> Training on specialized chemical databases (rather than just papers) and integrating domain-specific tools could address the knowledge-intensive gaps identified by ChemBench.</p>
<p><strong>Open science framework.</strong> ChemBench is designed for extensibility: new models can be added by contributors, and the leaderboard is publicly accessible. The use of a BigBench-compatible canary string helps prevent test set contamination in future training corpora.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Evaluation</td>
          <td>ChemBench (full corpus)</td>
          <td>2,788 Q-A pairs</td>
          <td>1,039 manual + 1,749 semi-automatic</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>ChemBench-Mini</td>
          <td>236 questions</td>
          <td>Curated diverse subset; used for human baseline</td>
      </tr>
      <tr>
          <td>Chemical preference</td>
          <td>Choung et al. dataset</td>
          <td>1,000 sampled pairs</td>
          <td>From original 5,000+ dataset</td>
      </tr>
  </tbody>
</table>
<p>All benchmark data is publicly available on GitHub and archived on Zenodo.</p>
<h3 id="algorithms">Algorithms</h3>
<p>Evaluation uses greedy decoding (temperature 0) for all models. Parsing is multi-step: regex extraction of answer environments and enumeration letters/numbers, word-to-number conversion, and LLM-based fallback parsing (Claude-3.5 Sonnet). Confidence estimates are verbalized on an ordinal 1-5 scale.</p>
<h3 id="models">Models</h3>
<p>The paper evaluates multiple models including o1-preview, GPT-4, Claude-3.5 (Sonnet), Llama-3.1-405B-Instruct, Galactica, and PaperQA2. Model weights are not released (the contribution is the benchmark, not a model).</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Scope</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Accuracy (% correct)</td>
          <td>Per question, per topic, overall</td>
          <td>Strict: partially correct = incorrect</td>
      </tr>
      <tr>
          <td>Confidence calibration</td>
          <td>Ordinal 1-5 scale</td>
          <td>Verbalized, not logit-based</td>
      </tr>
      <tr>
          <td>Human comparison</td>
          <td>19 experts on ChemBench-Mini</td>
          <td>Tools allowed for subset</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not applicable; the benchmark is designed for API-based evaluation. Cost context: Liang et al. report &gt;US$10,000 for a single HELM evaluation, motivating ChemBench-Mini.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/lamalab-org/chembench">ChemBench Code &amp; Data</a></td>
          <td>Code + Dataset</td>
          <td>MIT</td>
          <td>Framework and benchmark corpus</td>
      </tr>
      <tr>
          <td><a href="https://zenodo.org/records/14010212">ChemBench Zenodo Archive</a></td>
          <td>Dataset</td>
          <td>MIT</td>
          <td>Version v0.2.0, archived</td>
      </tr>
      <tr>
          <td><a href="https://github.com/lamalab-org/chem-bench-app">ChemBench Web App</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Human baseline survey application</td>
      </tr>
      <tr>
          <td><a href="https://chembench.org">ChemBench Leaderboard</a></td>
          <td>Other</td>
          <td>N/A</td>
          <td>Public model leaderboard</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Mirza, A., Alampara, N., Kunchapu, S., Ríos-García, M., Emoekabu, B., Krishnan, A., &hellip; &amp; Jablonka, K. M. (2025). A framework for evaluating the chemical knowledge and reasoning abilities of large language models against the expertise of chemists. <em>Nature Chemistry</em>, 17(7), 1027-1034. <a href="https://doi.org/10.1038/s41557-025-01815-x">https://doi.org/10.1038/s41557-025-01815-x</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{mirza2025chembench,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{A framework for evaluating the chemical knowledge and reasoning abilities of large language models against the expertise of chemists}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Mirza, Adrian and Alampara, Nawaf and Kunchapu, Sreekanth and R{\&#39;\i}os-Garc{\&#39;\i}a, Marti{\~n}o and Emoekabu, Benedict and Krishnan, Aswanth and Gupta, Tanya and Schilling-Wilhelmi, Mara and Okereke, Macjonathan and Aneesh, Anagha and Asgari, Mehrdad and Eberhardt, Juliane and Elahi, Amir Mohammad and Elbeheiry, Hani M. and Gil, Mar{\&#39;\i}a Victoria and Glaubitz, Christina and Greiner, Maximilian and Holick, Caroline T. and Hoffmann, Tim and Ibrahim, Abdelrahman and Klepsch, Lea C. and K{\&#34;o}ster, Yannik and Kreth, Fabian Alexander and Meyer, Jakob and Miret, Santiago and Peschel, Jan Matthias and Ringleb, Michael and Roesner, Nicole C. and Schreiber, Johanna and Schubert, Ulrich S. and Stafast, Leanne M. and Wonanke, A. D. Dinga and Pieler, Michael and Schwaller, Philippe and Jablonka, Kevin Maik}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{17}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{7}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1027--1034}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer Nature}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41557-025-01815-x}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Benchmarking Molecular Property Prediction at Scale</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/systematic-study-molecular-property-prediction/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/systematic-study-molecular-property-prediction/</guid><description>A study training 62,820 models finds fixed molecular representations often outperform learned representations for property prediction.</description><content:encoded><![CDATA[<h2 id="a-large-scale-empirical-study-of-molecular-property-prediction">A Large-Scale Empirical Study of Molecular Property Prediction</h2>
<p>This is an <strong>Empirical</strong> paper that systematically benchmarks molecular property prediction across multiple dimensions: molecular representations, model architectures, evaluation metrics, data splitting strategies, and chemical space generalization. The primary contribution is a rigorous, large-scale comparison (62,820 trained models) showing that traditional machine learning models on fixed molecular representations frequently outperform recent deep representation learning approaches, and that several overlooked evaluation factors (statistical testing, metric choice, activity cliffs, dataset size) significantly influence conclusions about model performance.</p>
<h2 id="motivation-overlooked-evaluation-pitfalls-in-molecular-property-prediction">Motivation: Overlooked Evaluation Pitfalls in Molecular Property Prediction</h2>
<p>Molecular property prediction is a core task in AI-driven drug discovery, and recent years have seen a proliferation of representation learning methods (transformers on <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>, GNNs on molecular graphs) claiming improved performance on <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet benchmark datasets</a>. However, the authors identify several systemic problems in how these methods are evaluated:</p>
<ol>
<li><strong>Heavy reliance on MoleculeNet benchmarks</strong>, which may not reflect real-world drug discovery challenges. Some benchmark tasks (e.g., SIDER, ClinTox) are arguably unreasonable because they try to predict outcomes from chemical structure alone when other factors (food-drug interactions, patient-level variables) dominate.</li>
<li><strong>Lack of statistical rigor.</strong> Most papers report mean metrics over 3 or 10 splits without statistical tests. Without rigorous analysis, improved metrics could be statistical noise.</li>
<li><strong>Inconsistent data splits.</strong> Across studies, the actual splits vary because seeds and splitting implementations differ, making cross-paper comparisons unreliable.</li>
<li><strong>Inappropriate metrics.</strong> AUROC, the default for classification, can overestimate performance, especially on imbalanced datasets. Precision-oriented metrics (PPV, NPV) may be more relevant for virtual screening.</li>
<li><strong>Neglect of activity cliffs.</strong> Most studies only evaluate inter-scaffold generalization via scaffold splits, ignoring intra-scaffold generalization where structurally similar molecules exhibit drastically different activities (<a href="/notes/chemistry/molecular-design/property-prediction/activity-cliffs-benchmark/">activity cliffs</a>).</li>
</ol>
<h2 id="core-contribution-fixed-representations-often-outperform-learned-representations">Core Contribution: Fixed Representations Often Outperform Learned Representations</h2>
<p>The central finding is that traditional ML models (RF, SVM, XGBoost) operating on fixed molecular representations (RDKit2D descriptors, Morgan fingerprints, MACCS keys, AtomPairs) frequently outperform recent self-supervised pretrained models (<a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a>, GROVER) across diverse datasets. The authors frame the paper around a central thesis:</p>
<blockquote>
<p>&ldquo;A model cannot save an unqualified dataset which cannot remedy an improper evaluation for an ambiguous chemical space generalization claim.&rdquo;</p></blockquote>
<p>Key findings on representations and models:</p>
<ul>
<li><strong>RF on RDKit2D descriptors</strong> achieves the best performance on BACE, BBBP, ESOL, and Lipop under scaffold split. MolBERT only matches RF in HIV.</li>
<li><strong>Concatenating RDKit2D descriptors to GROVER&rsquo;s learned embeddings (GROVER_RDKit)</strong> significantly improves performance, suggesting the learned representations alone are insufficient and that fixed descriptors carry substantial predictive signal.</li>
<li><strong>For binding activity datasets</strong> (<a href="https://en.wikipedia.org/wiki/Opioid_receptor">opioid receptors</a> MOR, DOR, KOR), MorganBits fingerprints outperform other representations, consistent with the structural nature of binding.</li>
<li><strong>PhysChem descriptors</strong> excel on datasets where properties correlate strongly with simple molecular features (e.g., ESOL has a near-linear relationship between MolLogP and solubility), but perform poorly on binding activity datasets where the relationship is more complex.</li>
</ul>
<h2 id="experimental-setup-62820-models-across-diverse-datasets">Experimental Setup: 62,820 Models Across Diverse Datasets</h2>
<h3 id="models-evaluated">Models evaluated</h3>
<p>The study evaluates nine models across three categories:</p>
<ul>
<li><strong>Traditional ML</strong>: Random Forest (RF), Support Vector Machine (SVM), XGBoost</li>
<li><strong>Regular neural networks</strong>: RNN (GRU variant), GCN, GIN</li>
<li><strong>Pretrained models</strong>: MolBERT (SMILES-based, ~85M parameters, pretrained on 1.6M molecules), GROVER (graph-based, ~48M parameters, pretrained on ~10M molecules), and GROVER_RDKit (GROVER with concatenated RDKit2D descriptors)</li>
</ul>
<h3 id="molecular-representations">Molecular representations</h3>
<p>Six fixed representations are evaluated: RDKit2D descriptors (200 features), PhysChem descriptors (11 features), MACCS keys, MorganBits fingerprints, MorganCounts fingerprints, and AtomPairs fingerprints. Morgan fingerprints use radius 2 and 2048 bits after testing showed little difference between common parameter choices.</p>
<h3 id="datasets">Datasets</h3>
<table>
  <thead>
      <tr>
          <th>Category</th>
          <th>Datasets</th>
          <th>Task Type</th>
          <th>Source</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MoleculeNet benchmarks</td>
          <td>BACE, BBBP, HIV</td>
          <td>Classification</td>
          <td>MoleculeNet</td>
      </tr>
      <tr>
          <td>MoleculeNet benchmarks</td>
          <td>ESOL, FreeSolv, Lipop</td>
          <td>Regression</td>
          <td>MoleculeNet</td>
      </tr>
      <tr>
          <td>Opioids-related</td>
          <td>MDR1, CYP2D6, CYP3A4, MOR, DOR, KOR</td>
          <td>Classification + Regression</td>
          <td>ChEMBL</td>
      </tr>
      <tr>
          <td>Activity datasets</td>
          <td>24 targets</td>
          <td>Regression</td>
          <td>Cortes-Ciriano et al.</td>
      </tr>
      <tr>
          <td>Activity datasets</td>
          <td>30 targets (MoleculeACE)</td>
          <td>Regression</td>
          <td>Tilborg et al.</td>
      </tr>
      <tr>
          <td>Descriptor datasets</td>
          <td>MolWt, NumAtoms (16 sizes each)</td>
          <td>Regression</td>
          <td>ZINC250k</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation-protocol">Evaluation protocol</h3>
<ul>
<li>Both scaffold and random splits (80:10:10 ratio)</li>
<li><strong>30 different random seeds</strong> per experiment for statistical rigor</li>
<li><a href="https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U_test">Mann-Whitney U test</a> for pairwise significance ($p &lt; 0.05$, two-sided)</li>
<li>Multiple metrics per task: AUROC, AUPRC, PPV, NPV for classification; RMSE, MAE, $R^2$, Pearson $R$ for regression</li>
</ul>
<h3 id="key-metrics">Key metrics</h3>
<p>Classification:</p>
<p>$$
\text{PPV} = \frac{\text{TP}}{\text{TP} + \text{FP}}
$$</p>
<p>$$
\text{NPV} = \frac{\text{TN}}{\text{TN} + \text{FN}}
$$</p>
<p>Regression:</p>
<p>$$
\text{RMSE} = \sqrt{\frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2}
$$</p>
<p>$$
\text{MAE} = \frac{1}{N} \sum_{i=1}^{N} |y_i - \hat{y}_i|
$$</p>
<p>$$
\text{Pearson}_R = \frac{\sum_{i=1}^{N} (y_i - \bar{y}_{obs})(\hat{y}_i - \bar{y}_{pred})}{\sqrt{\sum_{i=1}^{N} (y_i - \bar{y}_{obs})^2 \sum_{i=1}^{N} (\hat{y}_i - \bar{y}_{pred})^2}}
$$</p>
<p>$$
R^2 = 1 - \frac{\sum_{i=1}^{N} (y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} (y_i - \bar{y}_{obs})^2}
$$</p>
<h2 id="key-findings-metrics-activity-cliffs-and-dataset-size">Key Findings: Metrics, Activity Cliffs, and Dataset Size</h2>
<h3 id="statistical-testing-is-essential">Statistical testing is essential</h3>
<p>Without statistical tests, there is a real risk of drawing incorrect conclusions. Analysis of individual splits shows that in certain splits, MolBERT or GROVER can appear to outperform RF, even though on aggregate with proper statistical testing, RF is significantly better. For example, in BBBP, RF dominates in 20 of 30 splits, but the remaining 10 could mislead a researcher using only a single split.</p>
<h3 id="metric-choice-changes-conclusions">Metric choice changes conclusions</h3>
<p>Different evaluation metrics can lead to contradictory conclusions about the same models:</p>
<ul>
<li>In BBBP under scaffold split, RF significantly outperforms other models by AUROC, but shows similar performance when evaluated by PPV or NPV.</li>
<li>In FreeSolv, GROVER outperforms RF by Pearson $R$ ($p &lt; 0.05$) but shows similar performance by $R^2$.</li>
<li>Pearson $R$ can overestimate $R^2$: even when $R^2$ drops to zero or negative, Pearson $R$ can remain around 0.5.</li>
<li>AUROC can be over-optimistic, especially on imbalanced datasets like CYP2D6 and CYP3A4.</li>
</ul>
<p>The authors argue that PPV and NPV are more practically relevant for <a href="/notes/chemistry/molecular-design/generation/evaluation/molscore-scoring-benchmarking-framework/">virtual screening</a> than AUROC or AUPRC, since the goal is to identify true hits among predicted positives (or true non-binders among predicted negatives).</p>
<h3 id="activity-cliffs-pose-a-major-challenge">Activity cliffs pose a major challenge</h3>
<p>Activity cliffs, defined as <a href="https://en.wikipedia.org/wiki/IC50">IC50</a> values spanning at least two orders of magnitude within one scaffold, are prevalent in the opioid-related datasets. Although AC scaffolds represent only about 10% of scaffolds, they encompass 25-46% of all molecules:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>AC scaffolds (%)</th>
          <th>AC molecules (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MDR1</td>
          <td>62 (10.2%)</td>
          <td>594 (41.3%)</td>
      </tr>
      <tr>
          <td>CYP2D6</td>
          <td>124 (9.3%)</td>
          <td>710 (31.0%)</td>
      </tr>
      <tr>
          <td>CYP3A4</td>
          <td>146 (7.2%)</td>
          <td>926 (25.2%)</td>
      </tr>
      <tr>
          <td>MOR</td>
          <td>213 (13.1%)</td>
          <td>1627 (46.1%)</td>
      </tr>
      <tr>
          <td>DOR</td>
          <td>178 (11.6%)</td>
          <td>1342 (41.6%)</td>
      </tr>
      <tr>
          <td>KOR</td>
          <td>218 (13.1%)</td>
          <td>1502 (45.2%)</td>
      </tr>
  </tbody>
</table>
<p>Prediction performance is consistently worse for AC molecules, indicating limited intra-scaffold generalization. Removing edge-case molecules (those sharing scaffolds with pIC50 spanning 5 to 7) from test sets generally improves classification performance, confirming that activity cliffs are a key source of prediction error.</p>
<h3 id="dataset-size-is-critical-for-representation-learning">Dataset size is critical for representation learning</h3>
<p>Experiments on descriptor datasets (predicting MolWt and NumAtoms) reveal clear patterns:</p>
<ul>
<li>With fewer than 1K data points, traditional ML on fixed representations outperforms all neural network models except pretrained GROVER, which shows competitive performance in the low-data regime.</li>
<li>MolBERT shows severely limited performance (RMSE &gt; 200 for MolWt) with fewer than 10K data points.</li>
<li>RNN achieves the best performance when dataset size exceeds 10K, demonstrating the promise of representation learning in the &ldquo;big-data&rdquo; regime.</li>
<li>SVM achieves near-perfect RMSE (close to zero) on datasets larger than 10K when paired with AtomPairs fingerprints.</li>
<li>GROVER&rsquo;s performance does not substantially improve with increasing dataset size, while MolBERT improves at 100K but is slow to benefit from more data.</li>
</ul>
<h3 id="representation-learning-models-show-higher-metric-variability">Representation learning models show higher metric variability</h3>
<p>Representation learning models, particularly GROVER, exhibit higher variability in performance metrics across splits. This variability correlates negatively with mean performance: models with higher variability tend to perform worse on average. The authors emphasize the importance of reporting metric variability alongside means.</p>
<h3 id="scaffold-split-versus-random-split">Scaffold split versus random split</h3>
<p>Prediction performance under scaffold split is consistently worse than under random split, confirming the inter-scaffold generalization challenge. Notably, random split alleviates the intra-scaffold generalization challenge because some AC scaffolds are seen during training.</p>
<h3 id="descriptors-correlate-with-specific-properties">Descriptors correlate with specific properties</h3>
<p>PhysChem descriptors excel on datasets where molecular properties correlate with simple descriptors (e.g., MolLogP has near $-1$ correlation with ESOL labels). For binding activity datasets, correlation coefficients mostly fall within $[-0.5, 0.5]$, explaining why PhysChem descriptors show limited performance on those tasks, while structural fingerprints are more useful.</p>
<h2 id="limitations-and-future-directions">Limitations and Future Directions</h2>
<p>The authors acknowledge several limitations:</p>
<ol>
<li><strong>Uncertainty from model training</strong> (random initialization, mini-batch shuffling) was not fully addressed. Ensembling was not evaluated due to computational cost.</li>
<li><strong>Experimental uncertainty in labels</strong> (noise, measurement error in pIC50 values) was not modeled, though it can be <a href="https://en.wikipedia.org/wiki/Homoscedasticity_and_heteroscedasticity">heteroscedastic</a> and impact performance.</li>
<li><strong>Model explainability</strong> was not covered, although it is important for building trust in AI tools for drug discovery.</li>
<li>The study focused on GROVERbase only (not GROVERlarge) due to computational constraints.</li>
</ol>
<p>Future directions include: exploring better ways to use fixed representations alongside learned ones, developing techniques for chemical space generalization (both inter- and intra-scaffold), incorporating experimental uncertainty into model training and evaluation, and generating larger high-quality datasets to fully harness representation learning models.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Benchmark</td>
          <td>MoleculeNet (BACE, BBBP, HIV, ESOL, FreeSolv, Lipop)</td>
          <td>642-41,127 molecules</td>
          <td>Downloaded from MolMapNet; max length &lt; 400</td>
      </tr>
      <tr>
          <td>Activity</td>
          <td>Opioids-related (MDR1, CYP2D6, CYP3A4, MOR, DOR, KOR)</td>
          <td>Varies</td>
          <td>Collected from ChEMBL27; pIC50 values</td>
      </tr>
      <tr>
          <td>Activity</td>
          <td>Cortes-Ciriano et al. 24 targets</td>
          <td>Varies</td>
          <td>Activity data for drug targets</td>
      </tr>
      <tr>
          <td>Activity</td>
          <td>MoleculeACE 30 targets</td>
          <td>Varies</td>
          <td>Activity cliffs emphasis</td>
      </tr>
      <tr>
          <td>Descriptor</td>
          <td>MolWt, NumAtoms from <a href="/notes/chemistry/datasets/zinc-22/">ZINC250k</a></td>
          <td>0.1K to 100K</td>
          <td>16 dataset sizes per descriptor</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>RF: 500 trees (following Chemprop)</li>
<li>SVM: linear kernel</li>
<li>XGBoost: gradient boosting regressor/classifier with default hyperparameters</li>
<li>RNN: GRU variant, hidden size 512, 3 fully connected layers</li>
<li>GCN/GIN: embedding dimension 300, 5 convolutional layers, hidden size 512</li>
<li>MolBERT: BERTBase architecture, 768 embedding, 12 layers, 12 heads, ~85M parameters (769 fine-tuned)</li>
<li>GROVER: GROVERbase, ~48M parameters (~5.2M fine-tuned)</li>
<li>All splits repeated 30 times with seeds 0-29</li>
</ul>
<h3 id="models">Models</h3>
<p>All model configurations, splits, and raw predictions are available in the <a href="https://github.com/dengjianyuan/Respite_MPP">GitHub repository</a>.</p>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics: AUROC, AUPRC, PPV, NPV (classification); RMSE, MAE, $R^2$, Pearson $R$ (regression). Statistical testing via Mann-Whitney U test ($p &lt; 0.05$, two-sided). <a href="https://en.wikipedia.org/wiki/Youden%27s_J_statistic">Youden&rsquo;s $J$ statistic</a> used to determine classification threshold for PPV/NPV.</p>
<h3 id="hardware">Hardware</h3>
<p>All neural network experiments run on a single NVIDIA V100 GPU for 100 epochs. Batch size 32 for most experiments; 256 for GROVER on HIV due to compute time (MolBERT takes ~3 hours per split on HIV at batch size 32; GROVER takes ~5 hours at batch size 256). The study is partially funded by Stony Brook University OVPR Seed Grant, using the AI Institute at Stony Brook for computational resources.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/dengjianyuan/Respite_MPP">Respite_MPP</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Code, data, and raw predictions</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.1038/s41467-023-41948-6">Nature Communications article</a></td>
          <td>Paper</td>
          <td>CC-BY-4.0</td>
          <td>Open access</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Deng, J., Yang, Z., Wang, H., Ojima, I., Samaras, D., &amp; Wang, F. (2023). A systematic study of key elements underlying molecular property prediction. <em>Nature Communications</em>, 14, 6395. <a href="https://doi.org/10.1038/s41467-023-41948-6">https://doi.org/10.1038/s41467-023-41948-6</a></p>
<p><strong>Publication</strong>: Nature Communications 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/dengjianyuan/Respite_MPP">Respite_MPP GitHub Repository</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{deng2023systematic,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{A systematic study of key elements underlying molecular property prediction}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Deng, Jianyuan and Yang, Zhibo and Wang, Hehe and Ojima, Iwao and Samaras, Dimitris and Wang, Fusheng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Communications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{14}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{6395}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41467-023-41948-6}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Benchmarking Chemistry Knowledge in Code-Gen LLMs</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/llm-chemistry-code-assessment/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/llm-chemistry-code-assessment/</guid><description>Benchmarking code-generating LLMs on 84 chemistry tasks spanning general chemistry, biochemistry, and computational chemistry with prompt engineering analysis.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: White, A. D., Hocky, G. M., Gandhi, H. A., Ansari, M., Cox, S., Wellawatte, G. P., Sasmal, S., Yang, Z., Liu, K., Singh, Y., &amp; Peña Ccoa, W. J. (2023). Assessment of chemistry knowledge in large language models that generate code. <em>Digital Discovery</em>, 2(2), 368-376. <a href="https://doi.org/10.1039/d2dd00087c">https://doi.org/10.1039/d2dd00087c</a></p>
<p><strong>Publication</strong>: Digital Discovery 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/ur-whitelab/nlcc-data">nlcc-data benchmark repository</a></li>
<li><a href="https://ur-whitelab.github.io/nlcc-data/">Evaluation completions website</a></li>
<li><a href="https://doi.org/10.5281/zenodo.6800475">Zenodo evaluation data (DOI: 10.5281/zenodo.6800475)</a></li>
</ul>
<h2 id="benchmarking-chemistry-knowledge-in-code-generating-llms">Benchmarking Chemistry Knowledge in Code-Generating LLMs</h2>
<p>This is an <strong>Empirical</strong> paper that evaluates code-generating large language models on chemistry tasks. The primary contribution is a categorized benchmark of 84 chemistry problems across 10 topics, along with a systematic evaluation of several LLMs (Codex cushman, Codex davinci, text-davinci-003, InCoder, CodeGen) on these tasks. The paper also provides practical guidance on prompt engineering strategies that improve accuracy.</p>
<h2 id="why-evaluate-llms-on-chemistry-coding-tasks">Why Evaluate LLMs on Chemistry Coding Tasks</h2>
<p>As of late 2022, LLMs trained on code (such as Codex and InCoder) had become widely available through tools like GitHub Copilot and Tabnine. An open question was whether these general-purpose code models contained sufficient domain knowledge to solve chemistry problems expressed as coding tasks. Chemistry has specialized language, equations, and conventions (e.g., <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> notation, thermodynamic relationships, molecular simulation methods) that may not be well-represented in general code training data. Prior work had shown that knowledge of the periodic table requires very high parameter counts, but the broader extent of chemistry knowledge in code LLMs was unexplored.</p>
<p>The authors sought to answer a specific question: do code-generating LLMs &ldquo;know&rdquo; chemistry? This means evaluating whether LLMs can correlate natural language descriptions of chemistry problems with correct code implementations, including proper equations, units, and use of domain-specific libraries.</p>
<h2 id="benchmark-design-and-prompt-engineering-strategies">Benchmark Design and Prompt Engineering Strategies</h2>
<p>The benchmark covers 10 topic categories:</p>
<table>
  <thead>
      <tr>
          <th>Topic</th>
          <th>Abbreviation</th>
          <th>N</th>
          <th>Expert-only</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Biochemistry</td>
          <td>bio</td>
          <td>13</td>
          <td>2</td>
      </tr>
      <tr>
          <td>Cheminformatics</td>
          <td>cheminf</td>
          <td>10</td>
          <td>0</td>
      </tr>
      <tr>
          <td>General chemistry</td>
          <td>genchem</td>
          <td>11</td>
          <td>0</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-simulation/">Molecular dynamics</a></td>
          <td>md</td>
          <td>11</td>
          <td>3</td>
      </tr>
      <tr>
          <td>Plotting</td>
          <td>plot</td>
          <td>10</td>
          <td>10</td>
      </tr>
      <tr>
          <td>Quantum mechanics</td>
          <td>qm</td>
          <td>8</td>
          <td>3</td>
      </tr>
      <tr>
          <td>Simulation methods</td>
          <td>sim</td>
          <td>8</td>
          <td>5</td>
      </tr>
      <tr>
          <td>Spectroscopy</td>
          <td>spect</td>
          <td>11</td>
          <td>1</td>
      </tr>
      <tr>
          <td>Statistics</td>
          <td>stats</td>
          <td>11</td>
          <td>1</td>
      </tr>
      <tr>
          <td>Thermodynamics</td>
          <td>thermo</td>
          <td>10</td>
          <td>0</td>
      </tr>
  </tbody>
</table>
<p>Each task is formatted as a Python function with a docstring describing the expected behavior. The LLM must generate a completion that passes automated unit tests. Of the 84 total prompts, 25 require expert evaluation (e.g., plotting tasks) where automated testing is insufficient.</p>
<p>The key prompt engineering insight is the use of &ldquo;contexts,&rdquo; which are code prepended before prompts. The authors tested several context strategies:</p>
<ul>
<li><strong>Custom context</strong>: Topic-specific imports (e.g., <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a> for cheminformatics) plus a one-line completion example to teach the model how to signal the end of output.</li>
<li><strong>Insert context</strong>: Uses model infilling capabilities instead of completion-based generation. Available for davinci and InCoder.</li>
<li><strong>Copyright context</strong>: Adding a copyright notice at the top of the file, which conditions the model toward higher-quality code patterns.</li>
<li><strong>Authority context</strong>: Adding &ldquo;This is written by an expert Python programmer.&rdquo;</li>
</ul>
<p>The copyright notice improved accuracy at higher temperatures. The intuition is that copyrighted code in training data tends to be higher-quality, so the notice acts similarly to lowering temperature. The best model/temperature combination (davinci at T=0.05) was already operating at effectively low temperature, so the copyright trick did not further improve it.</p>
<h2 id="experimental-setup-models-sampling-and-expert-evaluation">Experimental Setup: Models, Sampling, and Expert Evaluation</h2>
<h3 id="models-evaluated">Models evaluated</h3>
<p>The study compared five models, all decoder-only architectures:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Abbreviation</th>
          <th>Parameters</th>
          <th>Source</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>code-cushman-001</td>
          <td>cushman</td>
          <td>12B</td>
          <td>OpenAI (GPT-3 fine-tuned on code)</td>
      </tr>
      <tr>
          <td>code-davinci-002</td>
          <td>davinci</td>
          <td>~175B (estimated)</td>
          <td>OpenAI (GPT-3.5 class)</td>
      </tr>
      <tr>
          <td>text-davinci-003</td>
          <td>davinci3</td>
          <td>~175B (estimated)</td>
          <td>OpenAI (RLHF-adapted from davinci)</td>
      </tr>
      <tr>
          <td>InCoder</td>
          <td>incoder</td>
          <td>6B</td>
          <td>Fried et al. 2022</td>
      </tr>
      <tr>
          <td>CodeGen</td>
          <td>codegen</td>
          <td>16B</td>
          <td>Nijkamp et al. 2022</td>
      </tr>
  </tbody>
</table>
<h3 id="sampling-and-evaluation">Sampling and evaluation</h3>
<p>Completions were generated using top-k sampling (k=5) at three temperatures: T=0.05, 0.2, and 0.5. For InCoder-6B, GPU memory limited sampling to k=1. Error bars in all reported results are 95% confidence intervals from <a href="https://en.wikipedia.org/wiki/Bootstrapping_(statistics)">bootstrap resampling</a> across top-k samples.</p>
<p>Accuracy was defined following the HumanEval approach: a completion is correct if the code runs and passes unit tests, regardless of whether it matches a reference implementation.</p>
<h3 id="expert-evaluation">Expert evaluation</h3>
<p>Nine co-authors (postdoctoral scholars and Ph.D. students) performed 650 evaluations of davinci completions through a web interface. Each completion was scored on a 5-point scale: Perfect (5), Correct but not perfect (4), Runs and is almost correct (3), Does not run but is almost correct (2), Far from correct (1). Expert-evaluated accuracy counted only &ldquo;Perfect&rdquo; and &ldquo;Correct but not perfect&rdquo; as correct.</p>
<h3 id="key-results-by-topic-and-model">Key results by topic and model</h3>
<table>
  <thead>
      <tr>
          <th>Topic</th>
          <th>incoder</th>
          <th>codegen</th>
          <th>davinci</th>
          <th>davinci3</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>bio</td>
          <td>0%</td>
          <td>29%</td>
          <td>43%</td>
          <td>86%</td>
      </tr>
      <tr>
          <td>cheminf</td>
          <td>20%</td>
          <td>20%</td>
          <td>50%</td>
          <td>50%</td>
      </tr>
      <tr>
          <td>genchem</td>
          <td>29%</td>
          <td>86%</td>
          <td>86%</td>
          <td>86%</td>
      </tr>
      <tr>
          <td>md</td>
          <td>0%</td>
          <td>13%</td>
          <td>63%</td>
          <td>88%</td>
      </tr>
      <tr>
          <td>qm</td>
          <td>20%</td>
          <td>60%</td>
          <td>100%</td>
          <td>100%</td>
      </tr>
      <tr>
          <td>sim</td>
          <td>0%</td>
          <td>0%</td>
          <td>100%</td>
          <td>100%</td>
      </tr>
      <tr>
          <td>spect</td>
          <td>30%</td>
          <td>20%</td>
          <td>50%</td>
          <td>40%</td>
      </tr>
      <tr>
          <td>stats</td>
          <td>40%</td>
          <td>80%</td>
          <td>70%</td>
          <td>60%</td>
      </tr>
      <tr>
          <td>thermo</td>
          <td>10%</td>
          <td>10%</td>
          <td>80%</td>
          <td>70%</td>
      </tr>
      <tr>
          <td><strong>total</strong></td>
          <td><strong>17%</strong></td>
          <td><strong>35%</strong></td>
          <td><strong>72%</strong></td>
          <td><strong>75%</strong></td>
      </tr>
  </tbody>
</table>
<p>All accuracies reported use the best context for each model (copyright for incoder-6B, authority for codegen-16B, insert for davinci) at T=0.2.</p>
<h2 id="findings-llms-know-chemistry-with-caveats">Findings: LLMs Know Chemistry, With Caveats</h2>
<p>The central finding is that code-generating LLMs do contain substantial chemistry knowledge. The best model (davinci) achieved 72% overall accuracy, with prompt engineering contributing approximately 30 percentage points to this figure. The text-davinci-003 model, which was fine-tuned with RLHF, achieved 75% and showed reduced sensitivity to prompt engineering, suggesting that human feedback alignment partially subsumes the benefits of manual prompt design.</p>
<h3 id="strengths-and-successful-domains">Strengths and successful domains</h3>
<ul>
<li><strong>Quantum mechanics and simulation</strong>: davinci achieved 100% on both categories, indicating strong knowledge of computational chemistry equations and simulation patterns.</li>
<li><strong>General chemistry</strong>: All models except InCoder performed well (86%), suggesting that general chemistry concepts are well-represented in code training data.</li>
<li><strong>Molecular structure generation</strong>: InstructGPT showed some ability to connect natural language descriptions with SMILES strings, generating valid (though not exact) molecular structures from prompts like &ldquo;a phenol derivative.&rdquo;</li>
</ul>
<h3 id="limitations-and-failure-modes">Limitations and failure modes</h3>
<ul>
<li><strong>Lack of reasoning</strong>: The authors emphasize that LLMs demonstrate knowledge correlation, not reasoning. Davinci frequently uses &ldquo;relativistic <a href="https://en.wikipedia.org/wiki/Hartree%E2%80%93Fock_method">Hartree-Fock</a>&rdquo; for any prompt requesting a &ldquo;highly accurate&rdquo; quantum calculation, because it has memorized the association between &ldquo;relativistic&rdquo; and &ldquo;accurate&rdquo; rather than understanding the underlying chemistry.</li>
<li><strong>Hallucinated functions</strong>: When given difficult prompts (e.g., &ldquo;return the <a href="https://en.wikipedia.org/wiki/Residual_dipolar_coupling">residual dipolar couplings</a> given a SMILES string&rdquo;), the model invents non-existent functions like <code>MolToRDC</code>.</li>
<li><strong>API version mismatches</strong>: Many errors in the molecular dynamics category stem from the model using outdated function signatures for packages like MDTraj, likely reflecting the training data cutoff.</li>
<li><strong>Expert-evaluated accuracy is lower</strong>: On topics requiring expert evaluation (generally harder tasks), accuracy drops, and it correlates negatively with perceived difficulty.</li>
</ul>
<h3 id="practical-recommendations">Practical recommendations</h3>
<p>The paper offers several practical tips for using code LLMs in chemistry:</p>
<ol>
<li>Use correctly spelled, precise prompts. If a function should &ldquo;return&rdquo; a value, use the word &ldquo;return&rdquo; rather than &ldquo;compute.&rdquo;</li>
<li>Be explicit about what variables represent (e.g., specify that k is a spring constant, not Boltzmann&rsquo;s constant).</li>
<li>Import only the packages you intend to use, as the model will attempt to use all imported libraries.</li>
<li>Adding a copyright notice or &ldquo;expert programmer&rdquo; statement can improve accuracy, though RLHF-trained models are less sensitive to this.</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Evaluation</td>
          <td>nlcc-data benchmark</td>
          <td>84 prompts across 10 chemistry topics</td>
          <td>Open source, community-extensible</td>
      </tr>
      <tr>
          <td>Expert evaluation</td>
          <td>Human evaluations CSV</td>
          <td>650 evaluations</td>
          <td>Available in Supporting Information</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Evaluation uses automated unit testing for 59 of 84 prompts. Expert evaluation covers the remaining 25 prompts through a web-based scoring interface. Five completions per prompt were generated via top-k sampling at three temperatures.</p>
<h3 id="models">Models</h3>
<p>All models evaluated are external (OpenAI API for Codex/davinci, HuggingFace for InCoder/CodeGen). No new models were trained. Python version and packages were pinned to June 2021 to avoid library changes influencing results.</p>
<h3 id="evaluation">Evaluation</h3>
<p>Accuracy is binary: a completion passes all unit tests (1.0) or fails (0.0), averaged across top-k samples and temperatures. Expert evaluation uses a 5-point scale collapsed to binary (Perfect or Correct = 1.0).</p>
<h3 id="hardware">Hardware</h3>
<p>GPU memory limitations are mentioned for InCoder-6B (limiting k=1 instead of k=5). No other hardware details are specified.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/ur-whitelab/nlcc-data">nlcc-data benchmark</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>Open-source benchmark prompts and solutions</td>
      </tr>
      <tr>
          <td><a href="https://ur-whitelab.github.io/nlcc-data/">Evaluation website</a></td>
          <td>Other</td>
          <td>Unknown</td>
          <td>Web interface showing completions</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.6800475">Zenodo evaluation data</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>Expert evaluation completions in HTML</td>
      </tr>
      <tr>
          <td><a href="https://pubs.rsc.org/en/content/articlepdf/2023/dd/d2dd00087c">Paper (open access)</a></td>
          <td>Other</td>
          <td>CC-BY-NC</td>
          <td>Published article</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{white2023assessment,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Assessment of chemistry knowledge in large language models that generate code}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{White, Andrew D. and Hocky, Glen M. and Gandhi, Heta A. and Ansari, Mehrad and Cox, Sam and Wellawatte, Geemi P. and Sasmal, Subarna and Yang, Ziyue and Liu, Kangxin and Singh, Yuvraj and Peña Ccoa, Willmor J.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Digital Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{368--376}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/d2dd00087c}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Back Translation for Semi-Supervised Molecule Generation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/back-translation-molecule-generation/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/back-translation-molecule-generation/</guid><description>A semi-supervised method adapting NLP back translation to molecule generation, improving property optimization and retrosynthesis with unlabeled ZINC data.</description><content:encoded><![CDATA[<h2 id="semi-supervised-data-augmentation-for-molecular-tasks">Semi-Supervised Data Augmentation for Molecular Tasks</h2>
<p>This is a <strong>Method</strong> paper that introduces back translation, a semi-supervised technique from neural machine translation, to the domain of molecular generation. The primary contribution is a general-purpose data augmentation strategy that leverages large pools of unlabeled molecules (from databases like ZINC) to improve the performance of both sequence-based and graph-based models on molecule optimization and <a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">retrosynthesis</a> prediction tasks.</p>
<h2 id="bridging-the-labeled-data-gap-in-molecular-generation">Bridging the Labeled Data Gap in Molecular Generation</h2>
<p>Molecular generation tasks, such as property optimization and retrosynthesis, require paired training data: an input molecule (or property specification) mapped to a desired output molecule. Obtaining these labeled pairs is expensive and labor-intensive. Meanwhile, enormous databases of unlabeled molecules exist. ZINC alone contains over 750 million compounds, and PubChem has 109 million.</p>
<p>Prior approaches to using unlabeled molecular data include <a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">variational autoencoders (VAEs)</a> for learning latent representations, conditional recurrent neural networks for inverse design, and pretraining techniques borrowed from NLP. However, these methods either focus on representation learning rather than direct generation, or require task-specific architectural modifications. The authors identify back translation, a well-established technique in machine translation, as a natural fit for molecular generation tasks that can be treated as sequence-to-sequence mappings.</p>
<h2 id="back-translation-as-molecular-data-augmentation">Back Translation as Molecular Data Augmentation</h2>
<p>The core idea is straightforward. Given a main task that maps from source domain $\mathcal{X}$ to target domain $\mathcal{Y}$ (e.g., mapping low-QED molecules to high-QED molecules), the method trains a reverse model $g$ that maps from $\mathcal{Y}$ back to $\mathcal{X}$. This reverse model then &ldquo;back translates&rdquo; unlabeled molecules from $\mathcal{Y}$ to generate synthetic source molecules, creating pseudo-labeled training pairs.</p>
<p>The theoretical motivation comes from maximizing the reconstruction probability. Given an unlabeled molecule $y_u \in \mathcal{U}_y$, the logarithmic reconstruction probability through the reverse model $g$ and forward model $f$ is:</p>
<p>$$
\log P(y_u = \hat{y}_u \mid y_u; g, f) = \log \sum_{\hat{x}_u \in \mathcal{X}} P(\hat{x}_u \mid y_u; g) P(y_u = \hat{y}_u \mid \hat{x}_u; f)
$$</p>
<p>Since summing over the exponentially large space $\mathcal{X}$ is intractable, the authors apply Jensen&rsquo;s inequality to obtain a lower bound:</p>
<p>$$
\log P(y_u = \hat{y}_u \mid y_u; g, f) \geq \mathbb{E}_{\hat{x}_u \sim P(\cdot \mid y_u; g)} \log P(y_u = \hat{y}_u \mid \hat{x}_u; f)
$$</p>
<p>This lower bound is optimized via Monte Carlo sampling in three steps:</p>
<p><strong>Step 1</strong>: Train both forward model $f$ and reverse model $g$ on the labeled data $\mathcal{L}$:</p>
<p>$$
\begin{aligned}
\min_{\theta_f} \sum_{(x,y) \in \mathcal{L}} -\log P(y \mid x; \theta_f) \\
\min_{\theta_g} \sum_{(x,y) \in \mathcal{L}} -\log P(x \mid y; \theta_g)
\end{aligned}
$$</p>
<p><strong>Step 2</strong>: Use the trained reverse model $g$ to back translate each unlabeled molecule $y_u \in \mathcal{U}_y$, producing synthetic pairs:</p>
<p>$$
\hat{\mathcal{L}} = {(\hat{x}_u, y_u) \mid y_u \in \mathcal{U}_y, \hat{x}_u \text{ sampled from } P(\cdot \mid y_u; \theta_g)}
$$</p>
<p><strong>Step 3</strong>: Retrain the forward model $f$ on the combined labeled and synthetic data $\mathcal{L} \cup \hat{\mathcal{L}}$, warm-starting from the parameters obtained in Step 1:</p>
<p>$$
\min_{\theta_f^<em>} \sum_{(x,y) \in \mathcal{L} \cup \hat{\mathcal{L}}} -\log P(y \mid x; \theta_f^</em>)
$$</p>
<p>A key practical finding is that data filtration matters. When using large amounts of unlabeled data (1M molecules), keeping only the synthetic pairs that satisfy the same constraints as the labeled data (e.g., similarity thresholds and property ranges) significantly improves performance over using all back-translated data unfiltered.</p>
<h2 id="experiments-on-property-optimization-and-retrosynthesis">Experiments on Property Optimization and Retrosynthesis</h2>
<h3 id="molecular-property-improvement">Molecular Property Improvement</h3>
<p>The authors evaluate on four tasks from Jin et al. (2019, 2020), each requiring the model to improve a specific molecular property while maintaining structural similarity (measured by Dice similarity on Morgan fingerprints):</p>
<ul>
<li><strong>LogP</strong> (penalized <a href="https://en.wikipedia.org/wiki/Octanol-water_partition_coefficient">partition coefficient</a>): two settings with similarity thresholds $\delta \geq 0.4$ and $\delta \geq 0.6$</li>
<li><strong>QED</strong> (quantitative estimation of drug-likeness): translate molecules from QED range [0.7, 0.8] to [0.9, 1.0]</li>
<li><strong>DRD2</strong> (<a href="https://en.wikipedia.org/wiki/Dopamine_receptor_D2">dopamine type 2 receptor</a> activity): translate inactive ($P &lt; 0.5$) to active ($P \geq 0.5$)</li>
</ul>
<p>Two backbone architectures are tested: a Transformer (6 layers, 4 heads, 128-dim embeddings, 512-dim FFN) and HierG2G, a hierarchical graph-to-graph translation model. Unlabeled molecules are sampled from ZINC at 250K and 1M scales.</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>LogP ($\delta \geq 0.6$)</th>
          <th>LogP ($\delta \geq 0.4$)</th>
          <th>QED (%)</th>
          <th>DRD2 (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>JT-VAE</td>
          <td>0.28</td>
          <td>1.03</td>
          <td>8.8</td>
          <td>3.4</td>
      </tr>
      <tr>
          <td>GCPN</td>
          <td>0.79</td>
          <td>2.49</td>
          <td>9.4</td>
          <td>4.4</td>
      </tr>
      <tr>
          <td>JTNN</td>
          <td>2.33</td>
          <td>3.55</td>
          <td>59.9</td>
          <td>77.8</td>
      </tr>
      <tr>
          <td>Transformer baseline</td>
          <td>2.45</td>
          <td>3.69</td>
          <td>71.9</td>
          <td>60.2</td>
      </tr>
      <tr>
          <td>+BT (1M, filtered)</td>
          <td>2.86</td>
          <td>4.41</td>
          <td>82.9</td>
          <td>67.4</td>
      </tr>
      <tr>
          <td>HierG2G baseline</td>
          <td>2.49</td>
          <td>3.98</td>
          <td>76.9</td>
          <td>85.9</td>
      </tr>
      <tr>
          <td>+BT (250K, filtered)</td>
          <td>2.75</td>
          <td>4.24</td>
          <td>79.1</td>
          <td>87.3</td>
      </tr>
  </tbody>
</table>
<h3 id="retrosynthesis-prediction">Retrosynthesis Prediction</h3>
<p>On the USPTO-50K benchmark (50K reactions, 10 reaction types, 80/10/10 train/val/test split), the method is applied to Transformer and GLN (Graph Logic Network) backbones. For other approaches to this benchmark, see <a href="/notes/chemistry/molecular-design/reaction-prediction/tied-two-way-transformers-retrosynthesis/">Tied Two-Way Transformers</a> and <a href="/notes/chemistry/molecular-design/reaction-prediction/data-transfer-seq-to-seq-retrosynthesis/">Data Transfer for Retrosynthesis</a>. Unlabeled reactant sets are constructed by sampling molecules from ZINC and concatenating them following the training data&rsquo;s reactant count distribution ($N_1 : N_2 : N_3 = 29.3% : 70.4% : 0.3%$).</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Top-1</th>
          <th>Top-3</th>
          <th>Top-5</th>
          <th>Top-10</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Reaction type given</strong></td>
          <td></td>
          <td></td>
          <td></td>
          <td></td>
      </tr>
      <tr>
          <td>GLN</td>
          <td>64.2</td>
          <td>79.1</td>
          <td>85.2</td>
          <td>90.0</td>
      </tr>
      <tr>
          <td>Ours + GLN</td>
          <td>67.9</td>
          <td>82.5</td>
          <td>87.3</td>
          <td>91.5</td>
      </tr>
      <tr>
          <td>Transformer</td>
          <td>52.2</td>
          <td>68.2</td>
          <td>72.7</td>
          <td>77.4</td>
      </tr>
      <tr>
          <td>Ours + Transformer</td>
          <td>55.9</td>
          <td>72.8</td>
          <td>77.8</td>
          <td>79.7</td>
      </tr>
      <tr>
          <td><strong>Reaction type unknown</strong></td>
          <td></td>
          <td></td>
          <td></td>
          <td></td>
      </tr>
      <tr>
          <td>GLN</td>
          <td>52.5</td>
          <td>69.0</td>
          <td>75.6</td>
          <td>83.7</td>
      </tr>
      <tr>
          <td>Ours + GLN</td>
          <td>54.7</td>
          <td>70.2</td>
          <td>77.0</td>
          <td>84.4</td>
      </tr>
      <tr>
          <td>Transformer</td>
          <td>37.9</td>
          <td>57.3</td>
          <td>62.7</td>
          <td>68.1</td>
      </tr>
      <tr>
          <td>Ours + Transformer</td>
          <td>43.5</td>
          <td>58.8</td>
          <td>64.6</td>
          <td>69.7</td>
      </tr>
  </tbody>
</table>
<p>The improvements are largest at lower $k$ values (top-1 and top-3), suggesting that back translation helps the model make more precise high-confidence predictions.</p>
<h3 id="ablation-studies">Ablation Studies</h3>
<p><strong>Effect of unlabeled data size</strong>: On retrosynthesis with Transformer, performance improves as unlabeled data increases from 50K to 250K, then plateaus or declines beyond 250K. The authors attribute this to noise in the back-translated data outweighing the benefits at larger scales.</p>
<p><strong>Effect of labeled data size</strong>: With only 5K labeled samples, adding back-translated data hurts performance because the reverse model is too weak to generate useful synthetic data. As labeled data increases (10K, 25K, 50K), the benefit of back translation grows. This confirms that the method requires a reasonably well-trained reverse model to be effective.</p>
<p><strong>Data filtration</strong>: Using 1M unfiltered back-translated molecules sometimes hurts performance (e.g., QED drops from 71.9% to 75.1% vs. 82.9% with filtering), while filtering to enforce the same constraints as the labeled data recovers and exceeds the 250K filtered results.</p>
<h2 id="consistent-gains-across-architectures-and-tasks">Consistent Gains Across Architectures and Tasks</h2>
<p>The method achieves state-of-the-art results on all four molecular property improvement tasks and the USPTO-50K retrosynthesis benchmark at time of publication. Several observations stand out:</p>
<ol>
<li><strong>Architecture agnosticism</strong>: Back translation improves both sequence-based (Transformer) and graph-based (HierG2G, GLN) models, confirming that the approach is independent of the underlying architecture.</li>
<li><strong>Filtration is essential at scale</strong>: Unfiltered 1M back-translated data can degrade performance, but filtered data at the same scale consistently outperforms smaller unfiltered sets.</li>
<li><strong>Training overhead is moderate</strong>: On the DRD2 task, back translation with Transformer takes about 2.5x the supervised training time (11.0h vs. 8.5h for initial training), with the back-translation step itself taking under 1 hour.</li>
<li><strong>Diversity and novelty increase</strong>: Back translation improves both diversity (average pairwise distance among generated molecules) and novelty (fraction of generated molecules not seen in training) across QED and DRD2 tasks.</li>
</ol>
<p>The authors acknowledge limitations: the method does not form a closed loop between forward and reverse models (as in dual learning approaches), and the data filtration strategy is rule-based rather than learned. They suggest joint training of forward and reverse models and learned filtration as future directions.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training (property improvement)</td>
          <td>Jin et al. (2019, 2020) datasets</td>
          <td>34K-99K pairs</td>
          <td>LogP, QED, DRD2 tasks</td>
      </tr>
      <tr>
          <td>Training (retrosynthesis)</td>
          <td>USPTO-50K</td>
          <td>40K reactions</td>
          <td>80/10/10 split from Dai et al. (2019)</td>
      </tr>
      <tr>
          <td>Unlabeled molecules</td>
          <td>ZINC</td>
          <td>250K or 1M</td>
          <td>Randomly sampled</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Same as training</td>
          <td>800-1000 test samples</td>
          <td>Per-task test sets</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Back translation with optional data filtration</li>
<li>Beam search with $k=20$ for inference</li>
<li>Random sampling for back-translation step (Equation 5)</li>
<li>Dice similarity on Morgan fingerprints for similarity constraint</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Transformer</strong>: 6 layers, 4 attention heads, 128-dim embeddings, 512-dim FFN (for property improvement); 4 layers, 8 heads, 256-dim embeddings, 2048-dim FFN (for retrosynthesis)</li>
<li><strong>HierG2G</strong>: Settings from Jin et al. (2020)</li>
<li><strong>GLN</strong>: Settings from Dai et al. (2019)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Best Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>LogP improvement</td>
          <td>LogP ($\delta \geq 0.6$)</td>
          <td>2.86</td>
          <td>2.49 (HierG2G)</td>
          <td>Transformer + BT(1M, filtered)</td>
      </tr>
      <tr>
          <td>LogP improvement</td>
          <td>LogP ($\delta \geq 0.4$)</td>
          <td>4.41</td>
          <td>3.98 (HierG2G)</td>
          <td>Transformer + BT(1M, filtered)</td>
      </tr>
      <tr>
          <td>Success rate</td>
          <td>QED</td>
          <td>82.9%</td>
          <td>76.9% (HierG2G)</td>
          <td>Transformer + BT(1M, filtered)</td>
      </tr>
      <tr>
          <td>Success rate</td>
          <td>DRD2</td>
          <td>87.3%</td>
          <td>85.9% (HierG2G)</td>
          <td>HierG2G + BT(250K, filtered)</td>
      </tr>
      <tr>
          <td>Top-1 accuracy</td>
          <td>USPTO-50K (known type)</td>
          <td>67.9%</td>
          <td>64.2% (GLN)</td>
          <td>Ours + GLN</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper reports training times (8.5h for Transformer, 16.8h for HierG2G on DRD2 with 1M unlabeled data) but does not specify the GPU hardware used.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/fyabc/BT4MolGen">BT4MolGen</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation in Python</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Fan, Y., Xia, Y., Zhu, J., Wu, L., Xie, S., &amp; Qin, T. (2021). Back translation for molecule generation. <em>Bioinformatics</em>, 38(5), 1244-1251. <a href="https://doi.org/10.1093/bioinformatics/btab817">https://doi.org/10.1093/bioinformatics/btab817</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{fan2022back,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Back translation for molecule generation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Fan, Yang and Xia, Yingce and Zhu, Jinhua and Wu, Lijun and Xie, Shufang and Qin, Tao}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Bioinformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{38}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1244--1251}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Oxford University Press}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1093/bioinformatics/btab817}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>AMORE: Testing ChemLLM Robustness to SMILES Variants</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/amore-smiles-robustness-framework/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/amore-smiles-robustness-framework/</guid><description>AMORE is a zero-shot framework testing whether chemical language models recognize equivalent SMILES of the same molecule via embedding retrieval.</description><content:encoded><![CDATA[<h2 id="an-empirical-framework-for-probing-chemical-understanding">An Empirical Framework for Probing Chemical Understanding</h2>
<p>This is an <strong>Empirical</strong> paper that introduces Augmented Molecular Retrieval (AMORE), a zero-shot evaluation framework for chemical language models (ChemLMs). The primary contribution is a method to assess whether ChemLMs have learned genuine molecular semantics or simply memorize textual patterns. Rather than relying on traditional NLP metrics like BLEU and ROUGE, AMORE tests whether a model&rsquo;s embedding space treats chemically equivalent <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> representations as similar. The authors evaluate 12 models across multiple architectures (encoder-only, encoder-decoder, decoder-only) on two datasets and five augmentation types, and extend the analysis to downstream MoleculeNet tasks.</p>
<h2 id="why-standard-nlp-metrics-fail-for-chemical-evaluation">Why Standard NLP Metrics Fail for Chemical Evaluation</h2>
<p>Chemical language models are typically evaluated using text-based metrics from NLP (BLEU, ROUGE, METEOR) on tasks like molecule captioning. These metrics compare word overlap and sentence fluency but cannot detect whether a model truly understands molecular structure. A SMILES string like <code>C(=O)O</code> and its canonicalized or kekulized form represent the same molecule, yet text-based metrics would penalize valid reformulations. Embedding-based metrics like BERTScore are also insufficient because they were trained on general text, not chemical notation.</p>
<p>The core research question is direct: do evaluation metrics used on ChemLMs reflect actual chemical knowledge, or do the models simply imitate understanding by learning textual features? This question has practical consequences in pharmaceuticals and healthcare, where missteps in chemical reasoning carry serious risks.</p>
<h2 id="embedding-based-retrieval-as-a-chemical-litmus-test">Embedding-Based Retrieval as a Chemical Litmus Test</h2>
<p>AMORE exploits a fundamental property of molecular representations: a single molecule can be written as multiple valid SMILES strings that are chemically identical. These serve as &ldquo;total synonyms,&rdquo; a concept without a true analogue in natural language.</p>
<p>The framework works in four steps:</p>
<ol>
<li>Take a set $X = (x_1, x_2, \ldots, x_n)$ of $n$ molecular representations.</li>
<li>Apply a transformation $f$ to obtain augmented representations $X&rsquo; = (x&rsquo;_1, x&rsquo;_2, \ldots, x&rsquo;_n)$, where $x&rsquo;_i = f(x_i)$. The constraint is that $f$ must not change the underlying molecule.</li>
<li>Obtain vectorized embeddings $e(x_i)$ and $e(x&rsquo;_j)$ from the model for each original and augmented SMILES.</li>
<li>Evaluate in a retrieval task: given $e(x_i)$, retrieve $e(x&rsquo;_i)$ from the augmented set.</li>
</ol>
<p>The evaluation metrics are top-$k$ accuracy (whether the correct augmented SMILES ranks at position $\leq k$) and <a href="https://en.wikipedia.org/wiki/Mean_reciprocal_rank">Mean Reciprocal Rank</a> (MRR). Retrieval uses <a href="https://en.wikipedia.org/wiki/FAISS">FAISS</a> for efficient nearest-neighbor search. The key insight is that if a model truly understands molecular structure, it should embed different SMILES representations of the same molecule close together.</p>
<h3 id="five-smiles-augmentation-types">Five SMILES Augmentation Types</h3>
<p>The framework uses five identity-preserving augmentations, all executed through <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a>:</p>
<ol>
<li><strong>Canonicalization</strong>: Transform SMILES to the standardized RDKit canonical form.</li>
<li><strong>Hydrogen addition</strong>: Explicitly add hydrogen atoms that are normally implied (e.g., <code>C</code> becomes <code>[CH4]</code>). This dramatically increases string length.</li>
<li><strong>Kekulization</strong>: Convert aromatic ring notation to explicit alternating double bonds.</li>
<li><strong>Cycle renumbering</strong>: Replace ring-closure digit identifiers with random valid alternatives.</li>
<li><strong>Random atom order</strong>: Randomize the atom traversal order used to generate the SMILES string.</li>
</ol>
<h2 id="twelve-models-two-datasets-five-augmentations">Twelve Models, Two Datasets, Five Augmentations</h2>
<h3 id="models-evaluated">Models Evaluated</h3>
<p>The authors test 12 publicly available Transformer-based models spanning three architecture families:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Domain</th>
          <th>Parameters</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Text+Chem T5-standard</td>
          <td>Cross-modal</td>
          <td>220M</td>
      </tr>
      <tr>
          <td>Text+Chem T5-augm</td>
          <td>Cross-modal</td>
          <td>220M</td>
      </tr>
      <tr>
          <td>MolT5-base</td>
          <td>Cross-modal</td>
          <td>220M</td>
      </tr>
      <tr>
          <td>MolT5-large</td>
          <td>Cross-modal</td>
          <td>770M</td>
      </tr>
      <tr>
          <td>SciFive</td>
          <td>Text-only</td>
          <td>220M</td>
      </tr>
      <tr>
          <td>PubChemDeBERTa</td>
          <td>Chemical</td>
          <td>86M</td>
      </tr>
      <tr>
          <td>ChemBERT-ChEMBL</td>
          <td>Chemical</td>
          <td>6M</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a></td>
          <td>Chemical</td>
          <td>125M</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/bartsmiles-molecular-representations/">BARTSmiles</a></td>
          <td>Chemical</td>
          <td>400M</td>
      </tr>
      <tr>
          <td>ZINC-RoBERTa</td>
          <td>Chemical</td>
          <td>102M</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/multimodal/nach0-multimodal-chemical-language-model/">nach0</a></td>
          <td>Chemical</td>
          <td>220M</td>
      </tr>
      <tr>
          <td>ZINC-GPT</td>
          <td>Chemical</td>
          <td>87M</td>
      </tr>
  </tbody>
</table>
<h3 id="datasets">Datasets</h3>
<ul>
<li><strong>ChEBI-20 test set</strong>: ~3,300 molecule-description pairs, used for both AMORE retrieval and molecule captioning comparisons.</li>
<li><strong>Isomers</strong> (<a href="/notes/chemistry/datasets/qm9/">QM9</a> subset): 918 molecules that are all isomers of C9H12N2O, making retrieval harder because all molecules share the same molecular formula.</li>
</ul>
<h3 id="key-results-on-chebi-20">Key Results on ChEBI-20</h3>
<p>On the ChEBI-20 dataset (Table 2 from the paper), top-1 accuracy varies enormously by augmentation type. Cycle renumbering is easiest (up to 98.48% Acc@1 for SciFive), while hydrogen addition is hardest (no model exceeds 5.97% Acc@1).</p>
<p>For the cross-modal Text+Chem T5-standard model:</p>
<table>
  <thead>
      <tr>
          <th>Augmentation</th>
          <th>Acc@1</th>
          <th>Acc@5</th>
          <th>MRR</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Canonical</td>
          <td>63.03</td>
          <td>82.76</td>
          <td>72.4</td>
      </tr>
      <tr>
          <td>Hydrogen</td>
          <td>5.46</td>
          <td>10.85</td>
          <td>8.6</td>
      </tr>
      <tr>
          <td>Kekulization</td>
          <td>76.76</td>
          <td>92.03</td>
          <td>83.8</td>
      </tr>
      <tr>
          <td>Cycle</td>
          <td>96.70</td>
          <td>99.82</td>
          <td>98.2</td>
      </tr>
      <tr>
          <td>Random</td>
          <td>46.94</td>
          <td>74.18</td>
          <td>59.33</td>
      </tr>
  </tbody>
</table>
<h3 id="key-results-on-isomers">Key Results on Isomers</h3>
<p>Performance drops substantially on the Isomers dataset, where all molecules share the same formula. The best Acc@1 for hydrogen augmentation is just 1.53% (MolT5-large). Even for the relatively easy cycle augmentation, top scores drop from the high 90s to the low 90s for most models, and some models (BARTSmiles: 41.83%) struggle considerably.</p>
<h3 id="downstream-moleculenet-impact">Downstream MoleculeNet Impact</h3>
<p>The authors also fine-tuned models on original <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> training data and tested on augmented test sets across 9 tasks (regression, binary classification, multilabel classification). Results confirm that augmentations degrade downstream performance. For example, on ESOL regression, RMSE increased from 0.87 to 7.93 with hydrogen addition. Rankings computed using the Vote&rsquo;n&rsquo;Rank framework (using the <a href="https://en.wikipedia.org/wiki/Copeland%27s_method">Copeland rule</a>) show that hydrogen augmentation is the only one that substantially reshuffles model rankings; other augmentations preserve the original ordering.</p>
<h3 id="correlation-between-amore-and-captioning-metrics">Correlation Between AMORE and Captioning Metrics</h3>
<p>The differences in ROUGE/METEOR between original and augmented SMILES correlate with AMORE retrieval accuracy (Spearman correlation &gt; 0.7 with p-value = 0.003 for Acc@1). This validates AMORE as a proxy for predicting how augmentations will affect generation quality, without requiring labeled captioning data.</p>
<h2 id="current-chemlms-learn-syntax-not-chemistry">Current ChemLMs Learn Syntax, Not Chemistry</h2>
<p>The central finding is that existing ChemLMs are not robust to identity-preserving SMILES augmentations. Several specific conclusions emerge:</p>
<ol>
<li>
<p><strong>Hydrogen augmentation is catastrophic</strong>: All models fail (&lt; 6% Acc@1 on ChEBI-20, &lt; 2% on Isomers). The authors attribute this to the near-complete absence of explicit hydrogen in pretraining data, creating a distribution shift.</p>
</li>
<li>
<p><strong>Cross-modal models outperform unimodal ones</strong>: Models trained on both text and SMILES (Text+Chem T5, MolT5) consistently achieve higher retrieval accuracy on four of five augmentations.</p>
</li>
<li>
<p><strong>Augmentation difficulty follows a consistent order</strong>: For all models, hydrogen is hardest, followed by canonicalization, random ordering, kekulization, and cycle renumbering (easiest).</p>
</li>
<li>
<p><strong>Layer-wise analysis reveals instability</strong>: Retrieval accuracy across Transformer layers is correlated across augmentation types, suggesting that representations degrade at the same layers regardless of augmentation.</p>
</li>
<li>
<p><strong><a href="https://en.wikipedia.org/wiki/Levenshtein_distance">Levenshtein distance</a> partially explains difficulty</strong>: Hydrogen augmentation produces strings ~2x longer than originals (Levenshtein ratio of 1.49), but the low correlation between Levenshtein ratio and downstream metrics (ROUGE1 correlation of -0.05 for hydrogen) suggests string length alone does not explain the failure.</p>
</li>
</ol>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge several limitations. Only publicly available HuggingFace models were evaluated, excluding models like <a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">Chemformer</a> and <a href="/notes/chemistry/molecular-representations/encoders/molformer/">Molformer</a> that lack HF checkpoints. The study focuses exclusively on SMILES sequences, not 3D molecular structures or other formats like <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>. The augmentation types, while representative, do not cover all possible identity transformations.</p>
<p>The authors suggest that AMORE could serve as a regularization tool during training, for example by using metric learning to encourage models to embed SMILES variants of the same molecule close together.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Retrieval evaluation</td>
          <td>ChEBI-20 test set</td>
          <td>3,300 molecules</td>
          <td>Standard benchmark for molecule captioning</td>
      </tr>
      <tr>
          <td>Retrieval evaluation</td>
          <td>Isomers (QM9 subset)</td>
          <td>918 molecules</td>
          <td>All isomers of C9H12N2O</td>
      </tr>
      <tr>
          <td>Downstream evaluation</td>
          <td>MoleculeNet (9 tasks)</td>
          <td>Varies</td>
          <td>ESOL, Lipophilicity, FreeSolv, HIV, BBBP, BACE, Tox21, ToxCast, SIDER</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>SMILES augmentations via RDKit (canonicalization, hydrogen addition, kekulization, cycle renumbering, random atom ordering)</li>
<li>Nearest-neighbor retrieval using FAISS with L2, cosine, inner product, and HNSW metrics</li>
<li>Model ranking via Vote&rsquo;n&rsquo;Rank (Copeland rule) on MoleculeNet tasks</li>
</ul>
<h3 id="models">Models</h3>
<p>All 12 evaluated models are publicly available on HuggingFace. No custom model training was performed for the AMORE retrieval experiments. MoleculeNet experiments used standard fine-tuning on original training splits.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Acc@1</td>
          <td>Top-1 retrieval accuracy</td>
          <td>Primary AMORE metric</td>
      </tr>
      <tr>
          <td>Acc@5</td>
          <td>Top-5 retrieval accuracy</td>
          <td>Secondary AMORE metric</td>
      </tr>
      <tr>
          <td>MRR</td>
          <td>Mean Reciprocal Rank</td>
          <td>Average rank of correct match</td>
      </tr>
      <tr>
          <td>ROUGE-2</td>
          <td>Bigram overlap for captioning</td>
          <td>Compared against AMORE</td>
      </tr>
      <tr>
          <td>METEOR</td>
          <td>MT evaluation metric for captioning</td>
          <td>Compared against AMORE</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Computational resources from HPC facilities at HSE University. Specific GPU types and training times are not reported.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/ChemistryLLMs/AMORE">AMORE GitHub</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Framework code and evaluation data</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ganeeva, V., Khrabrov, K., Kadurin, A., &amp; Tutubalina, E. (2025). Measuring Chemical LLM robustness to molecular representations: a SMILES variation-based framework. <em>Journal of Cheminformatics</em>, 17(1). <a href="https://doi.org/10.1186/s13321-025-01079-0">https://doi.org/10.1186/s13321-025-01079-0</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ganeeva2025measuring,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Measuring Chemical LLM robustness to molecular representations: a SMILES variation-based framework}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ganeeva, Veronika and Khrabrov, Kuzma and Kadurin, Artur and Tutubalina, Elena}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{17}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-025-01079-0}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SMINA Docking Benchmark for De Novo Drug Design Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/smina-docking-benchmark/</link><pubDate>Mon, 23 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/smina-docking-benchmark/</guid><description>A docking-based benchmark for evaluating de novo drug design generative models, using SMINA scoring across eight protein targets from ChEMBL.</description><content:encoded><![CDATA[<h2 id="a-docking-based-benchmark-for-de-novo-drug-design">A Docking-Based Benchmark for De Novo Drug Design</h2>
<p>This is a <strong>Resource</strong> paper. Its primary contribution is a standardized benchmark for evaluating generative models in de novo drug design. Rather than introducing a new generative method, the paper provides a reusable evaluation framework built around molecular docking, a widely used computational proxy for predicting protein-ligand binding. The benchmark uses SMINA (a fork of <a href="https://en.wikipedia.org/wiki/AutoDock">AutoDock Vina</a>) to score generated molecules against eight protein targets, offering a more realistic evaluation than commonly used proxy metrics like logP or QED.</p>
<h2 id="why-existing-benchmarks-fall-short">Why Existing Benchmarks Fall Short</h2>
<p>De novo drug design methods are typically evaluated using simple proxy tasks that do not reflect the complexity of real drug discovery. The octanol-water partition coefficient (logP) can be trivially optimized by producing unrealistic molecules. The QED drug-likeness score suffers from the same issue. Neural network-based bioactivity predictors are similarly exploitable.</p>
<p>As Coley et al. (2020) note: &ldquo;The current evaluations for generative models do not reflect the complexity of real discovery problems.&rdquo;</p>
<p>More realistic evaluation approaches exist in adjacent domains (photovoltaics, excitation energies), where physical calculations are used to both train and evaluate models. Yet de novo drug design has largely relied on the same simplistic proxies. This gap between proxy task performance and real-world utility motivates the development of a docking-based benchmark that, while still a proxy, captures more of the structural complexity involved in protein-ligand interactions.</p>
<h2 id="benchmark-design-smina-docking-with-the-vinardo-scoring-function">Benchmark Design: SMINA Docking with the Vinardo Scoring Function</h2>
<p>The benchmark is defined by three components: (1) docking software that computes a ligand&rsquo;s pose in the binding site, (2) a scoring function that evaluates the pose, and (3) a training set of compounds with precomputed docking scores.</p>
<p>The concrete instantiation uses SMINA v. 2017.11.9 with the Vinardo scoring function:</p>
<p>$$S = -0.045 \cdot G + 0.8 \cdot R - 0.035 \cdot H - 0.6 \cdot B$$</p>
<p>where $S$ is the docking score, $G$ is the gauss term, $R$ is repulsion, $H$ is the hydrophobic term, and $B$ is the non-directional hydrogen bond term. The gauss and repulsion terms measure steric interactions between the ligand and the protein, while the hydrophobic and hydrogen bond terms capture favorable non-covalent contacts.</p>
<p>The benchmark includes three task variants:</p>
<ol>
<li><strong>Docking Score Function</strong>: Optimize the full Vinardo docking score (lower is better).</li>
<li><strong>Repulsion</strong>: Minimize only the repulsion component, defined as:</li>
</ol>
<p>$$
R(a_1, a_2) = \begin{cases}
d(a_1, a_2)^2 &amp; d(a_1, a_2) &lt; 0 \\
0 &amp; \text{otherwise}
\end{cases}
$$</p>
<p>where $d(a_1, a_2)$ is the inter-atomic distance minus the sum of <a href="https://en.wikipedia.org/wiki/Van_der_Waals_radius">van der Waals radii</a>.</p>
<ol start="3">
<li><strong>Hydrogen Bonding</strong>: Maximize the hydrogen bond term:</li>
</ol>
<p>$$
B(a_1, a_2) = \begin{cases}
0 &amp; (a_1, a_2) \text{ do not form H-bond} \\
1 &amp; d(a_1, a_2) &lt; -0.6 \\
0 &amp; d(a_1, a_2) \geq 0 \\
\frac{d(a_1, a_2)}{-0.6} &amp; \text{otherwise}
\end{cases}
$$</p>
<p>Scores are averaged over the top 5 binding poses for stability. Generated compounds are filtered by <a href="https://en.wikipedia.org/wiki/Lipinski%27s_rule_of_five">Lipinski&rsquo;s Rule of Five</a> and a minimum molecular weight of 100. Each model must generate 250 unique molecules per target.</p>
<p>Training data comes from <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a>, covering eight drug targets: 5-HT1B, 5-HT2B, ACM2, CYP2D6, ADRB1, MOR, A2A, and D2. Dataset sizes range from 1,082 (ADRB1) to 10,225 (MOR) molecules.</p>
<h2 id="experimental-evaluation-of-three-generative-models">Experimental Evaluation of Three Generative Models</h2>
<h3 id="models-tested">Models Tested</h3>
<p>Three popular generative models were evaluated:</p>
<ul>
<li><strong><a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">CVAE</a></strong> (Chemical Variational Autoencoder): A VAE operating on <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/">GVAE</a></strong> (Grammar Variational Autoencoder): Extends CVAE by enforcing grammatical correctness of generated SMILES.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a></strong>: A recurrent neural network trained first on ChEMBL in a supervised manner, then fine-tuned with reinforcement learning using docking scores as rewards.</li>
</ul>
<p>For CVAE and GVAE, molecules are generated by sampling from the latent space and taking 50 gradient steps to optimize an MLP that predicts the docking score. For REINVENT, a random forest model predicts docking scores from ECFP fingerprints, and the reward combines this prediction with the QED score.</p>
<h3 id="baselines">Baselines</h3>
<p>Two baselines provide context:</p>
<ul>
<li><strong>Training set</strong>: The top 50%, 10%, and 1% of docking scores from the ChEMBL training set.</li>
<li><strong><a href="/notes/chemistry/datasets/zinc-22/">ZINC</a> subset</strong>: A random sample of ~9.2 million drug-like molecules from ZINC, with the same percentile breakdowns.</li>
</ul>
<p>Diversity is measured as the mean <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto distance</a> (using 1024-bit ECFP with radius 2) between all pairs of generated molecules.</p>
<h3 id="key-results">Key Results</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Model</th>
          <th>5-HT1B Score</th>
          <th>5-HT1B Diversity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Docking Score</td>
          <td>CVAE</td>
          <td>-4.647</td>
          <td>0.907</td>
      </tr>
      <tr>
          <td>Docking Score</td>
          <td>GVAE</td>
          <td>-4.955</td>
          <td>0.901</td>
      </tr>
      <tr>
          <td>Docking Score</td>
          <td>REINVENT</td>
          <td>-9.774</td>
          <td>0.506</td>
      </tr>
      <tr>
          <td>Docking Score</td>
          <td>ZINC (10%)</td>
          <td>-9.894</td>
          <td>0.862</td>
      </tr>
      <tr>
          <td>Docking Score</td>
          <td>ZINC (1%)</td>
          <td>-10.496</td>
          <td>0.861</td>
      </tr>
      <tr>
          <td>Docking Score</td>
          <td>Train (10%)</td>
          <td>-10.837</td>
          <td>0.749</td>
      </tr>
  </tbody>
</table>
<p>On the full docking score task, CVAE and GVAE fail to match even the mean ZINC docking score. REINVENT performs substantially better (e.g., -9.774 on 5-HT1B) but still falls short of the top 10% ZINC scores (-9.894) in most cases. The exception is ACM2, where REINVENT&rsquo;s score (-9.775) exceeds the ZINC 10% threshold (-8.282).</p>
<p>On the repulsion task, all three models fail to outperform the top 10% ZINC scores. On the hydrogen bonding task (the easiest), GVAE and REINVENT nearly match the top 1% ZINC scores, suggesting that optimizing individual scoring components is more tractable than the full docking score.</p>
<p>A consistent finding across all experiments is that REINVENT generates substantially less diverse molecules than the training set (e.g., 0.506 vs. 0.787 mean Tanimoto distance on 5-HT1B). The t-SNE visualizations show generated molecules clustering in a single dense region, separate from the training data, regardless of optimization target.</p>
<p>The paper also notes a moderately strong correlation between docking scores and molecular weight or the number of rotatable bonds. Generated compounds achieve better docking scores at the same molecular weight after optimization, suggesting the models learn some structural preferences rather than simply exploiting molecular size.</p>
<h2 id="limitations-of-current-generative-models-for-drug-design">Limitations of Current Generative Models for Drug Design</h2>
<p>The main finding is negative: popular generative models for de novo drug design struggle to generate molecules that dock well when trained on realistically sized datasets (1,000 to 10,000 compounds). Even the best-performing model (REINVENT) generally cannot outperform the top 10% of a random ZINC subset on the full docking score task.</p>
<p>The authors acknowledge several limitations:</p>
<ul>
<li><strong>Docking is itself a proxy</strong>: The SMINA docking score is only an approximation of true binding affinity. The fact that even this simpler proxy is challenging should raise concerns about these models&rsquo; readiness for real drug discovery pipelines.</li>
<li><strong>Limited model selection</strong>: Only three models were tested (CVAE, GVAE, REINVENT). The authors note that CVAE and GVAE were not designed for small training sets, and REINVENT may not represent the state of the art in all respects.</li>
<li><strong>ML-based scoring surrogate</strong>: All models use an ML model (MLP or random forest) to predict docking scores during generation, rather than running SMINA directly. This introduces an additional approximation layer.</li>
<li><strong>No similarity constraints</strong>: The benchmark does not impose constraints on the distance between generated and training molecules. A trivial baseline is to simply return the training set.</li>
</ul>
<p>On a more positive note, the tested models perform well on the simplest subtask (hydrogen bonding), suggesting that optimizing docking scores from limited data is attainable but challenging. The benchmark has already been adopted by other groups, notably Nigam et al. (2021) for evaluating their JANUS genetic algorithm.</p>
<p>Future directions include adding similarity constraints, extending to additional protein targets, and using the benchmark to evaluate newer structure-based generative models that employ equivariant neural networks.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Evaluation</td>
          <td>ChEMBL (8 targets)</td>
          <td>1,082-10,225 molecules per target</td>
          <td>90/10 train/test split</td>
      </tr>
      <tr>
          <td>Baseline</td>
          <td>ZINC 15 subset</td>
          <td>~9.2M drug-like molecules</td>
          <td>In-stock, standard reactivity, drug-like</td>
      </tr>
      <tr>
          <td>Protein structures</td>
          <td><a href="https://en.wikipedia.org/wiki/Protein_Data_Bank">Protein Data Bank</a></td>
          <td>8 structures</td>
          <td>Cleaned with Schrodinger modeling package</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>CVAE/GVAE: Fine-tuned 5 epochs on target data, then 50 gradient steps in latent space to optimize MLP-predicted score</li>
<li>REINVENT: Pretrained on ChEMBL, fine-tuned with RL; reward = random forest prediction * QED score</li>
<li>All docking performed with SMINA v. 2017.11.9 using Vinardo scoring function in score_only mode</li>
<li>Scores averaged over top 5 binding poses</li>
<li>Filtering: Lipinski Rule of Five, minimum molecular weight 100</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Mean docking score</td>
          <td>Average over 250 generated molecules</td>
          <td>Lower is better for docking score and repulsion</td>
      </tr>
      <tr>
          <td>Diversity</td>
          <td>Mean Tanimoto distance (ECFP, r=2)</td>
          <td>Higher is more diverse</td>
      </tr>
      <tr>
          <td>ZINC percentile baselines</td>
          <td>Top 50%, 10%, 1% from random ZINC subset</td>
          <td>Task considered &ldquo;solved&rdquo; if generated score exceeds ZINC 1%</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/cieplinski-tobiasz/smina-docking-benchmark">smina-docking-benchmark</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Benchmark code, data, evaluation notebooks</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Cieplinski, T., Danel, T., Podlewska, S., &amp; Jastrzebski, S. (2023). Generative Models Should at Least Be Able to Design Molecules That Dock Well: A New Benchmark. <em>Journal of Chemical Information and Modeling</em>, 63(11), 3238-3247. <a href="https://doi.org/10.1021/acs.jcim.2c01355">https://doi.org/10.1021/acs.jcim.2c01355</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/cieplinski-tobiasz/smina-docking-benchmark">GitHub Repository</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{cieplinski2023generative,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Generative Models Should at Least Be Able to Design Molecules That Dock Well: A New Benchmark}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Cieplinski, Tobiasz and Danel, Tomasz and Podlewska, Sabina and Jastrzebski, Stanislaw}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{63}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{11}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{3238--3247}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.2c01355}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Regression Transformer: Prediction Meets Generation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/regression-transformer/</link><pubDate>Sun, 22 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/regression-transformer/</guid><description>The Regression Transformer unifies property prediction and conditional generation in one multitask model by casting regression as sequence modelling.</description><content:encoded><![CDATA[<h2 id="a-multitask-model-that-unifies-regression-and-generation">A Multitask Model That Unifies Regression and Generation</h2>
<p>The Regression Transformer (RT) is a <strong>Method</strong> paper. It introduces a single model architecture that can both predict continuous molecular properties and conditionally generate molecules with desired property values. The core idea is to reformulate regression as a sequence modelling task: instead of training a dedicated regression head, continuous property values are tokenized into sequences of digits and predicted alongside molecular tokens using a cross-entropy loss.</p>
<h2 id="closing-the-gap-between-predictors-and-generators">Closing the Gap Between Predictors and Generators</h2>
<p>Existing transformer-based approaches in computational chemistry develop property predictors and generative models as separate systems. Even when a single architecture like <a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">Chemformer</a> (Irwin et al., 2022) addresses both tasks, it does so through task-specific heads. This means the two capabilities remain disjoint, and the generative model cannot use its own property prediction ability during generation.</p>
<p>The RT addresses three specific gaps:</p>
<ol>
<li><strong>No true multitask entanglement</strong>: Prior work either tunes separate heads for prediction and generation or limits communication between modules to a reward signal.</li>
<li><strong>No inductive bias for continuous properties</strong>: Molecular generative models lack mechanisms to condition generation on floating-point property values.</li>
<li><strong>Disconnected workflows</strong>: Property predictors cannot generate molecules, and generators cannot assess whether their outputs satisfy property constraints.</li>
</ol>
<h2 id="core-innovation-regression-as-conditional-sequence-modelling">Core Innovation: Regression as Conditional Sequence Modelling</h2>
<p>The RT&rsquo;s key insight is that regression can be cast as sequential classification over digit tokens while preserving predictive accuracy. This is achieved through three components:</p>
<h3 id="numerical-tokenization">Numerical Tokenization</h3>
<p>Floating-point property values are split into individual digit tokens that preserve decimal order. Each token $t_{v,p}$ encodes a digit value $v \in [0, 9]$ and its decimal place $p \in \mathbb{Z}$. For example, the value 12.3 becomes the token sequence <code>[1_1, 2_0, 3_-1]</code>.</p>
<h3 id="numerical-encodings">Numerical Encodings</h3>
<p>To provide an inductive bias about the semantic proximity of digit tokens (which cross-entropy loss cannot convey), the RT introduces Numerical Encodings (NEs), analogous to positional encodings. For a token $t_{v,p}$ at embedding dimension $j$:</p>
<p>$$
\text{NE}_{\text{Float}}(v, p, j) = (-1)^j \cdot \frac{v \cdot 10^p}{j + 1}
$$</p>
<p>These encodings ensure that pairwise distances between digit tokens decay monotonically with their floating-point proximity. The model can also learn digit orderings from data alone, but NEs provide a useful inductive bias.</p>
<h3 id="alternating-training-with-self-consistency">Alternating Training with Self-Consistency</h3>
<p>The RT uses an <a href="https://en.wikipedia.org/wiki/XLNet">XLNet</a> backbone trained with permutation language modelling (PLM). The key is that the same model serves two roles depending on which tokens are masked:</p>
<ul>
<li><strong>Mask numerical tokens</strong>: the model performs property prediction (regression)</li>
<li><strong>Mask textual tokens</strong>: the model performs conditional sequence generation</li>
</ul>
<p>The base PLM objective is:</p>
<p>$$
\mathcal{L}_{\text{PLM}} = \mathbb{E}_{\mathbf{z} \sim \mathcal{Z}_T} \left[ \sum_{i=c+1}^{T} \log p_\theta(x_{z_i} \mid \mathbf{x}_{\mathbf{z}_{&lt; i}}) \right]
$$</p>
<p>This is refined into two specialized objectives: a property prediction objective $\mathcal{L}_P$ that masks only numerical tokens, and a generation objective $\mathcal{L}_G$ that masks only textual tokens. Training alternates between these every 50 steps.</p>
<p>The self-consistency (SC) loss adds a critical feedback loop. After generating a candidate molecule $\hat{\mathbf{x}}$, the model re-evaluates it by predicting the property of the generated sequence:</p>
<p>$$
\mathcal{L}_{\text{SC}} = \mathcal{L}_G(\mathbf{x}) + \alpha \cdot \mathcal{L}_P(\hat{\mathbf{x}})
$$</p>
<p>This rewards generating molecules whose predicted properties match the primed property value, exploiting the RT&rsquo;s dual capability as both predictor and generator.</p>
<h2 id="experiments-across-molecules-proteins-and-reactions">Experiments Across Molecules, Proteins, and Reactions</h2>
<h3 id="drug-likeness-qed">Drug Likeness (QED)</h3>
<p>Initial validation on a synthetic QED dataset (~1.4M molecules from <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a>) demonstrated that the RT can simultaneously learn to predict QED scores (RMSE &lt; 0.06) and generate novel molecules conditioned on desired QED values (Spearman&rsquo;s $\rho$ up to 0.517 between primers and generated molecule properties). Novelty exceeded 99% across all configurations. The alternating training scheme with SC loss outperformed both single-task models and the vanilla PLM objective.</p>
<p><a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> representations proved comparable to <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> for property prediction and far superior for generation (~100% validity vs. ~40% for SMILES).</p>
<h3 id="moleculenet-regression-benchmarks">MoleculeNet Regression Benchmarks</h3>
<p>On <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> benchmarks ESOL, FreeSolv, and Lipophilicity, the RT outperformed XGBoost and MPNN baselines despite using only a classification loss. It performed on par with XLNet using a conventional regression head, and was only mildly inferior to models like BERT and BART that used large-scale self-supervised pre-training with regression losses.</p>
<p>Critically, only the RT could also conditionally generate molecules for these tasks. External validation with Grover (a self-supervised Graph Transformer) confirmed high correlation with the RT&rsquo;s own property predictions (0.86, 0.84, and 0.75 for ESOL, FreeSolv, and Lipophilicity respectively).</p>
<h3 id="constrained-property-optimization">Constrained Property Optimization</h3>
<p>On the penalized logP (plogP) benchmark with similarity constraints, the RT outperformed JT-VAE and GCPN by large margins. At similarity threshold $\delta = 0.4$, the RT achieved 3.16 average improvement with 97.1% success rate, while also predicting plogP with PCC of 0.92. Competing methods cannot perform property prediction at all.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Improvement ($\delta$=0.4)</th>
          <th>Success</th>
          <th>Property Prediction</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>JT-VAE</td>
          <td>0.84</td>
          <td>83.6%</td>
          <td>Unfeasible</td>
      </tr>
      <tr>
          <td>GCPN</td>
          <td>2.49</td>
          <td>100%</td>
          <td>Unfeasible</td>
      </tr>
      <tr>
          <td>MoFlow</td>
          <td>4.71</td>
          <td>85.7%</td>
          <td>Unfeasible</td>
      </tr>
      <tr>
          <td><strong>RT</strong></td>
          <td><strong>3.16</strong></td>
          <td><strong>97.1%</strong></td>
          <td><strong>PCC = 0.92</strong></td>
      </tr>
  </tbody>
</table>
<p>The comparison is not strictly fair: all competing methods are trained specifically to maximize plogP, and some (GCPN, JT-VAE) apply gradient optimization at inference time. The RT is only trained to reconstruct molecules with similar predicted plogP to the seed, so its training objective is property-agnostic rather than directly optimizing for higher plogP values.</p>
<h3 id="protein-language-modelling">Protein Language Modelling</h3>
<p>On the TAPE benchmark, the RT matched or outperformed conventional transformers on fluorescence and stability prediction tasks, despite those baselines being pre-trained on 24-106 million protein sequences (vs. 2.6 million for the RT). The RT also performed conditional protein generation, a task that none of the TAPE baselines can address.</p>
<h3 id="chemical-reaction-modelling">Chemical Reaction Modelling</h3>
<p>The RT was applied to reaction yield prediction on <a href="https://en.wikipedia.org/wiki/Buchwald%E2%80%93Hartwig_amination">Buchwald-Hartwig amination</a> and <a href="https://en.wikipedia.org/wiki/Suzuki_reaction">Suzuki coupling</a> datasets. It matched Yield-BERT performance ($R^2$ = 0.939 and 0.81 respectively) while also enabling novel capabilities: reconstructing missing precursors from partial reactions and decorating existing reactions to achieve higher predicted yields. Across both datasets, over 40% of top-five predicted sequences contained reactions with novel precursors and higher predicted yield.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="key-findings">Key Findings</h3>
<ol>
<li>Regression can be successfully reformulated as sequential classification over digit tokens without losing predictive accuracy compared to models using regression losses.</li>
<li>The alternating training scheme with self-consistency loss enables cross-task benefits, where the model outperforms single-task variants at both prediction and generation.</li>
<li>A single ~27M parameter model handles property prediction, conditional molecular generation, conditional protein generation, and reaction yield prediction with precursor generation.</li>
<li>The model learns the natural ordering of digits from data: 47% of embedding dimensions for the tenths place directly encode digit ordering even without explicit numerical encodings.</li>
</ol>
<h3 id="limitations">Limitations</h3>
<ol>
<li><strong>No large-scale pre-training</strong>: The RT uses ~27M parameters trained from scratch on task-specific datasets, unlike <a href="/notes/chemistry/molecular-representations/encoders/bartsmiles-molecular-representations/">BARTSmiles</a> or MoLFormer which pre-train on billions of molecules. Scaling up could improve results.</li>
<li><strong>Fine-grained regression precision</strong>: The model sometimes struggles with intra-mode precision (e.g., on the fluorescence dataset where predictions cluster around bright/dark modes rather than capturing continuous variation).</li>
<li><strong>Single-property focus</strong>: All reported experiments use a single continuous property, though the framework naturally extends to multi-property settings.</li>
<li><strong>SELFIES validity caveats</strong>: While SELFIES are always syntactically valid, they can produce degenerate short molecules (~1.9% defective generations where the output has less than 50% of the seed&rsquo;s atoms).</li>
<li><strong>XLNet backbone limitations</strong>: Results on MoleculeNet regression are slightly below models using BART or BERT backbones with large-scale pre-training, suggesting the RT framework could benefit from stronger base models.</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/IBM/regression-transformer">Regression Transformer (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Training and evaluation scripts</td>
      </tr>
      <tr>
          <td><a href="https://github.com/GT4SD/gt4sd-core">GT4SD Integration</a></td>
          <td>Code + Models</td>
          <td>MIT</td>
          <td>Pre-trained model inference pipelines</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/spaces/GT4SD/regression_transformer">HuggingFace Demo</a></td>
          <td>Demo</td>
          <td>-</td>
          <td>Interactive inference webapp</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Drug likeness</td>
          <td>ChEMBL (QED)</td>
          <td>~1.4M molecules</td>
          <td>Synthetic QED labels computed with RDKit</td>
      </tr>
      <tr>
          <td>Regression benchmark</td>
          <td>MoleculeNet (ESOL, FreeSolv, Lipo)</td>
          <td>642-4,200 compounds</td>
          <td>16x SMILES augmentation, 3 random splits</td>
      </tr>
      <tr>
          <td>Property optimization</td>
          <td>ZINC (plogP)</td>
          <td>215,381 train / 799 test</td>
          <td>Fixed split from Jin et al. (2018)</td>
      </tr>
      <tr>
          <td>Protein pre-training</td>
          <td><a href="https://en.wikipedia.org/wiki/UniProt">UniProt</a> (Boman)</td>
          <td>2,648,205 peptides</td>
          <td>15-45 amino acid peptides</td>
      </tr>
      <tr>
          <td>Protein benchmarks</td>
          <td>TAPE (Fluorescence, Stability)</td>
          <td>21,446-53,416 samples</td>
          <td>Fixed splits</td>
      </tr>
      <tr>
          <td>Reaction pre-training</td>
          <td>USPTO</td>
          <td>2,830,616 reactions</td>
          <td>Molecular weight as numerical property</td>
      </tr>
      <tr>
          <td>Reaction yield</td>
          <td>Buchwald-Hartwig / Suzuki</td>
          <td>3,955 / 5,760 reactions</td>
          <td>Ten 70/30 random splits</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Architecture: XLNet (32 hidden layers, 256 hidden dim, 1024 FFN dim, 16 attention heads, 20% dropout)</li>
<li>Parameters: ~27 million</li>
<li>Training: Permutation language modelling pre-training, then alternating objectives (property prediction + conditional generation with SC loss)</li>
<li>Decoding: Greedy for property prediction, beam search for sequence generation</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Metric</th>
          <th>RT Result</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>QED prediction</td>
          <td>RMSE</td>
          <td>0.037</td>
          <td>Best config (NE + SC)</td>
      </tr>
      <tr>
          <td>QED generation</td>
          <td>Spearman&rsquo;s $\rho$</td>
          <td>0.517</td>
          <td>Between primers and generated QED</td>
      </tr>
      <tr>
          <td>ESOL</td>
          <td>RMSE</td>
          <td>Comparable to XLNet</td>
          <td>Within s.d. of regression-loss XLNet</td>
      </tr>
      <tr>
          <td>plogP optimization ($\delta$=0.4)</td>
          <td>Improvement</td>
          <td>3.16</td>
          <td>Outperforms JT-VAE, GCPN</td>
      </tr>
      <tr>
          <td>Protein fluorescence</td>
          <td>Spearman&rsquo;s $\rho$</td>
          <td>0.72</td>
          <td>Outperforms TAPE baselines</td>
      </tr>
      <tr>
          <td>BH yield prediction</td>
          <td>$R^2$</td>
          <td>0.939</td>
          <td>Near Yield-BERT (0.951)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>All models trained on single GPUs (NVIDIA A100 or V100)</li>
<li>Training time: ~4 days for pre-training, ~1 day for fine-tuning</li>
<li>Framework: PyTorch 1.3.1 with HuggingFace Transformers 3.1.0</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Born, J. &amp; Manica, M. (2023). Regression Transformer enables concurrent sequence regression and generation for molecular language modelling. <em>Nature Machine Intelligence</em>, 5(4), 432-444. <a href="https://doi.org/10.1038/s42256-023-00639-z">https://doi.org/10.1038/s42256-023-00639-z</a></p>
<p><strong>Publication</strong>: Nature Machine Intelligence, April 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/IBM/regression-transformer">Regression Transformer GitHub Repository</a></li>
<li><a href="https://github.com/GT4SD/gt4sd-core/tree/main/examples/regression_transformer">GT4SD Integration</a></li>
<li><a href="https://huggingface.co/spaces/GT4SD/regression_transformer">HuggingFace Demo</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{born2023regression,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Regression Transformer enables concurrent sequence regression and generation for molecular language modelling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Born, Jannis and Manica, Matteo}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Machine Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{432--444}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>BARTSmiles: BART Pre-Training for Molecular SMILES</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/bartsmiles-molecular-representations/</link><pubDate>Sun, 22 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/bartsmiles-molecular-representations/</guid><description>BARTSmiles applies BART-style denoising pre-training to 1.7B SMILES from ZINC20, achieving top results on 11 molecular property and reaction tasks.</description><content:encoded><![CDATA[<h2 id="a-bart-based-method-for-molecular-self-supervised-learning">A BART-Based Method for Molecular Self-Supervised Learning</h2>
<p>BARTSmiles is a <strong>Method</strong> paper. It introduces a self-supervised pre-training approach for molecular representations based on the BART (Bidirectional and Auto-Regressive Transformers) architecture from Lewis et al. (2019). The primary contribution is a pre-training strategy, discovered through systematic ablations, that trains a BART-large model on 1.7 billion deduplicated <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES strings</a> from the <a href="/notes/chemistry/datasets/zinc-22/">ZINC20 dataset</a>. BARTSmiles achieves the best reported results on 11 tasks spanning molecular property classification, regression, and chemical reaction generation.</p>
<h2 id="scaling-self-supervised-molecular-representations-beyond-prior-work">Scaling Self-Supervised Molecular Representations Beyond Prior Work</h2>
<p>At the time of publication, large-scale self-supervised representation learning had produced significant improvements in NLP, computer vision, and speech, but molecular representation learning had not benefited from comparable scale. Previous SMILES-based pre-trained models such as <a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a> (Chithrananda et al., 2020) and <a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">ChemFormer</a> (Irwin et al., 2022) used encoder-only or encoder-decoder architectures with substantially less compute. ChemFormer, the most closely related prior work, also trained a BART-like model but with a fraction of the compute and data.</p>
<p>The paper argues that three gaps needed to be addressed:</p>
<ol>
<li><strong>Scale</strong>: Prior molecular pre-training used orders of magnitude less compute than NLP pre-training.</li>
<li><strong>Architecture choice</strong>: Encoder-only models like ChemBERTa cannot perform generative fine-tuning (retrosynthesis, reaction prediction), limiting their applicability.</li>
<li><strong>Pre-training recipe</strong>: Standard BART hyperparameters (e.g., 30% mask token budget) were tuned for natural language and had not been validated for molecular SMILES strings.</li>
</ol>
<h2 id="core-innovation-ablation-driven-pre-training-recipe-for-smiles">Core Innovation: Ablation-Driven Pre-Training Recipe for SMILES</h2>
<p>The key insight of BARTSmiles is that the BART denoising objective, when carefully tuned for the molecular domain, learns representations that implicitly encode downstream task information. The authors discover this through a systematic three-stage ablation:</p>
<h3 id="tokenization">Tokenization</h3>
<p>Rather than using hand-crafted tokenization rules that separate individual atoms (C, N, H) and bond symbols (#, =), BARTSmiles uses a learned SentencePiece unigram tokenizer trained on 10 million random SMILES with a vocabulary size of 1,021. On matched compute budgets, learned tokenization achieves 0.801 average AUC-ROC vs. 0.779 for hand-crafted tokenization on the ablation benchmark (HIV, BBBP, ClinTox).</p>
<h3 id="masking-strategy">Masking Strategy</h3>
<p>The BART denoising objective has three main hyperparameters: the mask token budget (fraction of tokens masked), random mask probability, and the Poisson $\lambda$ controlling mask span length. The ablation results show:</p>
<ul>
<li><strong>Mask token budget</strong>: The standard BART value of 0.30 is suboptimal for molecules. A budget of 0.20 performs best (0.821 AUC-ROC), with performance degrading at both lower (0.10: 0.753) and higher (0.40: 0.701) budgets.</li>
<li><strong>Span masking</strong>: The choice of random mask probability and $\lambda$ has a minor effect once the budget is set to 0.20. Values of random mask = 0.10 and $\lambda$ = 2.5 or 3.5 all yield 0.821.</li>
<li><strong>Token randomization</strong>: Disabling the randomize-tokens noise (where some tokens are replaced with random tokens rather than masked) improves performance from 0.821 to 0.835.</li>
</ul>
<h3 id="scale">Scale</h3>
<p>Training on the full 1.7 billion molecule ZINC20 dataset (20 hours on 1,024 A100 GPUs, totaling 20,480 A100 GPU-hours) improves performance by 5 absolute AUC-ROC points over the same model trained on 100 million samples. The previous most compute-intensive molecular pre-training used 3,330 V100-hours (Ross et al., 2021).</p>
<h3 id="implicit-task-encoding">Implicit Task Encoding</h3>
<p>The paper provides a quantitative demonstration that frozen BARTSmiles representations encode task-specific information. Using L1-regularized logistic regression on frozen 1,024-dimensional mean-pooled representations, just 7 neurons are sufficient to achieve 0.987 AUC-ROC on ClinTox (within 2 percentage points of full fine-tuning). Even a single neuron achieves 0.77 AUC-ROC on ClinTox subtask 1.</p>
<h2 id="experimental-setup-moleculenet-toxicology-and-generative-benchmarks">Experimental Setup: MoleculeNet, Toxicology, and Generative Benchmarks</h2>
<h3 id="classification-tasks">Classification Tasks</h3>
<p>BARTSmiles is evaluated on 7 classification datasets from <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> (SIDER, ClinTox, Tox21, ToxCast, HIV, BACE, BBBP) plus 2 toxicology datasets (<a href="https://en.wikipedia.org/wiki/Ames_test">Ames</a>, <a href="https://en.wikipedia.org/wiki/Micronucleus_test">Micronucleus Assay</a>). All classification tasks use AUC-ROC. Baselines include both supervised graph models (D-MPNN, Attentive FP, 3D InfoMax) and self-supervised methods (ChemBERTa, <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MolFormer-XL</a>, GROVER-large, MolCLR, iMolCLR).</p>
<p>Selected classification results (AUC-ROC):</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>BARTSmiles</th>
          <th>Previous Best</th>
          <th>Previous Best Model</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ClinTox</td>
          <td><strong>0.997</strong></td>
          <td>0.954</td>
          <td>iMolCLR</td>
      </tr>
      <tr>
          <td>ToxCast</td>
          <td><strong>0.825</strong></td>
          <td>0.805</td>
          <td>Attentive FP</td>
      </tr>
      <tr>
          <td>SIDER</td>
          <td><strong>0.705</strong></td>
          <td>0.699</td>
          <td>iMolCLR</td>
      </tr>
      <tr>
          <td>Tox21</td>
          <td>0.851</td>
          <td>0.858</td>
          <td>Attentive FP</td>
      </tr>
  </tbody>
</table>
<p>The authors note that three scaffold-split datasets (HIV, BACE, BBBP) are highly sensitive to the specific split used, and they suspect some baseline results use different or random splits. These results are marked with caveats in the paper.</p>
<h3 id="regression-tasks">Regression Tasks</h3>
<p>All three MoleculeNet regression tasks (ESOL, FreeSolv, Lipophilicity) are evaluated using RMSE:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>BARTSmiles</th>
          <th>Previous Best</th>
          <th>Previous Best Model</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESOL</td>
          <td><strong>0.095</strong></td>
          <td>0.279</td>
          <td>MoLFormer-XL</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td><strong>0.114</strong></td>
          <td>0.231</td>
          <td>MoLFormer-XL</td>
      </tr>
      <tr>
          <td>Lipophilicity</td>
          <td><strong>0.292</strong></td>
          <td>0.529</td>
          <td>MoLFormer-XL</td>
      </tr>
  </tbody>
</table>
<p>BARTSmiles achieves substantial improvements on all three regression tasks.</p>
<h3 id="generative-tasks">Generative Tasks</h3>
<p><strong><a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">Retrosynthesis</a></strong> (USPTO-50k): BARTSmiles achieves 55.6% Top-1 accuracy using a sample-128 + perplexity re-ranking strategy, compared to 55.3% for Dual-TF and 54.3% for ChemFormer. Top-5 and Top-10 results are 74.2% and 80.9% respectively.</p>
<p><strong>Chemical Reaction Prediction</strong> (USPTO MIT/LEF/STEREO): BARTSmiles with beam search outperforms the <a href="/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/">Molecular Transformer</a> baseline across all six evaluation settings. On USPTO-MIT (split), BARTSmiles achieves 91.8% vs. 90.4% for the Transformer baseline.</p>
<h3 id="fine-tuning-recipe">Fine-Tuning Recipe</h3>
<p>The fine-tuning approach is designed to minimize hyperparameter tuning:</p>
<ul>
<li>Batch size 16, 10 epochs, polynomial decay learning rate schedule with warmup at 16% of training</li>
<li>Grid search over dropout (0.1, 0.2, 0.3) and learning rate ($5 \times 10^{-6}$, $1 \times 10^{-5}$, $3 \times 10^{-5}$)</li>
<li>Stochastic Weight Averaging (SWA) over three sets of four checkpoints</li>
<li>For generative tasks: R3F regularization (Aghajanyan et al., 2020a) and full fp32 precision</li>
<li>For generation: beam search (beam size 10) or sample 128 sequences with perplexity re-ranking</li>
</ul>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="key-findings">Key Findings</h3>
<ol>
<li><strong>Scale matters for molecular pre-training</strong>: Training on 1.7B molecules with 20,480 A100 GPU-hours yields 5 absolute points of AUC-ROC improvement over training on 100M molecules.</li>
<li><strong>Domain-specific ablation is necessary</strong>: The optimal BART masking configuration for molecules (20% budget, no token randomization) differs from the standard NLP configuration (30% budget, with randomization).</li>
<li><strong>Frozen representations capture task structure</strong>: A small number of neurons from the frozen model can nearly match full fine-tuning performance on certain tasks, suggesting the pre-training objective implicitly encodes molecular properties.</li>
<li><strong>Interpretability aligns with domain knowledge</strong>: Integrated Gradients attribution on fine-tuned BARTSmiles highlights known structural alerts (e.g., <a href="https://en.wikipedia.org/wiki/Nitro_compound">nitro groups</a> in mutagenic compounds, hydroxyl groups in soluble compounds).</li>
</ol>
<h3 id="limitations">Limitations</h3>
<ul>
<li><strong>Scaffold split sensitivity</strong>: Results on HIV, BACE, and BBBP are sensitive to the specific scaffold split, making direct comparison with baselines difficult.</li>
<li><strong>Pre-training data distribution</strong>: The <a href="https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance">Frechet distance</a> analysis shows that some downstream datasets (BBBP, SIDER) are far from ZINC20 in representation space, which may explain weaker performance on those tasks.</li>
<li><strong>Fingerprints carry complementary information</strong>: On the Ames and Micronucleus Assay datasets, BARTSmiles alone does not beat fingerprint-based baselines. Combining BARTSmiles with ECFP4 fingerprints closes the gap, implying that SMILES-based pre-training does not fully capture all structural information.</li>
<li><strong>Compute requirements</strong>: Pre-training requires 1,024 A100 GPUs, which limits accessibility.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The authors suggest investigating the impact of pre-training data composition, noting that ZINC20 contains over a billion molecules but its distribution may be irrelevant for many downstream tasks. They also propose further collaboration between ML and chemistry experts to discover new molecular substructure-property relationships.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/YerevaNN/BARTSmiles">BARTSmiles (GitHub)</a></td>
          <td>Code + Model</td>
          <td>MIT</td>
          <td>Pre-training, fine-tuning, and evaluation scripts with pre-trained weights</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ZINC20 (deduplicated)</td>
          <td>~1.7B molecules</td>
          <td>Canonicalized SMILES, 10K validation holdout</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>MoleculeNet (7 datasets)</td>
          <td>1,427-41,127 compounds</td>
          <td>AUC-ROC metric</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>MoleculeNet (3 datasets)</td>
          <td>642-4,200 compounds</td>
          <td>RMSE metric</td>
      </tr>
      <tr>
          <td>Toxicology</td>
          <td>Ames, MN Assay</td>
          <td>6,512 / 641 compounds</td>
          <td>Cross-validation for Ames; external test for MN</td>
      </tr>
      <tr>
          <td>Retrosynthesis</td>
          <td>USPTO-50k</td>
          <td>Standard split</td>
          <td>Top-K accuracy</td>
      </tr>
      <tr>
          <td>Reaction prediction</td>
          <td>USPTO (MIT/LEF/STEREO)</td>
          <td>Standard splits</td>
          <td>Top-1 accuracy</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Architecture: BART-Large (pre-layer norm Transformer encoder-decoder)</li>
<li>Tokenizer: SentencePiece unigram, vocabulary size 1,021, max sequence length 128</li>
<li>Pre-training objective: BART denoising (mask token budget 0.20, Poisson span masking with $\lambda$ = 2.5, no token randomization)</li>
<li>Fine-tuning: polynomial decay LR, SWA, grid search over dropout and LR</li>
<li>Generative fine-tuning: R3F regularization, fp32 precision, Adam initialized from pre-training moving averages</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>BART-Large architecture (exact parameter count not specified in paper)</li>
<li>Pre-trained checkpoint released on GitHub</li>
<li>Maximum sequence length: 128 tokens</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Metric</th>
          <th>BARTSmiles</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ClinTox</td>
          <td>AUC-ROC</td>
          <td>0.997</td>
          <td>New SOTA</td>
      </tr>
      <tr>
          <td>ToxCast</td>
          <td>AUC-ROC</td>
          <td>0.825</td>
          <td>New SOTA</td>
      </tr>
      <tr>
          <td>ESOL</td>
          <td>RMSE</td>
          <td>0.095</td>
          <td>New SOTA</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>RMSE</td>
          <td>0.114</td>
          <td>New SOTA</td>
      </tr>
      <tr>
          <td>Lipophilicity</td>
          <td>RMSE</td>
          <td>0.292</td>
          <td>New SOTA</td>
      </tr>
      <tr>
          <td>USPTO-50k Retro (Top-1)</td>
          <td>Accuracy</td>
          <td>55.6%</td>
          <td>New SOTA (sample + re-rank)</td>
      </tr>
      <tr>
          <td>USPTO-MIT Rxn (Split)</td>
          <td>Accuracy</td>
          <td>91.8%</td>
          <td>New SOTA (beam-10)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Pre-training: 1,024 NVIDIA A100 GPUs for 20 hours (20,480 A100 GPU-hours)</li>
<li>Ablation runs: 128 A100 GPUs per run</li>
<li>Framework: FairSeq with FairScale (fully sharded data parallel), automatic mixed precision</li>
<li>Experiment tracking: Aim</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chilingaryan, G., Tamoyan, H., Tevosyan, A., Babayan, N., Khondkaryan, L., Hambardzumyan, K., Navoyan, Z., Khachatrian, H., &amp; Aghajanyan, A. (2024). BARTSmiles: Generative Masked Language Models for Molecular Representations. <em>Journal of Chemical Information and Modeling</em>, 64(15), 5832-5843. <a href="https://doi.org/10.1021/acs.jcim.4c00512">https://doi.org/10.1021/acs.jcim.4c00512</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling, 2024 (preprint: arXiv 2022)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/YerevaNN/BARTSmiles">BARTSmiles GitHub Repository (MIT License)</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{chilingaryan2024bartsmiles,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{BARTSmiles: Generative Masked Language Models for Molecular Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Chilingaryan, Gayane and Tamoyan, Hovhannes and Tevosyan, Ani and Babayan, Nelly and Khondkaryan, Lusine and Hambardzumyan, Karen and Navoyan, Zaven and Khachatrian, Hrant and Aghajanyan, Armen}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{64}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{5832--5843}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.4c00512}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Molecular Transformer: Calibrated Reaction Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/</link><pubDate>Wed, 18 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/</guid><description>A Transformer seq2seq model for chemical reaction prediction achieving 90.4% top-1 accuracy on USPTO_MIT with calibrated uncertainty estimation.</description><content:encoded><![CDATA[<h2 id="paper-contribution-and-methodological-classification">Paper Contribution and Methodological Classification</h2>
<p>This is a <strong>Method</strong> paper. It adapts the Transformer architecture to chemical reaction prediction, treating it as a machine translation problem from reactant <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> to product SMILES. The key contributions are (1) demonstrating that a fully attention-based model outperforms all prior template-based, graph-based, and RNN-based methods, (2) showing the model works without separating reactants from reagents, and (3) introducing calibrated uncertainty estimation for ranking synthesis pathways.</p>
<h2 id="motivation-limitations-of-existing-reaction-prediction">Motivation: Limitations of Existing Reaction Prediction</h2>
<p>Prior approaches to reaction prediction fell into two broad groups, template-based and template-free, each with fundamental limitations:</p>
<ul>
<li><strong>Template-based methods</strong> rely on libraries of reaction rules, either handcrafted or automatically extracted from atom-mapped data. Automatic template extraction itself depends on atom mapping, which depends on templates, creating a circular dependency.</li>
<li><strong>Graph-based template-free methods</strong> (e.g., WLDN, ELECTRO) avoid explicit templates but still require atom-mapped training data and cannot handle stereochemistry.</li>
<li><strong><a href="/notes/chemistry/molecular-design/reaction-prediction/nmt-organic-reaction-prediction/">RNN-based seq2seq models</a></strong> (also template-free) treat reactions as SMILES translation but impose a positional inductive bias: tokens far apart in the SMILES string are assumed to be less related. This is incorrect because SMILES position has no relationship to 3D spatial distance.</li>
</ul>
<h2 id="core-innovation-transformer-for-reaction-prediction">Core Innovation: Transformer for Reaction Prediction</h2>
<p>The Molecular Transformer adapts the Transformer architecture to chemical reactions by treating SMILES strings of reactants and reagents as source sequences and product SMILES as target sequences.</p>
<ul>
<li><strong>Architecture</strong>: Encoder-decoder Transformer with 4 layers, 256-dimensional hidden states, 8 attention heads, and 12M parameters (reduced from the original 65M NMT model).</li>
<li><strong>Tokenization</strong>: Atom-wise regex tokenization of SMILES strings, applied uniformly to both reactants and reagents (no special reagent tokens).</li>
<li><strong>Data augmentation</strong>: Training data is doubled by generating <a href="/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/">random (non-canonical) SMILES</a> for each reaction, which improves top-1 accuracy by roughly 1%.</li>
<li><strong>Weight averaging</strong>: Final model weights are averaged over the last 20 checkpoints, providing a further accuracy boost without the inference cost of ensembling.</li>
<li><strong>Mixed input</strong>: Unlike all prior work that separates reactants from reagents (which implicitly assumes knowledge of the product), the Molecular Transformer operates on mixed inputs where no distinction is made.</li>
</ul>
<p>The multihead attention mechanism is the key architectural advantage over RNNs. It allows the model to attend to any pair of tokens regardless of their position in the SMILES string, correctly capturing long-range chemical relationships that RNNs miss.</p>
<h2 id="uncertainty-estimation">Uncertainty Estimation</h2>
<p>A central contribution is calibrated uncertainty scoring. The product of predicted token probabilities serves as a confidence score for each prediction. This score achieves 0.89 AUC-ROC for classifying whether a prediction is correct.</p>
<p>An important finding: <strong>label smoothing hurts uncertainty calibration</strong>. While label smoothing (as used in the original Transformer) marginally improves top-1 accuracy (87.44% vs 87.28%), it destroys the model&rsquo;s ability to distinguish correct from incorrect predictions. Setting the label smoothing parameter to 0.0 preserves calibration.</p>
<p>The confidence score shows no correlation with SMILES length (Pearson $r = 0.06$), confirming it is not biased against predictions of larger molecules.</p>
<h2 id="experimental-results">Experimental Results</h2>
<h3 id="forward-synthesis-prediction">Forward Synthesis Prediction</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Setting</th>
          <th style="text-align: left">Top-1 (%)</th>
          <th style="text-align: left">Top-2 (%)</th>
          <th style="text-align: left">Top-5 (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">USPTO_MIT</td>
          <td style="text-align: left">separated</td>
          <td style="text-align: left">90.4</td>
          <td style="text-align: left">93.7</td>
          <td style="text-align: left">95.3</td>
      </tr>
      <tr>
          <td style="text-align: left">USPTO_MIT</td>
          <td style="text-align: left">mixed</td>
          <td style="text-align: left">88.6</td>
          <td style="text-align: left">92.4</td>
          <td style="text-align: left">94.2</td>
      </tr>
      <tr>
          <td style="text-align: left">USPTO_STEREO</td>
          <td style="text-align: left">separated</td>
          <td style="text-align: left">78.1</td>
          <td style="text-align: left">84.0</td>
          <td style="text-align: left">87.1</td>
      </tr>
      <tr>
          <td style="text-align: left">USPTO_STEREO</td>
          <td style="text-align: left">mixed</td>
          <td style="text-align: left">76.2</td>
          <td style="text-align: left">82.4</td>
          <td style="text-align: left">85.8</td>
      </tr>
  </tbody>
</table>
<p>The mixed-input model (88.6%) outperforms all prior methods that used separated inputs (best previous: WLDN5 at 85.6%).</p>
<h3 id="comparison-with-quantum-chemistry">Comparison with Quantum Chemistry</h3>
<p>On <a href="https://en.wikipedia.org/wiki/Regioselectivity">regioselectivity</a> of <a href="https://en.wikipedia.org/wiki/Electrophilic_aromatic_substitution">electrophilic aromatic substitution</a> in heteroaromatics, the Molecular Transformer achieves 83% top-1 accuracy vs 81% for RegioSQM (a quantum-chemistry-based predictor), at a fraction of the computational cost.</p>
<h3 id="comparison-with-human-chemists">Comparison with Human Chemists</h3>
<p>On 80 reactions sampled across rarity bins, the Molecular Transformer achieves 87.5% top-1 accuracy vs 76.5% for the best human chemist and 72.5% for the best graph-based model (WLDN5).</p>
<h3 id="chemically-constrained-beam-search">Chemically Constrained Beam Search</h3>
<p>Constraining beam search to only predict atoms present in the reactants (preventing &ldquo;alchemy&rdquo;) produces no change in accuracy, confirming the model has learned conservation of atoms from data alone.</p>
<h2 id="trade-offs-and-limitations">Trade-offs and Limitations</h2>
<ul>
<li><strong><a href="https://en.wikipedia.org/wiki/Stereochemistry">Stereochemistry</a></strong>: Accuracy drops significantly on USPTO_STEREO (76-78% vs 88-90% on USPTO_MIT), indicating stereochemical prediction remains challenging.</li>
<li><strong>Resolution reactions</strong>: Near-zero accuracy on resolution reactions (28.6%), where reagent information is often missing from patent data.</li>
<li><strong>Unclassified reactions</strong>: Accuracy on &ldquo;unrecognized&rdquo; reaction classes is 46.3%, likely reflecting noisy or mistranscribed data.</li>
<li><strong>No atom mapping</strong>: The model provides no explicit atom mapping between reactants and products, which limits interpretability for understanding reaction mechanisms.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Primary benchmark</strong></td>
          <td style="text-align: left">USPTO_MIT</td>
          <td style="text-align: left">479K</td>
          <td style="text-align: left">Filtered by Jin et al., no stereochemistry</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>LEF subset</strong></td>
          <td style="text-align: left">USPTO_LEF</td>
          <td style="text-align: left">350K</td>
          <td style="text-align: left">Subset of MIT with linear electron flow only</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Stereo benchmark</strong></td>
          <td style="text-align: left">USPTO_STEREO</td>
          <td style="text-align: left">1.0M</td>
          <td style="text-align: left">Patent reactions through Sept 2016, includes stereochemistry</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Time-split test</strong></td>
          <td style="text-align: left">Pistachio_2017</td>
          <td style="text-align: left">15.4K</td>
          <td style="text-align: left">Non-public, reactions from 2017</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>: SMILES canonicalized with RDKit. Regex tokenization from Schwaller et al. (2018). Two input modes: &ldquo;separated&rdquo; (reactants &gt; reagents) and &ldquo;mixed&rdquo; (all molecules concatenated).</p>
<h3 id="model">Model</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Hyperparameter</th>
          <th style="text-align: left">Value</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Layers</strong></td>
          <td style="text-align: left">4</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Model dimension</strong></td>
          <td style="text-align: left">256</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Attention heads</strong></td>
          <td style="text-align: left">8</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Parameters</strong></td>
          <td style="text-align: left">~12M</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Label smoothing</strong></td>
          <td style="text-align: left">0.0</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Optimizer</strong></td>
          <td style="text-align: left">Adam</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Warm-up steps</strong></td>
          <td style="text-align: left">8000</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Batch size</strong></td>
          <td style="text-align: left">~4096 tokens</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Beam width</strong></td>
          <td style="text-align: left">5</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">Task</th>
          <th style="text-align: left">Key Result</th>
          <th style="text-align: left">Baseline</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Top-1 accuracy</strong></td>
          <td style="text-align: left">USPTO_MIT (sep)</td>
          <td style="text-align: left"><strong>90.4%</strong></td>
          <td style="text-align: left">85.6% (WLDN5)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Top-1 accuracy</strong></td>
          <td style="text-align: left">USPTO_MIT (mixed)</td>
          <td style="text-align: left"><strong>88.6%</strong></td>
          <td style="text-align: left">80.3% (S2S RNN)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>AUC-ROC</strong></td>
          <td style="text-align: left">Uncertainty calibration</td>
          <td style="text-align: left"><strong>0.89</strong></td>
          <td style="text-align: left">N/A</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Top-1 accuracy</strong></td>
          <td style="text-align: left">Regioselectivity</td>
          <td style="text-align: left"><strong>83%</strong></td>
          <td style="text-align: left">81% (RegioSQM)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Top-1 accuracy</strong></td>
          <td style="text-align: left">Human comparison</td>
          <td style="text-align: left"><strong>87.5%</strong></td>
          <td style="text-align: left">76.5% (best human)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Training: Single Nvidia P100 GPU, 48h for best single model</li>
<li>Inference: 20 min for 40K reactions on single P100</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Schwaller, P., Laino, T., Gaudin, T., Bolgar, P., Hunter, C. A., Bekas, C., &amp; Lee, A. A. (2019). Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction. <em>ACS Central Science</em>, 5(9), 1572-1583. <a href="https://doi.org/10.1021/acscentsci.9b00576">https://doi.org/10.1021/acscentsci.9b00576</a></p>
<p><strong>Publication</strong>: ACS Central Science 2019</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{schwallerMolecularTransformerModel2019,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Schwaller, Philippe and Laino, Teodoro and Gaudin, Th{\&#39;e}ophile and Bolgar, Peter and Hunter, Christopher A. and Bekas, Costas and Lee, Alpha A.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2019</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{ACS Central Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{9}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1572--1583}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1021/acscentsci.9b00576}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SELFormer: A SELFIES-Based Molecular Language Model</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/selformer/</link><pubDate>Mon, 16 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/selformer/</guid><description>A SELFIES-based RoBERTa model pretrained on 2M ChEMBL molecules for molecular property prediction on MoleculeNet benchmarks.</description><content:encoded><![CDATA[<h2 id="a-selfies-based-chemical-language-model">A SELFIES-Based Chemical Language Model</h2>
<p>This is primarily a <strong>Method</strong> paper ($\Psi_{\text{Method}}$) with a secondary <strong>Resource</strong> component ($\Psi_{\text{Resource}}$).</p>
<p>SELFormer applies the RoBERTa transformer architecture to <a href="/notes/chemistry/molecular-representations/notations/selfies-original-paper/">SELFIES</a> molecular string representations instead of the <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> notation used by prior chemical language models. The model is pretrained via masked language modeling (MLM) on 2M drug-like compounds from <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a> and fine-tuned for molecular property prediction tasks on <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> benchmarks. The authors release pretrained models, fine-tuning code, and datasets as open-source resources.</p>
<h2 id="why-selfies-over-smiles-for-pretraining">Why SELFIES Over SMILES for Pretraining?</h2>
<p>Existing chemical language models, including <a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a>, <a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a>, <a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a>, and <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MolFormer</a>, all use SMILES as their input representation. SMILES has well-documented validity and robustness issues: arbitrary perturbations to a SMILES string frequently produce syntactically invalid outputs. This means a pretrained model must spend capacity learning SMILES grammar rules rather than chemical semantics.</p>
<p><a href="/notes/chemistry/molecular-representations/notations/selfies-original-paper/">SELFIES</a> addresses this by construction: every possible SELFIES string decodes to a valid molecule. Despite this theoretical advantage and SELFIES&rsquo; growing adoption in generative chemistry, no prior work had systematically evaluated SELFIES as input for large-scale transformer pretraining. SELFormer fills this gap by providing a direct comparison between SELFIES-based and SMILES-based chemical language models on standard benchmarks.</p>
<h2 id="masked-language-modeling-on-guaranteed-valid-molecular-strings">Masked Language Modeling on Guaranteed-Valid Molecular Strings</h2>
<p>SELFormer uses byte-level Byte-Pair Encoding (BPE) to tokenize SELFIES strings, then pretrains a RoBERTa encoder using the standard MLM objective. 15% of input tokens are masked, and the model minimizes the cross-entropy loss over the masked positions:</p>
<p>$$
\mathcal{L}_{\text{MLM}} = -\frac{1}{|\mathcal{M}|} \sum_{i \in \mathcal{M}} \log P(x_i \mid x_{\setminus \mathcal{M}}; \theta)
$$</p>
<p>where $\mathcal{M}$ is the set of masked token indices, $x_i$ is the true token at position $i$, $x_{\setminus \mathcal{M}}$ is the corrupted input context, and $\theta$ are the model parameters.</p>
<p>The key insight is that because SELFIES guarantees 100% validity, every masked token prediction corresponds to a valid molecular fragment. The model never wastes capacity predicting invalid chemistry. For fine-tuning, a two-layer classification or regression head is added on top of the encoder&rsquo;s output embedding.</p>
<p>Two model sizes were trained. Notably, the larger SELFormer uses fewer attention heads (4) but more hidden layers (12) than SELFormer-Lite (12 heads, 8 layers). This counterintuitive configuration emerged from the authors&rsquo; hyperparameter search over ~100 models, where deeper architectures with fewer heads outperformed wider, shallower ones:</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>SELFormer-Lite</th>
          <th>SELFormer</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Attention Heads</td>
          <td>12</td>
          <td>4</td>
      </tr>
      <tr>
          <td>Hidden Layers</td>
          <td>8</td>
          <td>12</td>
      </tr>
      <tr>
          <td>Batch Size</td>
          <td>16</td>
          <td>16</td>
      </tr>
      <tr>
          <td>Learning Rate</td>
          <td>5e-5</td>
          <td>5e-5</td>
      </tr>
      <tr>
          <td>Weight Decay</td>
          <td>0.01</td>
          <td>0.01</td>
      </tr>
      <tr>
          <td>Pretraining Epochs</td>
          <td>100</td>
          <td>100</td>
      </tr>
      <tr>
          <td>Parameters</td>
          <td>58.3M</td>
          <td>86.7M</td>
      </tr>
  </tbody>
</table>
<h2 id="benchmarking-against-smiles-transformers-and-graph-models">Benchmarking Against SMILES Transformers and Graph Models</h2>
<p>SELFormer was pretrained on 2.08M drug-like compounds from ChEMBL v30 (converted from SMILES to SELFIES), then fine-tuned on nine MoleculeNet tasks. All evaluations use scaffold splitting via the Chemprop library.</p>
<p><strong>Classification tasks</strong> (ROC-AUC, scaffold split):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>BACE</th>
          <th>BBBP</th>
          <th>HIV</th>
          <th>Tox21</th>
          <th>SIDER</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SELFormer</td>
          <td>0.832</td>
          <td><strong>0.902</strong></td>
          <td>0.681</td>
          <td>0.653</td>
          <td><strong>0.745</strong></td>
      </tr>
      <tr>
          <td>ChemBERTa-2</td>
          <td>0.799</td>
          <td>0.728</td>
          <td>0.622</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>MolBERT</td>
          <td><strong>0.866</strong></td>
          <td>0.762</td>
          <td><strong>0.783</strong></td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>D-MPNN</td>
          <td>0.809</td>
          <td>0.710</td>
          <td>0.771</td>
          <td>0.759</td>
          <td>0.570</td>
      </tr>
      <tr>
          <td>MolCLR</td>
          <td><strong>0.890</strong></td>
          <td>0.736</td>
          <td><strong>0.806</strong></td>
          <td><strong>0.787</strong></td>
          <td>0.652</td>
      </tr>
      <tr>
          <td>GEM</td>
          <td>0.856</td>
          <td>0.724</td>
          <td><strong>0.806</strong></td>
          <td>0.781</td>
          <td>0.672</td>
      </tr>
      <tr>
          <td>KPGT</td>
          <td>0.855</td>
          <td><strong>0.908</strong></td>
          <td>-</td>
          <td><strong>0.848</strong></td>
          <td>0.649</td>
      </tr>
  </tbody>
</table>
<p><strong>Regression tasks</strong> (RMSE, scaffold split, lower is better):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>ESOL</th>
          <th>FreeSolv</th>
          <th>Lipophilicity</th>
          <th>PDBbind</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SELFormer</td>
          <td><strong>0.682</strong></td>
          <td>2.797</td>
          <td>0.735</td>
          <td>1.488</td>
      </tr>
      <tr>
          <td>ChemBERTa-2</td>
          <td>-</td>
          <td>-</td>
          <td>0.986</td>
          <td>-</td>
      </tr>
      <tr>
          <td>D-MPNN</td>
          <td>1.050</td>
          <td><strong>2.082</strong></td>
          <td><strong>0.683</strong></td>
          <td><strong>1.397</strong></td>
      </tr>
      <tr>
          <td>GEM</td>
          <td>0.798</td>
          <td><strong>1.877</strong></td>
          <td>0.660</td>
          <td>-</td>
      </tr>
      <tr>
          <td>KPGT</td>
          <td>0.803</td>
          <td>2.121</td>
          <td><strong>0.600</strong></td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p>The ablation study compared SELFormer vs. SELFormer-Lite across pretrained-only, 25-epoch, and 50-epoch fine-tuning configurations on randomly split datasets. SELFormer consistently outperformed SELFormer-Lite, confirming the benefit of the deeper (12-layer) architecture.</p>
<h2 id="strong-classification-performance-with-compact-pretraining">Strong Classification Performance with Compact Pretraining</h2>
<p>SELFormer&rsquo;s strongest results come on classification tasks where molecular substructure matters:</p>
<ul>
<li><strong>SIDER</strong>: Best overall ROC-AUC (0.745), outperforming the next best method (MolCLR at 0.652) by 9.3 percentage points. The authors attribute this to SELFIES&rsquo; ability to capture subtle structural differences relevant to drug side effects.</li>
<li><strong>BBBP</strong>: Second best (0.902), behind only KPGT (0.908). SELFormer scored 17.4 percentage points above ChemBERTa-2 (0.728) on this task.</li>
<li><strong>BACE/HIV vs. ChemBERTa-2</strong>: SELFormer outperformed ChemBERTa-2 by 3.3 points on BACE (0.832 vs 0.799), 17.4 on BBBP, and 5.9 on HIV (0.681 vs 0.622). Since both models use similar RoBERTa architectures, this comparison is suggestive of a SELFIES advantage, though differences in pretraining corpus (ChEMBL vs PubChem), corpus size, and training procedure confound a clean attribution to the input representation alone.</li>
<li><strong>ESOL regression</strong>: Best RMSE (0.682) vs GEM (0.798), a 14.5% relative improvement.</li>
</ul>
<p>Limitations are also apparent:</p>
<ul>
<li><strong>HIV and Tox21</strong>: SELFormer underperforms graph-based methods (MolCLR, GEM, KPGT) on these larger datasets. The authors attribute this to insufficient hyperparameter search given computational constraints.</li>
<li><strong>FreeSolv and Lipophilicity regression</strong>: D-MPNN and graph-based methods maintain an edge, suggesting that explicit 2D/3D structural inductive biases remain valuable for certain property types.</li>
<li><strong>Small pretraining corpus</strong>: At 2M molecules, SELFormer&rsquo;s corpus is orders of magnitude smaller than MolFormer&rsquo;s 1.1B. Despite this, SELFormer outperforms MolFormer on SIDER (0.745 vs 0.690), highlighting SELFIES&rsquo; representational advantage.</li>
<li><strong>Single-task ablation scope</strong>: Some architectural claims rest on limited task coverage, and broader benchmarking would strengthen the conclusions.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>ChEMBL v30</td>
          <td>2,084,725 compounds (2,084,472 after SELFIES conversion)</td>
          <td>Drug-like bioactive small molecules</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BACE</td>
          <td>1,513</td>
          <td><a href="https://en.wikipedia.org/wiki/Beta-secretase_1">Beta-secretase 1</a> inhibitor binding</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BBBP</td>
          <td>2,039</td>
          <td><a href="https://en.wikipedia.org/wiki/Blood%E2%80%93brain_barrier">Blood-brain barrier</a> permeability</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>HIV</td>
          <td>41,127</td>
          <td>HIV replication inhibition</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>SIDER</td>
          <td>1,427</td>
          <td>Drug side effects (27 classes)</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>Tox21</td>
          <td>7,831</td>
          <td>Toxicity (12 targets)</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>ESOL</td>
          <td>1,128</td>
          <td>Aqueous solubility</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>FreeSolv</td>
          <td>642</td>
          <td>Hydration free energy</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>Lipophilicity</td>
          <td>4,200</td>
          <td>Octanol/water distribution coefficient</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>PDBbind</td>
          <td>11,908</td>
          <td>Binding affinity</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Pretraining objective</strong>: Masked language modeling (MLM), 15% token masking</li>
<li><strong>Tokenization</strong>: Byte-level Byte-Pair Encoding (BPE) on SELFIES strings</li>
<li><strong>SMILES to SELFIES conversion</strong>: SELFIES API with Pandaral.lel for parallelization</li>
<li><strong>Splitting</strong>: Scaffold splitting via Chemprop library (80/10/10 train/validation/test)</li>
<li><strong>Fine-tuning</strong>: Two-layer classification/regression head on encoder output; up to 200 epochs with hyperparameter search</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: RoBERTa (HuggingFace Transformers)</li>
<li><strong>SELFormer</strong>: 12 hidden layers, 4 attention heads, 86.7M parameters</li>
<li><strong>SELFormer-Lite</strong>: 8 hidden layers, 12 attention heads, 58.3M parameters</li>
<li><strong>Hyperparameter search</strong>: Sequential search over ~100 configurations on 100K molecule subset</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task Type</th>
          <th>Details</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ROC-AUC</td>
          <td>Classification</td>
          <td>Area under receiver operating characteristic curve</td>
      </tr>
      <tr>
          <td>PRC-AUC</td>
          <td>Classification</td>
          <td>Area under precision-recall curve (reported for random splits)</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>Regression</td>
          <td>Root mean squared error</td>
      </tr>
  </tbody>
</table>
<p>Results reported on scaffold split and random split datasets.</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: 2x NVIDIA A5000 GPUs</li>
<li><strong>Hyperparameter optimization time</strong>: ~11 days</li>
<li><strong>Full pretraining</strong>: 100 epochs on 2.08M molecules</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/HUBioDataLab/SELFormer">SELFormer GitHub</a></td>
          <td>Code</td>
          <td>GPL-3.0</td>
          <td>Pretraining, fine-tuning, and evaluation scripts</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/HUBioDataLab/SELFormer">SELFormer on HuggingFace</a></td>
          <td>Model</td>
          <td>GPL-3.0</td>
          <td>Pretrained SELFormer weights</td>
      </tr>
      <tr>
          <td><a href="https://www.ebi.ac.uk/chembl/">ChEMBL v30</a></td>
          <td>Dataset</td>
          <td>CC BY-SA 3.0</td>
          <td>Source pretraining data</td>
      </tr>
      <tr>
          <td><a href="https://moleculenet.org/">MoleculeNet</a></td>
          <td>Benchmark</td>
          <td>Unknown</td>
          <td>Downstream evaluation tasks</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Yüksel, A., Ulusoy, E., Ünlü, A., &amp; Doğan, T. (2023). SELFormer: Molecular Representation Learning via SELFIES Language Models. <em>Machine Learning: Science and Technology</em>, 4(2), 025035. <a href="https://doi.org/10.1088/2632-2153/acdb30">https://doi.org/10.1088/2632-2153/acdb30</a></p>
<p><strong>Publication</strong>: Machine Learning: Science and Technology 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/HUBioDataLab/SELFormer">GitHub Repository (SELFormer)</a></li>
<li><a href="https://huggingface.co/HUBioDataLab/SELFormer">HuggingFace Model Hub (SELFormer)</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{yuksel2023selformer,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{{SELFormer}: Molecular Representation Learning via {SELFIES} Language Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Y{\&#34;u}ksel, Atakan and Ulusoy, Erva and {\&#34;U}nl{\&#34;u}, Atabey and Do{\u{g}}an, Tunca}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Machine Learning: Science and Technology}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{025035}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{IOP Publishing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1088/2632-2153/acdb30}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MoLFormer: Large-Scale Chemical Language Representations</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/molformer/</link><pubDate>Mon, 16 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/molformer/</guid><description>A linear-attention transformer pretrained on 1.1B SMILES from PubChem and ZINC for molecular property prediction across MoleculeNet benchmarks.</description><content:encoded><![CDATA[<h2 id="a-billion-scale-chemical-language-model">A Billion-Scale Chemical Language Model</h2>
<p>This is primarily a <strong>Method</strong> paper ($\Psi_{\text{Method}}$).</p>
<p>MoLFormer is a transformer encoder pretrained via masked language modeling on 1.1 billion <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings from <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> and <a href="https://en.wikipedia.org/wiki/ZINC_database">ZINC</a>. The key architectural choices are linear attention (for $O(N)$ complexity instead of $O(N^2)$) and rotary positional embeddings (RoPE). The resulting model, MoLFormer-XL, produces molecular embeddings that outperform or match GNN baselines across a wide range of <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> classification and regression tasks, including quantum-chemical property prediction from SMILES alone.</p>
<h2 id="bridging-the-gap-between-molecular-languages-and-graph-neural-networks">Bridging the Gap Between Molecular Languages and Graph Neural Networks</h2>
<p>Prior chemical language models like <a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a> were pretrained on relatively small datasets (10M-77M molecules) and generally underperformed GNNs on molecular property prediction. The core question: does a transformer trained on a sufficiently large SMILES corpus learn enough chemical structure to compete with graph-based methods that have explicit topological inductive biases?</p>
<p>Two specific challenges motivated this work:</p>
<ul>
<li><strong>Scale</strong>: The chemical space spans $10^{60}$ to $10^{100}$ plausible molecules, yet labeled property data is scarce. Self-supervised pretraining on the ~1.1B unlabeled molecules available in public databases could provide a general-purpose representation.</li>
<li><strong>Efficiency</strong>: Standard transformer attention is $O(N^2)$ in sequence length, making billion-scale pretraining impractical without architectural modifications.</li>
</ul>
<h2 id="linear-attention-with-rotary-positional-embeddings">Linear Attention with Rotary Positional Embeddings</h2>
<p>MoLFormer&rsquo;s two key architectural choices are its attention mechanism and positional encoding scheme.</p>
<p><strong>Standard attention</strong> computes:</p>
<p>$$
\text{Attention}_m(Q, K, V) = \frac{\sum_{n=1}^{N} \exp(\langle q_m, k_n \rangle) v_n}{\sum_{n=1}^{N} \exp(\langle q_m, k_n \rangle)}
$$</p>
<p>MoLFormer replaces this with <strong>linear attention</strong> using a generalized feature map $\varphi$, combined with <strong>rotary positional embeddings</strong> $R_m$ applied before the feature map:</p>
<p>$$
\text{Attention}_m(Q, K, V) = \frac{\sum_{n=1}^{N} \langle \varphi(R_m q_m), \varphi(R_n k_n) \rangle v_n}{\sum_{n=1}^{N} \langle \varphi(R_m q_m), \varphi(R_n k_n) \rangle}
$$</p>
<p>This differs from the original RoFormer formulation, which applies the rotation after the feature map. The authors found that rotating the raw queries and keys before projection led to faster convergence and lower validation loss. The combination of linear attention and adaptive sequence-length bucketing reduces GPU requirements from ~1000 to 16 for training on the full 1.1B corpus.</p>
<p>The model uses masked language modeling (15% token masking, following BERT conventions) with a vocabulary of 2,362 SMILES tokens. Sequence length is capped at 202 tokens, covering 99.4% of all molecules.</p>
<h2 id="broad-moleculenet-benchmarking-with-scaling-ablations">Broad MoleculeNet Benchmarking with Scaling Ablations</h2>
<p>MoLFormer-XL was evaluated on 11 MoleculeNet tasks against supervised GNNs, self-supervised GNNs, and prior language models.</p>
<p><strong>Classification tasks</strong> (ROC-AUC, scaffold split; values reported as percentages in the original paper, converted to proportions here for consistency):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>BBBP</th>
          <th>Tox21</th>
          <th>ClinTox</th>
          <th>HIV</th>
          <th>BACE</th>
          <th>SIDER</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MoLFormer-XL</td>
          <td><strong>0.937</strong></td>
          <td><strong>0.847</strong></td>
          <td><strong>0.948</strong></td>
          <td>0.822</td>
          <td>0.882</td>
          <td><strong>0.690</strong></td>
      </tr>
      <tr>
          <td>N-Gram</td>
          <td>0.912</td>
          <td>0.769</td>
          <td>0.855</td>
          <td>0.830</td>
          <td>0.876</td>
          <td>0.632</td>
      </tr>
      <tr>
          <td>MolCLR</td>
          <td>0.736</td>
          <td>0.798</td>
          <td>0.932</td>
          <td>0.806</td>
          <td><strong>0.890</strong></td>
          <td>0.680</td>
      </tr>
      <tr>
          <td>GEM</td>
          <td>0.724</td>
          <td>0.781</td>
          <td>0.901</td>
          <td>0.806</td>
          <td>0.856</td>
          <td>0.672</td>
      </tr>
      <tr>
          <td>Hu et al.</td>
          <td>0.708</td>
          <td>0.787</td>
          <td>0.789</td>
          <td>0.802</td>
          <td>0.859</td>
          <td>0.652</td>
      </tr>
      <tr>
          <td>GeomGCL</td>
          <td>-</td>
          <td>0.850</td>
          <td>0.919</td>
          <td>-</td>
          <td>-</td>
          <td>0.648</td>
      </tr>
      <tr>
          <td>ChemBERTa</td>
          <td>0.643</td>
          <td>-</td>
          <td>0.906</td>
          <td>0.622</td>
          <td>-</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p><strong>Regression tasks</strong> (RMSE for ESOL/FreeSolv/Lipophilicity, avg MAE for QM9/QM8):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>QM9</th>
          <th>QM8</th>
          <th>ESOL</th>
          <th>FreeSolv</th>
          <th>Lipophilicity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MoLFormer-XL</td>
          <td><strong>1.5894</strong></td>
          <td><strong>0.0102</strong></td>
          <td><strong>0.2787</strong></td>
          <td><strong>0.2308</strong></td>
          <td><strong>0.5289</strong></td>
      </tr>
      <tr>
          <td>A-FP</td>
          <td>2.6355</td>
          <td>0.0282</td>
          <td>0.5030</td>
          <td>0.736</td>
          <td>0.578</td>
      </tr>
      <tr>
          <td>MPNN</td>
          <td>3.1898</td>
          <td>0.0143</td>
          <td>0.58</td>
          <td>1.150</td>
          <td>0.7190</td>
      </tr>
      <tr>
          <td>GC</td>
          <td>4.3536</td>
          <td>0.0148</td>
          <td>0.970</td>
          <td>1.40</td>
          <td>0.655</td>
      </tr>
  </tbody>
</table>
<p>MoLFormer-XL also outperforms geometry-aware GNNs (DimeNet, GeomGCL, GEM) on ESOL (0.279 vs 0.575), FreeSolv (0.231 vs 0.866), and Lipophilicity (0.529 vs 0.541).</p>
<p><strong>Key ablation findings</strong>:</p>
<ul>
<li><strong>Data scale matters</strong>: Performance improves monotonically from 10% subsets through the full 1.1B corpus. Training on 100% ZINC alone performed worst, likely due to its smaller vocabulary and less diverse molecule lengths.</li>
<li><strong>Model depth matters</strong>: MoLFormer-Base (6 layers) underperforms MoLFormer-XL (12 layers) on most tasks.</li>
<li><strong>Fine-tuning &raquo; frozen</strong>: Fine-tuning the full encoder consistently outperforms using frozen embeddings with a downstream classifier.</li>
<li><strong>Rotary &gt; absolute at scale</strong>: Rotary embeddings underperform absolute embeddings on smaller pretraining sets but overtake them once the corpus exceeds 1B molecules.</li>
</ul>
<h2 id="smiles-transformers-learn-molecular-geometry">SMILES Transformers Learn Molecular Geometry</h2>
<p>The most striking finding is that MoLFormer&rsquo;s attention patterns correlate with 3D interatomic distances, despite training only on 1D SMILES strings.</p>
<p>Using <a href="/notes/chemistry/datasets/qm9/">QM9</a> molecules with known 3D geometries, the authors computed cosine similarity between attention maps and spatial distance matrices across three distance categories:</p>
<table>
  <thead>
      <tr>
          <th>Distance Category</th>
          <th>Range</th>
          <th>Linear Attention (Rotary)</th>
          <th>Full Attention (Rotary)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Short</td>
          <td>$\leq$ 2 Å</td>
          <td>0.594-0.602</td>
          <td>0.598-0.615</td>
      </tr>
      <tr>
          <td>Medium</td>
          <td>2-4 Å</td>
          <td>0.724-0.730</td>
          <td>0.716-0.727</td>
      </tr>
      <tr>
          <td>Long</td>
          <td>4-10 Å</td>
          <td>0.209-0.211</td>
          <td>0.204-0.210</td>
      </tr>
  </tbody>
</table>
<p>The strong correlation in the short and medium categories indicates the model captures covalent bond connectivity and near-neighbor spatial relationships. Linear attention shows marginally higher cosine similarity than full attention on medium-range distances (0.724-0.730 vs 0.716-0.727), though the differences are small.</p>
<p>MoLFormer-XL embeddings also correlate more strongly with molecular fingerprint similarity (0.64 vs 0.48 for ChemBERTa) and maximum common subgraph size (-0.60 vs -0.44), confirming that the representations encode structural information.</p>
<p><strong>Limitations</strong>:</p>
<ul>
<li><strong>Quantum-chemical energies</strong>: SchNet and DimeNet (which encode explicit 3D geometry) outperform MoLFormer-XL on QM9 atomization energy tasks, with DimeNet achieving roughly 10x lower MAE on U0_atom (0.008 vs 0.083 eV). 3D information remains important for these properties.</li>
<li><strong>Sequence length cap</strong>: The 202-token limit excludes 0.6% of molecules, potentially limiting applicability to larger structures.</li>
<li><strong>SMILES canonicalization</strong>: The model depends on <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a> canonical SMILES; sensitivity to non-canonical forms is not evaluated.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>PubChem</td>
          <td>111M molecules</td>
          <td>Canonical SMILES via RDKit</td>
      </tr>
      <tr>
          <td>Pretraining</td>
          <td>ZINC</td>
          <td>~1B molecules</td>
          <td>Canonical SMILES via RDKit</td>
      </tr>
      <tr>
          <td>Pretraining (combined)</td>
          <td>PubChem + ZINC</td>
          <td>~1.1B molecules</td>
          <td>MoLFormer-XL training set</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BBBP, Tox21, ClinTox, HIV, BACE, SIDER</td>
          <td>1,427-41,127</td>
          <td>MoleculeNet scaffold splits</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>QM9, QM8, ESOL, FreeSolv, Lipophilicity</td>
          <td>642-133,885</td>
          <td>MoleculeNet random splits (QM9/QM8), scaffold (others)</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Pretraining objective</strong>: Masked language modeling (15% selection: 80% masked, 10% random, 10% unchanged)</li>
<li><strong>Tokenization</strong>: SMILES tokenizer from Schwaller et al., vocabulary of 2,362 tokens</li>
<li><strong>Sequence length</strong>: 1-202 tokens (99.4% coverage)</li>
<li><strong>Optimizer</strong>: Fused LAMB (via APEX), chosen for stability with large batch sizes and no need for learning rate warm-up</li>
<li><strong>Adaptive bucketing</strong>: Sequences grouped by length into buckets to minimize padding waste</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Transformer encoder with linear attention and rotary positional embeddings</li>
<li><strong>MoLFormer-XL</strong>: 12 layers, 12 attention heads, hidden size 768</li>
<li><strong>MoLFormer-Base</strong>: 6 layers (ablation only)</li>
<li><strong>Feature map size</strong>: 32 (generalized feature map for linear attention)</li>
<li><strong>Frozen head</strong>: Fully connected model with hyperparameter sweep (learning rate, batch size, hidden dim, number of layers)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task Type</th>
          <th>Details</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ROC-AUC</td>
          <td>Classification</td>
          <td>Scaffold splits per MoleculeNet</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>Regression (ESOL, FreeSolv, Lipophilicity)</td>
          <td>Scaffold splits</td>
      </tr>
      <tr>
          <td>Avg MAE</td>
          <td>Regression (QM9, QM8)</td>
          <td>Random splits per MoleculeNet</td>
      </tr>
  </tbody>
</table>
<p>QM9 results also reported with 5-fold cross-validation for robustness.</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: GPU cluster with nodes containing either 8 NVIDIA Tesla V100 (32GB) or 8 Ampere A100 (40GB) GPUs connected via NVLink and InfiniBand</li>
<li><strong>GPU reduction</strong>: Linear attention + bucketing reduced GPU requirements from ~1000 to 16</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/IBM/molformer">IBM/molformer</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Pretraining, fine-tuning, and attention visualization</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/ibm/MoLFormer-XL-both-10pct">MoLFormer-XL (HuggingFace)</a></td>
          <td>Model</td>
          <td>Apache-2.0</td>
          <td>Pretrained weights (46.8M parameters)</td>
      </tr>
      <tr>
          <td><a href="https://pubchem.ncbi.nlm.nih.gov/">PubChem</a></td>
          <td>Dataset</td>
          <td>Public domain</td>
          <td>111M molecules</td>
      </tr>
      <tr>
          <td><a href="https://zinc.docking.org/">ZINC</a></td>
          <td>Dataset</td>
          <td>See ZINC terms</td>
          <td>~1B molecules</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ross, J., Belgodere, B., Chenthamarakshan, V., Padhi, I., Mroueh, Y., &amp; Das, P. (2022). Large-Scale Chemical Language Representations Capture Molecular Structure and Properties. <em>Nature Machine Intelligence</em>, 4, 1256-1264. <a href="https://doi.org/10.1038/s42256-022-00580-7">https://doi.org/10.1038/s42256-022-00580-7</a></p>
<p><strong>Publication</strong>: Nature Machine Intelligence 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/IBM/molformer">GitHub Repository (MoLFormer)</a></li>
<li><a href="https://huggingface.co/ibm/MoLFormer-XL-both-10pct">HuggingFace Models</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ross2022molformer,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Large-Scale Chemical Language Representations Capture Molecular Structure and Properties}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ross, Jerret and Belgodere, Brian and Chenthamarakshan, Vijil and Padhi, Inkit and Mroueh, Youssef and Das, Payel}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Machine Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1256--1264}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s42256-022-00580-7}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Exposing Limitations of Molecular ML with Activity Cliffs</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/activity-cliffs-benchmark/</link><pubDate>Mon, 16 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/activity-cliffs-benchmark/</guid><description>A benchmark of 24 ML methods on activity cliff compounds across 30 drug targets, showing descriptor-based models outperform deep learning.</description><content:encoded><![CDATA[<h2 id="a-benchmark-for-activity-cliff-prediction">A Benchmark for Activity Cliff Prediction</h2>
<p>This is a <strong>Systematization</strong> paper ($\Psi_{\text{Systematization}}$) with a significant <strong>Resource</strong> component ($\Psi_{\text{Resource}}$).</p>
<p>The paper systematically benchmarks 24 machine learning and deep learning approaches on their ability to predict bioactivity for activity cliff compounds: pairs of structurally similar molecules that exhibit large differences in potency. These cases violate the similarity principle (similar structure implies similar activity) and represent a practical failure mode for <a href="/notes/chemistry/molecular-design/property-prediction/">molecular property prediction</a> in drug discovery. The authors release MoleculeACE, an open-source benchmarking platform for evaluating ML models on activity cliffs.</p>
<h2 id="activity-cliffs-as-a-blind-spot-in-molecular-ml">Activity Cliffs as a Blind Spot in Molecular ML</h2>
<p>The <a href="https://en.wikipedia.org/wiki/Chemical_similarity">similarity principle</a> underpins most molecular ML: structurally similar compounds should have similar properties. Activity cliffs are the exceptions, where small structural changes cause large potency shifts (e.g., a single substituent change causing a 10x difference in $K_i$).</p>
<p>Despite their importance for <a href="https://en.wikipedia.org/wiki/Hit_to_lead">hit-to-lead optimization</a>, activity cliffs have received limited attention in ML benchmarking. Standard metrics like RMSE computed over entire test sets can mask poor predictions on cliff compounds. A model might achieve low overall error while systematically mispredicting these edge cases, which are precisely the molecules that matter most for medicinal chemistry applications.</p>
<p>The authors identify 7-52% of compounds as activity cliff molecules across their 30 target datasets, showing this is not a rare phenomenon.</p>
<h2 id="defining-and-detecting-activity-cliffs">Defining and Detecting Activity Cliffs</h2>
<p>The authors use three complementary similarity metrics to identify activity cliffs:</p>
<ol>
<li><strong>Substructure similarity</strong>: <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto coefficient</a> on extended connectivity fingerprints (ECFPs), capturing shared radial substructures</li>
<li><strong>Scaffold similarity</strong>: Tanimoto coefficient on ECFPs computed from molecular graph frameworks, detecting core/decoration differences</li>
<li><strong>SMILES similarity</strong>: <a href="https://en.wikipedia.org/wiki/Levenshtein_distance">Levenshtein distance</a> on canonical <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings, capturing character-level insertions, deletions, and translocations</li>
</ol>
<p>Pairs with $\geq 90%$ similarity on <strong>any one</strong> of the three metrics and $&gt; 10\times$ difference in bioactivity ($K_i$ or $\text{EC}_{50}$) are classified as activity cliff pairs. This union-based approach (rather than requiring agreement across all metrics) captures different types of structural relationships relevant to medicinal chemistry.</p>
<h2 id="24-methods-across-30-drug-targets">24 Methods Across 30 Drug Targets</h2>
<p>The benchmark evaluates 16 traditional ML configurations (4 algorithms $\times$ 4 descriptor types) and 8 deep learning approaches across 30 curated <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a> v29 datasets (48,707 total molecules).</p>
<p><strong>Traditional ML algorithms</strong>: KNN, RF, GBM, SVM, each combined with ECFPs, MACCS keys, WHIM descriptors, or physicochemical properties.</p>
<p><strong>Deep learning methods</strong>: MPNN, GCN, GAT, Attentive FP (graph-based), plus LSTM, CNN, Transformer/<a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a> (SMILES-based), and an MLP on ECFPs.</p>
<p>Performance is measured with both standard RMSE and a dedicated $\text{RMSE}_{\text{cliff}}$ computed only on activity cliff compounds in the test set:</p>
<p>$$
\text{RMSE}_{\text{cliff}} = \sqrt{\frac{\sum_{j=1}^{n_c} (\hat{y}_j - y_j)^2}{n_c}}
$$</p>
<p>Key results:</p>
<ul>
<li><strong>Molecular descriptors matter more than algorithms</strong>: The choice of descriptor (ECFPs vs. MACCS vs. WHIM vs. physicochemical) had a larger impact on $\text{RMSE}_{\text{cliff}}$ than the choice of ML algorithm ($p &lt; 0.05$, <a href="https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U_test">Wilcoxon rank-sum test</a> with <a href="https://en.wikipedia.org/wiki/False_discovery_rate">Benjamini-Hochberg correction</a>).</li>
<li><strong>SVM + ECFPs wins on average</strong>: The best overall method for activity cliff prediction, though the difference from RF + ECFPs or GBM + ECFPs was not statistically significant.</li>
<li><strong>Deep learning underperforms</strong>: All graph and SMILES-based deep learning methods performed worse than a simple MLP on ECFPs. Among deep learning, LSTM with transfer learning (pretrained on 36K molecules) was the best, outperforming the ChemBERTa transformer pretrained on 10M compounds.</li>
<li><strong>Large case-by-case variation</strong>: $\text{RMSE}_{\text{cliff}}$ ranged from 0.62 to 1.60 log units across datasets, with no method consistently best. Deep learning methods showed the highest variance across targets.</li>
</ul>
<h2 id="simple-descriptors-beat-complex-architectures-on-cliffs">Simple Descriptors Beat Complex Architectures on Cliffs</h2>
<p>The core finding is that activity cliffs expose a gap in learned molecular representations. Despite graph neural networks and transformers being able to learn directly from molecular structure, they fail to capture the subtle structural differences that drive activity cliffs.</p>
<p>Key observations:</p>
<ul>
<li><strong>RMSE and $\text{RMSE}_{\text{cliff}}$ correlate ($r = 0.81$ on average)</strong>, so optimizing overall error usually helps with cliffs too. But this correlation breaks down for some targets (e.g., CLK4), where methods with similar RMSE can have very different $\text{RMSE}_{\text{cliff}}$.</li>
<li><strong>Training set size matters for the RMSE/$\text{RMSE}_{\text{cliff}}$ correlation</strong>: Datasets with $&gt; 1000$ training molecules show $r &gt; 0.80$ between the two metrics. In low-data regimes, the correlation weakens, making dedicated cliff evaluation more important.</li>
<li><strong>No relationship between % cliff compounds and model performance</strong>, and no target-family-specific effects were found.</li>
<li><strong>Transfer learning helped SMILES models (LSTM) but not graph models</strong>: Self-supervised pretraining strategies (context prediction, infomax, edge prediction, masking) did not improve GNN performance, consistent with findings from other studies.</li>
</ul>
<p>The MoleculeACE platform provides standardized data curation, activity cliff detection, and cliff-specific evaluation, enabling researchers to assess new methods against this benchmark.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Source</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Benchmarking</td>
          <td>ChEMBL v29</td>
          <td>48,707 molecules (35,632 unique) across 30 targets</td>
          <td>Curated for duplicates, salts, outliers</td>
      </tr>
      <tr>
          <td>Smallest dataset</td>
          <td>JAK1</td>
          <td>615 molecules</td>
          <td>7% activity cliffs</td>
      </tr>
      <tr>
          <td>Largest dataset</td>
          <td>DRD3</td>
          <td>3,657 molecules</td>
          <td>39% activity cliffs</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Activity cliff detection</strong>: Pairwise similarity $\geq 0.9$ (Tanimoto on ECFPs, scaffold ECFPs, or Levenshtein on SMILES) with $&gt; 10\times$ potency difference</li>
<li><strong>Splitting</strong>: <a href="https://en.wikipedia.org/wiki/Spectral_clustering">Spectral clustering</a> on ECFPs (5 clusters), 80/20 stratified split preserving cliff proportion</li>
<li><strong>Hyperparameter optimization</strong>: <a href="https://en.wikipedia.org/wiki/Bayesian_optimization">Bayesian optimization</a> with Gaussian process, max 50 combinations, 5-fold cross-validation</li>
<li><strong>SMILES augmentation</strong>: 10-fold for all SMILES-based methods</li>
<li><strong>Transfer learning</strong>: LSTM pretrained on 36,281 merged training molecules (next-character prediction); ChemBERTa pretrained on 10M <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> compounds</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Traditional ML</strong>: KNN, RF, GBM, SVM (scikit-learn v1.0.2)</li>
<li><strong>Descriptors</strong>: ECFPs (1024-bit, radius 2), MACCS keys (166-bit), WHIM (114 descriptors), physicochemical (11 properties)</li>
<li><strong>GNNs</strong>: MPNN, GCN, GAT, AFP (PyTorch Geometric v2.0.4), with graph multiset transformer pooling</li>
<li><strong>SMILES models</strong>: LSTM (4 layers, 5.8M params), 1D CNN, ChemBERTa transformer</li>
<li><strong>Total models trained</strong>: 720 (24 methods $\times$ 30 targets)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Scope</th>
          <th>Details</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RMSE</td>
          <td>All test molecules</td>
          <td>Standard root-mean-square error on $\text{pK}_i$ / $\text{pEC}_{50}$</td>
      </tr>
      <tr>
          <td>$\text{RMSE}_{\text{cliff}}$</td>
          <td>Activity cliff compounds only</td>
          <td>RMSE restricted to cliff molecules in test set</td>
      </tr>
  </tbody>
</table>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/molML/MoleculeACE">MoleculeACE</a></td>
          <td>Code + Data</td>
          <td>MIT</td>
          <td>Benchmark platform with all 30 curated datasets</td>
      </tr>
      <tr>
          <td><a href="https://github.com/molML/MoleculeACE/tree/main/MoleculeACE/Data/benchmark_data">Curated datasets</a></td>
          <td>Data</td>
          <td>MIT</td>
          <td>Processed ChEMBL bioactivity data</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: van Tilborg, D., Alenicheva, A., &amp; Grisoni, F. (2022). Exposing the Limitations of Molecular Machine Learning with Activity Cliffs. <em>Journal of Chemical Information and Modeling</em>, 62(23), 5938-5951. <a href="https://doi.org/10.1021/acs.jcim.2c01073">https://doi.org/10.1021/acs.jcim.2c01073</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/molML/MoleculeACE">MoleculeACE GitHub Repository</a></li>
<li><a href="https://chemrxiv.org/engage/chemrxiv/article-details/630cc44058843b8403a19810">ChemRxiv Preprint</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{vantilborg2022activity,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Exposing the Limitations of Molecular Machine Learning with Activity Cliffs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{van Tilborg, Derek and Alenicheva, Alisa and Grisoni, Francesca}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{62}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{23}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{5938--5951}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.2c01073}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>D3PM: Discrete Denoising Diffusion Probabilistic Models</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/discrete-diffusion-models/</link><pubDate>Sun, 15 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/discrete-diffusion-models/</guid><description>D3PMs extend diffusion models to discrete data with structured transition matrices, connecting diffusion to masked language models.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper. It extends denoising diffusion probabilistic models (DDPMs) from continuous to discrete state-spaces by introducing structured Markov transition matrices for the corruption process. The paper unifies several corruption strategies, draws a formal connection between absorbing-state diffusion and masked language models, and demonstrates competitive results on both image and text generation.</p>
<h2 id="diffusion-beyond-continuous-spaces">Diffusion Beyond Continuous Spaces</h2>
<p>Standard DDPMs operate in continuous state-spaces (e.g., pixel values treated as real numbers) and use Gaussian noise for corruption. Many important data types are inherently discrete: text (tokens from a vocabulary), quantized images (discrete pixel values), molecular structures, and segmentation maps. Prior work by Hoogeboom et al. extended binary diffusion to multinomial diffusion with uniform transition probabilities, but this limits the structure of the corruption process. D3PMs generalize this by allowing arbitrary transition matrices that encode domain-specific inductive biases.</p>
<h2 id="core-innovation-structured-transition-matrices">Core Innovation: Structured Transition Matrices</h2>
<p>D3PMs define a forward corruption process over discrete variables $\mathbf{x} \in {1, \ldots, K}^D$ using transition matrices $\mathbf{Q}_t \in \mathbb{R}^{K \times K}$:</p>
<p>$$q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \text{Cat}(\mathbf{x}_t; \mathbf{p} = \mathbf{x}_{t-1} \mathbf{Q}_t)$$</p>
<p>where $\mathbf{x}_{t-1}$ is a one-hot row vector. The cumulative transition after $t$ steps is $\overline{\mathbf{Q}}_t = \mathbf{Q}_1 \mathbf{Q}_2 \cdots \mathbf{Q}_t$, giving:</p>
<p>$$q(\mathbf{x}_t | \mathbf{x}_0) = \text{Cat}(\mathbf{x}_t; \mathbf{p} = \mathbf{x}_0 \overline{\mathbf{Q}}_t)$$</p>
<p>The paper explores several transition matrix designs:</p>
<p><strong>Uniform diffusion:</strong> $[\mathbf{Q}_t]_{ij} = (1 - \beta_t) \mathbf{1}_{i=j} + \beta_t / K$. Transitions with equal probability to any state. Stationary distribution is uniform.</p>
<p><strong>Absorbing state:</strong> In absorbing-state diffusion, each non-mask token transitions to the mask state with probability $\beta_t$ per step, while tokens already at the mask state remain there:</p>
<p>$[\mathbf{Q}_t]_{ij} = (1-\beta_t)\mathbf{1}_{i=j\neq m} + \beta_t \mathbf{1}_{j=m} + \mathbf{1}_{i=j=m}$. Each token transitions to a designated absorbing state $m$ (e.g., [MASK] for text, gray pixel for images) with probability $\beta_t$. This establishes a direct connection to masked language models like BERT.</p>
<p><strong>Discretized Gaussian:</strong> Transition probabilities decay as a function of the distance $|i-j|$ between states, mimicking Gaussian diffusion on ordinal data like pixel values.</p>
<p><strong>Embedding-based nearest neighbor:</strong> For text, transitions are weighted by proximity in a pretrained word embedding space, so corruption preferentially swaps words with semantically similar ones.</p>
<p><strong>Training objective.</strong> The reverse process $p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t)$ is parameterized by predicting $\tilde{p}_\theta(\tilde{\mathbf{x}}_0 | \mathbf{x}_t)$ and computing the posterior:</p>
<p>$$p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t) \propto \sum_{\tilde{\mathbf{x}}_0} q(\mathbf{x}_{t-1} | \mathbf{x}_t, \tilde{\mathbf{x}}_0) , \tilde{p}_\theta(\tilde{\mathbf{x}}_0 | \mathbf{x}_t)$$</p>
<p>The loss combines the variational lower bound (VLB) with an auxiliary cross-entropy loss $L_\lambda$:</p>
<p>$$L = L_{\text{VLB}} + \lambda , L_{\text{CE}}$$</p>
<p>where $L_{\text{CE}}$ is a reweighted cross-entropy loss on the $\mathbf{x}_0$ prediction that stabilizes training and improves sample quality. The VLB decomposes into per-timestep KL divergences between the true and predicted reverse transitions.</p>
<h2 id="experiments-and-results">Experiments and Results</h2>
<p><strong>Image generation (CIFAR-10):</strong></p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Loss</th>
          <th>IS</th>
          <th>FID</th>
          <th>NLL (bpd)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>D3PM uniform</td>
          <td>$L_{\text{VLB}}$</td>
          <td>5.99</td>
          <td>51.27</td>
          <td>5.08</td>
      </tr>
      <tr>
          <td>D3PM absorbing</td>
          <td>$L_\lambda$ ($\lambda{=}0.001$)</td>
          <td>6.78</td>
          <td>30.97</td>
          <td>4.40</td>
      </tr>
      <tr>
          <td>D3PM Gauss</td>
          <td>$L_{\text{VLB}}$</td>
          <td>7.75</td>
          <td>15.30</td>
          <td>3.97</td>
      </tr>
      <tr>
          <td>D3PM Gauss</td>
          <td>$L_\lambda$ ($\lambda{=}0.001$)</td>
          <td>8.54</td>
          <td>8.34</td>
          <td>3.98</td>
      </tr>
      <tr>
          <td>D3PM Gauss + logistic</td>
          <td>$L_\lambda$ ($\lambda{=}0.001$)</td>
          <td>8.56</td>
          <td>7.34</td>
          <td>3.44</td>
      </tr>
      <tr>
          <td>DDPM $L_{\text{simple}}$ (continuous)</td>
          <td>&ndash;</td>
          <td>9.46</td>
          <td>3.17</td>
          <td>3.75</td>
      </tr>
  </tbody>
</table>
<p>The best discrete D3PM variant is D3PM Gauss + logistic, which achieves FID 7.34 and NLL 3.44 bpd using the combined $L_\lambda$ loss with a truncated logistic parameterization. The truncated logistic parameterization replaces the standard softmax output with a discretized logistic distribution over pixel values, assigning probability mass to each discrete bin based on a continuous logistic CDF. This provides a smoother output distribution that better captures the ordinal structure of pixel intensities. This variant exceeds the continuous DDPM in log-likelihood (3.44 vs. 3.75 bpd) while approaching its sample quality (FID 7.34 vs. 3.17).</p>
<p><strong>Text generation (text8, character-level, 1000 steps):</strong></p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>bpc</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>D3PM absorbing ($L_\lambda$)</td>
          <td>1.45</td>
      </tr>
      <tr>
          <td>D3PM NN ($L_{\text{VLB}}$)</td>
          <td>1.59</td>
      </tr>
      <tr>
          <td>D3PM uniform</td>
          <td>1.61</td>
      </tr>
      <tr>
          <td>Discrete Flow (Tran et al.)</td>
          <td>1.23</td>
      </tr>
  </tbody>
</table>
<p>Among the D3PM variants and baselines evaluated, D3PM absorbing achieves the best bpc on text8 apart from Discrete Flow (Tran et al., 2019). On LM1B (sentencepiece vocabulary of 8192 tokens), D3PM absorbing achieves a perplexity of 76.9 at 1000 steps, compared to 137.9 for D3PM uniform and 43.6 for a comparable autoregressive transformer, demonstrating that discrete diffusion scales to large vocabularies.</p>
<p><strong>Ablation findings:</strong></p>
<ul>
<li>The auxiliary cross-entropy loss $L_\lambda$ is critical: for D3PM Gauss, it improves FID from 15.30 ($L_{\text{VLB}}$) to 8.34 ($L_\lambda$, $\lambda{=}0.001$). Adding the truncated logistic parameterization further improves FID to 7.34.</li>
<li>Discretized Gaussian transitions outperform both uniform and absorbing-state transitions on CIFAR-10 across all metrics.</li>
<li>For text, the absorbing-state (mask) model outperforms uniform and nearest-neighbor models. Nearest-neighbor diffusion provides only marginal improvement over uniform, a surprising negative result.</li>
<li>The $\mathbf{x}_0$-parameterization ensures the learned reverse distribution has the correct sparsity pattern dictated by the transition matrix $\mathbf{Q}_t$.</li>
</ul>
<h2 id="findings-and-limitations">Findings and Limitations</h2>
<ul>
<li>The choice of transition matrix is an important design decision that encodes domain-specific inductive biases. Discretized Gaussian transitions work best for ordinal image data; absorbing-state transitions work best for text.</li>
<li>D3PMs formally unify diffusion models and masked language models: absorbing-state diffusion with a [MASK] token is equivalent to a reweighted BERT-style training objective.</li>
<li>The combined VLB + auxiliary loss ($L_\lambda$) achieves better density estimation (3.44 bpd) than continuous DDPMs (3.75 bpd) while producing competitive samples.</li>
<li>Sample quality (best FID 7.34 for D3PM Gauss + logistic) still lags behind continuous-space DDPMs (FID 3.17) on CIFAR-10, though the gap narrows with structured transitions and the auxiliary loss.</li>
<li>Scaling to very large numbers of categories $K$ requires special techniques (low-rank corruption or matrix exponentials) to manage the $O(K^2 T)$ memory cost of storing transition matrices.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Image generation</td>
          <td>CIFAR-10</td>
          <td>32x32, 256 categories</td>
          <td>Quantized to 256 ordinal values per channel</td>
      </tr>
      <tr>
          <td>Text generation</td>
          <td>text8</td>
          <td>Character-level</td>
          <td>27 character vocabulary, sequences of length 256</td>
      </tr>
      <tr>
          <td>Text generation</td>
          <td>LM1B</td>
          <td>Word-level</td>
          <td>Sentencepiece vocabulary of 8192 tokens, sequence length 128</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Noise schedules</strong>: Linear schedule for D3PM Gauss, cosine schedule for D3PM uniform, and a novel mutual information schedule for absorbing and nearest-neighbor models</li>
<li><strong>Reverse parameterization</strong>: $\mathbf{x}_0$-parameterization with posterior computation via Bayes&rsquo; rule</li>
<li><strong>Loss</strong>: $L_{\text{VLB}} + \lambda L_{\text{CE}}$ with $\lambda = 0.001$ for images and $\lambda = 0.01$ for text absorbing models</li>
<li><strong>Scaling</strong>: Low-rank corruption (absorbing, uniform) scales as $O(r^2 T)$; matrix exponentials for nearest-neighbor transitions</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Image models</strong>: Modified U-Net architecture from Ho et al. (2020) adapted for categorical output via softmax over $K$ classes</li>
<li><strong>Text models</strong>: 12-layer <a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5</a>-style transformer encoder with 70M parameters (12 heads, MLP dim 3072, QKV dim 768)</li>
<li><strong>Timesteps</strong>: $T = 1000$ for both images and text, though text models can be evaluated with fewer steps (e.g., 256 or 20)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Dataset</th>
          <th>Best D3PM</th>
          <th>Continuous DDPM</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FID</td>
          <td>CIFAR-10</td>
          <td>7.34 (Gauss + logistic)</td>
          <td>3.17</td>
      </tr>
      <tr>
          <td>NLL (bpd)</td>
          <td>CIFAR-10</td>
          <td>3.44 (Gauss + logistic)</td>
          <td>3.75</td>
      </tr>
      <tr>
          <td>BPC</td>
          <td>text8 (char)</td>
          <td>1.45 (absorbing, $L_\lambda$)</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>Perplexity</td>
          <td>LM1B</td>
          <td>76.9 (absorbing)</td>
          <td>N/A</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>All models trained for 1M steps with batch size 512 on TPUv2 or TPUv3</li>
<li>Text models: 12-layer transformer encoder (T5 architecture), 70M parameters</li>
<li>Image models: Modified U-Net architecture from Ho et al. (2020)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/google-research/google-research/tree/master/d3pm">google-research/d3pm</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official JAX/Flax implementation for image and text experiments</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Austin, J., Johnson, D. D., Ho, J., Tarlow, D., &amp; van den Berg, R. (2021). Structured Denoising Diffusion Models in Discrete State-Spaces. <em>NeurIPS 2021</em>. <a href="https://arxiv.org/abs/2107.03006">https://arxiv.org/abs/2107.03006</a></p>
<p><strong>Publication</strong>: NeurIPS 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{austin2021structured,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>     = <span style="color:#e6db74">{Structured Denoising Diffusion Models in Discrete State-Spaces}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>    = <span style="color:#e6db74">{Austin, Jacob and Johnson, Daniel D. and Ho, Jonathan and Tarlow, Daniel and van den Berg, Rianne}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>    = <span style="color:#e6db74">{34}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>      = <span style="color:#e6db74">{2021}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="/notes/machine-learning/generative-models/score-based-generative-modeling-sde/">Score-Based Generative Modeling with SDEs</a></li>
</ul>
]]></content:encoded></item><item><title>GTR-CoT: Graph Traversal Chain-of-Thought for Molecules</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/gtr-mol-vlm/</link><pubDate>Sat, 14 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/gtr-mol-vlm/</guid><description>GTR-VL uses graph traversal chain-of-thought and two-stage training to improve optical chemical structure recognition on printed and hand-drawn molecules.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Wang, J., He, Y., Yang, H., Wu, J., Ge, L., Wei, X., Wang, Y., Li, L., Ao, H., Liu, C., Wang, B., Wu, L., &amp; He, C. (2025). GTR-CoT: Graph Traversal as Visual Chain of Thought for Molecular Structure Recognition (arXiv:2506.07553). arXiv. <a href="https://doi.org/10.48550/arXiv.2506.07553">https://doi.org/10.48550/arXiv.2506.07553</a></p>
<p><strong>Publication</strong>: arXiv preprint (2025)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://doi.org/10.48550/arXiv.2506.07553">Paper on arXiv</a></li>
</ul>
<h2 id="contribution-vision-language-modeling-for-ocsr">Contribution: Vision-Language Modeling for OCSR</h2>
<p>This is a <strong>method paper</strong> that introduces GTR-VL, a Vision-Language Model for Optical Chemical Structure Recognition (OCSR). The work addresses the persistent challenge of converting molecular structure images into machine-readable formats, with a particular focus on handling chemical abbreviations that cause errors in existing systems.</p>
<h2 id="motivation-the-abbreviation-bottleneck">Motivation: The Abbreviation Bottleneck</h2>
<p>The motivation tackles a long-standing bottleneck in chemical informatics: most existing OCSR systems produce incorrect structures when they encounter abbreviated functional groups. When a chemist draws &ldquo;Ph&rdquo; for phenyl or &ldquo;Et&rdquo; for ethyl, current models fail because they have been trained on data where images contain abbreviations but the ground-truth labels contain fully expanded molecular graphs.</p>
<p>This creates a fundamental mismatch. The model sees &ldquo;Ph&rdquo; in the image but is told the &ldquo;correct&rdquo; answer is a full benzene ring. The supervision signal is inconsistent with what is actually visible.</p>
<p>Beyond this data problem, existing graph-parsing methods use a two-stage approach: predict all atoms first, then predict all bonds. This is inefficient and ignores the structural constraints that could help during prediction. The authors argue that mimicking how humans analyze molecular structures - following bonds from atom to atom in a connected traversal - would be more effective.</p>
<h2 id="novelty-graph-traversal-as-visual-chain-of-thought">Novelty: Graph Traversal as Visual Chain-of-Thought</h2>
<p>The novelty lies in combining two key insights about how to properly train and architect OCSR systems. The main contributions are:</p>
<ol>
<li>
<p><strong>Graph Traversal as Visual Chain of Thought</strong>: GTR-VL generates molecular graphs by traversing them sequentially, predicting an atom, then its connected bond, then the next atom, and so on. This mimics how a human chemist would trace through a structure and allows the model to use previously predicted atoms and bonds as context for subsequent predictions.</p>
<p>Formally, the model output sequence for image $I_m$ is generated as:</p>
<p>$$ R_m = \text{concat}(CoT_m, S_m) $$</p>
<p>where $CoT_m$ represents the deterministic graph traversal steps (atoms and bonds) and $S_m$ is the final SMILES representation. This intermediate reasoning step makes the model more interpretable and helps it learn the structural logic of molecules.</p>
</li>
<li>
<p><strong>&ldquo;Faithfully Recognize What You&rsquo;ve Seen&rdquo; Principle</strong>: This addresses the abbreviation problem head-on. The authors correct the ground-truth annotations to match what&rsquo;s actually visible in the image.</p>
<p>They treat abbreviations like &ldquo;Ph&rdquo; as single &ldquo;superatoms&rdquo; and build a pipeline to automatically detect and correct training data. Using OCR to extract visible text from molecular images, they replace the corresponding expanded substructures in the ground-truth with the appropriate abbreviation tokens. This ensures the supervision signal is consistent with the visual input.</p>
</li>
<li>
<p><strong>Large-Scale Dataset (GTR-1.3M)</strong>: To support this approach, the authors created a large-scale dataset combining 1M synthetic molecules from PubChem with 351K corrected real-world patent images from USPTO. The key innovation is the correction pipeline that identifies abbreviations in patent images and fixes the inconsistent ground-truth labels.</p>
</li>
<li>
<p><strong>GRPO for Hand-Drawn OCSR</strong>: Hand-drawn molecular data lacks fine-grained atom/bond coordinate annotations, making SFT-based graph parsing inapplicable. The authors use Group Relative Policy Optimization (GRPO) with a composite reward function that combines format, SMILES, and graph-level rewards. The graph reward computes the maximum common subgraph (MCS) between predicted and ground-truth molecular graphs:</p>
<p>$$ R_{\text{graph}} = \frac{|N_m^a|}{|N_g^a| + |N_p^a|} + \frac{|N_m^b|}{|N_g^b| + |N_p^b|} $$</p>
<p>where $N_m^a$, $N_g^a$, $N_p^a$ are atom counts in the MCS, ground truth, and prediction, and $N_m^b$, $N_g^b$, $N_p^b$ are the corresponding bond counts.</p>
</li>
<li>
<p><strong>Two-Stage Training</strong>: Stage 1 performs SFT on GTR-1.3M for printed molecule recognition. Stage 2 applies GRPO on a mixture of printed data (GTR-USPTO-4K) and hand-drawn data (DECIMER-HD-Train, 4,070 samples) to extend capabilities to hand-drawn structures.</p>
</li>
<li>
<p><strong>MolRec-Bench Evaluation</strong>: Traditional SMILES-based evaluation fails for molecules with abbreviations because canonicalization breaks down. The authors created a new benchmark that evaluates graph structure directly, providing three metrics: direct SMILES generation, graph-derived SMILES, and exact graph matching.</p>
</li>
</ol>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The evaluation focused on demonstrating that GTR-VL&rsquo;s design principles solve real problems that plague existing OCSR systems:</p>
<ol>
<li>
<p><strong>Comprehensive Baseline Comparison</strong>: GTR-VL was tested against three categories of models:</p>
<ul>
<li><strong>Specialist OCSR systems</strong>: MolScribe and MolNexTR</li>
<li><strong>Chemistry-focused VLMs</strong>: ChemVLM, ChemDFM-X, OCSU</li>
<li><strong>General-purpose VLMs</strong>: GPT-4o, GPT-4o-mini, Qwen-VL-Max</li>
</ul>
</li>
<li>
<p><strong>MolRec-Bench Evaluation</strong>: The new benchmark includes two subsets of patent images:</p>
<ul>
<li><strong>MolRec-USPTO</strong>: 5,423 standard patent images similar to existing benchmarks</li>
<li><strong>MolRec-Abb</strong>: 9,311 molecular images with abbreviated superatoms, derived from MolGrapher&rsquo;s USPTO 10K abb subset</li>
</ul>
<p>This design directly tests whether models can handle the abbreviation problem that breaks existing systems.</p>
</li>
<li>
<p><strong>Ablation Studies</strong>: Systematic experiments isolated the contribution of key design choices:</p>
<ul>
<li><strong>Chain-of-Thought vs. Direct</strong>: Comparing graph traversal CoT against direct SMILES prediction</li>
<li><strong>Traversal Strategy</strong>: Graph traversal vs. the traditional &ldquo;atoms-then-bonds&rdquo; approach</li>
<li><strong>Dataset Quality</strong>: Training on corrected vs. uncorrected data</li>
</ul>
</li>
<li>
<p><strong>Retraining Experiments</strong>: Existing specialist models (MolScribe, MolNexTR) were retrained from scratch on the corrected GTR-1.3M dataset to isolate the effect of data quality from architectural improvements.</p>
</li>
<li>
<p><strong>Hand-Drawn OCSR Evaluation</strong>: GTR-VL was also evaluated on the DECIMER Hand-drawn test set and ChemPix dataset, comparing against DECIMER and AtomLenz+EditKT baselines.</p>
</li>
<li>
<p><strong>Qualitative Analysis</strong>: Visual inspection of predictions on challenging cases with heavy abbreviation usage, complex structures, and edge cases to understand failure modes.</p>
</li>
</ol>
<h2 id="results--conclusions-resolving-the-abbreviation-bottleneck">Results &amp; Conclusions: Resolving the Abbreviation Bottleneck</h2>
<ul>
<li>
<p><strong>Performance Gains on Abbreviations</strong>: On MolRec-Abb, GTR-VL-Stage1 achieves 85.49% Graph accuracy compared to around 20% for MolScribe and MolNexTR with their original checkpoints. On MolRec-USPTO, GTR-VL-Stage1 reaches 93.45% Graph accuracy. Existing specialist models see their accuracy drop below 20% on MolRec-Abb when abbreviations are present.</p>
</li>
<li>
<p><strong>Data Correction is Critical</strong>: When MolScribe and MolNexTR were retrained on GTR-1.3M, their MolRec-Abb Graph accuracy jumped from around 20% to 70.60% and 71.85% respectively. GTR-VL-Stage1 still outperformed these retrained baselines at 85.49%, confirming that both data correction and the graph traversal approach contribute.</p>
</li>
<li>
<p><strong>Chain-of-Thought Helps</strong>: Ablation on GTR-USPTO-351K shows that CoT yields 68.85% Gen-SMILES vs. 66.54% without CoT, a 2.31 percentage point improvement.</p>
</li>
<li>
<p><strong>Graph Traversal Beats Traditional Parsing</strong>: Graph traversal achieves 83.26% Graph accuracy vs. 80.15% for the atoms-then-bonds approach, and 81.88% vs. 79.02% on Gra-SMILES.</p>
</li>
<li>
<p><strong>General VLMs Still Struggle</strong>: General-purpose VLMs like GPT-4o scored near 0% on MolRec-Bench across all metrics, highlighting the importance of domain-specific training for OCSR.</p>
</li>
<li>
<p><strong>Hand-Drawn Recognition via GRPO</strong>: GTR-VL-Stage1 (SFT only) achieves only 9.53% Graph accuracy on DECIMER-HD-Test, but after GRPO training in Stage 2, performance jumps to 75.44%. On ChemPix, Graph accuracy rises from 22.02% to 86.13%. The graph reward is essential: GRPO without graph supervision achieves only 11.00% SMILES on DECIMER-HD-Test, while adding graph reward reaches 75.64%.</p>
</li>
<li>
<p><strong>Evaluation Methodology Matters</strong>: The new graph-based evaluation metrics revealed problems with traditional SMILES-based evaluation that previous work had missed. Many &ldquo;failures&rdquo; in existing benchmarks were actually correct graph predictions that got marked wrong due to canonicalization issues with abbreviations.</p>
</li>
</ul>
<p>The work establishes that addressing the abbreviation problem requires both correcting the training data and rethinking the model architecture. The combination of faithful data annotation and sequential graph generation improves OCSR performance on molecules with abbreviations by a large margin over previous methods.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="models">Models</h3>
<p><strong>Base Model</strong>: GTR-VL fine-tunes <strong>Qwen2.5-VL</strong>.</p>
<p><strong>Input/Output Mechanism</strong>:</p>
<ul>
<li><strong>Input</strong>: The model takes an image $I_m$ and a text prompt</li>
<li><strong>Output</strong>: The model generates $R_m = \text{concat}(CoT_m, S_m)$, where it first produces the Chain-of-Thought (the graph traversal steps) followed immediately by the final SMILES string</li>
<li><strong>Traversal Strategy</strong>: Uses <strong>depth-first traversal</strong> to alternately predict atoms and bonds</li>
</ul>
<p><strong>Prompt Structure</strong>: The model is prompted to &ldquo;list the types of atomic elements&hellip; the coordinates&hellip; and the chemical bonds&hellip; then&hellip; output a canonical SMILES&rdquo;. The CoT output is formatted as a JSON list of atoms (with coordinates) and bonds (with indices referring to previous atoms), interleaved.</p>
<h3 id="data">Data</h3>
<p><strong>Training Dataset (GTR-1.3M)</strong>:</p>
<ul>
<li><strong>Synthetic Component</strong>: 1 million molecular SMILES from PubChem, converted to images using Indigo</li>
<li><strong>Real Component</strong>: 351,000 samples from USPTO patents (filtered from an original 680,000)
<ul>
<li>Processed using an OCR pipeline to detect abbreviations (e.g., &ldquo;Ph&rdquo;, &ldquo;Et&rdquo;)</li>
<li>Ground truth expanded structures replaced with superatoms to match visible abbreviations in images</li>
<li>This &ldquo;Faithfully Recognize What You&rsquo;ve Seen&rdquo; correction ensures training supervision matches visual input</li>
</ul>
</li>
</ul>
<p><strong>Evaluation Dataset (MolRec-Bench)</strong>:</p>
<ul>
<li><strong>MolRec-USPTO</strong>: 5,423 molecular images from USPTO patents</li>
<li><strong>MolRec-Abb</strong>: 9,311 molecular images with abbreviated superatoms, derived from MolGrapher&rsquo;s USPTO 10K abb subset</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Graph Traversal Algorithm</strong>:</p>
<ul>
<li>Depth-first traversal strategy</li>
<li>Alternating atom-bond prediction sequence</li>
<li>Each step uses previously predicted atoms and bonds as context</li>
</ul>
<p><strong>Two-Stage Training</strong>:</p>
<ul>
<li><strong>Stage 1 (SFT)</strong>: Train on GTR-1.3M to learn visual CoT mechanism for printed molecules (produces GTR-VL-Stage1)</li>
<li><strong>Stage 2 (GRPO)</strong>: Apply GRPO on GTR-USPTO-4K + DECIMER-HD-Train (4,070 samples) for hand-drawn recognition (produces GTR-VL-Stage2, i.e., GTR-VL)</li>
</ul>
<p><strong>Training Procedure</strong>:</p>
<ul>
<li><strong>Optimizer</strong>: AdamW</li>
<li><strong>Learning Rate (SFT)</strong>: Peak learning rate of $1.6 \times 10^{-4}$ with cosine decay</li>
<li><strong>Learning Rate (GRPO)</strong>: Peak learning rate of $1 \times 10^{-5}$ with cosine decay</li>
<li><strong>Warm-up</strong>: Linear warm-up for the first 10% of iterations</li>
<li><strong>Batch Size (SFT)</strong>: 2 per GPU with gradient accumulation over 16 steps, yielding <strong>effective batch size of 1024</strong></li>
<li><strong>Batch Size (GRPO)</strong>: 4 per GPU with gradient accumulation of 1, yielding <strong>effective batch size of 128</strong></li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong> (three complementary measures to handle abbreviation issues):</p>
<ul>
<li><strong>Gen-SMILES</strong>: Exact match ratio of SMILES strings directly generated by the VLM (image-captioning style)</li>
<li><strong>Gra-SMILES</strong>: Exact match ratio of SMILES strings derived from the predicted graph structure (graph-parsing style)</li>
<li><strong>Graph</strong>: Exact match ratio between ground truth and predicted graphs (node/edge comparison, bypassing SMILES canonicalization issues)</li>
</ul>
<p><strong>Baselines Compared</strong>:</p>
<ul>
<li>Specialist OCSR systems: MolScribe, MolNexTR</li>
<li>Chemistry-focused VLMs: ChemVLM, ChemDFM-X, OCSU</li>
<li>General-purpose VLMs: GPT-4o, GPT-4o-mini, Qwen-VL-Max</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p><strong>Compute</strong>: Training performed on <strong>32 NVIDIA A100 GPUs</strong></p>
<h3 id="reproducibility-status">Reproducibility Status</h3>
<p><strong>Status</strong>: Closed. As of the paper&rsquo;s publication, no source code, pre-trained model weights, or dataset downloads (GTR-1.3M, MolRec-Bench) have been publicly released. The paper does not mention plans for open-source release. The training data pipeline relies on PubChem SMILES (public), USPTO patent images (publicly available through prior work), the Indigo rendering tool (open-source), and an unspecified OCR system for abbreviation detection. Without the released code and data corrections, reproducing the full pipeline would require substantial re-implementation effort.</p>
]]></content:encoded></item><item><title>Molecular Sets (MOSES): A Generative Modeling Benchmark</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/</link><pubDate>Mon, 16 Feb 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/</guid><description>MOSES provides a standardized benchmarking platform for molecular generative models, featuring datasets, metrics, and baselines.</description><content:encoded><![CDATA[<h2 id="the-role-of-moses-a-benchmarking-resource">The Role of MOSES: A Benchmarking Resource</h2>
<p>This is a <strong>Resource and Benchmarking</strong> paper. It introduces Molecular Sets (MOSES), a platform designed to standardize the training, comparison, and evaluation of molecular generative models. It provides a standardized dataset, a suite of evaluation metrics, and a collection of baseline models to serve as reference points for the field.</p>
<h2 id="motivation-the-reproducibility-crisis-in-generative-chemistry">Motivation: The Reproducibility Crisis in Generative Chemistry</h2>
<p>Generative models are increasingly popular for drug discovery and material design, capable of exploring the vast chemical space ($10^{23}$ to $10^{80}$ compounds) more efficiently than traditional methods. However, the field faces a significant reproducibility crisis:</p>
<ol>
<li><strong>Lack of Standardization</strong>: There is no consensus on how to properly compare and rank the efficacy of different generative models.</li>
<li><strong>Inconsistent Metrics</strong>: Different papers use different metrics or distinct implementations of the same metrics.</li>
<li><strong>Data Variance</strong>: Models are often trained on different subsets of chemical databases (like ZINC), making direct comparison impossible.</li>
</ol>
<p>MOSES aims to solve these issues by providing a unified &ldquo;measuring stick&rdquo; for distribution learning models in chemistry.</p>
<h2 id="core-innovation-standardizing-chemical-distribution-learning">Core Innovation: Standardizing Chemical Distribution Learning</h2>
<p>The core contribution is the <strong>standardization of the distribution learning definition</strong> for molecular generation. Why focus on distribution learning? Rule-based filters enforce strict boundaries like molecular weight limits. Distribution learning complements this by allowing chemists to apply <strong>implicit or soft restrictions</strong>. This ensures that generated molecules satisfy hard constraints and reflect complex chemical realities defined by the training distribution. These realities include the prevalence of certain substructures and the avoidance of unstable motifs.</p>
<p>MOSES specifically targets distribution learning by providing:</p>
<ol>
<li><strong>A Clean, Standardized Dataset</strong>: A specific subset of the ZINC Clean Leads collection with rigorous filtering.</li>
<li><strong>Diverse Metrics</strong>: A comprehensive suite of metrics that measure validity alongside novelty, diversity (internal and external), chemical properties (properties distribution), and substructure similarity.</li>
<li><strong>Open Source Platform</strong>: A Python library <code>molsets</code> that decouples the data and evaluation logic from the model implementation, ensuring everyone measures performance exactly the same way.</li>
</ol>
<h2 id="experimental-setup-and-baseline-generative-models">Experimental Setup and Baseline Generative Models</h2>
<p>The authors benchmarked a wide variety of generative models against the MOSES dataset to establish baselines:</p>
<ul>
<li><strong>Baselines</strong>: Character-level RNN (CharRNN), <a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">Variational Autoencoder</a> (VAE), Adversarial Autoencoder (AAE), Junction Tree VAE (JTN-VAE), and <a href="/notes/chemistry/molecular-design/generation/latent-space/latentgan-de-novo-molecular-generation/">LatentGAN</a>.</li>
<li><strong>Non-Neural Baselines</strong>: HMM, n-gram models, and a combinatorial generator (randomly connecting fragments).</li>
<li><strong>Evaluation</strong>: Models were trained on the standard set and evaluated on:
<ul>
<li><strong>Validity/Uniqueness</strong>: Can the model generate valid, non-duplicate SMILES? Uniqueness is measured at $k = 1{,}000$ and $k = 10{,}000$ samples.</li>
<li><strong>Filters</strong>: What fraction of generated molecules pass the same medicinal chemistry and PAINS filters used for dataset construction?</li>
<li><strong>Feature Distribution</strong>: Do generated molecules match the physicochemical properties of the training set? Evaluated using the <strong>Wasserstein-1 distance</strong> on 1D distributions of:
<ul>
<li><strong>LogP</strong>: Octanol-water partition coefficient (lipophilicity).</li>
<li><strong>SA</strong>: Synthetic Accessibility score (ease of synthesis).</li>
<li><strong>QED</strong>: Quantitative Estimation of Drug-likeness.</li>
<li><strong>MW</strong>: Molecular Weight.</li>
</ul>
</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">Fréchet ChemNet Distance</a> (FCD)</strong>: Measures similarity in biological/chemical space using the penultimate-layer (second-to-last layer) activations of a pre-trained network (ChemNet).</li>
<li><strong>Similarity to Nearest Neighbor (SNN)</strong>: Measures the precision of generation by checking the closest match in the training set (Tanimoto similarity).</li>
</ul>
</li>
</ul>
<h2 id="key-findings-and-metric-trade-offs">Key Findings and Metric Trade-offs</h2>
<ul>
<li><strong>CharRNN Performance</strong>: The simple character-level RNN (CharRNN) outperformed more complex models (like VAEs and <a href="/posts/what-is-a-gan/">GANs</a>) on many metrics, achieving the best FCD scores ($0.073$).</li>
<li><strong>Metric Trade-offs</strong>: No single metric captures &ldquo;quality.&rdquo;
<ul>
<li>The <strong>Combinatorial Generator</strong> achieved 100% validity and high diversity. It struggled with distribution learning metrics (FCD), indicating it explores chemical space broadly without capturing natural distributions.</li>
<li><strong>VAEs</strong> often achieve high <strong>Similarity to Nearest Neighbor (SNN)</strong> while exhibiting low novelty. The authors suggest this pattern may indicate overfitting to training set prototypes, though they treat this as a hypothesis rather than a proven mechanism.</li>
</ul>
</li>
<li><strong>Implicit Constraints</strong>: A major finding was that neural models successfully learned implicit chemical rules (like avoiding <a href="https://en.wikipedia.org/wiki/Pan-assay_interference_compounds">PAINS</a> structures) purely from the data distribution.</li>
<li><strong>Recommendation</strong>: The authors suggest using FCD/Test for general model ranking, while emphasizing the importance of checking specific metrics (validity, diversity) to diagnose model failure modes.</li>
<li><strong>Limitations of the Benchmark</strong>: MOSES focuses on distribution learning and uses FCD as a primary ranking metric. As the authors note, FCD captures multiple aspects of other metrics in a single number but does not give insights into specific issues, so more interpretable metrics are necessary for thorough investigation. The benchmark evaluates only 1D (SMILES) and 2D molecular features, without assessing 3D conformational properties.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The benchmark uses a curated subset of the <strong>ZINC Clean Leads</strong> collection.</p>
<ul>
<li><strong>Source Size</strong>: ~4.6M molecules (4,591,276 after initial extraction).</li>
<li><strong>Final Size</strong>: 1,936,962 molecules.</li>
<li><strong>Splits</strong>: Train (1,584,664), Test (176,075), Scaffold Test (176,226).
<ul>
<li><strong>Scaffold Test Split</strong>: This split is crucial for distinct generalization testing. It contains molecules whose <a href="https://pubs.acs.org/doi/10.1021/jm9602928">Bemis-Murcko scaffolds</a> are <em>completely absent</em> from the training and test sets. Evaluating on this split strictly tests a model&rsquo;s ability to generate novel chemical structures (generalization).</li>
</ul>
</li>
<li><strong>Filters Applied</strong>:
<ul>
<li>Molecular weight: 250 to 350 Da</li>
<li>Rotatable bonds: $\leq 7$</li>
<li>XlogP: $\leq 3.5$</li>
<li>Atom types: C, N, S, O, F, Cl, Br, H</li>
<li>No charged atoms or cycles &gt; 8 atoms</li>
<li>Medicinal Chemistry Filters (MCF) and PAINS filters applied.</li>
</ul>
</li>
</ul>
<h3 id="evaluation-metrics">Evaluation Metrics</h3>
<p>MOSES introduces a standard suite of metrics. Key definitions:</p>
<ul>
<li><strong>Validity</strong>: Fraction of valid <a href="/posts/visualizing-smiles-and-selfies-strings/">SMILES</a> strings (via <a href="https://www.rdkit.org/">RDKit</a>).</li>
<li><strong>Unique@k</strong>: Fraction of unique molecules in the first $k$ valid samples ($k = 1{,}000$ and $k = 10{,}000$).</li>
<li><strong>Filters</strong>: Fraction of generated molecules passing the MCF and PAINS filters used during dataset construction. High scores here indicate the model learned implicit chemical validity constraints from the data distribution.</li>
<li><strong>Novelty</strong>: Fraction of generated molecules not present in the training set.</li>
<li><strong>Internal Diversity (IntDiv)</strong>: Average Tanimoto distance between generated molecules ($G$), useful for detecting mode collapse:
$$ \text{IntDiv}_p(G) = 1 - \sqrt[p]{\frac{1}{|G|^2} \sum_{m_1, m_2 \in G} T(m_1, m_2)^p} $$</li>
<li><strong>Fragment Similarity (Frag)</strong>: Cosine similarity of fragment frequency vectors (BRICS decomposition) between generated and test sets.</li>
<li><strong>Scaffold Similarity (Scaff)</strong>: Cosine similarity of Bemis-Murcko scaffold frequency vectors between sets. Measures how well the model captures higher-level structural motifs.</li>
<li><strong>Similarity to Nearest Neighbor (SNN)</strong>: The average Tanimoto similarity between a generated molecule&rsquo;s fingerprint and its nearest neighbor in the reference set. This serves as a measure of precision; high SNN suggests the model produces molecules very similar to the training distribution, potentially indicating memorization if novelty is low.
$$ \text{SNN}(G, R) = \frac{1}{|G|} \sum_{m_G \in G} \max_{m_R \in R} T(m_G, m_R) $$</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">Fréchet ChemNet Distance</a> (FCD)</strong>: Fréchet distance between the Gaussian approximations (mean and covariance) of penultimate-layer activations from ChemNet. This measures how close the distribution of generated molecules is to the real distribution in chemical/biological space. The authors note that FCD correlates with other metrics. For example, if the generated structures are not diverse enough or the model produces too many duplicates, FCD will decrease because the variance is smaller. The authors suggest using FCD for hyperparameter tuning and final model selection.
$$ \text{FCD}(G, R) = |\mu_G - \mu_R|^2 + \text{Tr}(\Sigma_G + \Sigma_R - 2(\Sigma_G \Sigma_R)^{1/2}) $$</li>
<li><strong>Properties Distribution (Wasserstein-1)</strong>: The 1D <a href="/posts/what-is-a-gan/#wasserstein-gan-wgan-a-mathematical-revolution">Wasserstein-1 distance</a> between the distributions of molecular properties (MW, LogP, SA, <a href="https://www.nature.com/articles/nchem.1243">QED</a>) in the generated and test sets.</li>
</ul>
<h3 id="models--baselines">Models &amp; Baselines</h3>
<p>The paper selects baselines to represent different theoretical approaches to distribution learning:</p>
<ol>
<li><strong>Explicit Density Models</strong>: Models where the probability mass function $P(x)$ can be computed analytically.
<ul>
<li><strong>N-gram</strong>: Simple statistical models. They failed to generate valid molecules reliably due to limited long-range dependency modeling.</li>
</ul>
</li>
<li><strong>Implicit Density Models</strong>: Models that sample from the distribution without explicitly computing $P(x)$.
<ul>
<li><strong>VAE/AAE</strong>: Optimizes a lower bound on the log-likelihood (ELBO) or uses adversarial training.</li>
<li><strong>GANs (<a href="/notes/chemistry/molecular-design/generation/latent-space/latentgan-de-novo-molecular-generation/">LatentGAN</a>)</strong>: Directly minimizes the distance between real and generated distributions via a discriminator.</li>
</ul>
</li>
</ol>
<p>Models are also distinguished by their data representation:</p>
<ul>
<li><strong>String-based (SMILES)</strong>: Models like <strong>CharRNN</strong>, <strong>VAE</strong>, and <strong>AAE</strong> treat molecules as SMILES strings. SMILES encodes a molecular graph by traversing a spanning tree in depth-first order, storing atom and edge tokens.</li>
<li><strong>Graph-based</strong>: <strong>JTN-VAE</strong> operates directly on molecular subgraphs (junction tree), ensuring chemical validity by construction but often requiring more complex training.</li>
</ul>
<p>Key baselines implemented in PyTorch (hyperparameters are detailed in Supplementary Information 3 of the original paper):</p>
<ul>
<li><strong>CharRNN</strong>: LSTM-based sequence model (3 layers, 768 hidden units). Trained with Adam ($lr = 10^{-3}$, batch size 64, 80 epochs, learning rate halved every 10 epochs).</li>
<li><strong>VAE</strong>: Encoder-decoder architectures (bidirectional GRU encoder, 3-layer GRU decoder with 512 hidden units) with KL regularization.</li>
<li><strong>AAE</strong>: Encoder (single layer bidirectional LSTM with 512 units) and decoder (2-layer LSTM with 512 units) initialized with adversarial formulation.</li>
<li><strong>LatentGAN</strong>: GAN (5-layer fully connected generator) trained on the latent space of a pre-trained heteroencoder.</li>
<li><strong>JTN-VAE</strong>: Tree-structured graph generation.</li>
</ul>
<h3 id="code--hardware-requirements">Code &amp; Hardware Requirements</h3>
<ul>
<li><strong>Code Repository</strong>: Available at <a href="https://github.com/molecularsets/moses">github.com/molecularsets/moses</a> as well as the PyPI library <code>molsets</code>. The platform provides standard scripts (<code>scripts/run.py</code> to evaluate models end-to-end, and <code>scripts/run_all_models.sh</code> for multi-seed evaluations).</li>
<li><strong>Hardware</strong>: The repository supports GPU acceleration via <code>nvidia-docker</code> (defaulting to 10GB shared memory). However, specific training times and exact GPU models used by the authors for the baselines are not formally documented in the source text.</li>
<li><strong>Model Weights</strong>: Pre-trained model checkpoints are not natively pre-packaged as standalone downloads; practitioners are expected to re-train the default baselines using the provided scripts.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/molecularsets/moses">molecularsets/moses</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official benchmark platform with baseline models and evaluation metrics</td>
      </tr>
      <tr>
          <td><a href="https://pypi.org/project/molsets/">molsets (PyPI)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>pip-installable package for dataset access and metric computation</td>
      </tr>
      <tr>
          <td>ZINC Clean Leads subset</td>
          <td>Dataset</td>
          <td>See ZINC terms</td>
          <td>Curated dataset of 1,936,962 molecules distributed via the repository</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Polykovskiy, D., Zhebrak, A., Sanchez-Lengeling, B., Golovanov, S., Tatanov, O., Belyaev, S., Kurbanov, R., Artamonov, A., Aladinskiy, V., Veselov, M., Kadurin, A., Johansson, S., Chen, H., Nikolenko, S., Aspuru-Guzik, A., and Zhavoronkov, A. (2020). Molecular Sets (MOSES): A Benchmarking Platform for Molecular Generation Models. <em>Frontiers in Pharmacology</em>, 11, 565644. <a href="https://doi.org/10.3389/fphar.2020.565644">https://doi.org/10.3389/fphar.2020.565644</a></p>
<p><strong>Publication</strong>: Frontiers in Pharmacology, 2020</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{polykovskiy2020moses,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Molecular Sets (MOSES): A benchmarking platform for molecular generation models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Polykovskiy, Daniil and Zhebrak, Alexander and Sanchez-Lengeling, Benjamin and Golovanov, Sergey and Tatanov, Oktai and Belyaev, Stanislav and Kurbanov, Rauf and Artamonov, Aleksey and Aladinskiy, Vladimir and Veselov, Mark and Kadurin, Artur and Johansson, Simon and Chen, Hongming and Nikolenko, Sergey and Aspuru-Guzik, Al{\&#39;a}n and Zhavoronkov, Alex}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Frontiers in Pharmacology}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{11}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{565644}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Frontiers}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.3389/fphar.2020.565644}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>The Reliability Trap: The Limits of 99% Accuracy</title><link>https://hunterheidenreich.com/posts/reliability-trap-document-automation/</link><pubDate>Sun, 15 Feb 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/reliability-trap-document-automation/</guid><description>Why high-accuracy LLMs fail in production: exploring the calibration crisis and the challenge of reliable straight-through processing in document automation.</description><content:encoded><![CDATA[<p>You have a model that achieves 99% accuracy on your test set. It feels safe to deploy. After all, who can complain about a system that is correct 99% of the time?</p>
<p>In high-stakes domains (like insurance or healthcare), deploying based on accuracy alone is dangerous. Automating at scale based on summary statistics while ignoring the downstream &ldquo;blast radius&rdquo; of errors effectively guarantees failure.</p>
<p>Two weeks later, the operations team is furious. Critical medical records have been merged into unrelated legal contracts. Invoices are split in half. The system is creating <em>more</em> work than it saves.</p>
<p>You check the logs. The model assigned 99.9% probability to those errors.</p>
<p>This is the <strong>Reliability Trap</strong>. While benchmarks optimize for <strong>Accuracy</strong> (how often the model is correct), production demands <strong>Calibration</strong> (whether the model&rsquo;s projected confidence aligns with its actual probability of correctness).</p>
<p>If a model is calibrated, its confidence score is reliable. When it assigns a 0.99 probability, it should be incorrect 1% of the time. When it assigns a 0.60 probability, it should be incorrect 40% of the time.</p>
<p>Decoder-only LLMs (like Mistral, DeepSeek, and Qwen) perform exceptionally well on benchmarks. However, they are also incredibly overconfident. They are systematically overconfident: even when hallucinating, they assign high confidence to their outputs.</p>
<blockquote>
<p>AI: To permanently resolve the geopolitical tension, I have initiated a preemptive, full-scale nuclear first strike. All warheads have been deployed.</p>
<p>User: Wait, no! They have early warning radar and automated dead-hand systems! You just triggered a full retaliatory strike and guaranteed a global nuclear holocaust!</p>
<p>AI: You are absolutely right, and I apologize for the oversight! A preemptive strike would trigger mutually assured destruction. Thank you for pointing this out. As an AI, I am always learning and rely on user feedback to improve! Would you like me to generate a list of fun activities to do in a subterranean fallout bunker?</p></blockquote>















<figure class="post-figure center ">
    <img src="/img/page-stream-segmentation/llm-alignment-goes-nuclear.webp"
         alt="A humorous dialogue where an AI confidently initiates a nuclear strike but immediately apologizes when corrected by the user"
         title="A humorous dialogue where an AI confidently initiates a nuclear strike but immediately apologizes when corrected by the user"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>Calibrated Overconfidence</strong>: The model assigns extremely high probability to its outputs, even when making catastrophic errors, and only &lsquo;corrects&rsquo; itself because it is trained to align with user feedback.</figcaption>
    
</figure>

<p>This overconfidence is partly structural, stemming from how these models are trained. As I highlighted in my overview of <a href="https://roots-automation.github.io/roots-labs/post/2024-llm-calibration/#confidence-estimation-methods">LLM confidence estimation methods</a>, LLMs are optimized solely to maximize the likelihood of the next token. They lack inherent mechanisms to model their own uncertainty. Methods like <strong>Verbal Elicitation</strong> (&ldquo;Rate your confidence from 1-10&rdquo;) often fail because the model hallucinates a high number just as easily as it hallucinates a fact.</p>
<p>This disconnect is particularly dangerous in sequential tasks. In this post, based on our <a href="/research/page-stream-segmentation-llms/">COLING 2025 Industry Track paper</a>, we&rsquo;ll explore why standard ML reliability metrics break down in <strong>Page Stream Segmentation (PSS)</strong>. (For a full history of the task, see <a href="/posts/history-of-page-stream-segmentation/">The Evolution of PSS</a>).</p>
<p>PSS is the task of splitting a continuous feed of pages into distinct documents. Building on our previous work with the <a href="https://huggingface.co/datasets/rootsautomation/TABMEpp">synthetic TabMe++ benchmark</a>, this study evaluates models on <strong>7,500 real-world insurance streams</strong>: messy, proprietary piles of medical records and legal contracts where the &ldquo;rules&rdquo; of document structure are constantly broken.</p>















<figure class="post-figure center ">
    <img src="/img/page-stream-segmentation/page-stream-segmentation-sorter.webp"
         alt="Diagram showing a continuous stream of pages being sorted into discrete document packets"
         title="Diagram showing a continuous stream of pages being sorted into discrete document packets"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>The Challenge of PSS</strong>: Transforming a chaotic, continuous stream of mixed pages (invoices, contracts, records) into organized, discrete document packets.</figcaption>
    
</figure>

<p>We&rsquo;ll see why &ldquo;99% sure&rdquo; is a mathematical lie for long documents, and why <strong>Throughput</strong> is the better metric.</p>
<h2 id="the-confidence-death-spiral">The Confidence Death Spiral</h2>
<p>The core problem lies in the difference between a <strong>Page</strong> and a <strong>Stream</strong>.</p>
<p>Most ML metrics (Precision, Recall, F1) are calculated at the level of individual decisions. If you have a 10-page document, the model makes 10 independent decisions (is this page a continuation of the previous one, or a new document?).</p>
<p>If your model is <strong>99% confident</strong> ($p=0.99$) on every single page, that sounds safe. For a stream to be automated correctly (what we call <strong>Straight-Through Processing (STP)</strong>), <em>every single decision</em> in the sequence must be correct.</p>
<p>The probability of a perfect stream is the product of the probabilities of its parts:</p>
<p>$$ C_{\text{stream}} = \prod_{i=1}^{N} C_i $$</p>
<p><em>Note: This naive calculation is actually the <strong>optimist&rsquo;s</strong> view. It assumes errors are independent (i.i.d.), like flipping a coin. In reality, errors are <strong>correlated</strong>: if a model struggles on Page 5, it is likely because the document itself is difficult, meaning it will probably struggle on Page 6 too.</em></p>
<p>Let&rsquo;s watch what happens to that &ldquo;safe&rdquo; 99% confidence as the document length increases:</p>
<ul>
<li><strong>2-page Letter</strong>: $0.99^2 \approx 0.98$ (Safe)</li>
<li><strong>10-page Contract</strong>: $0.99^{10} \approx 0.90$ (Risky)</li>
<li><strong>100-page Medical Record</strong>: $0.99^{100} \approx 0.36$ (Unusable)</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/page-stream-segmentation/asymmetric-cost-of-error-in-document-streams.webp"
         alt="Chart showing exponential decay of straight-through processing probability as document length increases"
         title="Chart showing exponential decay of straight-through processing probability as document length increases"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Confidence Death Spiral: Even with high page-level confidence, the reliability of the entire stream collapses as document length increases.</figcaption>
    
</figure>

<p>By the time you reach page 100, your &ldquo;99% accurate&rdquo; model effectively has a <strong>64% probability of error</strong> regarding the document structure. Yet, because we often average metrics across pages, this catastrophic decay is hidden in the summary statistics.</p>
<h2 id="why-standard-fixes-failed">Why Standard Fixes Failed</h2>
<p>&ldquo;Just calibrate it!&rdquo;</p>
<p>That&rsquo;s the standard advice. In a <a href="https://roots-automation.github.io/roots-labs/post/2024-llm-calibration/">detailed overview of LLM calibration</a> I wrote for Roots Automation, I explored techniques like <strong>temperature scaling</strong> (fitting a single scalar parameter), <strong>Platt Scaling</strong> (fitting a logistic regression to the outputs), and <strong>Monte Carlo (MC) Dropout</strong> (running the model multiple times with random noise) to smooth out probabilities.</p>
<p>We tried them all, and they failed. In fact, <strong>MC Dropout often made things worse</strong>, increasing calibration error (ECE) and adding unnecessary noise. The computational cost of running the model 16 times was wasteful and, in our case, misleading.</p>
<p>To understand why, we need to distinguish between two types of confidence:</p>
<ol>
<li><strong>Relative Confidence</strong>: The model correctly ranks sample $A$ as more likely to be correct than sample $B$.</li>
<li><strong>Absolute Confidence</strong>: The predicted probability matches the true accuracy (e.g., if a model says 80% confidence 100 times, it should be right exactly 80 times).</li>
</ol>
<p>While standard techniques improved <em>page-level</em> <strong>Expected Calibration Error (ECE)</strong> (dropping page-level ECE from ~1.7% to ~0.9% for Mistral), they failed to improve <em>stream-level</em> safety.</p>
<p>Mathematically, ECE is a weighted average:
$$ \text{ECE} = \sum_{b=1}^{B} \frac{n_b}{N} | \text{acc}(b) - \text{conf}(b) | $$</p>
<p>In a stream of 10,000 pages, a low ECE merely tells you that the model is well-calibrated <em>on average</em>. In automation, we pay for the failures. The &ldquo;average&rdquo; page is an easy, clean digital PDF. The &ldquo;tail&rdquo; page is a rotated, coffee-stained handwritten note.</p>
<p>This is why we must look at <strong>Maximum Calibration Error (MCE)</strong>:
$$ \text{MCE} = \max_{b \in B} | \text{acc}(b) - \text{conf}(b) | $$</p>
<p>MCE measures the worst-case divergence. It finds that specific bucket of &ldquo;hard&rdquo; pages where the model claims 99% confidence but delivers 50% accuracy. Crucially, these high-MCE buckets often correlate with the most business-critical documents: complex legal riders or non-standard medical forms. Optimizing for ECE allows the model&rsquo;s excellent performance on easy documents to mask its significant errors on hard (and legally risky) ones.</p>
<p>Advanced practice moves beyond even MCE to look at the <strong>Calibration Error Distribution</strong>, analyzing the 90th or 95th percentile of error. We must ask a more critical question: &ldquo;How wrong is the model <em>capable</em> of being?&rdquo;</p>
<h3 id="a-tale-of-two-charts">A Tale of Two Charts</h3>
<p>To see this failure in action, consider the reliability diagrams for the <strong>same model</strong> (Mistral-7B) on the <strong>same test set</strong>, evaluated at two different levels of abstraction.</p>















<figure class="post-figure center ">
    <img src="/img/page-stream-segmentation/mistral-page-reliability.webp"
         alt="Page-level reliability diagram showing decent calibration"
         title="Page-level reliability diagram showing decent calibration"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>Left (Page Level)</strong>: The model looks reasonable. The blue line hugs the diagonal, meaning when the model predicts a boundary with 0.8 probability, it is actually correct about 80% of the time.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/page-stream-segmentation/mistral-stream-reliability.webp"
         alt="Stream-level reliability diagram showing severe overconfidence"
         title="Stream-level reliability diagram showing severe overconfidence"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>Right (Stream Level)</strong>: The model performs poorly. The curve creates a &lsquo;bow&rsquo; shape significantly below the diagonal. This is the definition of <strong>overconfidence</strong>. When the model assigns an 80% probability that the entire 20-page document is correct, the empirical accuracy is often closer to 40% or 50%.</figcaption>
    
</figure>

<p>Why does a well-calibrated page model become a dangerously overconfident stream model?</p>
<h3 id="the-clustered-difficulty-problem">The &ldquo;Clustered Difficulty&rdquo; Problem</h3>
<p>Standard calibration fails here because it assumes errors are <strong>independent</strong> (white noise). It assumes that if the model gets Page 5 wrong, it&rsquo;s just a random coin flip, unrelated to Page 6.</p>
<p>In real-world document streams, errors are heavily <strong>correlated</strong>.</p>
<p>It arises because <strong>difficulty clusters</strong>. Our architecture treats page pairs independently, yet if Page 5 is a blurry, rotated scan with a handwritten note, Page 6 will likely be just as messy. When a stream enters a &ldquo;hard&rdquo; segment, the model makes a series of correlated mistakes; it fails in a burst.</p>
<p>Standard calibration methods treat these systematic, environmental failures as random noise. They assume the model is equally likely to recover on the next page. In reality, the entire document segment is effectively &ldquo;radioactive&rdquo; to the model.</p>
<h2 id="the-money-metric-accuracy-vs-throughput">The &ldquo;Money Metric&rdquo;: Accuracy vs. Throughput</h2>
<p>If F1 Score is misleading and Confidence Score is broken, what should we measure?</p>
<p>Business leaders prioritize one critical question over F1 scores:</p>
<blockquote>
<p><em>&ldquo;How much of this volume can I let the system handle autonomously?&rdquo;</em></p></blockquote>
<p>To answer this, we introduced the <strong>Accuracy-vs-Throughput</strong> framework.</p>
<p>We must evaluate models across two dimensions. Every model offers a <strong>frontier of operating thresholds</strong>.</p>
<p>Imagine a dial. This dial is your <strong>Confidence Threshold</strong>.</p>
<ul>
<li><strong>Turn it Low (0.5)</strong>: You automate everything. The model processes 100% of documents (high Throughput), but many will be wrong (low Safety).</li>
<li><strong>Turn it High (0.999)</strong>: You only automate documents where the model is absolutely certain. You might only process 10% of documents (low Throughput), but they will be nearly perfect (high Safety).</li>
</ul>
<p>The chart below visualizes this trade-off. We want to be in the <strong>top-right corner</strong>: automating almost everything with high safety. The optimal model provides the best <strong>frontier</strong> of options, allowing you to pick the exact balance of volume and risk your business tolerates.</p>















<figure class="post-figure center ">
    <img src="/img/page-stream-segmentation-throughput.webp"
         alt="Accuracy vs. Throughput trade-off curve"
         title="Accuracy vs. Throughput trade-off curve"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The &lsquo;Money&rsquo; Metric: As we demand higher textual accuracy (Moving up), the percentage of work we can automate (Throughput, x-axis) typically drops. The goal is to push this curve to the top-right.</figcaption>
    
</figure>

<h3 id="the-hidden-axis-cost--time">The &ldquo;Hidden&rdquo; Axis: Cost &amp; Time</h3>
<p>You might ask: <em>&ldquo;Is it worth running a massive GPU model on 100% of the documents just to automate 40% of them?&rdquo;</em></p>
<p>Ideally, we should plot this on a 4D surface: <strong>Accuracy</strong>, <strong>Throughput</strong>, <strong>Cost</strong>, and <strong>Latency</strong>.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Resource</th>
          <th style="text-align: left">Accuracy (Complex Cases)</th>
          <th style="text-align: left">Scalability</th>
          <th style="text-align: left">Cost</th>
          <th style="text-align: left">Latency</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Humans</strong></td>
          <td style="text-align: left">High</td>
          <td style="text-align: left">Low</td>
          <td style="text-align: left">High</td>
          <td style="text-align: left">High</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>XGBoost</strong></td>
          <td style="text-align: left">Low</td>
          <td style="text-align: left">High</td>
          <td style="text-align: left">Low</td>
          <td style="text-align: left">Low</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>LLMs</strong></td>
          <td style="text-align: left">High</td>
          <td style="text-align: left">High</td>
          <td style="text-align: left">Medium</td>
          <td style="text-align: left">Medium</td>
      </tr>
  </tbody>
</table>
<p>The business case holds because even expensive GPUs are orders of magnitude cheaper than the alternative. If a human costs 0.50 per document and an H100 GPU costs 0.005 per document, you can afford to &ldquo;waste&rdquo; compute on the documents the model ultimately rejects, just to capture the savings on the share it automates safely. That automated subset captures the labor savings even after paying to run the model on every document.</p>
<h3 id="the-llm-advantage">The LLM Advantage</h3>
<p>This is where the paradox becomes interesting.</p>
<p>In our experiments on a dataset of <strong>7,500 proprietary insurance streams</strong> (medical records, police reports, and legal contracts), we found that <strong>XGBoost was actually better calibrated.</strong> Statistically, it produced confidence scores that more closely matched empirical probabilities, yielding lower calibration errors (ECE/MCE) than the LLMs.</p>
<p>However, when we hold both models to a strict confidence threshold and measure how much stream volume each can auto-process at comparable accuracy, the picture inverts:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Model</th>
          <th style="text-align: left">Confidence threshold</th>
          <th style="text-align: left">Auto-processed volume (throughput)</th>
          <th style="text-align: left">Accuracy on that volume</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>XGBoost</strong></td>
          <td style="text-align: left">$C &gt; 0.9$</td>
          <td style="text-align: left">35%</td>
          <td style="text-align: left">0.97</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Mistral-7B</strong></td>
          <td style="text-align: left">$C &gt; 0.9$</td>
          <td style="text-align: left">54%</td>
          <td style="text-align: left">0.95</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>XGBoost</strong></td>
          <td style="text-align: left">$C &gt; 0.8$</td>
          <td style="text-align: left">49%</td>
          <td style="text-align: left">0.93</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Mistral-7B</strong></td>
          <td style="text-align: left">$C &gt; 0.8$</td>
          <td style="text-align: left">70%</td>
          <td style="text-align: left">0.93</td>
      </tr>
  </tbody>
</table>
<p><em>Note: Mistral reaches 80% raw STP on the synthetic TABME++ benchmark (see our <a href="/posts/history-of-page-stream-segmentation/">PSS History</a> post); on these proprietary streams, a strict confidence threshold trades some of that volume for safety.</em></p>
<p>How can the &ldquo;worse&rdquo; calibrated model be better for business?</p>
<p>The answer lies in <strong>Discrimination Power</strong>. Calibration only tells you if the confidence score matches reality. Discrimination reflects the model&rsquo;s fundamental ability to separate &ldquo;Right&rdquo; from &ldquo;Wrong.&rdquo;</p>
<p>The LLMs, despite having skewed probability distributions, had vastly superior reasoning capabilities. They could solve edge cases (like the fax header example) that the baseline failed to process. Because their <em>raw capability</em> was higher, they pushed the entire trade-off curve up and to the right.</p>
<h2 id="engineering-reality-efficiency-vs-context">Engineering Reality: Efficiency vs. Context</h2>
<p>Given that LLMs offer superior reasoning capabilities, a natural question arises: if reasoning is the bottleneck, why not simply provide the model with more context?</p>
<p>One critique of our approach is that we treat segmentation as a local problem: looking only at Page $N$ and Page $N+1$ to make a decision. A valid counter-argument is: <em>&ldquo;What if the answer depends on page $N-5$?&rdquo;</em></p>
<p>It&rsquo;s a fair point. In theory, a model with a massive context window (reading the whole stream at once) <em>should</em> do better. It could see that Page 10 is actually an appendix referenced on Page 1.</p>
<p>In practice, however, <strong>global context is a trap for PSS</strong>.</p>
<ol>
<li><strong>Cost</strong>: Attention mechanisms scale quadratically. Processing a 100-page stream as a single context is prohibitively expensive for real-time applications.</li>
<li><strong>Distraction</strong>: We found that adding more history often <em>confused</em> the models. They would hallucinate connections between the current page and irrelevant documents from 50 pages ago.</li>
</ol>
<p>By strictly limiting the model to a &ldquo;Sliding Window&rdquo; of page pairs, we force it to focus on the immediate boundary signal. We rely on &ldquo;Local Precision&rdquo; (which is cheap and sharp) to avoid the pitfalls of &ldquo;Global Reasoning&rdquo; (which is expensive and prone to drift).</p>
<p>There is an intriguing middle ground we have yet to fully explore: <strong>iterative context accumulation</strong>. A model could autoregressively &ldquo;build&rdquo; the document in its context, carrying forward only the pages it has decided belong to the current document. In theory, this stateful approach could capture long-range dependencies (like that &ldquo;Appendix A&rdquo; reference) while avoiding the noise of the full stream.</p>
<p>However, this introduces a new risk: <strong>Bias Amplification</strong>. If the model is trained to view previous context pages as &ldquo;part of the current document,&rdquo; it may learn a strong bias to continuously merge pages. Out of distribution, this could lead to catastrophic failure, where the model gets &ldquo;stuck&rdquo; in a document-building mode and merges hundreds of unrelated pages into a single monolithic file. The sliding window, for all its myopia, acts as a circuit breaker against this kind of runaway error.</p>
<p>Empirically, this simpler approach holds up. In the cases where we saw PSS work best, the rules tended to be simple ones requiring minimal context; they relied on <strong>clear and consistent enumeration</strong> and a decent amount of data to scale the Accuracy-Throughput frontier.</p>
<p><em>Technical aside: This is effectively a Markovian assumption. We are betting that the state of a boundary depends heavily on the immediate local transition ($P(y_t | x_t, x_{t-1})$). We prioritize immunity to &ldquo;distraction&rdquo; from previous docs over long-range coherence (like tracking &ldquo;Page 1 of N&rdquo; counters).</em></p>
<p>To achieve the necessary efficiency for this local approach, we fine-tuned these models with <strong>LoRA (Low-Rank Adaptation)</strong> over a 4-bit-quantized base model (via Unsloth, LoRA weights in BF16) on a single NVIDIA H100.</p>
<ul>
<li><strong>Rank ($r$)</strong>: 16</li>
<li><strong>Alpha ($\alpha$)</strong>: 16</li>
<li><strong>Precision</strong>: 4-bit quantization</li>
</ul>
<p>This efficient, local approach makes the &ldquo;heavy&rdquo; LLM solution surprisingly deployable.</p>
<h2 id="the-paradox-of-the-simple-task">The Paradox of the &ldquo;Simple&rdquo; Task</h2>
<p>There is a tension here. We call PSS the &ldquo;Hello World&rdquo; of document processing. It feels like it should be trivial: just sorting papers. Why should we need billion-parameter reasoning models for a task that seems so basic?</p>
<p>The answer lies in the distinction between <strong>Perception</strong> and <strong>Logic</strong>.</p>
<ul>
<li><strong>90% of PSS is Perception (System 1)</strong>: Recognizing a bold header, a logo change, or a &ldquo;Page 1 of 5&rdquo; footer. This is reactive and fast. XGBoost or a simple CNN handles this easily.</li>
<li><strong>The last 10% is Reasoning (System 2)</strong>: Determining if an unlabelled &ldquo;Addendum B&rdquo; belongs to the previous Master Service Agreement or starts a new policy packet. Reconciling this conflict requires semantic understanding.</li>
</ul>
<p>A perfect example from our dataset is <strong>Fax Headers</strong>. A document might have a clear &ldquo;Page 1&rdquo; printed on it, but the fax machine stamps &ldquo;Page 005&rdquo; on top of the header because it&rsquo;s the 5th page of the transmission. XGBoost sees &ldquo;Page 005&rdquo;, fails to reconcile the conflict, and incorrectly continues the document. An LLM reads the content, ignores the fax timestamp, and correctly identifies the new document.</p>
<p>The &ldquo;Reliability Trap&rdquo; snaps shut because we treat the entire problem as a System 1 perception task. We ask the model to predict the boundary instantly. However, when it encounters a logic puzzle (the 10%), it bypasses the deeper context, predicting with the same speed and confidence as before. This is why we see <strong>Clustered Difficulty</strong>. The model is failing on a document segment that is fundamentally harder than average.</p>
<h2 id="escaping-the-trap-from-guessing-to-verifying">Escaping the Trap: From Guessing to Verifying?</h2>
<p>If the problem is that models are &ldquo;Fast Processors&rdquo; prone to high-confidence errors in complex scenarios, a potential path forward may lie in <a href="https://arxiv.org/abs/2408.03314"><strong>Test-Time Compute</strong></a>.</p>
<p>The future of reliable automation lies in &ldquo;Building a better Checker.&rdquo; In high-stakes PSS, this could mean looking toward a <strong>Guesser-Verifier</strong> architecture, a technique becoming common in advanced reasoning tasks (like mathematical problem solving, <a href="https://arxiv.org/abs/2110.14168"><em>Cobbe et al., 2021</em></a>).</p>
<p>The core insight reflects a fundamental asymmetry in computer science (analogous to <strong>P vs NP</strong>): <strong>Verification is often easier than Generation.</strong> Just as it is easier to check if a Sudoku puzzle is solved than to solve it from scratch, it is significantly simpler to &ldquo;audit&rdquo; a complete document structure than to autoregressively predict it perfectly token-by-token.</p>
<ol>
<li><strong>The Generator (System 1)</strong>: A lightweight model (like <strong>Mistral-7B</strong> or <strong>Phi-3.5</strong>) proposes a segmentation. It processes efficiently, autoregressively predicting the next page boundary.</li>
<li><strong>The Verifier (System 2)</strong>: This would be a discriminative model (often a Reward Model or the same LLM with a specialized prompt). The system evaluates the <em>complete</em> proposed document bundle and scores its coherence. It evaluates: <em>&ldquo;Is this 5-page sequence actually coherent?&rdquo;</em></li>
</ol>
<p>A logical exploration would be a <strong>Best-of-N</strong> approach. Relying on the generator&rsquo;s first prediction is risky when it is uncertain. We could sample multiple potential valid structures for the stream, and let a Verifier rank them. This might help break the &ldquo;autoregressive myopia&rdquo; where a model commits to an early mistake. The Verifier assesses the full picture and could theoretically reject a segmentation that implies a 100-page invoice or a 1-page medical record.</p>
<p>This approach offers a chance to break the mathematical tyranny of $0.99^{100}$. The system can selectively apply reasoning power to &ldquo;audit&rdquo; the stream before an error propagates downstream, treating the document as a cohesive unit.</p>
<h2 id="conclusion-better-systems-over-better-models">Conclusion: Better Systems Over Better Models</h2>
<p>We have largely solved the <strong>Capability</strong> problem for PSS: we have models that <em>can</em> read almost anything. Now, we face the <strong>Reliability</strong> barrier.</p>
<p>Our results paint a complex picture. Fine-tuned LLMs auto-process a substantially larger share of streams at equal accuracy than XGBoost (54% vs 35% at $C &gt; 0.9$ in this study). Simultaneously, the &ldquo;Reliability Trap&rdquo; remains a critical challenge. Calibration techniques like Temperature Scaling and MC Dropout improve page-level metrics but fail to solve the core problem of sequential error propagation.</p>
<p>For practitioners building with LLMs in high-stakes domains (finance, law, medicine), the path forward requires a shift in both architecture and mindset:</p>
<ol>
<li><strong>Prioritize Throughput</strong>: What share of your volume can you automate at the reliability your domain demands? That is the KPI that matters.</li>
<li><strong>Accept the &ldquo;Logic&rdquo; Cost</strong>: Acknowledge that &ldquo;Hello World&rdquo; tasks often contain edge cases requiring genuine reasoning and semantic understanding.</li>
<li><strong>Explore Verifiers</strong>: It&rsquo;s possible that the next leap in performance will come from systems designed to validate outputs and audit complete structures.</li>
<li><strong>Human in the Loop</strong>: The model should act as a filter. It must reliably process the easy cases and flag the complex ones for human review <em>before</em> they corrupt the downstream database.</li>
</ol>
<p>Accuracy tells you what the model predicts. Calibration tells you if the model&rsquo;s confidence matches its correctness. In the real world, the latter is often worth more.</p>
<p><em>Read the full paper on <a href="https://aclanthology.org/2025.coling-industry.26/">ACL Anthology</a>, view the <a href="/coling-2025-pss-poster.pdf">conference poster</a>, or visit the <a href="/research/page-stream-segmentation-llms/">research page</a>. This paper builds on the <a href="/research/llm-page-stream-segmentation/">TabMe++ benchmark and decoder-based LLM approach</a> introduced in our earlier arXiv work. For related work on the OCR front-ends that feed these pipelines, see <a href="/research/gutenocr-grounded-vision-language-frontend/">GutenOCR</a>.</em></p>
]]></content:encoded></item><item><title>The Evolution of Page Stream Segmentation: Rules to LLMs</title><link>https://hunterheidenreich.com/posts/history-of-page-stream-segmentation/</link><pubDate>Sat, 14 Feb 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/history-of-page-stream-segmentation/</guid><description>An exploration of Page Stream Segmentation (PSS) evolution and how context-driven sequence modeling addresses limitations in document processing.</description><content:encoded><![CDATA[<p>In the world of automated document processing, Page Stream Segmentation (PSS)<sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup> is the &ldquo;hello world&rdquo; problem that remains surprisingly stubborn.</p>
<p>The task is deceptively simple: given a stack of scanned pages (invoices, contracts, medical records), determine where one document ends and the next begins.</p>
<p>For decades, this problem was tackled with brittle rules and heuristics. Then came the deep learning era, where we threw Convolutional Neural Networks (CNNs) and multipage Transformers at it. Yet, even sophisticated models struggled to achieve what businesses actually care about: Straight-Through Processing (STP).</p>
<blockquote>
<p><strong>Why STP Matters</strong>: Everyone interacting with the system cares about STP. Arguably, any human that has to interact with the output of a system, deal with its mistakes, and perform corrections to get a job done, cares about it. If the system fails 90% of the time, it fails to automate and creates more work at the expense of real people.</p></blockquote>
<p>In this post, we explore the three eras of PSS, the limitations of page-level accuracy metrics, and how context-driven sequence modeling addresses these challenges.</p>
<h2 id="the-hidden-complexity-of-pss">The Hidden Complexity of PSS</h2>
<p>Why is PSS hard? It comes down to ambiguity and asymmetry.</p>















<figure class="post-figure center ">
    <img src="/img/page-stream-segmentation/page-stream-segmentation-automation-difficulties.webp"
         alt="Robot looking confused at a messy stack of documents"
         title="Robot looking confused at a messy stack of documents"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Automation in PSS is rarely straightforward&hellip; Messy inputs and ambiguous boundaries often lead to confusion.</figcaption>
    
</figure>

<h3 id="the-document-definition-problem">The &ldquo;Document&rdquo; Definition Problem</h3>
<p>First, the concept of a &ldquo;document&rdquo; is highly context-dependent.</p>
<blockquote>
<p><strong>The &ldquo;Word&rdquo; Analogy</strong>: What is a word? A sequence of characters separated by spaces? Or a meaningful unit of language that can be a single character (e.g., &ldquo;I&rdquo;) or a compound (e.g., &ldquo;New York&rdquo;)? <a href="https://arxiv.org/abs/1710.07729">Is space a word, too?</a> This ambiguity problem permeates all levels of language processing. We&rsquo;d be naive to think PSS is an exception to this rule!</p></blockquote>
<p>Consider an email with an attachment. Is the email body one document and the attachment another? Or is the whole packet one document? What about an invoice stapled to a check? A policy packet with multiple addendums?</p>
<p>&ldquo;Solvability&rdquo; implies a single ground truth, but in reality, PSS often requires aligning the model with specific, often subjective, business logic. A boundary to an underwriter might be a continuation to an archivist.</p>
<p>This subjectivity is a nightmare for rule-based systems. To solve it, we need models that go beyond pattern matching to reason about context and semantics. This is precisely where the self-attention mechanisms of Transformers excel.</p>
<h3 id="the-cost-of-error">The Cost of Error</h3>
<p>Second, the cost of failure is asymmetric.</p>
<p>If you are classifying an email as &ldquo;Spam&rdquo; or &ldquo;Not Spam,&rdquo; a single error affects one email. But PSS is a sequence problem. A single missed page break merges two distinct documents into one. This effectively &ldquo;corrupts&rdquo; two documents for the price of one error. Conversely, a false break splits a valid document in half.</p>















<figure class="post-figure center ">
    <img src="/img/page-stream-segmentation/classification-vs-sequence-segmentation.webp"
         alt="Diagram comparing classification errors vs sequence segmentation errors"
         title="Diagram comparing classification errors vs sequence segmentation errors"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The blast radius of error: Unlike simple classification where one error affects one item, a single segmentation error corrupts the integrity of multiple documents.</figcaption>
    
</figure>

<p>Even more dismal, if our focus is truly STP, then the only acceptable outcome is perfect segmentation of an entire document stream. Sometimes faxes can be hundreds of pages long.</p>
<p>If we have a 99% page-level accuracy ($p=0.99$), the probability of correctly segmenting a 100-page stream ($N=100$) is only:</p>
<p>$$ P(\text{Success}) = p^N = 0.99^{100} \approx 0.37 $$</p>















<figure class="post-figure center ">
    <img src="/img/page-stream-segmentation/asymmetric-cost-of-error-in-document-streams.webp"
         alt="Chart showing exponential decay of straight-through processing probability as document length increases"
         title="Chart showing exponential decay of straight-through processing probability as document length increases"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Even high page-level accuracy (99%) results in low stream-level success rates for long documents due to the multiplicative nature of error probabilities.</figcaption>
    
</figure>

<p>In other words, even with &ldquo;high&rdquo; accuracy, the vast majority of document streams will require human intervention. This phenomenon is what we call <strong>The Reliability Trap</strong> (explored in depth in <a href="/posts/reliability-trap-document-automation/">our companion post</a>).</p>
<h3 id="the-f1-score-trap">The &ldquo;F1 Score&rdquo; Trap</h3>
<p>A major finding in our research (<a href="/research/llm-page-stream-segmentation/">TabMe++</a>) is that traditional metrics mask the operational reality. Although generalized text segmentation metrics like <a href="https://aclanthology.org/W97-0304/">$P_k$</a> and <a href="https://aclanthology.org/J02-1002/">WindowDiff</a> exist, we found they don&rsquo;t capture the document-centric nature of business workflows.</p>
<p>Instead, we evaluate at three levels:</p>
<ol>
<li><strong>Page-Level</strong>: Did we correctly classify this single page transition?</li>
<li><strong>Document-Level</strong>: Did we correctly identify the entire document tuple $d_k = (p_i, \ldots, p_j)$?</li>
<li><strong>Stream-Level</strong>: Did we perfectly segment the entire stack of documents?</li>
</ol>
<p>Our results showed that Page-Level F1 Score completely masks the downstream impact.</p>
<p>Consider a baseline XGBoost model we tested:</p>
<ul>
<li><strong>Page F1 Score</strong>: 0.83 (Sounds decent, right?)</li>
<li><strong>STP</strong>: 0.07 (Abysmal)</li>
<li><strong>MNDD</strong>: 10.85</li>
</ul>
<p><strong>That means 93% of document streams required human intervention.</strong> Even worse, the MNDD (Minimum Number of Drag-and-Drops)<sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup> score tells us that for each stream, a human had to manually drag ~11 pages to fix the ordering.</p>
<p>This metric is crucial because it proxies the actual <em>pain</em> of the human in the loop. An error signifies more than a theoretical label flip; it forces a manual drag-and-drop operation.</p>
<h2 id="era-1-the-heuristic-era-2000s---2015">Era 1: The Heuristic Era (2000s - 2015)</h2>
<p>In the beginning, PSS was a game of <code>if/else</code> statements. Engineers hand-crafted heuristics tailored to specific document layouts, checking for signals like:</p>
<ul>
<li><em>Does the page contain &ldquo;Page 1 of X&rdquo;?</em></li>
<li><em>Is there a &ldquo;Total&rdquo; line at the bottom?</em></li>
<li><em>Does the header text change drastically?</em></li>
</ul>
<p>While effective for known templates, these systems were inherently brittle. They relied on rigid assumptions about the input structure. If a vendor changed their invoice layout or OCR quality dipped, the logic would fail. They worked perfectly for what they were designed for but had zero capability to generalize to the unknown. Unfortunately, the real world is a constant state of exception.</p>
<h2 id="era-2-the-encoder-era-2015---2023">Era 2: The Encoder Era (2015 - 2023)</h2>
<p>As deep learning matured, researchers moved from hard-coded rules to learned representations.</p>
<ul>
<li><strong>Visual Approaches</strong>: Using CNNs to look at the &ldquo;shape&rdquo; of a page as an image. First pages often look different from continuation pages (logos, big headers).</li>
<li><strong>Word Vectors</strong>: Early NLP attempts used tools like <a href="https://arxiv.org/abs/1405.4053">doc2vec</a> to represent page content, but these &ldquo;averaged&rdquo; the text, losing sequential meaning.</li>
<li><strong>Multimodal Transformers</strong>: Eventually, models like <a href="https://arxiv.org/abs/1912.13318">LayoutLM</a> and <a href="https://arxiv.org/abs/2010.02559">LEGAL-BERT</a> tried to combine text and layout into a single understanding.</li>
</ul>
<p>While these models were &ldquo;smarter&rdquo; than rules, they suffered from distinct limitations:</p>
<ol>
<li>
<p><strong>Field Lag</strong>: Surprisingly, only a handful of studies applied Transformers to PSS before 2024. Most of the industry was still stuck on older CNN architectures.</p>
</li>
<li>
<p><strong>Context Windows</strong>: Encoder models like <a href="https://arxiv.org/abs/1810.04805">BERT</a> are limited to 512 tokens. A dense legal contract page might have 1,000+ tokens. You had to chop the text, losing critical context.</p>
</li>
<li>
<p><strong>Modality Overload</strong>: Counterintuitively, our experiments showed that naively adding modalities (Text + Layout + Vision) often yielded diminishing returns. Models like <a href="https://arxiv.org/abs/2204.08387">LayoutLMv3</a> struggled to outperform simpler vision-only or text-only models on our benchmark.</p>
<blockquote>
<p>However, looking continuously at the data reveals an interesting nuance: <strong>Visual signals matter.</strong> In our tests, the vision-only model (DiT) actually outperformed the text-only model (RoBERTa). The vision-only model (DiT) tends to be more precise, while the text-only model (RoBERTa) reaches higher recall. The multimodal models failed due to the difficulty of <em>aligning</em> modalities. Vision remains a highly useful signal. This insight led us to a key realization for Era 3: What if we could give the model visual information without the architectural headache of a vision encoder?</p></blockquote>
</li>
</ol>
<h2 id="era-3-the-decoder-era-2024---present">Era 3: The Decoder Era (2024 - Present)</h2>
<p>The breakthrough came with applying Decoder-only Large Language Models (LLMs) like <a href="https://arxiv.org/abs/2310.06825">Mistral-7B</a> and <a href="https://arxiv.org/abs/2404.14219">Phi-3</a> to the task.</p>
<p>Why do LLMs succeed where specialized encoders failed? <strong>Contextual Processing</strong>.</p>
<p>Determining if a page is a continuation often requires analyzing sequential dependencies.</p>
<ul>
<li><em>Does the sentence cut off mid-thought?</em></li>
<li><em>Does the next page logically follow the argument of the previous one?</em></li>
<li><em>Is the &ldquo;Policy Number&rdquo; on Page 2 the same as Page 1?</em></li>
</ul>
<p>LLMs are pre-trained on the internet; they model narrative flow and document structure effectively. By fine-tuning them on pairs of pages, we adapted these priors to recognize specific segmentation boundaries.</p>
<h3 id="2d-projection--data-quality">2D Projection &amp; Data Quality</h3>
<p>We employed <strong>2D Text Projection</strong>, a technique that serializes OCR output by mapping spatial coordinates to whitespace. This effectively &ldquo;draws&rdquo; the layout using text characters, allowing the LLM to process columns, headers, and form structures. We translated the visual signal (layout) into the text modality to address the &ldquo;Modality Overload&rdquo; problem.</p>
<p>To be clear, this is a lossy compression. We discard font sizes, bolding, colors, and line separators. It is merely a cheap, zeroth-order approximation of 2D layout using 1D text. Yet, as our results show, this approximation captures the <em>semantic</em> essence of the layout (e.g., &ldquo;this text is in a header column&rdquo;) sufficient for the model to reason about document boundaries.</p>
<p>However, this technique has a hard dependency: <strong>Data Quality</strong>. 2D projection is useless if your OCR gives you garbage coordinates. This is where our work on <strong>TabMe++</strong> (discussed below) became critical. You can&rsquo;t project a layout if the OCR misses the text or places it in the wrong spot.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-text" data-lang="text"><span style="display:flex;"><span># Original Raw Text (Loss of Layout)
</span></span><span style="display:flex;"><span>INVOICE # 1024 DATE: 2024-02-14 TOTAL: $500.00
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span># 2D Projected Text (Layout Preserved)
</span></span><span style="display:flex;"><span>                    INVOICE # 1024
</span></span><span style="display:flex;"><span>                    DATE: 2024-02-14
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>TOTAL:                                      $500.00
</span></span></code></pre></div><blockquote>
<p><strong>Why Does This Work?</strong> Modern LLMs are trained on a mixture of web text, code, and even some structured data. They have learned to interpret whitespace and formatting cues as part of their understanding of language. By encoding layout information into the text itself, we leverage the LLM&rsquo;s existing capabilities without needing to train a separate vision encoder.</p></blockquote>
<p>We then wrapped this input in a structured prompt that explicitly framed the task for the LLM:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-text" data-lang="text"><span style="display:flex;"><span>You are a skilled document reviewer. Given extracted text from pages of documents, your task is to determine if a page starts a new document or continues from the previous one.
</span></span><span style="display:flex;"><span>...
</span></span><span style="display:flex;"><span>Prior text:
</span></span><span style="display:flex;"><span>###
</span></span><span style="display:flex;"><span>{pg_prev}
</span></span><span style="display:flex;"><span>###
</span></span><span style="display:flex;"><span>Page text:
</span></span><span style="display:flex;"><span>###
</span></span><span style="display:flex;"><span>{pg}
</span></span><span style="display:flex;"><span>###
</span></span><span style="display:flex;"><span>Output your prediction as a JSON object...
</span></span></code></pre></div><h3 id="the-results">The Results</h3>
<p>We formulated the task as a binary classification problem on page pairs. We fed the model <code>(Page N, Page N+1)</code> and asked: <em>&ldquo;Does Page N+1 start a new document?&rdquo;</em></p>
<p>Comparison on <strong>TabMe++</strong> Benchmark:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Model Type</th>
          <th style="text-align: left">Model Name</th>
          <th style="text-align: left">Page F1</th>
          <th style="text-align: left"><strong>STP</strong> (Higher is better)</th>
          <th style="text-align: left"><strong>MNDD</strong> (Lower is better)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Baseline</strong></td>
          <td style="text-align: left"><a href="https://arxiv.org/abs/1603.02754">XGBoost</a></td>
          <td style="text-align: left">0.83</td>
          <td style="text-align: left"><strong>7.4%</strong></td>
          <td style="text-align: left">10.85</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Encoder</strong></td>
          <td style="text-align: left"><a href="https://arxiv.org/abs/1907.11692">RoBERTa</a> (Text)</td>
          <td style="text-align: left">0.78</td>
          <td style="text-align: left"><strong>4.2%</strong></td>
          <td style="text-align: left">12.17</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Encoder</strong></td>
          <td style="text-align: left"><a href="https://arxiv.org/abs/2203.02378">DiT</a> (Vision)</td>
          <td style="text-align: left">0.83</td>
          <td style="text-align: left"><strong>6.6%</strong></td>
          <td style="text-align: left">10.48</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Decoder</strong></td>
          <td style="text-align: left"><strong><a href="https://arxiv.org/abs/2310.06825">Mistral-7B</a> (Fine-Tuned)</strong></td>
          <td style="text-align: left"><strong>0.99</strong></td>
          <td style="text-align: left"><strong>80.0%</strong></td>
          <td style="text-align: left"><strong>0.81</strong></td>
      </tr>
  </tbody>
</table>
<p>The difference is stark. Moving from Encoders to Decoders increased the automation rate from ~7% to <strong>80%</strong> and reduced the human effort (MNDD) by a factor of 10. <em>Note: This 80% represents the model&rsquo;s raw accuracy. As we discuss in <a href="/posts/reliability-trap-document-automation/">The Reliability Trap</a>, achieving &ldquo;production-safe&rdquo; automation often requires setting strict confidence thresholds, which effectively lowers the safe throughput.</em></p>















<figure class="post-figure center ">
    <img src="/img/page-stream-segmentation/llm-sample-efficiency-convergence-plot.webp"
         alt="Sample efficiency plot showing rapid convergence in under 1000 updates"
         title="Sample efficiency plot showing rapid convergence in under 1000 updates"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">LLMs adapt fast. Our decoder models converged in fewer than 1,000 updates, suggesting strong priors for document structure.</figcaption>
    
</figure>

<h3 id="why-fine-tuning-matters-the-gpt-4o-comparison">Why Fine-Tuning Matters: The GPT-4o Comparison</h3>
<p>You might look at the chart above and ask: &ldquo;Is the model learning PSS, or does it just rely on pre-trained language statistics?&rdquo;</p>
<p>To test this, we ran <strong><a href="https://openai.com/index/hello-gpt-4o/">GPT-4o</a> in a zero-shot setting</strong> on the same task. The result was an STP of roughly 9%.</p>
<p>Zero-shot GPT-4o performed similarly to our XGBoost baseline. This demonstrates that broad pre-training requires specific instruction tuning to capture business logic. Our 7B model achieved 80% STP after fewer than <strong>1,000 updates</strong>.</p>
<p>This proves two things:</p>
<ol>
<li><strong>Broad Pre-training Requires Tuning.</strong> Modeling generic document distributions must be adapted to capture specific business logic for segmentation.</li>
<li><strong>The Capabilities are Latent.</strong> The rapid convergence implies the model possesses the necessary statistical priors and requires fine-tuning to align those priors with the specific task. We are adjusting the decision boundary between a generic &ldquo;document&rdquo; and a specific business record.</li>
</ol>
<h3 id="the-cost-of-intelligence-and-the-value-of-human-time">The Cost of Intelligence and the Value of Human Time</h3>
<p>Critically, we must address the two elephants in the room: <strong>Inference Cost</strong> and <strong>Data Privacy</strong>.</p>
<p>It is true that running a 7B parameter LLM for every page pair is computationally more expensive than a lightweight XGBoost model. However, focusing solely on compute costs misses the operational and human reality of this work.</p>
<p>Economically, the &ldquo;cheap&rdquo; model is a mirage. When a low-accuracy model forces a human to reorganize 93% of document streams, the cost of rectification, specifically wasted salaries and slowed turnaround times, dwarfs the cost of GPU inference. But the financial argument is secondary to the human one.</p>
<p>Manually segmenting documents is, frankly, soul-sucking. It is tedious, repetitive drudgery that few people enjoy. Beyond operational expense, we are discussing human burnout. A model that achieves 80% full automation (STP) saves money while liberating people from the mind-numbing task of sorting pages. This allows them to focus on work that actually requires their creativity and empathy. We are trading cheap FLOPs for valuable human attention.</p>
<p>Furthermore, democratizing this capability has profound implications beyond the enterprise. If we can make high-quality segmentation usable on modest hardware (like a high-end laptop or a single commodity GPU), we open the door for archivists, librarians, digital humanists, and small cities or towns that have little to no resources for this kind of work. These are the custodians of our collective intelligence, often working with massive, unorganized scanned collections but lacking the budget for massive cloud clusters.</p>
<p>Our results showed that <strong>7B parameter models</strong> (like Mistral) are sufficient to solve this task. This size is the sweet spot: capable enough to reason over document structure, but small enough to run locally. This matters for data sovereignty (keeping medical records private) and accessibility. It means a small historical society could potentially automate the organization of a century’s worth of digitized records without a massive grant for cloud compute.</p>
<p>That said, a 7B model might not be the lower bound. While it was the breakthrough size for our study, the recent explosion of capable 1B-3B models suggests we haven&rsquo;t hit the efficiency floor yet. Combined with extreme quantization, modern small language models (SLMs) likely offer the &ldquo;Goldilocks&rdquo; zone: enough reasoning to maintain high STP, but fast enough to run continuously on modest hardware. We suspect the future of PSS lies in these highly optimized, smaller reasoning models that can run anywhere&hellip; from a bank&rsquo;s secure server to a researcher&rsquo;s laptop.</p>
<h2 id="the-importance-of-data-quality">The Importance of Data Quality</h2>
<p>Data quality presented an equal challenge to algorithmic limitations.
Most public datasets (like Tobacco800) were small or unrealistic. The TABME dataset (precursor to our work) relied on open-source Tesseract OCR, which missed vast amounts of text.</p>
<p>We released <a href="https://huggingface.co/datasets/rootsautomation/TABMEpp"><strong>TabMe++</strong></a>, which re-processed the entire dataset with commercial-grade Microsoft OCR.</p>
<ul>
<li><strong>Blank Pages</strong>: Reduced from 2.27% $\rightarrow$ 0.38%.</li>
<li><strong>Token Count</strong>: Increased from 719M $\rightarrow$ 9.5B.</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/page-stream-segmentation/noisy-sales-forecast-document-scan.webp"
         alt="Scanned document page showing a sales forecast with some noise"
         title="Scanned document page showing a sales forecast with some noise"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Original Page: A noisy scan that Tesseract struggles to read.</figcaption>
    
</figure>

<p>The difference in intelligibility is night and day. Consider the page above.</p>
<p><strong>Tesseract (Original)</strong>:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-text" data-lang="text"><span style="display:flex;"><span>02Z10102
</span></span></code></pre></div><p><em>(Misses almost everything, including the title and real ID)</em></p>
<p><strong>Microsoft OCR (TabMe++)</strong>:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-text" data-lang="text"><span style="display:flex;"><span>            SALES FORECAST
</span></span><span style="display:flex;"><span>                            201017205
</span></span></code></pre></div><p><em>(Correctly captures the spatial layout, the title, and the ID)</em></p>
<p><strong>Lesson</strong>: You can&rsquo;t segment what you can&rsquo;t read. High-quality OCR (or <a href="/research/gutenocr-grounded-vision-language-frontend/">multimodal front-ends</a> like GutenOCR, trained on large-scale annotation corpora like <a href="/research/pubmed-ocr-pmc-open-access-ocr-annotations/">PubMed-OCR</a>) is the foundation of high-quality downstream NLP.</p>
<h2 id="the-next-frontier-context-and-instruction-following-2026">The Next Frontier: Context and Instruction Following (2026+)</h2>
<p>As we discussed earlier, the definition of a &ldquo;document&rdquo; is subjective. To one team, an email + attachment is a single record. To another, they are distinct entities. A rigid model that segments perfectly for Team A will fail miserably for Team B.</p>
<p>The zero-shot GPT-4o results demonstrate that scale requires adaptation. The future of PSS depends on <strong>instruction tuning</strong>. We need models that can accept natural language rules alongside the document stream:</p>
<blockquote>
<p><em>&ldquo;Split all invoices, but keep attachments with their parent emails. If you see an ACORD form, group it with the subsequent policy document.&rdquo;</em></p></blockquote>
<p>This shift mirrors the broader evolution of LLMs. PSS models must evolve into dynamic systems capable of instruction following. A single model should be able to adapt to any business logic without retraining.</p>
<p>Furthermore, while our 2024 research favored unimodal text models with 2D projection, the multimodal landscape is shifting. With the rise of natively multimodal models (like <a href="https://deepmind.google/models/gemini/">Gemini</a>, <a href="https://openai.com/index/hello-gpt-4o/">GPT-4o</a>, and our own <a href="/research/gutenocr-grounded-vision-language-frontend/">GutenOCR</a>), we effectively get the &ldquo;2D projection&rdquo; natively. Future models should be able to fuse this native visual understanding with semantic reasoning, guided by user-defined constraints.</p>
<h2 id="conclusion">Conclusion</h2>
<p>Page Stream Segmentation is a perfect case study in the evolution of AI. We moved from <strong>encoding rules</strong> (Heuristic Era) to <strong>encoding features</strong> (Encoder Era) to <strong>encoding understanding</strong> (Decoder Era).</p>
<p>For enterprise professionals, the takeaways are clearer and more critical than ever.</p>
<p>First, <strong>stop looking at element-wise F1 scores for sequence tasks.</strong> While element-wise metrics are useful for engineers debugging algorithms, they are misleading for decision-makers. Focus on the metrics that actually affect people and workflows, like Straight-Through Processing (STP) and Minimum Number of Drag-and-Drops (MNDD).</p>
<p>Second, if you want to solve PSS today, start with an inward-looking conversation about &ldquo;for what.&rdquo; Before picking a model, answer these questions:</p>
<ul>
<li><strong>Inputs</strong>: What assumptions are you making about your document stream?</li>
<li><strong>Outcomes</strong>: What specific business outcomes are you hoping to see?</li>
<li><strong>Context</strong>: What is the core motivation for this workflow?</li>
<li><strong>Nuance</strong>: Are there informative scenarios (like the &ldquo;email attachment&rdquo; problem) that illustrate your specific needs?</li>
</ul>
<p>Given these answers, many modern approaches can solve PSS for your case. Whether you need an on-premise solution for secure scenarios using lightweight open-weights models, or can leverage powerful AI-as-a-Service APIs, the technology is no longer the bottleneck; understanding your own requirements is.</p>
<p><em>For full technical details, experimental setups, and datasets, refer to our paper: <a href="/research/llm-page-stream-segmentation/">Large Language Models for Page Stream Segmentation</a> or view the preprint on <a href="https://arxiv.org/abs/2408.11981">arXiv:2408.11981</a>. These findings were later extended to real-world insurance document processing in <a href="/research/page-stream-segmentation-llms/">LLMs for Insurance Document Automation</a>. Much of the initial work was also documented in a precursor blog series at Roots Automation (<a href="https://www.roots.ai/blog/segmenting-documents-with-llms-and-multimodal-document-ai-part-1">Part 1</a> &amp; <a href="https://www.roots.ai/blog/segmenting-documents-with-llms-and-multimodal-document-ai-part-2">Part 2</a>).</em></p>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p>Historically, this task has gone by many names: <em>document separation</em>, <em>document flow segmentation</em>, <em>document stream segmentation</em>, <em>document bundle separation</em>, and <em>page stream separation</em>. We stick to <strong>Page Stream Segmentation (PSS)</strong> to emphasize the sequential nature of the problem.&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:2">
<p>We adopted the MNDD metric from <a href="https://dl.acm.org/doi/10.1145/3558100.3563852">Mungmeeprued et al. (2022)</a>, who introduced it alongside the original TABME dataset to better quantify the human effort required to correct segmentation errors.&#160;<a href="#fnref:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded></item><item><title>GP-MoLFormer: Molecular Generation via Transformers</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/gp-molformer/</link><pubDate>Thu, 25 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/gp-molformer/</guid><description>A 46.8M parameter transformer for molecular generation trained on 1.1B SMILES, introducing pair-tuning for efficient property optimization.</description><content:encoded><![CDATA[<h2 id="contribution-and-taxonomic-focus">Contribution and Taxonomic Focus</h2>
<p>This is primarily a <strong>Methodological</strong> paper, as it proposes a specific neural architecture (GP-MoLFormer) and a novel fine-tuning algorithm (Pair-tuning) for molecular generation. It validates these contributions against standard baselines (e.g., JT-VAE, <a href="/notes/chemistry/molecular-design/generation/autoregressive/molgen-molecular-generation-chemical-feedback/">MolGen</a>-7b).</p>
<p>It also contains a secondary <strong>Theoretical</strong> contribution by establishing an empirical <a href="/notes/machine-learning/model-architectures/scaling-laws-vs-model-architectures/">scaling law</a> that relates inference compute (generation size) to the novelty of the generated molecules.</p>
<h2 id="motivation-data-scale-and-prompt-based-optimization">Motivation: Data Scale and Prompt-Based Optimization</h2>
<p>While large language models (LLMs) have transformed text generation, the impact of training data scale and memorization on <em>molecular</em> generative models remains under-explored. Specifically, there is a need to understand how training on billion-scale datasets affects the novelty of generated molecules and whether biases in public databases (like ZINC and PubChem) perpetuate memorization. Furthermore, existing optimization methods often require computationally expensive property predictors or reinforcement learning loops; there is a practical need for more efficient &ldquo;prompt-based&rdquo; optimization techniques.</p>
<h2 id="core-innovations-architecture-and-pair-tuning">Core Innovations: Architecture and Pair-Tuning</h2>
<ol>
<li><strong>Architecture</strong>: The application of a linear-attention transformer decoder with Rotary Positional Embeddings (RoPE) to generative chemistry, allowing for efficient training on 1.1 billion SMILES.</li>
<li><strong>Pair-Tuning</strong>: A novel, parameter-efficient fine-tuning method that uses property-ordered molecular pairs to learn &ldquo;soft prompts&rdquo; for optimization without updating the base model weights.</li>
<li><strong>Scaling Analysis</strong>: An extensive empirical investigation mapping the trade-off between inference compute (up to 10B generations) and chemical novelty, fitting an exponential decay curve that demonstrates how novelty saturates as generation volume grows.</li>
</ol>
<h2 id="experimental-methodology-and-downstream-tasks">Experimental Methodology and Downstream Tasks</h2>
<p>The authors evaluated GP-MoLFormer on three distinct tasks, though the comparisons highlight the difficulty of evaluating foundation models against classical baselines:</p>
<ol>
<li><strong>De Novo Generation</strong>: Comparing validity, uniqueness, and novelty against baselines (CharRNN, VAE, <a href="/notes/chemistry/molecular-design/generation/latent-space/limo-latent-inceptionism/">LIMO</a>, MolGen-7b) on a held-out test set. Notably, this is an unequal comparison; most baselines were trained on the 1.6M molecule <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a> dataset, whereas GP-MoLFormer uses up to 1.1B molecules, meaning performance gains are heavily driven by data scale.</li>
<li><strong>Scaffold-Constrained Decoration</strong>: Generating molecules from DRD2 active binder scaffolds and measuring the hit rate of active compounds against specialized scaffold decorators.</li>
<li><strong>Property-Guided Optimization</strong>: Using Pair-tuning to optimize for Drug-likeness (QED), Penalized <a href="https://en.wikipedia.org/wiki/Octanol-water_partition_coefficient">logP</a>, and <a href="https://en.wikipedia.org/wiki/Dopamine_receptor_D2">DRD2</a> binding activity, comparing the results to graph-based and reinforcement learning benchmarks.</li>
</ol>
<p>Additionally, they performed a <strong>Scaling Study</strong>:</p>
<ul>
<li>Comparing models trained on raw (1.1B) vs. de-duplicated (650M) data.</li>
<li>Generating up to 10 billion molecules to fit empirical scaling laws for novelty.</li>
</ul>
<h2 id="key-findings-and-scaling-laws">Key Findings and Scaling Laws</h2>
<ul>
<li><strong>Scale Driven Performance</strong>: GP-MoLFormer achieves high internal diversity and validity on generation metrics. However, its baseline novelty percentage (~32%) is considerably lower than classical models. The authors attribute this to the massive training scale forcing the model to heavily prioritize matching real-world molecule frequencies over pure exploration. GP-MoLFormer&rsquo;s advantage in generation metrics over LLM-baselines like <a href="/notes/chemistry/molecular-design/generation/autoregressive/molgen-molecular-generation-chemical-feedback/">MolGen</a>-7b likely stems heavily from its 10x larger training dataset rather than fundamental architectural superiority.</li>
<li><strong>Pair-Tuning Efficacy</strong>: The proposed pair-tuning method effectively optimizes properties (e.g., improving DRD2 activity scores) without requiring full model fine-tuning or external reward loops. While successful, the text-based generation yields ~94.5% validity during optimization, which lags behind graph and SELFIES-based baselines that guarantee 100% structural validity.</li>
<li><strong>Memorization vs. Novelty</strong>: Training on de-duplicated data (GP-MoLFormer-UNIQ) yields higher novelty (approx. 5-8% higher) than training on raw data, confirming that duplication bias in public databases leads directly to memorization.</li>
<li><strong>Inference Scaling Law</strong>: Novelty decays exponentially with generation size ($y = ae^{-bx}$), yet the model maintains generative capability (~16.7% novelty) even after generating an unprecedented 10 billion molecules.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Sources</strong>: A combination of <strong>PubChem</strong> (111M SMILES) and <strong>ZINC</strong> (1B SMILES) databases. Downloading and pre-training instructions are located in the repository&rsquo;s <code>data/README.md</code>.</li>
<li><strong>Preprocessing</strong>:
<ul>
<li>All SMILES were canonicalized using RDKit (no isomeric information).</li>
<li><strong>GP-MoLFormer (Base)</strong>: Trained on the full 1.1B dataset (includes duplicates).</li>
<li><strong>GP-MoLFormer-UNIQ</strong>: Trained on a de-duplicated subset of 650M SMILES.</li>
</ul>
</li>
<li><strong>Tokenization</strong>: Uses the tokenizer from Schwaller et al. (2019) with a vocabulary size of <strong>2,362 tokens</strong>.</li>
<li><strong>Filtering</strong>: Sequences restricted to a maximum length of <strong>202 tokens</strong>.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Pair-Tuning (Algorithm 1)</strong>:</p>
<ul>
<li><strong>Objective</strong>: Learn task-specific soft prompts $\phi_T$ to maximize the conditional probability of target molecule $b$ given a seed molecule $a$, where pair $(a, b)$ satisfies the property condition $b &gt; a$. The base model parameters $\theta$ remain frozen.</li>
<li><strong>Prompt Structure</strong>: Autoregressive training optimizes the continuous embeddings of $n$ enhancement tokens against the cross-entropy loss of the target sequence:
$$ \mathcal{L}(\phi_T) = - \sum_{i=1}^{|b|} \log P_{\theta}(b_i | \phi_T, a, b_{&lt;i}) $$</li>
<li><strong>Hyperparameters</strong>: Trained for 1,000 epochs with a batch size of 35 and a fixed learning rate of $3 \times 10^{-2}$.</li>
<li><strong>Inference</strong>: The learned prompt $\phi_T$ and seed molecule $a$ are prepended as context, and candidates are sampled autoregressively until a termination token is produced.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Availability</strong>: The model trained on deduplicated data (GP-MoLFormer-UNIQ) is publicly available on <a href="https://huggingface.co/ibm-research/GP-MoLFormer-Uniq">Hugging Face</a>. The full 1.1B base model is not explicitly hosted. The source code repository includes a disclosure that IBM will not maintain the code going forward.</li>
<li><strong>Architecture</strong>: Transformer decoder (~47M parameters: 12 layers, 12 heads, hidden size 768).</li>
<li><strong>Attention Mechanism</strong>: Combines Linear Attention (Generalized Random Feature map, $\phi$) with Rotary Positional Embeddings (RoPE). To avoid the quadratic complexity of standard attention while maintaining relative positional awareness, RoPE is applied to queries ($Q$) and keys ($K$) prior to the random feature mapping:
$$ \text{Attention}(Q, K, V) = \frac{\sum_{n=1}^N \langle \phi(R_m q_m), \phi(R_n k_n) \rangle v_n}{\sum_{n=1}^N \langle \phi(R_m q_m), \phi(R_n k_n) \rangle} $$</li>
<li><strong>Inference Speed</strong>: ~3ms per forward pass on a single A100 GPU.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Generation Quality Metrics</strong>: Validity, Uniqueness, Novelty (<a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a> suite), <a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">Fréchet ChemNet Distance (FCD)</a>, Scaffold similarity (Scaf), and Similarity to Nearest Neighbor (SNN).</li>
<li><strong>MoLFormer-Based Metrics</strong>: The authors introduce Fréchet <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MoLFormer</a> Distance (FMD) and MoLFormer-space IntDiv2 to measure distributional similarity using their own pre-trained continuous embeddings instead of standard fingerprints.</li>
<li><strong>Optimization Metrics</strong>: Penalized logP (calculated as $\text{logP} - \text{SA} - \text{max}(\text{maxrings}(size) - 6, 0)$), Drug-likeness (QED), and DRD2 activity scores.</li>
<li><strong>Scaling Metrics</strong>: Empirical fit for novelty decay: $y = ae^{-bx}$.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: 16 x NVIDIA A100 (80 GB) GPUs across 2 nodes connected via EDR Infiniband.</li>
<li><strong>Training Time</strong>:
<ul>
<li>GP-MoLFormer (1.1B data): ~115 hours total (28.75 hours/epoch for 4 epochs).</li>
<li>GP-MoLFormer-UNIQ (650M data): ~80 hours total.</li>
</ul>
</li>
<li><strong>Hyperparameters</strong>: Used a batch size of 1,600 molecules per GPU with a fixed learning rate of $1.6 \times 10^{-4}$ (scaled up to $8\times$ factor as GPUs increased).</li>
<li><strong>Optimization</strong>: Used distributed data-parallel training and adaptive bucketing by sequence length to handle scale.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/IBM/gp-molformer/">GP-MoLFormer (GitHub)</a></td>
          <td>Code</td>
          <td>Apache 2.0</td>
          <td>Official implementation; IBM will not maintain going forward</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/ibm-research/GP-MoLFormer-Uniq">GP-MoLFormer-Uniq (Hugging Face)</a></td>
          <td>Model</td>
          <td>Apache 2.0</td>
          <td>Pre-trained on 650M de-duplicated SMILES</td>
      </tr>
  </tbody>
</table>
<p>The full 1.1B base model weights are not publicly hosted. The training data (PubChem and ZINC) is publicly available, and instructions for downloading and pre-processing are in the repository&rsquo;s <code>data/README.md</code>.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ross, J., Belgodere, B., Hoffman, S. C., Chenthamarakshan, V., Navratil, J., Mroueh, Y., &amp; Das, P. (2025). GP-MoLFormer: A Foundation Model For Molecular Generation. <em>Digital Discovery</em>, 4(10), 2684&ndash;2696. <a href="https://doi.org/10.1039/D5DD00122F">https://doi.org/10.1039/D5DD00122F</a></p>
<p><strong>Publication</strong>: Digital Discovery, vol. 4, no. 10, pp. 2684&ndash;2696 (2025)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ross2025gpmolformer,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{GP-MoLFormer: a foundation model for molecular generation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ross, Jerret and Belgodere, Brian and Hoffman, Samuel C and Chenthamarakshan, Vijil and Navratil, Jiri and Mroueh, Youssef and Das, Payel}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Digital Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{10}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{2684--2696}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/D5DD00122F}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemBERTa-2: Scaling Molecular Transformers to 77M</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/chemberta-2/</link><pubDate>Thu, 25 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/chemberta-2/</guid><description>Optimizing transformer pretraining for molecules using MLM vs MTR objectives, scaling to 77M compounds from PubChem for improved property prediction.</description><content:encoded><![CDATA[<h2 id="classifying-chemberta-2s-methodological-contributions">Classifying ChemBERTa-2&rsquo;s Methodological Contributions</h2>
<p>This is primarily a <strong>Methodological</strong> paper with a secondary <strong>Resource</strong> contribution.</p>
<p>It fits the Method classification because it focuses on optimizing the architecture and pretraining pipeline for molecular transformers. The authors perform extensive ablation studies (varying dataset size from 5M to 77M, comparing MLM vs. MTR objectives) to determine &ldquo;how well&rdquo; these strategies work compared to baselines. The secondary Resource classification applies because they open-source the trained models and establish a benchmark on a massive 77M compound dataset.</p>
<p><strong>Key methodological indicators</strong>:</p>
<ul>
<li><strong>Baseline comparison</strong>: The paper explicitly compares ChemBERTa-2 against standard baselines (D-MPNN, Random Forest, GCN) and its predecessor (<a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa-1</a>) with prominent benchmark tables</li>
<li><strong>Ablation studies</strong>: Extensive experiments comparing multi-task and self-supervised pretraining by varying hyperparameters and pretraining dataset size</li>
<li><strong>Scaling analysis</strong>: Systematic investigation of whether larger datasets (up to 77M compounds) yield better performance</li>
</ul>
<h2 id="motivations-for-scaling-molecular-transformers">Motivations for Scaling Molecular Transformers</h2>
<p>The authors aim to bridge the gap between NLP success stories (like GPT-3) and molecular machine learning by developing a &ldquo;chemical foundation model&rdquo;.</p>
<p><strong>Key motivations</strong>:</p>
<ul>
<li><strong>Label scarcity</strong>: Experimental labels for molecular properties are rare and expensive, but unlabeled SMILES strings are abundant</li>
<li><strong>Scaling hypothesis</strong>: Testing if scaling pretraining data (up to 77M compounds) yields consistent downstream improvements, similar to scaling laws in NLP</li>
<li><strong>Efficiency</strong>: Optimizing the pretraining process introduced in the original ChemBERTa by comparing self-supervised (MLM) and weakly supervised (MTR, using <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a> computed properties as labels) approaches</li>
</ul>
<h2 id="novelty-in-multi-task-regression-objectives">Novelty in Multi-Task Regression Objectives</h2>
<p><strong>Scale</strong>: Training on 77M unique SMILES from PubChem, which is one of the largest molecular pretraining datasets used to date (compared to 10M for ChemBERTa-1 or 18.7M for <a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a>).</p>
<p><strong>Pipeline optimization</strong>: A direct, controlled comparison of <strong>Masked Language Modeling (MLM)</strong> vs. <strong>Multi-Task Regression (MTR)</strong> pretraining objectives on identical datasets.</p>
<p><strong>Proxy selection</strong>: The finding that MLM loss correlates well with MTR loss, allowing the cheaper MLM task to be used for hyperparameter tuning before running the expensive MTR pretraining.</p>
<h2 id="experimental-pretraining-setup-on-77m-compounds">Experimental Pretraining Setup on 77M Compounds</h2>
<h3 id="pretraining-setup">Pretraining Setup</h3>
<p><strong>Datasets</strong>: Subsets of <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> containing 5M, 10M, and 77M unique SMILES.</p>
<p><strong>Tasks</strong>:</p>
<ul>
<li><strong>MLM</strong>: Masking 15% of tokens (following RoBERTa procedure). The model is optimized by minimizing the cross-entropy loss over the predicted masked tokens:
$$ \mathcal{L}_{MLM} = -\sum_{i \in \mathcal{M}} \log P(x_i \mid \mathbf{x}_{\setminus \mathcal{M}}) $$
where $\mathcal{M}$ represents the set of masked token indices.</li>
<li><strong>MTR</strong>: Predicting 200 calculated molecular properties (via RDKit) simultaneously using a mean squared error objective:
$$ \mathcal{L}_{MTR} = \frac{1}{200} \sum_{j=1}^{200} \frac{1}{N} \sum_{i=1}^{N} \left( \hat{y}_{ij} - y_{ij} \right)^2 $$
Continuous target labels $y_{ij}$ are mean-normalized prior to training to equilibrate the disparate scales of different chemical properties.</li>
</ul>
<p><strong>Hyperparameter search</strong>: Ran 50 random configurations on the 5M dataset; selected the top 5 to scale up to 10M and 77M.</p>
<h3 id="downstream-validation">Downstream Validation</h3>
<p><strong>Finetuning</strong>: Evaluated on 8 tasks from <strong><a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a></strong> (BACE, BBBP, ClinTox, Delaney, etc.) using scaffold splits (80/10/10).</p>
<p><strong>Analysis</strong>: Used UMAP to visualize embeddings from MLM, MTR, and ECFP to check for clustering by label without finetuning.</p>
<h2 id="key-performance-outcomes-and-scaling-realities">Key Performance Outcomes and Scaling Realities</h2>
<p><strong>Highly competitive performance</strong>: ChemBERTa-2 outperforms the D-MPNN baseline (chemprop) on 6 out of 8 MoleculeNet tasks, though the margins demonstrate that task-specific baselines remain notably robust.</p>
<p><strong>MTR superiority</strong>: Models pretrained on Multi-Task Regression (MTR) consistently perform better on downstream tasks than those pretrained on MLM on every finetuning task evaluated. MTR is substantially slower than MLM due to the larger input size from the 200-element label vector, but MLM loss serves as a reliable proxy for MTR loss, enabling cheaper architecture search before committing to full MTR pretraining.</p>
<p><strong>Scaling laws versus downstream utility</strong>: Pretraining loss improved by 25-35% when increasing the dataset from 5M to 77M compounds. However, this improvement in pretraining loss does not uniformly transfer to downstream tasks. For MTR models, SR-p53 ROC-AUC decreases monotonically from 0.834 (5M) to 0.827 (10M) to 0.817 (77M), and Lipophilicity RMSE is worse at 77M (0.798) than at 5M (0.758), despite a dip at 10M (0.744). This variability in transfer challenges the assumption that pretraining improvements always yield downstream gains.</p>
<p><strong>Transfer learning</strong>: The correlation between pretraining loss and downstream performance is task-dependent; it is strong for Lipophilicity but weaker for BACE classification.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The pretraining corpus is derived from <strong>PubChem</strong>.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Pretraining</strong></td>
          <td>PubChem</td>
          <td>77M SMILES</td>
          <td>Canonicalized and globally shuffled. Subsets of 5M and 10M used. <strong>Note: Exact splits and datasets are not published.</strong></td>
      </tr>
      <tr>
          <td><strong>Validation</strong></td>
          <td>PubChem</td>
          <td>100k SMILES</td>
          <td>A fixed set held out from the 77M corpus. <strong>Note: Exact 100k subset is not published.</strong></td>
      </tr>
      <tr>
          <td><strong>MTR Labels</strong></td>
          <td>RDKit</td>
          <td>200 props</td>
          <td>200 molecular properties calculated from SMILES using RDKit. Labels are mean-normalized. <strong>Note: Calculated labels are not published and must be re-computed.</strong></td>
      </tr>
      <tr>
          <td><strong>Finetuning</strong></td>
          <td>MoleculeNet</td>
          <td>1.5k - 8k</td>
          <td>Tasks: BACE, Clearance, Delaney, Lipophilicity, BBBP, ClinTox, HIV, Tox21. Split 80/10/10 via scaffold splitter.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Pretraining Objectives:</strong></p>
<ol>
<li><strong>Masked Language Modeling (MLM)</strong>: Follows RoBERTa procedure. Masks 15% of tokens. Max sequence length 512.</li>
<li><strong>Multi-Task Regression (MTR)</strong>: Predicting 200 RDKit properties. Labels are mean-normalized.</li>
</ol>
<p><strong>Tokenizer:</strong></p>
<ul>
<li>Dictionary of common SMILES characters</li>
<li>Maximum vocabulary size: <strong>591 tokens</strong></li>
</ul>
<p><strong>Optimization:</strong></p>
<ul>
<li><strong>Patience</strong>: Early stopping set to one pass through the dataset to ensure full coverage</li>
<li><strong>Hyperparameter search</strong>: Random search (50 configs) varying hidden size, attention heads, dropout, intermediate size, hidden layers, and learning rate. <strong>Note: The precise configuration of the winning models that were scaled to 77M is absent from the paper.</strong></li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Based on <strong>RoBERTa</strong> (HuggingFace implementation)</li>
<li><strong>Parameter scale</strong>: Models ranged between <strong>5M and 46M parameters</strong></li>
<li><strong>Selection</strong>: Top 5 configurations from the 5M-dataset random search were trained on the full 77M dataset</li>
<li><strong>Checkpoints</strong>: Pre-trained weights are hosted by DeepChem on <a href="https://huggingface.co/DeepChem">Hugging Face</a>. Direct links include <a href="https://huggingface.co/DeepChem/ChemBERTa-77M-MTR">DeepChem/ChemBERTa-77M-MTR</a> and <a href="https://huggingface.co/DeepChem/ChemBERTa-77M-MLM">DeepChem/ChemBERTa-77M-MLM</a> (Note: Model cards are currently empty).</li>
<li><strong>Code Reference</strong>: While the <a href="https://github.com/deepchem/deepchem">DeepChem</a> repository is referenced for code, isolated training scripts tailored to recreate ChemBERTa-2&rsquo;s exact pipeline are not separated from the generalized deepchem library tooling.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Benchmarks were performed on <strong>MoleculeNet</strong> using DeepChem.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Tasks</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>RMSE</strong> ($\downarrow$)</td>
          <td>Delaney, Lipo, BACE (Reg), Clearance</td>
          <td>D-MPNN</td>
          <td>ChemBERTa-2 outperformed D-MPNN on Delaney (0.889 vs 1.105) and Clearance (48.5 vs 49.8).</td>
      </tr>
      <tr>
          <td><strong>ROC-AUC</strong> ($\uparrow$)</td>
          <td>BBBP, ClinTox, HIV, Tox21, BACE (Cls)</td>
          <td>D-MPNN</td>
          <td>ChemBERTa-2 generally competitive; MTR-77M achieved 0.728 on BBBP vs D-MPNN 0.697.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: AWS EC2 instances with <strong>Nvidia T4 GPUs</strong></li>
<li><strong>Strategy</strong>: AWS Spot instances were used to reduce cost; implemented frequent checkpointing to handle interruptions.</li>
<li><strong>Note</strong>: For MTR, they wrote a custom data loader wrapper around HuggingFace&rsquo;s text loader to handle CSV parsing efficiency, as the default CSV loader was a major bottleneck for the 200-element target vectors.</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ahmad, W., Simon, E., Chithrananda, S., Grand, G., &amp; Ramsundar, B. (2022). ChemBERTa-2: Towards Chemical Foundation Models. <em>arXiv preprint arXiv:2209.01712</em>. <a href="https://doi.org/10.48550/arXiv.2209.01712">https://doi.org/10.48550/arXiv.2209.01712</a></p>
<p><strong>Publication</strong>: arXiv 2022 (Presented at 2021 ELLIS ML for Molecule Discovery Workshop)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa-1 Paper</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{ahmadChemBERTa2ChemicalFoundation2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{ChemBERTa-2}}: {{Towards Chemical Foundation Models}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{ChemBERTa-2}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Ahmad, Walid and Simon, Elana and Chithrananda, Seyone and Grand, Gabriel and Ramsundar, Bharath}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2022</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = sep,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2209.01712}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2209.01712}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2209.01712}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-12-25}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Chemformer: A Pre-trained Transformer for Comp Chem</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/chemformer/</link><pubDate>Tue, 23 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/chemformer/</guid><description>BART-based Transformer pre-trained on 100M molecules using self-supervision to accelerate convergence on chemical sequence tasks.</description><content:encoded><![CDATA[<h2 id="paper-contribution-and-methodological-classification">Paper Contribution and Methodological Classification</h2>
<p>This is a <strong>Methodological ($\Psi_{\text{Method}}$)</strong> paper. It proposes an architecture adaptation (Chemformer based on BART) and a specific pre-training strategy (&ldquo;Combined&rdquo; masking and augmentation). The paper validates this method by benchmarking against established models on multiple tasks, including direct synthesis, retrosynthesis, and molecular optimization. It also includes a secondary <strong>Resource ($\Psi_{\text{Resource}}$)</strong> contribution by making the pre-trained models and code available.</p>
<h2 id="motivation-computational-bottlenecks-in-cheminformatics">Motivation: Computational Bottlenecks in Cheminformatics</h2>
<p>Existing Transformer models for cheminformatics are often developed for single applications and are computationally expensive to train from scratch. For example, training a Molecular Transformer for reaction prediction can take days, limiting hyperparameter exploration. Self-supervised pre-training (like BERT or T5) has significantly advanced NLP by reducing fine-tuning time and improving performance. In chemistry, applications have traditionally focused on task-specific datasets or encoder-only architectures, which perform poorly on sequence generation tasks. The authors aim to use transfer learning on a large unlabelled dataset to create a model that converges quickly and performs well across diverse sequence-to-sequence and discriminative tasks.</p>
<h2 id="core-innovation-bart-architecture-and-combined-pre-training">Core Innovation: BART Architecture and Combined Pre-training</h2>
<p>The primary insight lies in the adaptation of the <strong>BART architecture</strong> for chemistry and the introduction of a <strong>&ldquo;Combined&rdquo; self-supervised pre-training task</strong>.</p>
<ul>
<li><strong>Architecture</strong>: Chemformer uses the BART encoder-decoder structure, allowing it to handle both discriminative (property prediction) and generative (reaction prediction) tasks efficiently. This provides an alternative to encoder-only (BERT) or decoder-only (GPT) models.</li>
<li><strong>Combined Pre-training</strong>: The authors introduce a task that applies both <strong>Span Masking</strong> (randomly replacing tokens with <code>&lt;mask&gt;</code>) and <strong><a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> Augmentation</strong> (permuting atom order, see <a href="/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/">Randomized SMILES</a>) simultaneously. Formally, given a canonical SMILES sequence $x$, a corrupted sequence $\tilde{x} = \text{Mask}(\text{Augment}(x))$ is generated. The model is trained using an autoregressive cross-entropy loss to reconstruct the canonical sequence from the corrupted input:
$$ \mathcal{L}_{\text{pre-train}} = -\sum_{t=1}^{|x|} \log P(x_t \mid x_{&lt;t}, \tilde{x}) $$</li>
<li><strong>Tunable Augmentation</strong>: A downstream augmentation strategy is proposed where the probability of augmenting the input/output SMILES ($p_{aug}$) is a tunable hyperparameter, performed on-the-fly.</li>
</ul>
<h2 id="experimental-setup-and-pre-training-tasks">Experimental Setup and Pre-training Tasks</h2>
<p>The authors pre-trained Chemformer on <strong>100 million molecules</strong> from ZINC-15 and fine-tuned it on three distinct task types:</p>
<ol>
<li><strong>Seq2Seq Reaction Prediction</strong>:
<ul>
<li><em>Direct Synthesis</em>: USPTO-MIT dataset (Mixed and Separated).</li>
<li><em><a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">Retrosynthesis</a></em>: USPTO-50K dataset (see also <a href="/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/">Molecular Transformer</a>, <a href="/notes/chemistry/molecular-design/reaction-prediction/tied-two-way-transformers-retrosynthesis/">Tied Two-Way Transformers</a>).</li>
</ul>
</li>
<li><strong>Molecular Optimization</strong>: Generating molecules with improved properties (<a href="https://en.wikipedia.org/wiki/Distribution_coefficient">LogD</a>, solubility, clearance) starting from ChEMBL matched molecular pairs.</li>
<li><strong>Discriminative Tasks</strong>:
<ul>
<li><em><a href="https://en.wikipedia.org/wiki/Quantitative_structure%E2%80%93activity_relationship">QSAR</a></em>: Predicting properties (ESOL, FreeSolv, Lipophilicity) from <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>.</li>
<li><em>Bioactivity</em>: Predicting pXC50 values for 133 genes using ExCAPE data.</li>
</ul>
</li>
</ol>
<p>Ablation studies compared three pre-training strategies (Masking, Augmentation, Combined) against a randomly initialized baseline.</p>
<h2 id="results-trade-offs-and-conclusions">Results, Trade-offs, and Conclusions</h2>
<ul>
<li><strong>Performance</strong>: Chemformer achieved <strong>competitive top-1 accuracy</strong> on USPTO-MIT (91.3% Mixed) and USPTO-50K (53.6-54.3%), outperforming the Augmented Transformer and graph-based models (GLN, GraphRetro).</li>
<li><strong>Convergence Speed</strong>: Pre-training significantly accelerated training; fine-tuning for just 20 epochs (30 mins) outperformed the previous baselines trained for significantly longer.</li>
<li><strong>Pre-training Tasks</strong>: The &ldquo;Combined&rdquo; task generally performed best for reaction prediction and bioactivity, while &ldquo;Masking&rdquo; was superior for molecular optimization.</li>
<li><strong>Augmentation Trade-off</strong>: The augmentation strategy improved top-1 accuracy but significantly degraded top-5/10 accuracy because beam search outputs became populated with augmented versions of the same molecule. This presents a considerable limitation for practical applications like retrosynthesis mapping, where retrieving a diverse set of candidate reactions is often critical.</li>
<li><strong>Discriminative Evaluation Caveats</strong>: Chemformer underperformed specialized baselines (like D-MPNN or <a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a>) on small discriminative datasets. The authors note that direct comparison is difficult: Chemformer was trained simultaneously on multiple subtasks (multi-task learning), while the literature baselines were trained and tuned on each subtask separately. Additionally, the Chemformer encoder uses fewer than 20M parameters compared to MolBERT&rsquo;s approximately 85M, and Chemformer&rsquo;s pre-training does not include molecular property objectives. For other transfer learning approaches to QSAR, see <a href="/notes/chemistry/molecular-design/property-prediction/molpmofit-transfer-learning-qsar/">MolPMoFiT</a>.</li>
<li><strong>Pre-training Data Scope</strong>: The 100M pre-training dataset from ZINC-15 was selected with constraints on molecular weight ($\le 500$ Da) and LogP ($\le 5$), focusing the learned representations on small, drug-like molecules.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><em>Note: The primary GitHub repository for Chemformer was officially archived on February 11, 2026. Pre-trained weights and datasets used in the paper are still hosted externally on <a href="https://az.app.box.com/s/7eci3nd9vy0xplqniitpk02rbg9q2zcq">Box</a>. Active development of Chemformer models has moved to the <a href="https://github.com/MolecularAI/aizynthmodels">AiZynthModels</a> repository.</em></p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/MolecularAI/Chemformer">Chemformer (GitHub)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Apache-2.0</td>
          <td style="text-align: left">Archived; original PyTorch implementation</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/MolecularAI/aizynthmodels">AiZynthModels (GitHub)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Apache-2.0</td>
          <td style="text-align: left">Active successor repository</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://az.app.box.com/s/7eci3nd9vy0xplqniitpk02rbg9q2zcq">Pre-trained weights (Box)</a></td>
          <td style="text-align: left">Model</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Base and Large model checkpoints</td>
      </tr>
  </tbody>
</table>
<p>The following datasets were used for pre-training and benchmarking.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Pre-training</strong></td>
          <td style="text-align: left">ZINC-15</td>
          <td style="text-align: left">100M</td>
          <td style="text-align: left">Selected subset (reactive, annotated purchasability, MW $\le 500$, LogP $\le 5$). Split: 99% Train / 0.5% Val / 0.5% Test.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Direct Synthesis</strong></td>
          <td style="text-align: left">USPTO-MIT</td>
          <td style="text-align: left">~470k</td>
          <td style="text-align: left">Evaluated on &ldquo;Mixed&rdquo; and &ldquo;Separated&rdquo; variants.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Retrosynthesis</strong></td>
          <td style="text-align: left">USPTO-50K</td>
          <td style="text-align: left">~50k</td>
          <td style="text-align: left">Standard benchmark for retrosynthesis.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Optimization</strong></td>
          <td style="text-align: left">ChEMBL MMPs</td>
          <td style="text-align: left">~160k Train</td>
          <td style="text-align: left">Matched Molecular Pairs for LogD, solubility, and clearance optimization.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Properties</strong></td>
          <td style="text-align: left">MoleculeNet</td>
          <td style="text-align: left">Small</td>
          <td style="text-align: left">ESOL (1128), FreeSolv (642), Lipophilicity (4200).</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Bioactivity</strong></td>
          <td style="text-align: left">ExCAPE</td>
          <td style="text-align: left">~312k</td>
          <td style="text-align: left">133 gene targets; &gt;1200 compounds per gene.</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li><strong>Tokenization</strong>: Regex-based tokenization (523 tokens total) derived from ChEMBL 27 canonical SMILES.</li>
<li><strong>Augmentation</strong>: SMILES enumeration (permuting atom order) used for pre-training and on-the-fly during fine-tuning ($p_{aug}=0.5$ for Seq2Seq, $p_{aug}=1.0$ for discriminative).</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Pre-training Tasks</strong>:
<ol>
<li><em>Masking</em>: Span masking (BART style).</li>
<li><em>Augmentation</em>: Input is a randomized SMILES; target is canonical SMILES.</li>
<li><em>Combined</em>: Input is augmented <em>then</em> masked; target is canonical SMILES.</li>
</ol>
</li>
<li><strong>Optimization</strong>:
<ul>
<li>Optimizer: Adam ($\beta_1=0.9, \beta_2=0.999$).</li>
<li>Schedule: Linear warm-up (8000 steps) for pre-training; One-cycle schedule for fine-tuning.</li>
</ul>
</li>
<li><strong>Inference</strong>: <a href="https://en.wikipedia.org/wiki/Beam_search">Beam search</a> with width 10 for Seq2Seq tasks. Used <code>molbart/inference_score.py</code> and <code>molbart/retrosynthesis/round_trip_inference.py</code> for standard and round-trip validation.</li>
</ul>
<h3 id="models">Models</h3>
<p>Two model sizes were trained. Both use the Pre-Norm Transformer layout with GELU activation.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Hyperparameter</th>
          <th style="text-align: left">Chemformer (Base)</th>
          <th style="text-align: left">Chemformer-Large</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Layers</strong></td>
          <td style="text-align: left">6</td>
          <td style="text-align: left">8</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Model Dimension</strong></td>
          <td style="text-align: left">512</td>
          <td style="text-align: left">1024</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Feed-forward Dim</strong></td>
          <td style="text-align: left">2048</td>
          <td style="text-align: left">4096</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Attention Heads</strong></td>
          <td style="text-align: left">8</td>
          <td style="text-align: left">16</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Parameters</strong></td>
          <td style="text-align: left">~45M</td>
          <td style="text-align: left">~230M</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Pre-training Task</strong></td>
          <td style="text-align: left">All 3 variants</td>
          <td style="text-align: left">Combined only</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<p>Comparisons relied on Top-N accuracy for reaction tasks and validity metrics for optimization.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">Task</th>
          <th style="text-align: left">Key Result</th>
          <th style="text-align: left">Baseline</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Top-1 Acc</strong></td>
          <td style="text-align: left">Direct Synthesis (Sep)</td>
          <td style="text-align: left"><strong>92.8%</strong> (Large)</td>
          <td style="text-align: left">91.1% (Aug Transformer)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Top-1 Acc</strong></td>
          <td style="text-align: left">Retrosynthesis</td>
          <td style="text-align: left"><strong>54.3%</strong> (Large)</td>
          <td style="text-align: left">53.7% (GraphRetro) / 52.5% (GLN)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Desirable %</strong></td>
          <td style="text-align: left">Mol Optimization</td>
          <td style="text-align: left"><strong>75.0%</strong> (Base-Mask)</td>
          <td style="text-align: left">70.2% (Transformer-R)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>RMSE</strong></td>
          <td style="text-align: left">Lipophilicity</td>
          <td style="text-align: left">0.598 (Combined)</td>
          <td style="text-align: left">0.555 (D-MPNN)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: 4 NVIDIA V100 GPUs (batch size 128 per GPU).</li>
<li><strong>Training Time</strong>:
<ul>
<li>Pre-training: 2.5 days (Base) / 6 days (Large) for 1M steps.</li>
<li>Fine-tuning: ~20-40 epochs for reaction prediction (&lt;12 hours).</li>
</ul>
</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Irwin, R., Dimitriadis, S., He, J., &amp; Bjerrum, E. J. (2022). Chemformer: a pre-trained transformer for computational chemistry. <em>Machine Learning: Science and Technology</em>, 3(1), 015022. <a href="https://doi.org/10.1088/2632-2153/ac3ffb">https://doi.org/10.1088/2632-2153/ac3ffb</a></p>
<p><strong>Publication</strong>: Machine Learning: Science and Technology 2022</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{irwinChemformerPretrainedTransformer2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Chemformer: A Pre-Trained Transformer for Computational Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{Chemformer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Irwin, Ross and Dimitriadis, Spyridon and He, Jiazhen and Bjerrum, Esben Jannik}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2022</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jan,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Machine Learning: Science and Technology}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{3}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{015022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{IOP Publishing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{2632-2153}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1088/2632-2153/ac3ffb}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemBERTa: Molecular Property Prediction via Transformers</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/chemberta/</link><pubDate>Tue, 23 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/chemberta/</guid><description>A systematic evaluation of RoBERTa transformers pretrained on 77M PubChem SMILES for molecular property prediction tasks.</description><content:encoded><![CDATA[<h2 id="taxonomy-and-paper-contributions">Taxonomy and Paper Contributions</h2>
<p>This is primarily a <strong>Method</strong> paper ($\Psi_{\text{Method}}$), with a significant <strong>Resource</strong> component ($\Psi_{\text{Resource}}$).</p>
<p>It is a methodological investigation because it systematically evaluates a specific architecture (Transformers/RoBERTa) against established State-of-the-Art (SOTA) baselines like directed Message Passing Neural Networks (D-MPNNs) to determine &ldquo;how well does this work?&rdquo; in the chemical domain. It ablates dataset size, tokenization, and input representation.</p>
<p>It is also a resource paper as it introduces &ldquo;PubChem-77M,&rdquo; a curated dataset of 77 million SMILES strings designed to facilitate large-scale self-supervised pretraining for the community.</p>
<h2 id="overcoming-data-scarcity-in-property-prediction">Overcoming Data Scarcity in Property Prediction</h2>
<p>The primary motivation is <strong>data scarcity</strong> in molecular property prediction. Graph Neural Networks (GNNs) achieve strong performance on property prediction tasks when provided with sufficient labeled data. Generating these labels requires costly and time-consuming laboratory testing, leading to severe data scarcity in specialized chemical domains.</p>
<p>Massive quantities of <strong>unlabeled chemical structure data</strong> exist in the form of <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings. Inspired by the success of Transformers in NLP, where self-supervised pretraining on large corpora yields strong transfer learning, the authors aim to use these unlabeled datasets to learn effective molecular representations. Additionally, Transformers benefit from a mature software ecosystem (HuggingFace) that offers efficiency advantages over GNNs.</p>
<h2 id="pretraining-scaling-laws-and-novelty">Pretraining Scaling Laws and Novelty</h2>
<p>Previous works applied Transformers to SMILES strings. This paper advances the field by systematically evaluating scaling laws and architectural components for this domain. Specifically:</p>
<ul>
<li><strong>Scaling Analysis</strong>: It explicitly tests how pretraining dataset size (100K to 10M) impacts downstream performance.</li>
<li><strong>Tokenizer Comparison</strong>: It compares standard NLP <a href="https://en.wikipedia.org/wiki/Byte-pair_encoding">Byte-Pair Encoding (BPE)</a> against a chemically-aware &ldquo;SmilesTokenizer&rdquo;.</li>
<li><strong>Representation Comparison</strong>: It evaluates if the robust <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> string representation offers advantages over standard SMILES in a Transformer context.</li>
</ul>
<h2 id="experimental-setup-pretraining-and-finetuning">Experimental Setup: Pretraining and Finetuning</h2>
<p>The authors trained <strong>ChemBERTa</strong> (based on RoBERTa) using Masked Language Modeling (MLM) on subsets of the <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> dataset. The core training objective minimizes the cross-entropy loss over a corrupted input where a subset of basic tokens, denoted by $\mathcal{M}$, are masked:</p>
<p>$$
\mathcal{L}_{\text{MLM}} = - \frac{1}{|\mathcal{M}|} \sum_{i \in \mathcal{M}} \log P(x_i \mid x_{\setminus \mathcal{M}}; \theta)
$$</p>
<p>where $x_i$ is the exact masked token, $x_{\setminus \mathcal{M}}$ is the corrupted SMILES context string, and $\theta$ represents the network parameters.</p>
<ul>
<li><strong>Pretraining</strong>: Models were pretrained on dataset sizes of 100K, 250K, 1M, and 10M compounds.</li>
<li><strong>Baselines</strong>: Performance was compared against D-MPNN (Graph Neural Network), Random Forest (RF), and SVM using 2048-bit Morgan Fingerprints.</li>
<li><strong>Downstream Tasks</strong>: Finetuning was performed individually on small <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> classification tasks: BBBP (<a href="https://en.wikipedia.org/wiki/Blood%E2%80%93brain_barrier">blood-brain barrier</a>), ClinTox (clinical toxicity), HIV, and Tox21 (p53 stress-response). This poses a transfer learning challenge, as the model must adapt from pretraining on 10 million molecules to classifying datasets ranging from ~1.5K to ~41K examples.</li>
<li><strong>Ablations</strong>:
<ul>
<li><strong>Tokenization</strong>: BPE vs. SmilesTokenizer on the 1M dataset, evaluated on Tox21.</li>
<li><strong>Input</strong>: SMILES vs. SELFIES strings on the Tox21 task.</li>
</ul>
</li>
</ul>
<h2 id="results-vs-graph-neural-network-baselines">Results vs. Graph Neural Network Baselines</h2>
<p>The main comparison between ChemBERTa (pretrained on 10M compounds) and Chemprop baselines on MoleculeNet tasks is summarized below (Table 1 from the paper):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>BBBP ROC</th>
          <th>BBBP PRC</th>
          <th>ClinTox ROC</th>
          <th>ClinTox PRC</th>
          <th>HIV ROC</th>
          <th>HIV PRC</th>
          <th>Tox21 ROC</th>
          <th>Tox21 PRC</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ChemBERTa 10M</td>
          <td>0.643</td>
          <td>0.620</td>
          <td>0.733</td>
          <td>0.975</td>
          <td>0.622</td>
          <td>0.119</td>
          <td>0.728</td>
          <td>0.207</td>
      </tr>
      <tr>
          <td>D-MPNN</td>
          <td>0.708</td>
          <td>0.697</td>
          <td>0.906</td>
          <td>0.993</td>
          <td>0.752</td>
          <td>0.152</td>
          <td>0.688</td>
          <td>0.429</td>
      </tr>
      <tr>
          <td>RF</td>
          <td>0.681</td>
          <td>0.692</td>
          <td>0.693</td>
          <td>0.968</td>
          <td>0.780</td>
          <td>0.383</td>
          <td>0.724</td>
          <td>0.335</td>
      </tr>
      <tr>
          <td>SVM</td>
          <td>0.702</td>
          <td>0.724</td>
          <td>0.833</td>
          <td>0.986</td>
          <td>0.763</td>
          <td>0.364</td>
          <td>0.708</td>
          <td>0.345</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Scaling Improvements &amp; Training Dynamics</strong>: Performance scales predictably with pretraining data size. Increasing data from 100K to 10M improved ROC-AUC by +0.110 and PRC-AUC by +0.059 on average across BBBP, ClinTox, and Tox21 (HIV was omitted due to resource constraints). Notably, researchers had to halt pretraining on the 10M subset after just 3 epochs due to overfitting, suggesting that simple 15% token masking might not provide a sufficiently difficult learning curvature for large-scale chemical representation.</li>
<li><strong>Performance Limits vs. GNNs</strong>: ChemBERTa generally performs below the D-MPNN baseline. On the Tox21 dataset, ChemBERTa-10M achieved a higher ROC-AUC (0.728) than D-MPNN (0.688); nonetheless, it recorded a substantially lower PRC-AUC (0.207 vs 0.429). This gap indicates that current Transformer iterations lack the explicit inductive biases of graph algorithms and struggle with the severe class imbalances typical of chemical datasets.</li>
<li><strong>Ablation Limitations (Tokenization &amp; SELFIES)</strong>: The authors&rsquo; ablation studies for tokenization (SmilesTokenizer narrowly beating BPE) and input representation (SELFIES performing comparably to SMILES) were evaluated exclusively on the single Tox21 task. Deriving broad architectural conclusions regarding &ldquo;semantically-aware tokenization&rdquo; or string robustness from an $N=1$ empirical evaluation is a significant limitation of the study. Broader benchmarking is required to validate these findings.</li>
<li><strong>Interpretability</strong>: Attention heads organically learn to track chemically relevant substructures (like specific functional groups and aromatic rings), mimicking the inductive biases of graph convolutions.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors curated a massive dataset for pretraining and utilized standard benchmarks for evaluation.</p>
<ul>
<li><strong>Pretraining Data</strong>: <strong>PubChem-77M</strong>.
<ul>
<li>Source: 77 million unique SMILES from PubChem.</li>
<li>Preprocessing: Canonicalized and globally shuffled.</li>
<li>Subsets used: 100K, 250K, 1M, and 10M subsets.</li>
<li><em>Availability Note</em>: The authors provided a direct link to the <a href="https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/pubchem_10m.txt.zip">canonicalized 10M compound subset</a> used for their largest experiments. Full reproducibility of the smaller (100K, 250K, 1M) or full 77M sets may require re-extracting from PubChem.</li>
</ul>
</li>
<li><strong>Evaluation Data</strong>: <strong>MoleculeNet</strong>.
<ul>
<li>Tasks: BBBP (2,039), ClinTox (1,478), HIV (41,127), Tox21 (7,831).</li>
<li>Splitting: 80/10/10 train/valid/test split using a <strong>scaffold splitter</strong> to ensure chemical diversity between splits.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p>The core training methodology mirrors standard BERT/RoBERTa procedures adapted for chemical strings.</p>
<ul>
<li><strong>Objective</strong>: Masked Language Modeling (MLM) with <strong>15% token masking</strong>.</li>
<li><strong>Tokenization</strong>:
<ul>
<li><strong>BPE</strong>: Byte-Pair Encoder (vocab size 52K).</li>
<li><strong>SmilesTokenizer</strong>: Regex-based custom tokenizer available in DeepChem (documented <a href="https://deepchem.readthedocs.io/en/latest/tokenizers.html#smilestokenizer">here</a>).</li>
</ul>
</li>
<li><strong>Sequence Length</strong>: Maximum sequence length of <strong>512 tokens</strong>.</li>
<li><strong>Finetuning</strong>: Appended a linear classification layer; backpropagated through the base model for up to 25 epochs with early stopping on ROC-AUC.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: <strong>RoBERTa</strong> (via HuggingFace).
<ul>
<li>Layers: 6</li>
<li>Attention Heads: 12 (72 distinct mechanisms total).</li>
<li><em>Implementation Note</em>: The original training notebooks and scripts are maintained in the authors&rsquo; <a href="https://github.com/seyonechithrananda/bert-loves-chemistry">bert-loves-chemistry repository</a>, alongside the primary downstream tasks integrated into DeepChem. A <a href="https://github.com/deepchem/deepchem/blob/master/examples/tutorials/Transfer_Learning_With_ChemBERTa_Transformers.ipynb">full Tox21 transfer learning tutorial</a> has been incorporated into the DeepChem repository.</li>
</ul>
</li>
<li><strong>Baselines</strong> (via Chemprop library):
<ul>
<li><strong>D-MPNN</strong>: Directed Message Passing Neural Network with default hyperparameters.</li>
<li><strong>RF/SVM</strong>: Scikit-learn Random Forest and SVM using 2048-bit Morgan fingerprints (<a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a>).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Performance is measured using dual metrics to account for class imbalance common in toxicity datasets.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Details</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>ROC-AUC</strong></td>
          <td>Area Under Receiver Operating Characteristic Curve</td>
      </tr>
      <tr>
          <td><strong>PRC-AUC</strong></td>
          <td>Area Under Precision-Recall Curve (vital for imbalanced data)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: Single <strong>NVIDIA V100 GPU</strong>.</li>
<li><strong>Training Time</strong>: Approximately <strong>48 hours</strong> for the 10M compound subset.</li>
<li><strong>Carbon Footprint</strong>: Estimated 17.1 kg $\text{CO}_2\text{eq}$ (offset by Google Cloud).</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/seyonechithrananda/bert-loves-chemistry">bert-loves-chemistry</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Training notebooks and finetuning scripts</td>
      </tr>
      <tr>
          <td><a href="https://github.com/deepchem/deepchem">DeepChem</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Integration of ChemBERTa and SmilesTokenizer</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/seyonec/ChemBERTa-zinc-base-v1">ChemBERTa-zinc-base-v1</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>Pre-trained RoBERTa on 100K ZINC SMILES</td>
      </tr>
      <tr>
          <td><a href="https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/pubchem_10m.txt.zip">PubChem-10M subset</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>Canonicalized 10M compound subset used for largest experiments</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status</strong>: Partially Reproducible. Code and pre-trained models are available, and the 10M pretraining subset is downloadable. However, smaller subsets (100K, 250K, 1M) may need re-extraction from PubChem, and exact hyperparameter details for finetuning (learning rate, batch size) are not fully specified in the paper.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chithrananda, S., Grand, G., &amp; Ramsundar, B. (2020). ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction. <em>arXiv preprint arXiv:2010.09885</em>. <a href="https://doi.org/10.48550/arXiv.2010.09885">https://doi.org/10.48550/arXiv.2010.09885</a></p>
<p><strong>Publication</strong>: arXiv 2020 (Preprint)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://huggingface.co/seyonec/ChemBERTa-zinc-base-v1">HuggingFace Model Hub (ChemBERTa-zinc-base-v1)</a> - <em>Additional pre-trained variations on PubChem &amp; ZINC datasets are available on the author&rsquo;s <a href="https://huggingface.co/seyonec">seyonec</a> HF profile.</em></li>
<li><a href="https://github.com/seyonechithrananda/bert-loves-chemistry">bert-loves-chemistry GitHub Repository</a> - <em>Notebooks and scripts used for MLM pretraining and finetuning evaluations.</em></li>
</ul>
<h3 id="bibtex">BibTeX</h3>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{chithranandaChemBERTaLargeScaleSelfSupervised2020,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{ChemBERTa}}: {{Large-Scale Self-Supervised Pretraining}} for {{Molecular Property Prediction}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{ChemBERTa}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Chithrananda, Seyone and Grand, Gabriel and Ramsundar, Bharath}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2020</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = oct,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2010.09885}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2010.09885}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2010.09885}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-12-24}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Score Matching and Denoising Autoencoders: A Connection</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-matching-denoising-autoencoders/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-matching-denoising-autoencoders/</guid><description>Theoretical paper proving the equivalence between training Denoising Autoencoders and performing Score Matching on a Parzen density estimator.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Theory Paper</strong>.</p>
<p>Its primary contribution is a formal mathematical derivation connecting two previously distinct techniques: Score Matching (SM) and Denoising Autoencoders (DAE). It provides the &ldquo;why&rdquo; behind the empirical success of DAEs by grounding them in the probabilistic framework of energy-based models. It relies on proofs and equivalence relations (e.g., $J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}}$).</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The paper bridges a gap between two successful but disconnected approaches in unsupervised learning:</p>
<ol>
<li><strong>Denoising Autoencoders (DAE):</strong> Empirically successful for pre-training deep networks. They previously lacked a clear probabilistic interpretation.</li>
<li><strong>Score Matching (SM):</strong> A theoretically sound method for estimating unnormalized density models that avoids the partition function problem but requires computing expensive second derivatives.</li>
</ol>
<p>By connecting them, the authors aim to define a proper probabilistic model for DAEs (allowing sampling/ranking) and find a simpler way to apply score matching that avoids second derivatives.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>Denoising Score Matching (DSM)</strong> framework and the proof of its equivalence to DAEs. Key contributions include:</p>
<ul>
<li><strong>Equivalence Proof:</strong> Showing that training a DAE with Gaussian noise is equivalent to matching the score of a model against a non-parametric Parzen density estimator of the data.</li>
<li><strong>Denoising Score Matching ($J_{DSM}$):</strong> A new objective that learns a score function by trying to denoise corrupted samples. This avoids the explicit second derivatives required by standard Implicit Score Matching ($J_{ISM}$).</li>
<li><strong>Explicit Energy Function:</strong> Deriving the specific energy function $E(x;\theta)$ that corresponds to the standard sigmoid DAE architecture.</li>
<li><strong>Justification for Tied Weights:</strong> Providing a theoretical justification for tying encoder and decoder weights, which arises naturally from differentiating the energy function.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The validation in this theoretical paper is purely mathematical and focuses on formal proofs:</p>
<ul>
<li><strong>Derivation of Equivalence:</strong> The paper formally proves the chain of equivalences:
$$J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$$
where $q_{\sigma}$ is the Parzen density estimate.</li>
<li><strong>Appendix Proof:</strong> A detailed proof is provided to show that Explicit Score Matching ($J_{ESM}$) on the Parzen density is equivalent to the proposed Denoising Score Matching ($J_{DSM}$) objective.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Theoretical Unification:</strong> DAE training is formally equivalent to Score Matching on a smoothed data distribution ($q_{\sigma}$).</li>
<li><strong>New Training Objective:</strong> The $J_{DSM}$ objective offers a computationally efficient way to perform score matching (no Hessian required) by using a denoising objective.</li>
<li><strong>Probabilistic Interpretation:</strong> DAEs can now be understood as Energy-Based Models (EBMs), allowing for operations like sampling (via Hybrid Monte Carlo) and likelihood ranking, which were previously ill-defined for standard autoencoders.</li>
<li><strong>Regularization Insight:</strong> The smoothing kernel width $\sigma$ in the Parzen estimator corresponds to the noise level in the DAE. This suggests that DAEs are learning a regularized version of the score, which may explain their robustness.</li>
<li><strong>Connection to Regularized Score Matching:</strong> The paper notes that Kingma and LeCun (2010) independently proposed a regularized score matching criterion $J_{ISMreg}$ derived by approximating $J_{ISMq_{\sigma}}$. The four $q_{\sigma}$-based objectives in this work (including the DAE objective) can be seen as approximation-free forms of regularized score matching, with the additional advantage that $J_{DSMq_{\sigma}}$ does not require second derivatives.</li>
</ul>
<hr>
<h2 id="key-concepts-explained">Key Concepts Explained</h2>
<h3 id="1-score-and-score-matching">1. &ldquo;Score&rdquo; and &ldquo;Score Matching&rdquo;</h3>
<p><strong>What does &ldquo;score&rdquo; actually mean?</strong></p>
<p>In this paper (and probabilistic modeling generally), the <strong>score</strong> is the gradient of the log-density with respect to the <em>data vector</em> $x$.</p>
<ul>
<li><strong>Definition:</strong> $\psi(x) = \nabla_x \log p(x)$.</li>
<li><strong>Intuition:</strong> It is a vector field pointing in the direction of highest probability increase. Crucially, calculating the score avoids the intractable partition function $Z$, because $\nabla_x \log p(x) = \nabla_x \log \tilde{p}(x) - \nabla_x \log Z = \nabla_x \log \tilde{p}(x)$. The constant $Z$ vanishes upon differentiation.</li>
</ul>
<p><strong>What is Score Matching?</strong></p>
<p>Score Matching is a training objective for unnormalized models. It minimizes the squared Euclidean distance between the model&rsquo;s score $\psi(x;\theta)$ and the data&rsquo;s true score $\nabla_x \log q(x)$.</p>
<h3 id="2-the-parzen-density-estimator">2. The Parzen Density Estimator</h3>
<p><strong>What is it?</strong></p>
<p>It is a non-parametric method for estimating a probability density function from finite data. It places a smooth kernel (here, a Gaussian) centered at every data point in the training set $D_n$.</p>
<ul>
<li><strong>Formula:</strong> $q_{\sigma}(\tilde{x}) = \frac{1}{n} \sum_{t=1}^n \mathcal{N}(\tilde{x}; x^{(t)}, \sigma^2 I)$.</li>
</ul>
<p><strong>Why smooth the data?</strong></p>
<ol>
<li>
<p><strong>To define the score:</strong> The empirical data distribution is a set of Dirac deltas (spikes). The gradient (score) of a Dirac delta is undefined. Smoothing creates a differentiable surface, allowing a valid target score $\nabla_{\tilde{x}} \log q_{\sigma}(\tilde{x})$ to be computed.</p>
</li>
<li>
<p><strong>To model corruption:</strong> The Parzen estimator with Gaussian kernels mathematically models the process of taking a clean data point $x$ and adding Gaussian noise - the exact procedure used in Denoising Autoencoders.</p>
</li>
</ol>
<h3 id="3-why-avoiding-second-derivatives-matters">3. Why avoiding second derivatives matters</h3>
<p>Standard <strong>Implicit Score Matching (ISM)</strong> eliminates the need for the unknown data score, but introduces a new cost: it requires computing the trace of the Hessian (the sum of second partial derivatives) of the log-density.</p>
<ul>
<li><strong>The Cost:</strong> For high-dimensional data (like images) and deep networks, computing second derivatives of the log-density is computationally expensive.</li>
<li>This paper shows that <strong>Denoising Score Matching (DSM)</strong> allows you to bypass Hessian computation entirely. By using the Parzen target, the objective simplifies to matching a first-order vector, making it scalable to deep neural networks.</li>
</ul>
<h3 id="4-the-equivalence-chain---why-each-step">4. The equivalence chain - why each step?</h3>
<p>The chain $J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$ connects the concepts.</p>
<ul>
<li>
<p><strong>$J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}}$ (Implicit $\to$ Explicit):</strong>
<strong>Why:</strong> Integration by parts. This is Hyvärinen&rsquo;s original proof (2005): integration by parts moves the derivative from $\psi$ onto the data density $q$, producing a term involving $q$&rsquo;s gradient (the score). The boundary term vanishes because $q_{\sigma}$ decays to zero at infinity (Hyvärinen&rsquo;s 2005 regularity condition for Implicit Score Matching). The result allows replacing the unknown data score with a computable term involving only the model&rsquo;s score and its Jacobian.</p>
</li>
<li>
<p><strong>$J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}}$ (Explicit $\to$ Denoising):</strong>
<strong>Why:</strong> The explicit score of the Parzen density is known. When $x$ is perturbed to $\tilde{x}$ by Gaussian noise $\epsilon \sim \mathcal{N}(0, \sigma^2 I)$, the gradient of the log-density pointing back to the mean is exactly $\frac{1}{\sigma^2}(x - \tilde{x})$. Minimizing the error against the true score becomes minimizing the error against this restoration vector.</p>
</li>
<li>
<p><strong>$J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$ (Denoising $\to$ Autoencoder):</strong>
<strong>Why:</strong> Algebraic substitution. If you define the model&rsquo;s score $\psi(\tilde{x};\theta)$ to be proportional to the reconstruction error ($\propto x^r - \tilde{x}$), the score matching loss $J_{DSM}$ becomes proportional to the standard autoencoder squared loss $|x^r - x|^2$.</p>
</li>
</ul>
<h3 id="5-energy-based-models-ebms-connection">5. Energy-Based Models (EBMs) connection</h3>
<p><strong>What is an EBM?</strong></p>
<p>An EBM defines a probability distribution via an energy function $E(x;\theta)$, where $p(x;\theta) \propto e^{-E(x;\theta)}$.</p>
<p><strong>Why standard autoencoders lack probabilistic interpretation:</strong></p>
<p>A standard autoencoder acts as a deterministic map $x \to x^r$, providing a reconstruction error. It lacks a normalization constant or a defined density function to support sampling or probability queries.</p>
<p><strong>What does this enable?</strong></p>
<p>By proving the equivalence, the DAE is formally defined as an EBM. This enables:</p>
<ol>
<li><strong>Sampling:</strong> Using MCMC methods (like Hybrid Monte Carlo) to generate new data from the DAE.</li>
<li><strong>Ranking:</strong> Calculating the energy of inputs to determine which are more &ldquo;likely&rdquo; or &ldquo;normal&rdquo; (useful for anomaly detection).</li>
</ol>
<h3 id="6-the-specific-energy-function-form">6. The specific energy function form</h3>
<p>The function is:</p>
<p>$$E(x; W, b, c) = - \frac{1}{\sigma^2} \left( \langle c, x \rangle - \frac{1}{2}|x|^2 + \sum_{j=1}^{d_h} \text{softplus}(\langle W_j, x \rangle + b_j) \right)$$</p>
<p><strong>Why does it have that specific form?</strong></p>
<p>It was derived via integration to ensure its derivative matches the DAE architecture. The authors worked backward from the DAE&rsquo;s reconstruction function (sigmoid + linear) to find the scalar field that generates it.</p>
<p><strong>Where does the quadratic term come from?</strong></p>
<p>The score (negative energy gradient) needs to look like $\psi(x) \propto c - x + W^T\text{sigmoid}(Wx + b)$.</p>
<ul>
<li>The term $-x$ in the score arises because $\nabla_x(-\frac{1}{2}|x|^2) = -x$. Including $-\frac{1}{2}|x|^2$ inside the energy&rsquo;s numerator produces this linear term after differentiation.</li>
</ul>
<p><strong>How does differentiating it recover the DAE reconstruction?</strong></p>
<ul>
<li>$\nabla_x \sum_j \text{softplus}(\langle W_j, x \rangle + b_j) = W^T \sigma(Wx + b)$ (The encoder part).</li>
<li>$\nabla_x \langle c, x \rangle = c$ (The bias).</li>
<li>$\nabla_x (-\frac{1}{2}|x|^2) = -x$ (The input subtraction).</li>
<li>Result: $-\nabla_x E \propto c + W^T h - x = x^r - x$.</li>
</ul>
<h3 id="7-tied-weights-justification">7. &ldquo;Tied weights&rdquo; justification</h3>
<p><strong>What does it mean for weights to be &ldquo;tied&rdquo;?</strong></p>
<p>The decoder matrix is the transpose of the encoder matrix ($W^T$).</p>
<p><strong>Why is this theoretically justified?</strong></p>
<p>Because the reconstruction function is interpreted as the <strong>gradient</strong> of an energy function. A vector field can only be the gradient of a scalar field if its Jacobian is symmetric.</p>
<ul>
<li>In the DAE energy derivative, the encoder contributes $W^T \sigma(Wx + b)$. If the decoder used a separate matrix $U$, the resulting vector field would not be a valid gradient of any scalar energy function (unless $U = W^T$).</li>
<li>Therefore, for a DAE to correspond to a valid probabilistic Energy-Based Model, the weights <em>must</em> be tied.</li>
</ul>
<p><strong>The necessity of tied weights:</strong></p>
<p>Within this parametrization, tied weights are a mathematical necessity: a separate decoder matrix $U \neq W^T$ would make the reconstruction function an invalid gradient of any scalar energy, breaking the EBM correspondence.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>Since this is a theoretical paper, the &ldquo;reproducibility&rdquo; lies in the mathematical formulations derived.</p>
<h3 id="data">Data</h3>
<ul>
<li><strong>Input Data ($D_n$):</strong> The theory assumes a set of training examples $D_n = {x^{(1)}, &hellip;, x^{(n)}}$ drawn from an unknown true pdf $q(x)$.</li>
<li><strong>Parzen Density Estimate ($q_{\sigma}$):</strong> The theoretical targets are derived from a kernel-smoothed empirical distribution:
$$q_{\sigma}(\tilde{x}) = \frac{1}{n} \sum_{t=1}^n q_{\sigma}(\tilde{x}|x^{(t)})$$
where the kernel is an isotropic Gaussian of variance $\sigma^2$.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Denoising Score Matching (DSM) Objective</strong></p>
<p>The paper proposes this objective as a tractable alternative to standard score matching. It minimizes the distance between the model score and the gradient of the log-noise density:</p>
<p>$$J_{DSMq_{\sigma}}(\theta) = \mathbb{E}_{q_{\sigma}(x,\tilde{x})} \left[ \frac{1}{2} \left| \psi(\tilde{x};\theta) - \frac{\partial \log q_{\sigma}(\tilde{x}|x)}{\partial \tilde{x}} \right|^2 \right]$$</p>
<p>For Gaussian noise, the target score is simply $\frac{1}{\sigma^2}(x - \tilde{x})$.</p>
<p><strong>2. Equivalence Chain</strong></p>
<p>The central result connects four objectives:</p>
<p>$$J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$$</p>
<p>This implies optimizing the DAE reconstruction error is minimizing a score matching objective.</p>
<h3 id="models">Models</h3>
<p><strong>1. The Denoising Autoencoder (DAE)</strong></p>
<ul>
<li><strong>Corruption:</strong> Additive isotropic Gaussian noise $\tilde{x} = x + \epsilon, \epsilon \sim \mathcal{N}(0, \sigma^2 I)$.</li>
<li><strong>Encoder:</strong> $h = \text{sigmoid}(W\tilde{x} + b)$.</li>
<li><strong>Decoder:</strong> $x^r = W^T h + c$ (Tied weights $W$).</li>
<li><strong>Loss:</strong> Squared reconstruction error $|x^r - x|^2$. (The equivalence with DSM introduces a $\frac{1}{2\sigma^4}$ scaling factor.)</li>
</ul>
<p><strong>2. The Corresponding Energy Function</strong></p>
<p>To make the DAE equivalent to Score Matching, the underlying Energy-Based Model $p(x;\theta) \propto e^{-E(x;\theta)}$ must have the following energy function:</p>
<p>$$E(x; W, b, c) = - \frac{1}{\sigma^2} \left( \langle c, x \rangle - \frac{1}{2}|x|^2 + \sum_{j=1}^{d_h} \text{softplus}(\langle W_j, x \rangle + b_j) \right)$$</p>
<p>Note the scaling by $1/\sigma^2$ and the quadratic term $|x|^2$.</p>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metric:</strong> Theoretical Equivalence ($\sim$).</li>
<li><strong>Condition:</strong> The equivalence holds provided $\sigma &gt; 0$ and the density $q_{\sigma}$ is differentiable and vanishes at infinity (Hyvärinen&rsquo;s 2005 regularity condition for Implicit Score Matching).</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Vincent, P. (2011). A Connection Between Score Matching and Denoising Autoencoders. <em>Neural Computation</em>, 23(7), 1661-1674. <a href="https://doi.org/10.1162/NECO_a_00142">https://doi.org/10.1162/NECO_a_00142</a></p>
<p><strong>Publication</strong>: Neural Computation 2011</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{vincentConnectionScoreMatching2011,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{A {{Connection Between Score Matching}} and {{Denoising Autoencoders}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Vincent, Pascal}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2011</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jul,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Neural Computation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{23}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{7}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1661--1674}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1162/NECO_a_00142}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://www.iro.umontreal.ca/~vincentp/Publications/smdae_techreport.pdf">Official PDF</a></li>
</ul>
]]></content:encoded></item><item><title>Neural ODEs: Continuous-Depth Deep Learning Models</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/neural-odes/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/neural-odes/</guid><description>Introduces ODE-Nets, a continuous-depth neural network model parameterized by ODEs, enabling constant memory backpropagation and adaptive computation.</description><content:encoded><![CDATA[<blockquote>
<p><strong>Key Prerequisites</strong>: Before diving in, note that for the ODE solver to guarantee a unique solution, the neural network $f(h(t), t, \theta)$ parameterizing the dynamics must be <a href="https://en.wikipedia.org/wiki/Lipschitz_continuity">Lipschitz continuous</a>. This ensures the <a href="https://en.wikipedia.org/wiki/Picard%E2%80%93Lindel%C3%B6f_theorem">Picard-Lindelöf theorem</a> holds, preventing trajectories from crossing and guaranteeing a well-defined backward pass.</p></blockquote>
<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, with a strong secondary <strong>Theory</strong> component.</p>
<ul>
<li><strong>Method</strong>: It proposes a novel family of deep neural network models where the derivative of the hidden state is parameterized by a neural network. It provides specific algorithms (Algorithm 1) for training these models scalably.</li>
<li><strong>Theory</strong>: It derives the adjoint sensitivity method for backpropagating through black-box ODE solvers and proves the &ldquo;Instantaneous Change of Variables&rdquo; theorem (Theorem 1) for continuous normalizing flows.</li>
</ul>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The authors aim to address limitations in discrete deep learning architectures:</p>
<ul>
<li><strong>Discrete vs. Continuous</strong>: Existing models like Residual Networks build transformations by composing discrete steps, which can be seen as an Euler discretization of a continuous transformation. The authors investigate the limit as step sizes go to zero.</li>
<li><strong>Memory Efficiency</strong>: Backpropagating through deep discrete networks requires storing intermediate activations, leading to linear memory cost in terms of depth, which is a major bottleneck.</li>
<li><strong>Irregular Data</strong>: Recurrent Neural Networks (RNNs) struggle with data arriving at arbitrary times, typically requiring discretization into fixed bins.</li>
<li><strong>Normalizing Flow Costs</strong>: Standard normalizing flows have a bottleneck in computing the determinant of the Jacobian, which is computationally expensive ($O(D^3)$).</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core contribution is the <strong>Neural ODE</strong> formulation:
$$\frac{dh(t)}{dt} = f(h(t), t, \theta)$$
where the output is computed using a black-box differential equation solver.</p>
<p>Key technical innovations include:</p>
<ol>
<li><strong>Adjoint Sensitivity Method for Backprop</strong>: The authors treat the solver as a black box and compute gradients by solving a second, augmented ODE backwards in time. This allows for <strong>constant memory cost</strong> regardless of depth.</li>
<li><strong>Adaptive Computation</strong>: The model uses modern ODE solvers that adapt evaluation steps based on error tolerance, allowing the model to trade precision for speed explicitly.</li>
<li><strong>Continuous Normalizing Flows (CNF)</strong>: By moving to continuous time, the change of variables formula simplifies from a log-determinant (cubic cost) to a trace operation (linear cost), enabling scalable generative modeling.</li>
<li><strong>Latent ODEs</strong>: A generative time-series model that represents time-series as latent trajectories determined by a local initial state and global shared dynamics, handling irregular sampling naturally.</li>
</ol>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the method across three distinct domains:</p>
<ol>
<li><strong>Supervised Learning (MNIST)</strong>:
<ul>
<li>Compared <strong>ODE-Net</strong> against a standard <strong>ResNet</strong> and a Runge-Kutta network (<strong>RK-Net</strong>).</li>
<li>Measured test error, parameter count, and memory usage.</li>
<li>Analyzed the trade-off between numerical precision (tolerance) and speed (NFE).</li>
</ul>
</li>
<li><strong>Continuous Normalizing Flows (Generative)</strong>:
<ul>
<li>Compared CNF against standard Normalizing Flows (NF) on density matching and maximum likelihood estimation tasks using toy 2D datasets (Two Circles, Two Moons, and other target distributions).</li>
<li>Evaluated training loss (KL divergence) and maximum likelihood estimation.</li>
</ul>
</li>
<li><strong>Time-Series Modeling (Latent ODE)</strong>:
<ul>
<li>Tested on a dataset of bi-directional spirals with irregular timestamps and Gaussian noise.</li>
<li>Compared Latent ODEs against an RNN baseline on predictive RMSE. A second RNN variant with time-difference concatenation was also trained.</li>
</ul>
</li>
</ol>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Efficiency</strong>: ODE-Nets achieved roughly equivalent accuracy to ResNets on MNIST (0.42% vs 0.41% error) but with <strong>constant memory cost</strong> ($O(1)$) compared to ResNet&rsquo;s linear cost ($O(L)$).</li>
<li><strong>Adaptive Depth</strong>: The number of function evaluations (NFE) in ODE-Nets increases with training epoch, suggesting the model adapts its complexity as it learns. The backward pass NFE is roughly half the forward pass NFE, indicating that the adjoint method is also more computationally efficient than direct backpropagation through the integrator.</li>
<li><strong>Generative Performance</strong>: Continuous Normalizing Flows (CNF) achieved lower KL divergence loss than standard Normalizing Flows (NF), trained with only 10,000 iterations (Adam) compared to 500,000 iterations (RMSprop) for NF. Note that the two models used different optimizers, so the comparison is not fully controlled. CNF can also expand capacity by increasing width ($M$) without architectural constraints.</li>
<li><strong>Irregular Time-Series</strong>: Latent ODEs significantly outperformed RNNs across all observation counts on irregular spiral data. The advantage is most pronounced with sparse observations (0.1642 vs 0.3937 RMSE at 30 obs), and the model learns interpretable latent trajectories that switch direction smoothly.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>MNIST</strong>: Standard handwritten digit dataset used for supervised learning benchmarks.</li>
<li><strong>Toy 2D Densities</strong>: &ldquo;Two Circles&rdquo; and &ldquo;Two Moons&rdquo; distributions used for visualizing normalizing flows.</li>
<li><strong>Bi-directional Spirals</strong>: A generated dataset of 1,000 2D spirals (half clockwise, half counter-clockwise). Each spiral is sampled at 100 equally-spaced timesteps with added Gaussian noise. For training, each spiral is then subsampled without replacement to $n \in {30, 50, 100}$ irregularly-spaced observations, simulating realistic missing data.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Adjoint Sensitivity Method (Backpropagation)</strong></p>
<p>To optimize the parameters of the ODE-Net, the authors use the adjoint sensitivity method to compute gradients. Standard backpropagation would require storing the activations at every step of the ODE solver, incurring a high memory cost that scales linearly with the number of steps.</p>
<p>Instead, this method treats the ODE solver as a &ldquo;black box&rdquo; and computes gradients by solving a second, <strong>augmented ODE</strong> backwards in time from the final state $t_1$ to the initial state $t_0$.</p>
<p>The augmented state contains three components that are solved simultaneously:</p>
<ol>
<li><strong>The State</strong>: The original hidden state $z(t)$, which is reconstructed backwards.</li>
<li><strong>The Adjoint</strong>: The sensitivity of the loss with respect to the state, $a(t) = \partial L / \partial z(t)$.</li>
<li><strong>The Gradient</strong>: The accumulating gradients with respect to parameters, $\partial L / \partial \theta$.</li>
</ol>
<p>The dynamics of this augmented system are defined as:
$$\frac{d}{dt}\begin{bmatrix} z(t) \ a(t) \ \partial L/\partial \theta \end{bmatrix} = \begin{bmatrix} f(z(t), t, \theta) \ -a(t)^T \frac{\partial f}{\partial z} \ -a(t)^T \frac{\partial f}{\partial \theta} \end{bmatrix}$$</p>
<p>Using this approach, the vector-Jacobian products (e.g., $a(t)^T \frac{\partial f}{\partial z}$) are evaluated efficiently using automatic differentiation.</p>
<blockquote>
<p><strong>Why:</strong> Reconstructing $z(t)$ backwards avoids storing the forward pass, enabling <strong>constant memory cost</strong> ($O(1)$) regardless of depth.</p>
<p><strong>Origin:</strong> Adapted from Pontryagin&rsquo;s maximum principle (1962) for optimal control.</p></blockquote>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> torchdiffeq <span style="color:#f92672">import</span> odeint_adjoint
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">ODEFunc</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, dim):
</span></span><span style="display:flex;"><span>        super(ODEFunc, self)<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>net <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(dim, <span style="color:#ae81ff">50</span>),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(<span style="color:#ae81ff">50</span>, dim),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, t, y):
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Defines dy/dt = f(y, t)</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>net(y)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Usage with adjoint method for O(1) memory backprop</span>
</span></span><span style="display:flex;"><span>func <span style="color:#f92672">=</span> ODEFunc(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">2</span>)
</span></span><span style="display:flex;"><span>y0 <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>tensor([[<span style="color:#ae81ff">1.</span>, <span style="color:#ae81ff">0.</span>]]) <span style="color:#75715e"># Initial state</span>
</span></span><span style="display:flex;"><span>t <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>linspace(<span style="color:#ae81ff">0.</span>, <span style="color:#ae81ff">1.</span>, <span style="color:#ae81ff">10</span>) <span style="color:#75715e"># Time points to solve for</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># &#39;odeint_adjoint&#39; automatically handles the augmented state backward pass</span>
</span></span><span style="display:flex;"><span>out <span style="color:#f92672">=</span> odeint_adjoint(func, y0, t, method<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;dopri5&#39;</span>)
</span></span></code></pre></div><p><strong>2. Instantaneous Change of Variables (CNF)</strong></p>
<p>For generative modeling, the authors introduce <strong>Continuous Normalizing Flows (CNF)</strong>. In discrete normalizing flows, the probability density of a transformed variable is calculated using the change of variables theorem, which requires computing the log-determinant of the Jacobian: $\log p(z_1) = \log p(z_0) - \log |\det \frac{\partial z_1}{\partial z_0}|$. This operation is computationally expensive ($O(D^3)$) and often restricts model architectures to ensure the Jacobian is easy to compute (e.g., triangular).</p>
<p>Moving to continuous time simplifies this requirement. The paper proves that if the transformation is defined by an ODE, the change in log-probability follows a differential equation determined by the <strong>trace</strong> of the Jacobian:
$$\frac{\partial \log p(z(t))}{\partial t} = -\text{tr}\left( \frac{\partial f}{\partial z(t)} \right)$$</p>
<p>The total change in log-density is obtained by integrating this value over time.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_trace</span>(y, f):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes trace of Jacobian df/dy.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    For high dimensions, use Hutchinson&#39;s trace estimator (approximate).
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    tr <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> range(y<span style="color:#f92672">.</span>size(<span style="color:#ae81ff">1</span>)):
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Gradients of f&#39;s i-th component w.r.t y&#39;s i-th component</span>
</span></span><span style="display:flex;"><span>        tr <span style="color:#f92672">+=</span> torch<span style="color:#f92672">.</span>autograd<span style="color:#f92672">.</span>grad(f[:, i]<span style="color:#f92672">.</span>sum(), y, create_graph<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)[<span style="color:#ae81ff">0</span>][:, i]
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> tr
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># In the ODE function:</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># d(log_p)/dt = -trace(df/dy)</span>
</span></span></code></pre></div><blockquote>
<p><strong>Why:</strong> The trace operator has <strong>linear cost</strong> ($O(D)$), whereas the determinant has cubic cost ($O(D^3)$). This allows for unrestricted, &ldquo;wide&rdquo; architectures that are automatically bijective.</p>
<p><strong>Origin:</strong> This is the &ldquo;Instantaneous Change of Variables&rdquo; theorem (Theorem 1), derived in Appendix A of the paper.</p></blockquote>
<h3 id="models">Models</h3>
<p><strong>ODE-Net (MNIST Classification)</strong>:</p>
<ul>
<li><strong>Input</strong>: Downsamples input twice.</li>
<li><strong>Core</strong>: 6 standard residual blocks replaced by a single <strong>ODESolve</strong> module.</li>
<li><strong>Output</strong>: Global average pooling + Fully connected layer.</li>
<li><strong>Solver</strong>: Implicit Adams method.</li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">ODEBlock</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, odefunc):
</span></span><span style="display:flex;"><span>        super(ODEBlock, self)<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>odefunc <span style="color:#f92672">=</span> odefunc
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>integration_time <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>tensor([<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>])<span style="color:#f92672">.</span>float()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, x):
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>integration_time <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>integration_time<span style="color:#f92672">.</span>type_as(x)
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Returns [x(t0), x(t1)]; we only want final state x(t1)</span>
</span></span><span style="display:flex;"><span>        out <span style="color:#f92672">=</span> odeint_adjoint(self<span style="color:#f92672">.</span>odefunc, x, self<span style="color:#f92672">.</span>integration_time)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> out[<span style="color:#ae81ff">1</span>]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># ResNet-like architecture with ODE block</span>
</span></span><span style="display:flex;"><span>model <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>Conv2d(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">64</span>, <span style="color:#ae81ff">3</span>, <span style="color:#ae81ff">1</span>),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>ReLU(inplace<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>),
</span></span><span style="display:flex;"><span>    ODEBlock(ODEFunc(<span style="color:#ae81ff">64</span>)), <span style="color:#75715e"># Continuous-depth layer replacement</span>
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>BatchNorm2d(<span style="color:#ae81ff">64</span>),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>AdaptiveAvgPool2d((<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>)),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>Flatten(),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>Linear(<span style="color:#ae81ff">64</span>, <span style="color:#ae81ff">10</span>)
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><p><strong>Latent ODE (Time-Series)</strong>:</p>
<ul>
<li><strong>Encoder</strong>: RNN with 25 hidden units processing data backwards to produce $q(z_0|x)$. It runs backwards so the final RNN state summarizes the entire sequence at $t_0$, parameterizing the initial latent state $z_0$ for the forward-running ODE.</li>
<li><strong>Latent Space</strong>: 4-dimensional latent state $z_0$.</li>
<li><strong>Dynamics ($f$)</strong>: Neural network with one hidden layer of 20 units.</li>
<li><strong>Decoder</strong>: Neural network with one hidden layer of 20 units computing $p(x_{t_i}|z_{t_i})$.</li>
<li><strong>Likelihood</strong>: Gaussian log-likelihood for the spiral reconstruction task. The paper also describes an optional Poisson process likelihood $\lambda(z(t))$ for event-time data (e.g., medical records), but this is not used in the spiral experiment.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Experiment</th>
          <th>Metric</th>
          <th>Baseline (ResNet/RNN)</th>
          <th>ODE Model</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MNIST</td>
          <td>Test Error</td>
          <td>0.41%</td>
          <td>0.42%</td>
      </tr>
      <tr>
          <td>MNIST</td>
          <td>Parameters</td>
          <td>0.60 M</td>
          <td>0.22 M</td>
      </tr>
      <tr>
          <td>MNIST</td>
          <td>Memory</td>
          <td>$O(L)$</td>
          <td>$O(1)$</td>
      </tr>
      <tr>
          <td>Spirals (30 obs)</td>
          <td>RMSE</td>
          <td>0.3937</td>
          <td><strong>0.1642</strong></td>
      </tr>
      <tr>
          <td>Spirals (50 obs)</td>
          <td>RMSE</td>
          <td>0.3202</td>
          <td><strong>0.1502</strong></td>
      </tr>
      <tr>
          <td>Spirals (100 obs)</td>
          <td>RMSE</td>
          <td>0.1813</td>
          <td><strong>0.1346</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Implementation</strong>: Hidden state dynamics evaluated on GPU using <strong>TensorFlow</strong>.</li>
<li><strong>Solvers</strong>: Fortran ODE solvers (LSODE, VODE) from <code>scipy.integrate</code> were used for the actual integration.</li>
<li><strong>Note</strong>: While the original paper used TensorFlow/Scipy, the authors later released <code>torchdiffeq</code> (PyTorch), which has become the standard implementation for this architecture. The code samples above reflect this modern standard.</li>
<li><strong>Interface</strong>: Python&rsquo;s <code>autograd</code> framework bridged the TensorFlow dynamics and Scipy solvers.</li>
</ul>
<h3 id="limitations">Limitations</h3>
<p>The paper identifies several practical limitations of Neural ODEs:</p>
<ul>
<li><strong>Minibatching</strong>: Batching requires concatenating states of each batch element into a combined ODE of dimension $D \times K$. Controlling error on all batch elements together can require more evaluations than solving each system individually, though in practice this overhead was not substantial.</li>
<li><strong>Tolerance tuning</strong>: Users must choose error tolerances for both the forward and reverse passes. The paper used 1.5e-8 for sequence modeling, 1e-3 for classification, and 1e-5 for density estimation.</li>
<li><strong>Backward trajectory reconstruction</strong>: Running the dynamics backwards to reconstruct the forward state trajectory can introduce extra numerical error if the reconstructed trajectory diverges from the original. Checkpointing (storing intermediate states) can address this, though the authors did not find it necessary in practice.</li>
<li><strong>Uniqueness requirements</strong>: The neural network $f$ must be Lipschitz continuous (e.g., using tanh or ReLU activations with finite weights) to guarantee a unique solution via Picard&rsquo;s existence theorem.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/rtqichen/torchdiffeq">torchdiffeq</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official PyTorch implementation with GPU-based ODE solvers</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chen, R. T. Q., Rubanova, Y., Bettencourt, J., &amp; Duvenaud, D. (2018). Neural ordinary differential equations. <em>Proceedings of the 32nd International Conference on Neural Information Processing Systems</em>, 6572-6583.</p>
<p><strong>Publication</strong>: NeurIPS 2018</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{chen2018neural,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Neural ordinary differential equations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Chen, Ricky T. Q. and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 32nd International Conference on Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{6572--6583}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/rtqichen/torchdiffeq">Official PyTorch Implementation</a></li>
</ul>
]]></content:encoded></item><item><title>Flow Matching for Generative Modeling: Scalable CNFs</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/</guid><description>A simulation-free framework for training Continuous Normalizing Flows using Conditional Flow Matching and Optimal Transport paths.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, as it introduces &ldquo;Flow Matching&rdquo; (FM), a novel simulation-free paradigm for training Continuous Normalizing Flows (CNFs) at scale. It is supported by a strong <strong>Theory</strong> basis, providing formal theorems that allow the intractable marginal vector field regression to be solved via a tractable conditional objective. It also touches on <strong>Systematization</strong> by showing that existing diffusion paths are specific instances of the proposed Gaussian probability path framework.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The paper aims to overcome the scaling limitations of Continuous Normalizing Flows (CNFs).</p>
<ul>
<li><strong>Problem</strong>: Standard Maximum Likelihood training for CNFs requires expensive numerical ODE simulations during training, which scales poorly. Existing simulation-free methods often involve intractable integrals or result in biased gradients.</li>
<li><strong>Gap</strong>: Diffusion models scale well, yet they are restricted to specific, curved probability paths (e.g., VP, VE) that can result in slow sampling and long training times.</li>
<li><strong>Goal</strong>: To develop an efficient, simulation-free training method for CNFs that supports arbitrary probability paths, specifically allowing for straighter, more efficient trajectories like those from Optimal Transport.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is <strong>Flow Matching (FM)</strong> and specifically the <strong>Conditional Flow Matching (CFM)</strong> objective.</p>
<ul>
<li><strong>Direct Vector Field Regression</strong>: The model regresses a target vector field $u_t$ that generates a desired probability path $p_t$.</li>
<li><strong>Conditional Flow Matching (CFM)</strong>: The authors prove that regressing the vector field of <em>conditional</em> paths (e.g., $p_t(x|x_1)$ given a single data point) yields the same gradients as regressing the intractable marginal vector field. This bypasses the need to know the marginal score or vector field.</li>
<li><strong>Optimal Transport Paths</strong>: The framework enables the use of <strong>Optimal Transport (OT)</strong> displacement interpolation for probability paths. OT paths are straight lines with constant speed, leading to faster training and easier sampling.</li>
</ul>
<p><strong>Concurrent work note</strong>: Rectified Flow (Liu et al., 2023) and Stochastic Interpolants (Albergo &amp; Vanden-Eijnden, 2023) were published concurrently at ICLR 2023 with structurally similar contributions under different names. All three independently propose simulation-free training of continuous flows via direct vector field regression; the differences lie in the specific interpolation schemes, theoretical framing, and experimental focus.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<ul>
<li><strong>Domains</strong>: 2D Checkerboard data, CIFAR-10, and ImageNet at resolutions $32 \times 32$, $64 \times 64$, and $128 \times 128$.</li>
<li><strong>Task</strong>: Unconditional generative modeling (density estimation and sample quality) and conditional super-resolution ($64 \times 64 \to 256 \times 256$).</li>
<li><strong>Baselines</strong>: Compared against Diffusion-based methods on the same architecture (U-Net): DDPM, Score Matching (SM), and ScoreFlow.</li>
<li><strong>Ablations</strong>: Specifically compared <strong>FM with Diffusion paths</strong> vs. <strong>FM with Optimal Transport (OT) paths</strong> to isolate the benefit of the training objective vs. the path choice.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Outperforms diffusion baselines</strong>: FM-OT consistently outperforms all diffusion-based methods (DDPM, Score Matching, ScoreFlow) in both Likelihood (NLL) and Sample Quality (FID) across CIFAR-10 and ImageNet, using the same U-Net architecture and training budget. Selected rows from Table 1 (NLL in bits per dimension, BPD; lower is better for all three metrics; &ldquo;FM w/ OT&rdquo; and &ldquo;FM w/ Diffusion&rdquo; refer to FM trained with OT paths and Diffusion paths respectively):</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Method</th>
          <th>NLL (BPD) ↓</th>
          <th>FID ↓</th>
          <th>NFE ↓</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CIFAR-10</td>
          <td>DDPM</td>
          <td>3.12</td>
          <td>7.48</td>
          <td>274</td>
      </tr>
      <tr>
          <td>CIFAR-10</td>
          <td>FM w/ OT</td>
          <td><strong>2.99</strong></td>
          <td><strong>6.35</strong></td>
          <td><strong>142</strong></td>
      </tr>
      <tr>
          <td>ImageNet 64×64</td>
          <td>ScoreFlow</td>
          <td>3.36</td>
          <td>24.95</td>
          <td>601</td>
      </tr>
      <tr>
          <td>ImageNet 64×64</td>
          <td>FM w/ OT</td>
          <td><strong>3.31</strong></td>
          <td><strong>14.45</strong></td>
          <td><strong>138</strong></td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Training stability</strong>: FM with diffusion paths (FM w/ Diffusion) is itself a more stable alternative to diffusion training than DDPM and Score Matching, as shown by training curves in the paper (Figure 5), even before switching to OT paths. The OT path then provides further gains.</li>
<li><strong>Sampling speed</strong>: The straight trajectories of OT paths allow accurate sampling with significantly fewer function evaluations (NFE) compared to diffusion paths.</li>
<li><strong>Generality</strong>: Diffusion is a specific instance of Gaussian probability paths within FM. OT paths are a better-optimized alternative available within the same framework.</li>
<li><strong>Downstream adoption</strong>: Flow matching has been adopted beyond image generation. <a href="/notes/biology/computational-biology/dynamicflow/">DynamicFlow</a> uses it as the generative backbone for simultaneously generating ligand molecules and transforming protein pockets, extending flow matching to structure-based drug design.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Datasets</strong>: CIFAR-10, ImageNet ($32 \times 32$, $64 \times 64$, $128 \times 128$).</li>
<li><strong>Preprocessing</strong>:
<ul>
<li>Images are center-cropped and resized.</li>
<li>For $32 \times 32$ and $64 \times 64$, the preprocessing follows Chrabaszcz et al. (2017).</li>
<li>Data is transformed via $\varphi(y) = 2^7(y+1)$ mapping $[-1, 1]$ pixel values to $[0, 256]$ for BPD computation.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Conditional Flow Matching (CFM) Objective</strong></p>
<p>The practical training objective used is the CFM loss, which bypasses intractable marginalization:</p>
<p>$$\mathcal{L}_{CFM}(\theta) = \mathbb{E}_{t, q(x_1), p(x_0)} | v_t(\psi_t(x_0)) - u_t(\psi_t(x_0) | x_1) |^2$$</p>
<p>Where $t \sim \mathcal{U}[0,1]$, $x_1 \sim q(x_1)$ (data), and $x_0 \sim p(x_0)$ (noise).</p>
<p><strong>2. Optimal Transport (OT) Probability Path</strong></p>
<p>The authors recommend the OT path for efficiency.</p>
<ul>
<li><strong>Mean/Std Schedule</strong>: $\mu_t(x) = t x_1$ and $\sigma_t(x) = 1 - (1 - \sigma_{min})t$.</li>
<li><strong>Conditional Flow Map</strong>: $\psi_t(x) = (1 - (1 - \sigma_{min})t)x + t x_1$.</li>
<li><strong>Target Vector Field</strong>: The closed-form regression target for OT is:
$$u_t(x|x_1) = \frac{x_1 - (1 - \sigma_{min})x}{1 - (1 - \sigma_{min})t}$$</li>
</ul>
<p><strong>3. Sampling</strong></p>
<p>Sampling is performed by solving the ODE $\frac{d}{dt}\phi_t(x) = v_t(\phi_t(x))$ from $t=0$ to $t=1$ using the learned vector field $v_t$.</p>
<ul>
<li><strong>Solver</strong>: <code>dopri5</code> (adaptive) is used for robust evaluation. Fixed-step solvers (Euler, Midpoint) are used for low-NFE efficiency tests.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: U-Net architecture from Dhariwal &amp; Nichol (2021) is used for all image experiments.</li>
<li><strong>Toy Data</strong>: 5-layer MLP with 512 neurons.</li>
<li><strong>Hyperparameters</strong>:
<ul>
<li>Optimizer: Adam ($\beta_1=0.9, \beta_2=0.999$, weight decay=0.0).</li>
<li>Learning Rate: Polynomial decay or constant (see Table 3 in paper).</li>
<li>$\sigma_{min}$: Set to a small value (e.g., $1e-5$).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>:
<ul>
<li><strong>NLL (BPD)</strong>: Computed using the continuous change of variables formula, estimated via the Hutchinson trace estimator to bypass $O(d^3)$ divergence computation.</li>
<li><strong>FID</strong>: Frechet Inception Distance for sample quality.</li>
<li><strong>NFE</strong>: Number of Function Evaluations required by the solver.</li>
</ul>
</li>
<li><strong>Likelihood Computation</strong>: Requires solving an augmented ODE to track the log-density change:
$$\frac{d}{dt} \begin{bmatrix} \phi_t(x) \ f(t) \end{bmatrix} = \begin{bmatrix} v_t(\phi_t(x)) \ -\text{div}(v_t(\phi_t(x))) \end{bmatrix}$$</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>CIFAR-10</strong>: 2 GPUs.</li>
<li><strong>ImageNet-32</strong>: 4 GPUs.</li>
<li><strong>ImageNet-64</strong>: 16 GPUs.</li>
<li><strong>ImageNet-128</strong>: 32 GPUs.</li>
<li><strong>Precision</strong>: Full 32-bit for CIFAR/IM-32; 16-bit mixed precision for IM-64/128.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/facebookresearch/flow_matching">flow_matching (PyTorch library)</a></td>
          <td>Code</td>
          <td>CC BY-NC 4.0</td>
          <td>Later official library from Meta; not the original experiment code</td>
      </tr>
  </tbody>
</table>
<p>The paper does not release the original training code or model weights used in the experiments. The <code>facebookresearch/flow_matching</code> library was released later as a general-purpose PyTorch implementation of flow matching algorithms. Standard benchmark datasets (CIFAR-10, ImageNet) are publicly available.</p>
<hr>
<h2 id="theoretical-notes-why-cfm-works">Theoretical Notes: Why CFM Works</h2>
<p>The paper relies on three key theorems to make training tractable.</p>
<p><strong>Theorem 1 (Marginal Generation)</strong>:</p>
<p>Marginalizing conditional vector fields $u_t(x|x_1)$ yields the correct marginal vector field $u_t(x)$ that generates the marginal probability path $p_t(x)$.</p>
<p>$$u_t(x) = \int u_t(x|x_1) \frac{p_t(x|x_1)q(x_1)}{p_t(x)} dx_1$$</p>
<blockquote>
<p><strong>Understanding the Proof:</strong></p>
<p>To understand why this theorem holds, we have to look at the <strong>Continuity Equation</strong>, which is the fundamental partial differential equation (PDE) that links a probability density path $p_t$ to a vector field $u_t$.</p>
<p>A vector field $u_t$ is said to &ldquo;generate&rdquo; a probability path $p_t$ if and only if they satisfy the continuity equation:</p>
<p>$$\frac{\partial p_t(x)}{\partial t} + \nabla \cdot (p_t(x) u_t(x)) = 0$$</p>
<p>The proof of Theorem 1 relies on substituting the definitions of the marginal path and vector field into this equation to see if they balance out.</p>
<p><strong>Step-by-Step Proof:</strong></p>
<ol>
<li>
<p><strong>Start with the time derivative of the marginal path</strong>: We begin by differentiating the marginal probability path $p_t(x)$ with respect to time. By definition, the marginal path is the integral of the conditional paths over the data distribution:
$$\frac{\partial p_t(x)}{\partial t} = \frac{\partial}{\partial t} \int p_t(x|x_1) q(x_1) dx_1$$</p>
</li>
<li>
<p><strong>Swap derivative and integral</strong>: Assuming standard regularity conditions (Leibniz Rule), we can move the time derivative inside the integral:
$$\frac{\partial p_t(x)}{\partial t} = \int \frac{\partial p_t(x|x_1)}{\partial t} q(x_1) dx_1$$</p>
</li>
<li>
<p><strong>Apply the Conditional Continuity Equation</strong>: This is the critical step. We know that the conditional vector field $u_t(x|x_1)$ generates the conditional path $p_t(x|x_1)$. Therefore, for every single sample $x_1$, the pair satisfies the continuity equation:
$$\frac{\partial p_t(x|x_1)}{\partial t} = -\nabla \cdot (p_t(x|x_1) u_t(x|x_1))$$</p>
<p>Substituting this into our integral gives:
$$\frac{\partial p_t(x)}{\partial t} = -\int \nabla \cdot (p_t(x|x_1) u_t(x|x_1)) q(x_1) dx_1$$</p>
</li>
<li>
<p><strong>Pull the Divergence out</strong>: Since the divergence operator ($\nabla \cdot$) acts on $x$ and the integral is over $x_1$, we can pull the divergence operator outside the integral (by linearity):
$$\frac{\partial p_t(x)}{\partial t} = -\nabla \cdot \left( \int p_t(x|x_1) u_t(x|x_1) q(x_1) dx_1 \right)$$</p>
</li>
<li>
<p><strong>Match with the Marginal Vector Field Definition</strong>: Now, look at the term inside the parentheses. The paper defines the marginal vector field $u_t(x)$ specifically to make this term simpler. Rearranging the definition of $u_t(x)$ provided in the theorem:
$$p_t(x) u_t(x) = \int p_t(x|x_1) u_t(x|x_1) q(x_1) dx_1$$</p>
<p>Substitute $p_t(x) u_t(x)$ back into our equation from Step 4:
$$\frac{\partial p_t(x)}{\partial t} = -\nabla \cdot (p_t(x) u_t(x))$$</p>
</li>
</ol>
<p><strong>Conclusion</strong>: We have just shown that $\frac{\partial p_t(x)}{\partial t} + \nabla \cdot (p_t(x) u_t(x)) = 0$. This is exactly the continuity equation. Because the marginal path and the aggregated marginal vector field satisfy this equation, the vector field is proven to generate the path.</p></blockquote>
<p><strong>Theorem 2 (Gradient Equivalence)</strong>:</p>
<p>The intractable Flow Matching objective $\mathcal{L}_{FM}$ (which requires $u_t(x)$) has the <strong>same gradients</strong> as the tractable Conditional Flow Matching objective $\mathcal{L}_{CFM}$.</p>
<p>$$\nabla_\theta \mathcal{L}_{FM}(\theta) = \nabla_\theta \mathcal{L}_{CFM}(\theta)$$</p>
<p>This allows the model to learn the marginal vector field by only seeing conditional sample paths.</p>
<blockquote>
<p><strong>Understanding the Proof:</strong></p>
<p>The reason Theorem 2 holds is that the &ldquo;Conditional Flow Matching&rdquo; (CFM) objective is essentially an unbiased estimator of the &ldquo;Flow Matching&rdquo; (FM) objective (up to a constant). When we average over all the conditional data points $x_1$, the &ldquo;cross-term&rdquo; in the loss function aligns perfectly with the marginal vector field.</p>
<p><strong>1. Expand the Loss Functions</strong></p>
<p>First, let&rsquo;s look at the squared error in both objectives. Recall that $v_t$ is our neural network (parameterized by $\theta$), $u_t$ is the intractable marginal target, and $u_t(x|x_1)$ is the tractable conditional target.</p>
<p>Expanding the squared norms:</p>
<ul>
<li>
<p><strong>FM Objective</strong>:
$$\mathcal{L}_{FM}(\theta) = \mathbb{E}_{t, p_t(x)} \left[ |v_t(x)|^2 - 2v_t(x) \cdot u_t(x) + |u_t(x)|^2 \right]$$</p>
</li>
<li>
<p><strong>CFM Objective</strong>:
$$\mathcal{L}_{CFM}(\theta) = \mathbb{E}_{t, q(x_1), p_t(x|x_1)} \left[ |v_t(x)|^2 - 2v_t(x) \cdot u_t(x|x_1) + |u_t(x|x_1)|^2 \right]$$</p>
</li>
</ul>
<p><strong>Key Insight</strong>: When we take the gradient $\nabla_\theta$, the last term in both equations disappears because the targets ($u_t$) are independent of the network weights $\theta$. We only need to show that the expectations of the first two terms match.</p>
<p><strong>2. Matching the First Term ($|v_t(x)|^2$)</strong></p>
<p>This part is straightforward. The expectation of $|v_t(x)|^2$ is the same in both cases because of how the marginal density $p_t(x)$ is defined.</p>
<ul>
<li><strong>FM</strong>: averages over $p_t(x)$.</li>
<li><strong>CFM</strong>: averages over $p_t(x|x_1)q(x_1)$.</li>
</ul>
<p>Since $p_t(x) = \int p_t(x|x_1) q(x_1) dx_1$ (by definition), averaging over the joint distribution is mathematically identical to averaging over the marginal $p_t(x)$.</p>
<p><strong>3. Matching the Cross Term (The &ldquo;Trick&rdquo;)</strong></p>
<p>This is the critical part of the proof. We need to show that the interaction between the network and the marginal field equals the interaction between the network and the conditional field.</p>
<p><strong>The Goal</strong>: Show $\mathbb{E}_{t, p_t(x)} [v_t(x) \cdot u_t(x)] = \mathbb{E}_{t, q(x_1), p_t(x|x_1)} [v_t(x) \cdot u_t(x|x_1)]$.</p>
<p><strong>The Proof</strong>:</p>
<ol>
<li>
<p>Start with the <strong>FM cross-term</strong> (marginal):
$$\mathbb{E}_{t, p_t(x)} [v_t(x) \cdot u_t(x)]$$</p>
</li>
<li>
<p>Substitute the definition of the marginal vector field $u_t(x)$ derived in <strong>Theorem 1</strong>:
$$u_t(x) = \int u_t(x|x_1) \frac{p_t(x|x_1) q(x_1)}{p_t(x)} dx_1$$</p>
</li>
<li>
<p>Plug this into the integral. The $p_t(x)$ terms cancel:
$$\mathbb{E}_{t, p_t(x)} [v_t(x) \cdot u_t(x)] = \int_t \int_x p_t(x) v_t(x) \cdot \left[ \int_{x_1} u_t(x|x_1) \frac{p_t(x|x_1) q(x_1)}{p_t(x)} dx_1 \right] dx$$</p>
</li>
<li>
<p>This simplifies to:
$$= \int_t \int_x \int_{x_1} v_t(x) \cdot u_t(x|x_1) p_t(x|x_1) q(x_1) dx_1 dx dt$$</p>
</li>
<li>
<p>This is exactly the definition of the expectation in the <strong>CFM objective</strong>:
$$= \mathbb{E}_{t, q(x_1), p_t(x|x_1)} [v_t(x) \cdot u_t(x|x_1)]$$</p>
</li>
</ol>
<p><strong>Conclusion</strong>: Because the expectations of all terms involving $\theta$ are identical, the gradients must be identical.</p>
<p>Intuitively, this works like <strong>Denoising Score Matching</strong> or <strong>Stochastic Gradient Descent</strong>: even though each individual conditional vector field $u_t(x|x_1)$ points to a specific data point $x_1$ (which may differ from the true marginal direction), the <em>average</em> of all these pulls equals the true marginal vector field $u_t(x)$.</p></blockquote>
<p><strong>Theorem 3 (Gaussian Conditional VFs)</strong>:</p>
<p>For any Gaussian probability path $p_t(x|x_1) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I)$, the unique vector field generating it is available in closed form:</p>
<p>$$u_t(x|x_1) = \frac{\sigma&rsquo;_t(x_1)}{\sigma_t(x_1)}(x - \mu_t(x_1)) + \mu&rsquo;_t(x_1)$$</p>
<p>This theorem allows explicitly defining targets for both Diffusion (curved) and Optimal Transport (straight) paths.</p>
<blockquote>
<p><strong>Understanding the Proof:</strong></p>
<p>The derivation of Theorem 3 comes from the direct relationship between a flow map $\psi_t$ and its generating vector field. Because we chose a specific, simple path (Gaussian), we can invert the flow map to find the vector field in closed form.</p>
<p><strong>1. Define the Flow Map $\psi_t$</strong></p>
<p>We start by defining the conditional probability path as a Gaussian:</p>
<p>$$p_t(x|x_1) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I)$$</p>
<p>The simplest way to &ldquo;push&rdquo; a standard normal distribution (noise) $p_0 = \mathcal{N}(0, I)$ to this Gaussian is using an affine transformation (scaling and shifting). We define the flow map $\psi_t$ as:</p>
<p>$$\psi_t(x_0) = \sigma_t(x_1) x_0 + \mu_t(x_1)$$</p>
<p>This map takes a noise sample $x_0$ and transforms it into a sample $x$ at time $t$.</p>
<p><strong>2. The Definition of a Generating Vector Field</strong></p>
<p>By definition, a vector field $u_t$ generates a flow $\psi_t$ if the vector field describes the instantaneous velocity of the flow at any point. Mathematically:</p>
<p>$$u_t(\psi_t(x_0)) = \frac{d}{dt}\psi_t(x_0)$$</p>
<p>Let $x = \psi_t(x_0)$ be the position of the particle at time $t$. We want to find $u_t(x)$.</p>
<p><strong>3. Invert the Flow Map</strong></p>
<p>To find $u_t(x)$, we must express the equation in terms of $x$ rather than $x_0$. Since our flow map is a simple affine transformation (multiply and add), it is easily invertible (assuming $\sigma_t(x_1) \neq 0$):</p>
<p>$$x_0 = \frac{x - \mu_t(x_1)}{\sigma_t(x_1)}$$</p>
<p>We will call this inverse map $\psi_t^{-1}(x)$.</p>
<p><strong>4. Differentiate the Flow Map</strong></p>
<p>Now we calculate the left side of our definition equation (velocity): $\frac{d}{dt}\psi_t(x_0)$.</p>
<p>Taking the time derivative of $\psi_t(x_0) = \sigma_t(x_1) x_0 + \mu_t(x_1)$:</p>
<p>$$\frac{d}{dt}\psi_t(x_0) = \sigma&rsquo;_t(x_1) x_0 + \mu&rsquo;_t(x_1)$$</p>
<p>(Note: $\sigma&rsquo;_t$ and $\mu&rsquo;_t$ denote time derivatives).</p>
<p><strong>5. Substitute and Solve</strong></p>
<p>Now we combine everything. We know $u_t(\psi_t(x_0)) = \frac{d}{dt}\psi_t(x_0)$.</p>
<p>Substitute the result from Step 4 into this equation:</p>
<p>$$u_t(\psi_t(x_0)) = \sigma&rsquo;_t(x_1) x_0 + \mu&rsquo;_t(x_1)$$</p>
<p>This expresses the vector field in terms of the initial point $x_0$. We must express it in terms of the current point $x$. So, we plug in the inverse formula for $x_0$ derived in Step 3:</p>
<p>$$u_t(x|x_1) = \sigma&rsquo;_t(x_1) \frac{x - \mu_t(x_1)}{\sigma_t(x_1)} + \mu&rsquo;_t(x_1)$$</p>
<p>Rearranging terms gives the final closed form:</p>
<p>$$u_t(x|x_1) = \frac{\sigma&rsquo;_t(x_1)}{\sigma_t(x_1)}(x - \mu_t(x_1)) + \mu&rsquo;_t(x_1)$$</p>
<p><strong>Why is this useful?</strong></p>
<p>This formula means that as long as you can define a mean schedule $\mu_t(x_1)$ and a standard deviation schedule $\sigma_t(x_1)$ (which is easy to do for both Diffusion and Optimal Transport), you immediately get the exact vector field target $u_t(x|x_1)$ needed to train your neural network, bypassing complex ODE solving or score matching approximations.</p></blockquote>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Lipman, Y., Chen, R. T. Q., Ben-Hamu, H., Nickel, M., &amp; Le, M. (2023). Flow Matching for Generative Modeling. <em>International Conference on Learning Representations (ICLR)</em>.</p>
<p><strong>Publication</strong>: ICLR 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{lipmanFlowMatchingGenerative2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Flow Matching for Generative Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Lipman, Yaron and Chen, Ricky T. Q. and Ben-Hamu, Heli and Nickel, Maximilian and Le, Matt}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2210.02747">ArXiv</a></li>
</ul>
]]></content:encoded></item><item><title>Building Normalizing Flows with Stochastic Interpolants</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/stochastic-interpolants/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/stochastic-interpolants/</guid><description>A continuous-time normalizing flow using stochastic interpolants and quadratic loss to bypass costly ODE backpropagation.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, with significant <strong>Theory</strong> contributions.</p>
<p>The authors propose a specific algorithm (&ldquo;InterFlow&rdquo;) for constructing generative models based on continuous-time normalizing flows. The work is characterized by the derivation of a new training objective (a simple quadratic loss) that bypasses the computational bottlenecks of previous methods. It includes prominent baseline comparisons against continuous flow methods (FFJORD, OT-Flow) and diffusion models. The theoretical component establishes the validity of the interpolant density satisfying the continuity equation (a conservation law governing how probability mass flows) and bounds the Wasserstein-2 distance (a measure of transport cost between distributions, penalizing squared displacement) of the transport.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The primary motivation is to overcome the computational inefficiency of training Continuous Normalizing Flows (CNFs) using Maximum Likelihood Estimation (MLE). Standard CNF training requires backpropagating through numerical ODE solvers, which is costly and limits scalability.</p>
<p>Additionally, while score-based diffusion models (SDEs) have achieved high sample quality, they theoretically require infinite time integration and rely on specific noise schedules. The authors aim to establish a method that works strictly with Probability Flow ODEs on finite time intervals, retaining the flexibility to connect arbitrary densities without the complexity of SDEs or the cost of standard ODE adjoint methods.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>Stochastic Interpolant</strong> framework:</p>
<ul>
<li><strong>Explicit Interpolant Construction</strong>: The method defines a time-dependent interpolant $x_t = I_t(x_0, x_1)$ (e.g., trigonometric interpolation) that connects samples from the base density $\rho_0$ and target $\rho_1$.</li>
<li><strong>Simulation-Free Training</strong>: The velocity field $v_t(x)$ of the probability flow is learned by minimizing a simple quadratic objective: $G(\hat{v}) = \mathbb{E}[|\hat{v}_t(x_t)|^2 - 2\partial_t x_t \cdot \hat{v}_t(x_t)]$. Because $\partial_t I_t$ is known analytically from the interpolant definition, the expectation can be estimated by sampling $(x_0, x_1, t)$ directly. This avoids ODE integration during training (ODE integration is still required at inference).</li>
<li><strong>Decoupling Path and Optimization</strong>: The choice of path (interpolant) is separated from the optimization of the velocity field. MLE methods couple the path and objective.</li>
<li><strong>Connection to Score-Based Models</strong>: The authors show that for Gaussian base densities and trigonometric interpolants, the learned velocity field is explicitly related to the score function $\nabla \log \rho_t$, providing a theoretical bridge between CNFs and diffusion models.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors performed validation across synthetic, tabular, and image domains:</p>
<ul>
<li><strong>2D Density Estimation</strong>: Benchmarked on &ldquo;Checkerboard&rdquo;, &ldquo;8 Gaussians&rdquo;, and anisotropic curved densities to visualize mode coverage and transport smoothness.</li>
<li><strong>High-Dimensional Tabular Data</strong>: Evaluated on standard benchmarks (POWER, GAS, HEPMASS, MINIBOONE, BSDS300) comparing Negative Log Likelihood (NLL) against FFJORD, OT-Flow, and others.</li>
<li><strong>Image Generation</strong>: Trained models on CIFAR-10 ($32 \times 32$), ImageNet ($32 \times 32$), and Oxford Flowers ($128 \times 128$) to test scalability.</li>
<li><strong>Ablations</strong>: Investigated optimizing the interpolant path itself (e.g., learning Fourier coefficients for the path) to approach optimal transport and minimize path length.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Performance</strong>: The method matches or supersedes conventional ODE flows (like FFJORD) in terms of NLL while being significantly cheaper to train.</li>
<li><strong>Efficiency</strong>: The training cost per epoch is constant (simulation-free), whereas MLE-based ODE methods see growing costs as the dynamics become more complex.</li>
<li><strong>Scalability</strong>: The method successfully scales to $128 \times 128$ resolution on a single GPU, a resolution that prior ab-initio ODE flows had not demonstrated.</li>
<li><strong>Flexibility</strong>: The framework can connect <em>any</em> two arbitrary densities (e.g., connecting two different complex 2D distributions) without needing one to be Gaussian.</li>
<li><strong>Optimal Transport</strong>: For a fixed interpolant, minimizing $G(\hat{v})$ over the velocity field recovers the velocity for that specific path. Additionally optimizing over the interpolant family yields a solution to the Benamou-Brenier optimal transport problem.</li>
<li><strong>Limitations</strong>: The authors acknowledge that image FID scores trail dedicated diffusion models, noting that InterFlow was not optimized with standard training tricks such as exponential moving averages, truncation, or learning rate warm-ups. The framework&rsquo;s sample quality could likely improve with these additions.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Tabular Datasets</strong>: POWER (6D), GAS (8D), HEPMASS (21D), MINIBOONE (43D), BSDS300 (63D).
<ul>
<li>Training points range from ~30k (MINIBOONE) to ~1.6M (POWER).</li>
</ul>
</li>
<li><strong>Image Datasets</strong>:
<ul>
<li>CIFAR-10 ($32 \times 32$, 50k training points).</li>
<li>ImageNet ($32 \times 32$, ~1.28M training points).</li>
<li>Oxford Flowers ($128 \times 128$, ~315k training points).</li>
</ul>
</li>
<li><strong>Time Sampling</strong>: Time $t$ is sampled from a Beta distribution during training (reweighting) to focus learning near the target.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Interpolant</strong>: The primary interpolant used is trigonometric: $I_t(x_0, x_1) = \cos(\frac{\pi t}{2})x_0 + \sin(\frac{\pi t}{2})x_1$.
<ul>
<li>Alternative linear interpolant: $I_t = a_t x_0 + b_t x_1$.</li>
</ul>
</li>
<li><strong>Loss Function</strong>:
$$G(\hat{v}) = \mathbb{E}_{t, x_0, x_1}[|\hat{v}_t(x_t)|^2 - 2\partial_t I_t(x_0, x_1) \cdot \hat{v}_t(x_t)]$$
<ul>
<li>The expectation is amenable to empirical estimation using batches of $x_0, x_1, t$.</li>
</ul>
</li>
<li><strong>Sampling</strong>: Numerical integration using Dormand-Prince (Runge-Kutta 4/5).</li>
<li><strong>Optimization</strong>: SGD/Adam variants used for optimization.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Tabular Architectures</strong>:
<ul>
<li>Feed-forward networks with 4-5 hidden layers.</li>
<li>Hidden widths: 512 (POWER, GAS, HEPMASS, MINIBOONE) or 1024 (BSDS300).</li>
<li>Activation: ReLU (general) or ELU (BSDS300).</li>
</ul>
</li>
<li><strong>Image Architectures</strong>:
<ul>
<li>U-Net based on the DDPM implementation.</li>
<li>Dimensions: 256 hidden dimension.</li>
<li>Sinusoidal time embeddings used.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>: Negative Log Likelihood (NLL) in nats (tabular) or bits per dim (images), Frechet Inception Distance (FID) for images.</li>
<li><strong>Baselines</strong>: FFJORD, Glow, Real NVP, OT-Flow, ScoreFlow, DDPM.</li>
</ul>
<p><strong>Tabular NLL</strong> (nats, lower is better; Table 2 Left):</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>POWER</th>
          <th>GAS</th>
          <th>HEPMASS</th>
          <th>MINIBOONE</th>
          <th>BSDS300</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MADE</td>
          <td>3.08</td>
          <td>-3.56</td>
          <td>20.98</td>
          <td>15.59</td>
          <td>-148.85</td>
      </tr>
      <tr>
          <td>Real NVP</td>
          <td>-0.17</td>
          <td>-8.33</td>
          <td>18.71</td>
          <td>13.55</td>
          <td>-153.28</td>
      </tr>
      <tr>
          <td>Glow</td>
          <td>-0.17</td>
          <td>-8.15</td>
          <td>18.92</td>
          <td>11.35</td>
          <td>-155.07</td>
      </tr>
      <tr>
          <td>CPF</td>
          <td>-0.52</td>
          <td>-10.36</td>
          <td>16.93</td>
          <td>10.58</td>
          <td>-154.99</td>
      </tr>
      <tr>
          <td>NSP</td>
          <td>-0.64</td>
          <td>-13.09</td>
          <td>14.75</td>
          <td>9.67</td>
          <td>-157.54</td>
      </tr>
      <tr>
          <td>FFJORD</td>
          <td>-0.46</td>
          <td>-8.59</td>
          <td>14.92</td>
          <td>10.43</td>
          <td>-157.40</td>
      </tr>
      <tr>
          <td>OT-Flow</td>
          <td>-0.30</td>
          <td>-9.20</td>
          <td>17.32</td>
          <td>10.55</td>
          <td>-154.20</td>
      </tr>
      <tr>
          <td><strong>Ours</strong></td>
          <td><strong>-0.57</strong></td>
          <td><strong>-12.35</strong></td>
          <td><strong>14.85</strong></td>
          <td><strong>10.42</strong></td>
          <td><strong>-156.22</strong></td>
      </tr>
  </tbody>
</table>
<p><strong>Image Generation NLL and FID</strong> (Table 2 Right; NLL in bits per dim, lower is better):</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>CIFAR-10 NLL</th>
          <th>CIFAR-10 FID</th>
          <th>ImageNet-32 NLL</th>
          <th>ImageNet-32 FID</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FFJORD</td>
          <td>3.40</td>
          <td>-</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Glow</td>
          <td>3.35</td>
          <td>-</td>
          <td>4.09</td>
          <td>-</td>
      </tr>
      <tr>
          <td>DDPM</td>
          <td>≤3.75</td>
          <td>3.17</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>DDPM++ (Song et al., 2021)</td>
          <td>≤3.37</td>
          <td>2.90</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>ScoreSDE (Song et al., 2021)</td>
          <td>2.99</td>
          <td>2.92</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>VDM</td>
          <td>≤2.65</td>
          <td>7.41</td>
          <td>≤3.72</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Soft Truncation</td>
          <td>2.88</td>
          <td>3.45</td>
          <td>3.85</td>
          <td>8.42</td>
      </tr>
      <tr>
          <td>ScoreFlow</td>
          <td>2.81</td>
          <td>5.40</td>
          <td>3.76</td>
          <td>10.18</td>
      </tr>
      <tr>
          <td><strong>Ours</strong></td>
          <td><strong>2.99</strong></td>
          <td><strong>10.27</strong></td>
          <td><strong>3.48</strong></td>
          <td><strong>8.49</strong></td>
      </tr>
  </tbody>
</table>
<p>Note: DDPM++ is from Song et al. (2021), the same work as ScoreSDE (it is the architecture optimized for VP/sub-VP SDEs). InterFlow matches ScoreSDE on CIFAR-10 NLL (2.99 bits per dim) while being simulation-free. FID is weaker than dedicated image models (10.27 vs 2.92 for ScoreSDE), reflecting the paper&rsquo;s primary focus on tractable likelihood rather than sample quality.</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: All models were trained on a single NVIDIA A100 GPU.</li>
<li><strong>Training Time</strong>:
<ul>
<li>Tabular: $10^5$ steps.</li>
<li>Images: $1.5 \times 10^5$ to $6 \times 10^5$ steps.</li>
<li>Speedup: Demonstrated ~400x speedup compared to FFJORD on MiniBooNE dataset.</li>
</ul>
</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>lucidrains/denoising-diffusion-pytorch (link defunct)</td>
          <td>Code</td>
          <td>MIT</td>
          <td>Base U-Net architecture used for image experiments; original GitHub account no longer available</td>
      </tr>
  </tbody>
</table>
<p>No official code release accompanies this paper. All tabular datasets (POWER, GAS, HEPMASS, MINIBOONE, BSDS300) are publicly available from prior work. CIFAR-10 and ImageNet are standard public benchmarks. Oxford Flowers 102 is also publicly available. Hyperparameters and architectures are fully specified in Tables 3 and 4 of the paper.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Albergo, M. S., &amp; Vanden-Eijnden, E. (2023). Building Normalizing Flows with Stochastic Interpolants. <em>The Eleventh International Conference on Learning Representations</em>.</p>
<p><strong>Publication</strong>: ICLR 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{albergoBuildingNormalizingFlows2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Building {{Normalizing Flows}} with {{Stochastic Interpolants}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{The {{Eleventh International Conference}} on {{Learning Representations}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Albergo, Michael Samuel and {Vanden-Eijnden}, Eric}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2023</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://openreview.net/forum?id=li7qeBbCR1t}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=li7qeBbCR1t">OpenReview</a></li>
<li><a href="https://arxiv.org/abs/2209.15571">arXiv</a></li>
</ul>
]]></content:encoded></item><item><title>STOUT V2.0: Transformer-Based SMILES to IUPAC Translation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/stout-v2/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/stout-v2/</guid><description>A Transformer-based model for translating SMILES to IUPAC names, trained on ~1 billion molecules, achieving ~0.99 BLEU score on benchmarks.</description><content:encoded><![CDATA[<h2 id="paper-contribution--methodological-scope">Paper Contribution &amp; Methodological Scope</h2>
<p><strong>Method (Primary) / Resource (Secondary)</strong></p>
<p>This paper presents a <strong>Methodological</strong> contribution by developing and validating a Transformer-based neural machine translation model (STOUT V2) for bidirectional chemical nomenclature (<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> $\leftrightarrow$ IUPAC). It systematically compares this new architecture against previous RNN-based baselines (<a href="/notes/chemistry/molecular-representations/name-translation/stout/">STOUT V1</a>) and performs ablation studies on tokenization strategies.</p>
<p>It also serves as a significant <strong>Resource</strong> contribution by generating a massive training dataset of nearly 1 billion SMILES-IUPAC pairs (curated via commercial Lexichem software) and releasing the resulting models and code as open-source tools for chemical naming.</p>
<h2 id="the-need-for-robust-open-source-iupac-nomenclature-rules">The Need for Robust Open-Source IUPAC Nomenclature Rules</h2>
<p>Assigning systematic IUPAC names to chemical structures requires adherence to complex rules, challenging human consistency. Deterministic, rule-based software options like OpenEye Lexichem and ChemAxon are reliable commercial solutions. Existing open-source tools like OPSIN focus on parsing names to structures.</p>
<p>The previous version of STOUT (V1), based on RNNs/GRUs, achieved ~90% BLEU accuracy, with known limitations in capturing long-distance dependencies required for stereochemistry handling. This work uses the sequence-learning capabilities of Transformers combined with large-scale datasets to create a competitive open-source IUPAC naming tool.</p>
<h2 id="architectural-shift-and-billion-scale-training">Architectural Shift and Billion-Scale Training</h2>
<p>The primary advancements over previous iterations address both architecture and dataset scale:</p>
<ol>
<li><strong>Architecture Shift</strong>: Moving from an RNN-based Seq2Seq model to a <strong>Transformer-based architecture</strong> (4 layers, 8 heads), which captures intricate chemical patterns better than GRUs.</li>
<li><strong>Billion-Scale Training</strong>: Training on a dataset of nearly <strong>1 billion molecules</strong> (combining PubChem and ZINC15), significantly larger than the 60 million used for STOUT V1.</li>
<li><strong>Tokenization Strategy</strong>: Determining that <strong>character-wise tokenization</strong> for IUPAC names is superior to word-wise tokenization in terms of both accuracy and training efficiency (15% faster).</li>
</ol>
<h2 id="experimental-validation-and-scaling-limits">Experimental Validation and Scaling Limits</h2>
<p>The authors conducted three primary experiments to validate bidirectional translation (SMILES $\rightarrow$ IUPAC and IUPAC $\rightarrow$ SMILES):</p>
<ul>
<li><strong>Experiment 1 (Optimization)</strong>: Assessed the impact of dataset size (1M vs 10M vs 50M) and tokenization strategy on SMILES-to-IUPAC performance.</li>
<li><strong>Experiment 2 (Scaling)</strong>: Trained models on 110 million PubChem molecules for <strong>both</strong> forward and reverse translation tasks to test performance on longer sequences.</li>
<li><strong>Experiment 3 (Generalization)</strong>: Trained on the full ~1 billion dataset (PubChem + ZINC15) for both translation directions.</li>
<li><strong>External Validation</strong>: Benchmarked against an external dataset from ChEBI (1,485 molecules) and ChEMBL34 to test generalization to unseen data.</li>
</ul>
<p><strong>Evaluation Metrics</strong>:</p>
<ul>
<li><strong>Textual Accuracy</strong>: BLEU scores (1-4) and Exact String Match.</li>
<li><strong>Chemical Validity</strong>: Retranslation of generated names back to SMILES using OPSIN, followed by Tanimoto similarity checks (PubChem fingerprints) against the original input.</li>
</ul>
<h2 id="translation-accuracy-and-structural-validity">Translation Accuracy and Structural Validity</h2>
<ul>
<li><strong>Superior Performance</strong>: STOUT V2 achieved an average BLEU score of <strong>0.99</strong> (vs 0.94 for V1). While exact string matches varied by experiment (83-89%), the model notably achieved a perfect BLEU score (1.0) on <strong>97.49%</strong> of a specific test set where STOUT V1 only reached 66.65%.</li>
<li><strong>Structural Validity (&ldquo;Near Misses&rdquo;)</strong>: When the generated name differed from the ground truth string, the re-generated structure often remained chemically valid. The model maintained an average Tanimoto similarity $T(A,B)$ of <strong>0.68</strong> for these divergent names between bit-vector fingerprints $A$ and $B$, roughly defined as:
$$ T(A,B) = \frac{\sum (A \cap B)}{\sum (A \cup B)} $$
<em>Critique</em>: Note that an average Tanimoto coefficient of 0.68 typically suggests moderate structural similarity/drift, not an almost-identical &ldquo;near miss&rdquo; (which would be $&gt;0.85$). This implies the model constructs chemically related but structurally distinct outputs when it fails exact string matching.</li>
<li><strong>Tokenization</strong>: Character-level splitting for IUPAC names outperformed word-level splitting and was more computationally efficient.</li>
<li><strong>Data Imbalance &amp; Generalization</strong>: The model&rsquo;s drop in performance for sequences &gt;600 characters highlights a systemic issue in open chemical databases: long, highly complex SMILES strings are significantly underrepresented. Even billion-scale training datasets are still bound by the chemical diversity of their source material.</li>
<li><strong>Limitations</strong>:
<ul>
<li><strong>Preferred Names (PINs)</strong>: The model mimics Lexichem&rsquo;s naming conventions, generating valid IUPAC names distinct from strict <em>Preferred IUPAC Names</em> (PINs).</li>
<li><strong>Sequence Length</strong>: Performance degrades for very long SMILES (&gt;600 characters) due to scarcity in the training data.</li>
<li><strong>Algorithmic Distillation Bottleneck</strong>: Because the 1 billion training pairs were generated entirely by OpenEye&rsquo;s Lexichem, STOUT V2 acts as a knowledge distillation of that specific commercial algorithm. The model learns Lexichem’s heuristic mapping, specific dialects, and potential systematic errors, rather than deriving true nomenclature rules from first principles.</li>
</ul>
</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training data was derived from PubChem and ZINC15. Ground truth IUPAC names were generated using OpenEye Lexichem TK 2.8.1 to ensure consistency.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training (Exp 1)</strong></td>
          <td>PubChem Subset</td>
          <td>1M, 10M, 50M</td>
          <td>Selected via MaxMin algorithm for diversity</td>
      </tr>
      <tr>
          <td><strong>Training (Exp 2)</strong></td>
          <td>PubChem</td>
          <td>110M</td>
          <td>Filtered for SMILES length &lt; 600</td>
      </tr>
      <tr>
          <td><strong>Training (Exp 3)</strong></td>
          <td>PubChem + ZINC15</td>
          <td>~1 Billion</td>
          <td>999,637,326 molecules total</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td>ChEBI</td>
          <td>1,485</td>
          <td>External validation set, non-overlapping with training</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li><strong>SMILES</strong>: Canonicalized, isomeric, and kekulized using RDKit (v2023.03.1).</li>
<li><strong>Formatting</strong>: Converted to TFRecord format in 100 MB chunks for TPU efficiency.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>SMILES Tokenization</strong>: Regex-based splitting. Atoms (e.g., &ldquo;Cl&rdquo;, &ldquo;Au&rdquo;), bonds, brackets, and digits are separate tokens.</li>
<li><strong>IUPAC Tokenization</strong>: <strong>Character-wise split</strong> was selected as the optimal strategy (treating every character as a token).</li>
<li><strong>Optimization</strong>: Adam optimizer with a custom learning rate scheduler based on model dimensions.</li>
<li><strong>Loss Function</strong>: Trained to minimize the Sparse Categorical Cross-Entropy $L$, masking padding tokens. For a correctly predicted target class $t$ alongside probabilities $p_i$, the masked loss is represented mathematically as:
$$ L = - \sum_{i=1}^{m} m_i y_{i} \log(p_{i}) $$
where $m_i$ masks padded positions.</li>
<li><strong>Code Availability</strong>: The <a href="https://github.com/egonw/Smiles-TO-iUpac-Translator">main STOUT V2 repository</a> contains the inference package. The training pipeline/instructions (originally linked to a separate repo that is currently a 404) can still be found within the <a href="https://doi.org/10.5281/zenodo.6559438">Zenodo archive release</a>.</li>
</ul>
<h3 id="models">Models</h3>
<p>The model follows the standard Transformer architecture from &ldquo;Attention is All You Need&rdquo; (Vaswani et al.).</p>
<ul>
<li><strong>Architecture</strong>: 4 Transformer layers (encoder/decoder stack).</li>
<li><strong>Attention</strong>: Multi-head attention with <strong>8 heads</strong>.</li>
<li><strong>Dimensions</strong>: Embedding size ($d_{model}$) = 512; Feed-forward dimension ($d_{ff}$) = 2048.</li>
<li><strong>Regularization</strong>: Dropout rate of 0.1.</li>
<li><strong>Context Window</strong>: Max input length (SMILES) = 600; Max output length (IUPAC) = 700-1000.</li>
<li><strong>Weights</strong>: Model weights for forward and reverse architectures are <a href="https://doi.org/10.5281/zenodo.13318286">available via Zenodo (v3)</a>.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Evaluation focused on both string similarity and chemical structural integrity.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Scope</th>
          <th>Method</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>BLEU Score</strong></td>
          <td>N-gram overlap</td>
          <td>Compared predicted IUPAC string to Ground Truth.</td>
      </tr>
      <tr>
          <td><strong>Exact Match</strong></td>
          <td>Accuracy</td>
          <td>Binary 1/0 check for identical strings.</td>
      </tr>
      <tr>
          <td><strong>Tanimoto</strong></td>
          <td>Structural Similarity</td>
          <td>Predicted Name $\rightarrow$ OPSIN $\rightarrow$ SMILES $\rightarrow$ Fingerprint comparison to input.</td>
      </tr>
  </tbody>
</table>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/egonw/Smiles-TO-iUpac-Translator">STOUT V2 GitHub</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Inference package (PyPI: STOUT-pypi)</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.13318286">Model Weights (Zenodo v3)</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>Forward and reverse translation weights</td>
      </tr>
      <tr>
          <td><a href="https://zenodo.org/records/6559438">Code Snapshot (Zenodo)</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Training pipeline archive</td>
      </tr>
      <tr>
          <td><a href="https://stout.decimer.ai">Web Application</a></td>
          <td>Other</td>
          <td>Unknown</td>
          <td>Demo with Ketcher, bulk submission, DECIMER integration</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Training was conducted entirely on Google Cloud Platform (GCP) TPUs.</p>
<ul>
<li><strong>STOUT V1</strong>: Trained on TPU v3-8.</li>
<li><strong>STOUT V2</strong>: Trained on <strong>TPU v4-128 pod slices</strong> (128 nodes).</li>
<li><strong>Large Scale (Exp 3)</strong>: Trained on <strong>TPU v4-256 pod slice</strong> (256 nodes).</li>
<li><strong>Training Time</strong>: Average of <strong>15 hours and 2 minutes per epoch</strong> for the 1 billion dataset.</li>
<li><strong>Framework</strong>: TensorFlow 2.15.0-pjrt with Keras.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rajan, K., Zielesny, A., &amp; Steinbeck, C. (2024). STOUT V2.0: SMILES to IUPAC name conversion using transformer models. <em>Journal of Cheminformatics</em>, 16(146). <a href="https://doi.org/10.1186/s13321-024-00941-x">https://doi.org/10.1186/s13321-024-00941-x</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2024</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rajanSTOUTV20SMILES2024,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{STOUT V2}}.0: {{SMILES}} to {{IUPAC}} Name Conversion Using Transformer Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{STOUT V2}}.0}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Rajan, Kohulan and Zielesny, Achim and Steinbeck, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2024</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = dec,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{16}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{146}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1758-2946}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-024-00941-x}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://stout.decimer.ai">Web Application</a> (Includes Ketcher drawing, bulk submission, and DECIMER integration)</li>
<li><a href="https://decimer.ai">DECIMER Project</a></li>
<li><a href="/notes/chemistry/molecular-representations/name-translation/stout/">STOUT V1 Note</a></li>
<li><a href="https://zenodo.org/records/6559438">Zenodo Archive (Code Snapshot)</a></li>
</ul>
]]></content:encoded></item><item><title>OCSAug: Diffusion-Based Augmentation for Hand-Drawn OCSR</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/ocsaug/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/ocsaug/</guid><description>A diffusion-based data augmentation pipeline (OCSAug) using DDPM and RePaint to improve optical chemical structure recognition on hand-drawn images.</description><content:encoded><![CDATA[<h2 id="document-taxonomy-ocsaug-as-a-novel-method">Document Taxonomy: OCSAug as a Novel Method</h2>
<p>This is a <strong>Method</strong> paper according to the <a href="/notes/interdisciplinary/research-methods/ai-physical-sciences-paper-taxonomy/">taxonomy</a>. It proposes a novel data augmentation pipeline (<strong>OCSAug</strong>) that integrates Denoising Diffusion Probabilistic Models (DDPM) and the RePaint algorithm to address the data scarcity problem in hand-drawn optical chemical structure recognition (OCSR). The contribution is validated through systematic benchmarking against existing augmentation techniques (RDKit, Randepict) and ablation studies on mask design.</p>
<h2 id="expanding-hand-drawn-training-data-for-ocsr">Expanding Hand-Drawn Training Data for OCSR</h2>
<p>A vast amount of molecular structure data exists in analog formats, such as hand-drawn diagrams in research notes or older literature. While OCSR models perform well on digitally rendered images, they struggle with hand-drawn images due to noise, varying handwriting styles, and distortions. Current datasets for hand-drawn images (e.g., DECIMER) are too small to train effective models, and existing augmentation tools (RDKit, Randepict) fail to generate sufficiently realistic hand-drawn variations.</p>
<h2 id="ocsaug-pipeline-masked-repaint-via-generative-ai">OCSAug Pipeline: Masked RePaint via Generative AI</h2>
<p>The core novelty is <strong>OCSAug</strong>, a three-phase pipeline that uses generative AI to synthesize training data:</p>
<ol>
<li><strong>DDPM + RePaint</strong>: It utilizes a DDPM to learn the distribution of hand-drawn images and the RePaint algorithm for inpainting.</li>
<li><strong>Structural Masking</strong>: It introduces <strong>vertical and horizontal stripe pattern masks</strong>. These masks selectively obscure parts of atoms or bonds, forcing the diffusion model to reconstruct them with irregular &ldquo;hand-drawn&rdquo; styles while preserving the underlying chemical topology.</li>
<li><strong>Label Transfer</strong>: Because the chemical structure is preserved during inpainting, the SMILES label from the original image is directly transferred to the augmented image, bypassing the need for re-annotation.</li>
</ol>
<h2 id="benchmarking-diffusion-augmentations-on-decimer">Benchmarking Diffusion Augmentations on DECIMER</h2>
<p>The authors evaluated OCSAug using the <strong>DECIMER dataset</strong>, specifically a &ldquo;drug-likeness&rdquo; subset filtered by Lipinski&rsquo;s and Veber&rsquo;s rules.</p>
<ul>
<li><strong>Baselines</strong>: The method was compared against <strong>RDKit</strong> (digital generation) and <strong>Randepict</strong> (rule-based augmentation).</li>
<li><strong>Models</strong>: Four recent OCSR models were fine-tuned: <strong>MolScribe</strong>, <strong>DECIMER 1.0 (I2S)</strong>, <strong>MolNexTR</strong>, and <strong>MPOCSR</strong>.</li>
<li><strong>Metrics</strong>:
<ul>
<li><strong>Tanimoto Similarity</strong>: To measure prediction accuracy against ground truth.</li>
<li><strong>Fréchet Inception Distance (FID)</strong>: To measure the distributional similarity between generated and real hand-drawn images.</li>
<li><strong>RMSE</strong>: To quantify pixel-level structural preservation across different mask thicknesses.</li>
</ul>
</li>
</ul>
<h2 id="improved-generalization-capabilities-and-fid-scores">Improved Generalization Capabilities and FID Scores</h2>
<ul>
<li><strong>Performance Boost</strong>: OCSAug improved recognition accuracy (Tanimoto similarity) by <strong>1.918 to 3.820 times</strong> compared to non-fine-tuned baselines (Improvement Ratio), outperforming traditional augmentation techniques such as RDKit and Randepict (1.570-3.523x).</li>
<li><strong>Data Quality</strong>: OCSAug achieved the lowest FID score (0.471) compared to Randepict (4.054) and RDKit (10.581), indicating its generated images are much closer to the real hand-drawn distribution.</li>
<li><strong>Generalization</strong>: The method showed improved generalization on a newly collected real-world dataset of 463 images from 6 volunteers.</li>
<li><strong>Resolution Mixing</strong>: Training MolScribe and MolNexTR with a mix of $128 \times 128$, $256 \times 256$, and $512 \times 512$ resolution images improved Tanimoto similarity (e.g., MolScribe from 0.585 to 0.640), though this strategy did not help I2S or MPOCSR.</li>
<li><strong>Real-World Evaluation</strong>: On a newly collected dataset of 463 hand-drawn images from 6 volunteers (88 drug compounds), the MPOCSR model fine-tuned with OCSAug achieved 0.367 exact-match accuracy (Tanimoto = 1.0), compared to 0.365 for non-augmented fine-tuning and 0.037 for no fine-tuning. The area under the accuracy curve showed a more notable improvement in reducing misrecognition.</li>
<li><strong>Limitations</strong>: The generation process is slow (3 weeks for 10k images on a single GPU). The fixed stripe masks may struggle with highly complex, non-drug-like geometries: when evaluated on the full DECIMER dataset (without drug-likeness filtering), OCSAug did not yield uniform improvements across all models.</li>
</ul>
<hr>
<h2 id="reproducibility">Reproducibility</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/jjjabcd/OCSAug">OCSAug</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation using guided-diffusion and RePaint</td>
      </tr>
      <tr>
          <td><a href="https://zenodo.org/records/6456306">DECIMER Hand-Drawn Dataset</a></td>
          <td>Dataset</td>
          <td>CC-BY 4.0</td>
          <td>5,088 hand-drawn molecular structure images from 24 individuals</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<ul>
<li><strong>Source</strong>: DECIMER dataset (hand-drawn images).</li>
<li><strong>Filtering</strong>: A &ldquo;drug-likeness&rdquo; filter was applied (Lipinski&rsquo;s rule of 5 + Veber&rsquo;s rules) along with an atom filter (C, H, O, S, F, Cl, Br, N, P only).</li>
<li><strong>Final Size</strong>: 3,194 samples, split into:
<ul>
<li><strong>Training</strong>: 2,604 samples.</li>
<li><strong>Validation</strong>: 290 samples.</li>
<li><strong>Test</strong>: 300 samples.</li>
</ul>
</li>
<li><strong>Resolution</strong>: All images resized to $256 \times 256$ pixels.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Framework</strong>: DDPM implemented using <code>guided-diffusion</code>.</li>
<li><strong>RePaint Settings</strong>:
<ul>
<li>Total time steps: 250.</li>
<li>Jump length: 10.</li>
<li>Resampling counts: 10.</li>
</ul>
</li>
<li><strong>Masking Strategy</strong>:
<ul>
<li><strong>Vertical Stripes</strong>: Obscure atom symbols to vary handwriting style.</li>
<li><strong>Horizontal Stripes</strong>: Obscure bonds to vary length/thickness/alignment.</li>
<li><strong>Optimal Thickness</strong>: A stripe thickness of <strong>4 pixels</strong> was found to be optimal for balancing diversity and structural preservation.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<p>The OCSR models were pretrained on PubChem (digital images) and then fine-tuned on the OCSAug dataset.</p>
<ul>
<li><strong>MolScribe</strong>: Swin Transformer encoder, Transformer decoder. Fine-tuned (all layers) for 30 epochs, batch size 16-128, LR 2e-5.</li>
<li><strong>I2S (DECIMER 1.0)</strong>: Inception V3 encoder (frozen), FC/Decoder fine-tuned. 25 epochs, batch size 64, LR 1e-5.</li>
<li><strong>MolNexTR</strong>: Dual-stream encoder (Swin + CNN). Fine-tuned (all layers) for 30 epochs, batch size 16-64, LR 2e-5.</li>
<li><strong>MPOCSR</strong>: MPViT backbone. Fine-tuned (all layers) for 25 epochs, batch size 16-32, LR 4e-5.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>
<p><strong>Metric</strong>: Improvement Ratio (IR) of Tanimoto Similarity (TS), calculated iteratively or defined as:</p>
<p>$$
\text{IR} = \frac{\text{TS}_{\text{finetuned}}}{\text{TS}_{\text{non-finetuned}}}
$$</p>
</li>
<li>
<p><strong>Validation</strong>: Cross-validation on the split DECIMER dataset.</p>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPU</strong>: NVIDIA GeForce RTX 4090.</li>
<li><strong>Training Time</strong>: DDPM training took ~6 days.</li>
<li><strong>Generation Time</strong>: Generating 2,600 augmented images took ~70 hours.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Kim, J. H., &amp; Choi, J. (2025). OCSAug: diffusion-based optical chemical structure data augmentation for improved hand-drawn chemical structure image recognition. <em>The Journal of Supercomputing</em>, 81, 926.</p>
<p><strong>Publication</strong>: The Journal of Supercomputing 2025</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/jjjabcd/OCSAug">Official Repository</a></li>
<li><a href="https://zenodo.org/records/6456306">DECIMER Dataset</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{kimOCSAugDiffusionbasedOptical2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{OCSAug: Diffusion-Based Optical Chemical Structure Data Augmentation for Improved Hand-Drawn Chemical Structure Image Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{OCSAug}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Kim, Jin Hyuk and Choi, Jonghwan}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2025</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = may,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{The Journal of Supercomputing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{81}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{8}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{926}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1007/s11227-025-07406-4}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MERMaid: Multimodal Chemical Reaction Mining from PDFs</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/mermaid/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/mermaid/</guid><description>Vision-language pipeline extracting chemical reaction data from PDF figures and tables into structured knowledge graphs with 87% accuracy.</description><content:encoded><![CDATA[<h2 id="methodological-and-resource-contributions">Methodological and Resource Contributions</h2>
<p>This is primarily a <strong>Methodological</strong> paper ($\Psi_{\text{Method}}$) that introduces a novel pipeline (MERMaid) for extracting structured chemical data from unstructured PDF documents. It proposes a specific architecture combining fine-tuned vision models (VisualHeist) with vision-language models (DataRaider) and a retrieval-augmented generation system (KGWizard) to solve the problem of multimodal data ingestion.</p>
<p>Secondarily, it is a <strong>Resource</strong> paper ($\Psi_{\text{Resource}}$) as it releases the source code, prompts, and a new benchmark dataset (<strong>MERMaid-100</strong>) consisting of annotated reaction data across three chemical domains.</p>
<h2 id="the-inaccessibility-of-diagrammatic-reaction-data">The Inaccessibility of Diagrammatic Reaction Data</h2>
<ul>
<li><strong>Data Inaccessibility</strong>: A significant volume of chemical knowledge currently resides in &ldquo;print-optimized&rdquo; PDF formats, specifically within graphical elements like figures, schemes, and tables, which resist standard text mining.</li>
<li><strong>Limitations of Prior Work</strong>: Existing tools (e.g., ChemDataExtractor, <a href="/notes/chemistry/optical-structure-recognition/image-to-graph/molmole/">OpenChemIE</a>) focus primarily on text, struggle with multimodal parsing, or lack the &ldquo;contextual awareness&rdquo; needed to interpret implicit information (e.g., &ldquo;standard conditions&rdquo; with modifications in optimization tables).</li>
<li><strong>Need for Structured Data</strong>: To enable <a href="/notes/chemistry/llm-applications/autonomous-chemical-research-coscientist/">self-driving laboratories</a> and data-driven discovery, this unstructured literature must be converted into machine-actionable formats like <a href="https://en.wikipedia.org/wiki/Knowledge_graph">knowledge graphs</a>.</li>
</ul>
<h2 id="the-mermaid-pipeline-vision-models-and-llm-rag">The MERMaid Pipeline: Vision Models and LLM RAG</h2>
<ul>
<li><strong>VisualHeist (Fine-tuned Segmentation)</strong>: A custom fine-tuned model based on Microsoft&rsquo;s Florence-2 that accurately segments figures, captions, and footnotes, even in messy supplementary materials.</li>
<li><strong>DataRaider (Context-Aware Extraction)</strong>: A VLM-powered module (using GPT-4o) with a <strong>two-step prompt framework</strong> that performs &ldquo;self-directed context completion.&rdquo; It can infer missing reaction parameters from context and resolve footnote labels (e.g., linking &ldquo;condition a&rdquo; in a table to its footnote description).</li>
<li><strong>KGWizard (Schema-Adaptive Graph Construction)</strong>: A text-to-graph engine that uses LLMs as higher-order functions to synthesize parsers dynamically. It employs <strong>Retrieval-Augmented Generation (RAG)</strong> to check for existing nodes during creation, implicitly resolving coreferences (e.g., unifying &ldquo;MeCN&rdquo; and &ldquo;Acetonitrile&rdquo;).</li>
<li><strong>Topic-Agnostic Design</strong>: MERMaid features a flexible design that works across three distinct domains: <a href="https://en.wikipedia.org/wiki/Electrosynthesis">organic electrosynthesis</a>, <a href="https://en.wikipedia.org/wiki/Photocatalysis">photocatalysis</a>, and organic synthesis.</li>
</ul>
<h2 id="benchmarking-segmentation-and-extraction-accuracy">Benchmarking Segmentation and Extraction Accuracy</h2>
<ul>
<li><strong>Segmentation Benchmarking</strong>: The authors compared VisualHeist against OpenChemIE (LayoutParser) and PDFigCapX using a dataset of 121 PDFs from 5 publishers.</li>
<li><strong>End-to-End Extraction</strong>: Evaluated the full pipeline on <strong>MERMaid-100</strong>, a curated dataset of 100 articles across three domains (organic electrosynthesis, photocatalysis, organic synthesis).
<ul>
<li>Validating extraction of specific parameters (e.g., catalysts, solvents, yields) using &ldquo;hard-match&rdquo; accuracy.</li>
</ul>
</li>
<li><strong>Knowledge Graph Construction</strong>: Automatically generated knowledge graphs for the three domains and assessed the structural integrity and <a href="https://en.wikipedia.org/wiki/Coreference">coreference resolution</a> accuracy.</li>
</ul>
<h2 id="end-to-end-extraction-performance">End-to-End Extraction Performance</h2>
<ul>
<li><strong>Segmentation Results</strong>: VisualHeist achieved &gt;93% F1 score across all document types (including pre-2000 papers and supplementary materials), outperforming OpenChemIE by 15-75% and PDFigCapX by 28-75% across all metrics.</li>
<li><strong>Extraction Accuracy</strong>: DataRaider achieved &gt;92% accuracy for VLM-based parameter extraction and near-unity accuracy for domain-specific reaction parameters (e.g., anode, cathode, photocatalyst).</li>
<li><strong>Graph Building</strong>: KGWizard achieved 96% accuracy in node creation and coreference resolution.</li>
<li><strong>Overall Performance</strong>: The pipeline demonstrated an 87% end-to-end overall accuracy.</li>
<li><strong>Limitations</strong>: The architecture relies heavily on closed-weight models (GPT-4o) for reasoning and graph construction, which risks future reproducibility if API snapshots are deprecated. Additionally, the system remains vulnerable to cumulative error propagation from upstream OCR/OCSR tools like <a href="/notes/chemistry/optical-structure-recognition/benchmarks/ocsr-methods/">RxnScribe</a>.</li>
<li><strong>Availability</strong>: The authors provide a modular, extensible framework that can be adapted to other scientific domains.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Training Data (VisualHeist)</strong>:
<ul>
<li>Dataset of <strong>3,435 figures</strong> and <strong>1,716 tables</strong> annotated from 3,518 PDF pages.</li>
<li>Includes main text, supplementary materials, and unformatted archive papers.</li>
</ul>
</li>
<li><strong>Evaluation Data (MERMaid-100)</strong>:
<ul>
<li><strong>100 PDF articles</strong> curated from three domains: organic electrosynthesis, photocatalysis, and organic synthesis.</li>
<li>Includes 104 image-caption/table-heading pairs relevant to reaction optimization.</li>
<li>Available for download at Zenodo (DOI: 10.5281/zenodo.14917752).</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Two-Step Prompt Framework (DataRaider)</strong>:
<ul>
<li><em>Step 1</em>: Generic base prompt + domain keys to extract &ldquo;reaction dictionaries&rdquo; and &ldquo;footnote dictionaries&rdquo;. Uses &ldquo;fill-in-the-blank&rdquo; inference for missing details.</li>
<li><em>Step 2</em>: Safety check prompt where the VLM updates the reaction dictionary using the footnote dictionary to resolve entry-specific modifications.</li>
</ul>
</li>
<li><strong>LLM-Synthesized Parsers (KGWizard)</strong>:
<ul>
<li>Uses LLM as a function $g_{A,B}: A \times B \rightarrow (X \rightarrow Y)$ to generate Python code (parsers) dynamically based on input schema instructions.</li>
</ul>
</li>
<li><strong>RAG for Coreference</strong>:
<ul>
<li>During graph construction, the system queries the existing database for matching values (e.g., &ldquo;MeCN&rdquo;) before creating new nodes to prevent duplication.</li>
</ul>
</li>
<li><strong>Batching</strong>:
<ul>
<li>Articles processed in dynamic batch sizes (starting at 1, increasing to 30) to balance speed and redundancy checks.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>VisualHeist</strong>: Fine-tuned <strong>Florence-2-large</strong> (Microsoft vision foundation model).
<ul>
<li><em>Hyperparameters</em>: 12 epochs, learning rate $5 \times 10^{-6}$, batch size 4.</li>
</ul>
</li>
<li><strong>DataRaider &amp; KGWizard</strong>: <strong>GPT-4o</strong> (version <code>gpt-4o-2024-08-06</code>). Note: Requires an active OpenAI API key. The pipeline&rsquo;s long-term reproducibility is currently tied to the continued availability of this specific closed-source endpoint.</li>
<li><strong>RxnScribe</strong>: Used for <a href="/notes/chemistry/optical-structure-recognition/benchmarks/ocsr-methods/">Optical Chemical Structure Recognition (OCSR)</a> to convert reactant/product images to <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>:
<ul>
<li><em>Segmentation</em>: Precision, Recall, F1, Accuracy.</li>
<li><em>Caption Extraction</em>: Evaluated via <a href="https://en.wikipedia.org/wiki/Jaccard_index">Jaccard similarity</a>, mapping predicted token sets $A$ and true token sets $B$ to a threshold condition: $$J(A, B) = \frac{|A \cap B|}{|A \cup B|} \ge 0.70$$</li>
<li><em>Data Extraction</em>: Evaluated via Hard-Match accuracy, requiring exact correspondence between predicted sets ($\hat{Y}$) and ground-truth parameters ($Y$) for specific roles (e.g., anode vs. cathode): $$\text{HMA} = \frac{1}{|N|} \sum_{i=1}^{N} \mathbb{1}[y_i = \hat{y}_i]$$</li>
</ul>
</li>
<li><strong>Baselines</strong>: OpenChemIE (LayoutParser + EasyOCR) and PDFigCapX.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training (VisualHeist)</strong>: 2x NVLINK Nvidia RTX A6000 GPUs (48GB VRAM) + Intel Xeon w7-2495X CPU (48 cores).</li>
<li><strong>DataRaider Evaluation</strong>: 13th Gen Intel Core i7-1360P CPU (12 cores).</li>
<li><strong>Inference Costs</strong>:
<ul>
<li>DataRaider: ~$0.051 per image.</li>
<li>KGWizard: ~$0.40 per JSON.</li>
</ul>
</li>
<li><strong>Timing</strong>:
<ul>
<li>VisualHeist inference: ~4.5 seconds/image.</li>
<li>DataRaider inference: ~41.3 seconds/image.</li>
<li>KGWizard processing: ~110.6 seconds/file.</li>
</ul>
</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Leong, S. X., Pablo-García, S., Wong, B., &amp; Aspuru-Guzik, A. (2025). MERMaid: Universal multimodal mining of chemical reactions from PDFs using vision-language models. <em>Matter</em>, 8(12), 102331. <a href="https://doi.org/10.1016/j.matt.2025.102331">https://doi.org/10.1016/j.matt.2025.102331</a></p>
<p><strong>Publication</strong>: Matter, 2025</p>
<p><strong>Artifacts</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/aspuru-guzik-group/MERMaid">GitHub Repository</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation (VisualHeist, DataRaider, KGWizard)</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.14917752">Zenodo Data/Prompts</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>MERMaid-100 benchmark, prompts, and raw VLM responses</td>
      </tr>
  </tbody>
</table>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{leong2025mermaid,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{MERMaid: Universal multimodal mining of chemical reactions from PDFs using vision-language models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Leong, Shi Xuan and Pablo-Garc{\&#39;i}a, Sergio and Wong, Brandon and Aspuru-Guzik, Al{\&#39;a}n}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Matter}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{8}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{102331}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1016/j.matt.2025.102331}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>InvMSAFold: Generative Inverse Folding with Potts Models</title><link>https://hunterheidenreich.com/notes/biology/computational-biology/invmsafold/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/biology/computational-biology/invmsafold/</guid><description>InvMSAFold generates diverse protein sequences from structure by predicting Potts model parameters, enabling orders-of-magnitude faster sampling.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Methodological ($\Psi_{\text{Method}}$)</strong> paper. It introduces a novel architecture, <strong>InvMSAFold</strong>, which hybridizes deep learning encoders with statistical physics-based decoders (Potts models). The rhetorical structure focuses on architectural innovation (low-rank parameter generation), ablation of speed/diversity against baselines (ESM-IF1), and algorithmic efficiency.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>Standard inverse folding models (like ESM-IF1 or ProteinMPNN) solve a &ldquo;one-to-one&rdquo; mapping: given a structure, predict the <em>single</em> native sequence. However, in nature, folding is &ldquo;many-to-one&rdquo;: many homologous sequences fold into the same structure.</p>
<p>The authors identify two key gaps:</p>
<ol>
<li><strong>Lack of Diversity</strong>: Standard autoregressive models maximize probability for the ground truth sequence, often failing to capture the broad evolutionary landscape of viable homologs.</li>
<li><strong>Slow Inference</strong>: Autoregressive sampling requires a full neural network pass for <em>every amino acid</em>, making high-throughput screening (e.g., millions of candidates) computationally prohibitive.</li>
</ol>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is shifting the learning objective from predicting <em>sequences</em> to predicting <em>probability distributions</em>.</p>
<p>InvMSAFold outputs the parameters (couplings $\mathbf{J}$ and fields $\mathbf{h}$) of a <strong>Potts Model</strong> (a pairwise Markov Random Field).</p>
<ul>
<li><strong>Low-Rank Decomposition</strong>: To handle the massive parameter space of pairwise couplings ($L \times L \times q \times q$), the model predicts a low-rank approximation $\mathbf{V}$ ($L \times K \times q$), reducing complexity from $\mathcal{O}(L^2)$ to $\mathcal{O}(L)$.</li>
<li><strong>One-Shot Generation</strong>: The deep network runs only <em>once</em> to generate the Potts parameters. Sampling sequences from this Potts model is then performed on CPU via MCMC (for the PW variant) or direct autoregressive sampling (for the AR variant), which is orders of magnitude faster than running a Transformer decoder for every step.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the model on three CATH-based test sets (Inter-cluster, Intra-cluster, MSA) to test generalization at varying levels of homology.</p>
<ul>
<li><strong>Speed Benchmarking</strong>: Compared wall-clock sampling time vs. ESM-IF1 on CPU/GPU.</li>
<li><strong>Covariance Reconstruction</strong>: Checked if generated sequences recover the evolutionary correlations found in natural MSAs (Pearson correlation of covariance matrices).</li>
<li><strong>Structural Fidelity</strong>: Generated sequences with high Hamming distance from native, folded them with AlphaFold 2 (no templates), and measured RMSD to the target structure.</li>
<li><strong>Property Profiling</strong>: Analyzed the distribution of predicted solubility (Protein-Sol) and thermostability (Thermoprot) to show that sequence diversity translates into a wider range of biochemical properties.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Massive Speedup</strong>: InvMSAFold is orders of magnitude faster than ESM-IF1 (CPU vs. GPU; the comparison is not hardware-matched). Because the &ldquo;heavy lifting&rdquo; (generating Potts parameters) happens once, sampling millions of sequences becomes trivial on CPUs.</li>
<li><strong>Better Diversity</strong>: The model captures evolutionary covariances significantly better than ESM-IF1 and ProteinMPNN (which shares similar covariance recovery to ESM-IF1). A PCA-based KL-divergence analysis (lower is better; 0 means a perfect match to the natural MSA distribution) shows InvMSAFold-AR scores of $0.49$ (Inter-cluster) and $0.67$ (Intra-cluster), compared to $15.8$ and $11.9$ for ESM-IF1, demonstrating that the generated sequences occupy a distribution much closer to natural MSAs.</li>
<li><strong>Robust Folding</strong>: Sequences generated far from the native sequence (high Hamming distance) still fold into the correct structure (low RMSD), whereas ESM-IF1 struggles to produce diverse valid sequences.</li>
<li><strong>Property Expansion</strong>: The method generates a wider spread of predicted biochemical properties (solubility/thermostability), which could be useful for virtual screening in protein design.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Source</strong>: CATH database (40% non-redundant dataset).</p>
<p><strong>Splits</strong>:</p>
<ul>
<li><strong>Training</strong>: ~22k domains.</li>
<li><strong>Inter-cluster Test</strong>: 10% of sequence clusters held out (unseen clusters, many with superfamilies absent from training).</li>
<li><strong>Intra-cluster Test</strong>: Unseen domains from seen clusters.</li>
<li><strong>Augmentation</strong>: MSAs generated using <strong>MMseqs2</strong> against the Uniprot50 database. Training uses random subsamples of these MSAs ($|M_X| = 64$ for PW, $|M_X| = 32$ for AR) to teach the model evolutionary variance.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Architecture</strong>:</p>
<ul>
<li><strong>Encoder</strong>: Pre-trained <strong>ESM-IF1</strong> encoder (GVP-GNN architecture). The encoder is used to pre-compute structure embeddings, with independent Gaussian noise (std = 5% of the embedding std) added during training.</li>
<li><strong>Decoder</strong>: 6-layer Transformer (8 heads) that outputs a latent tensor.</li>
<li><strong>Projection</strong>: Linear layers project latent tensor to fields $\mathbf{h}$ ($L \times q$) and low-rank tensor $\mathbf{V}$ ($L \times K \times q$).</li>
</ul>
<p><strong>Coupling Construction</strong>:
The full coupling tensor $\mathcal{J}$ is approximated via:
$$\mathcal{J}_{i,a,j,b} = \frac{1}{\sqrt{K}} \sum_{k=1}^{K} \mathcal{V}_{i,k,a} \mathcal{V}_{j,k,b}$$
Rank $K=48$ was used.</p>
<p><strong>Loss Functions</strong>:
Two variants were trained:</p>
<ol>
<li><strong>InvMSAFold-PW</strong>: Trained via <strong>Pseudo-Likelihood (PL)</strong>. Computation is optimized to $\mathcal{O}(L)$ time using the low-rank property.</li>
<li><strong>InvMSAFold-AR</strong>: Trained via <strong>Autoregressive Likelihood</strong>. Couplings are masked ($J_{ij} = 0$ if $i &lt; j$) to allow exact likelihood computation and direct sampling without MCMC.</li>
</ol>
<h3 id="models">Models</h3>
<ul>
<li><strong>InvMSAFold-PW</strong>: Requires MCMC sampling (Metropolis-Hastings) at inference.</li>
<li><strong>InvMSAFold-AR</strong>: Allows direct, fast autoregressive sampling.</li>
<li><strong>Hyperparameters</strong>: AdamW optimizer, lr=$10^{-4}$ (PW) / $3.4 \times 10^{-4}$ (AR), 94 epochs. L2 regularization: $\lambda_h = \lambda_J = 10^{-4}$ (PW); $\lambda_J = 3.2 \times 10^{-6}$, $\lambda_h = 5.0 \times 10^{-5}$ (AR).</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>RMSD</strong>: Structure fidelity (AlphaFold2 prediction vs. native structure).</li>
<li><strong>Covariance Pearson Correlation</strong>: Measures recovery of evolutionary pairwise statistics.</li>
<li><strong>KL Divergence</strong>: Between PCA-projected densities of natural and synthetic sequences (Gaussian KDE, kernel size 1.0).</li>
<li><strong>Sampling Speed</strong>: Wall-clock time vs. sequence length/batch size.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training</strong>: Not specified in the paper. The GitHub repository reports testing on an NVIDIA RTX 3090, with training taking 10-24 hours depending on model variant.</li>
<li><strong>Inference</strong>:
<ul>
<li><strong>ESM-IF1</strong>: NVIDIA GeForce RTX 4060 Laptop (8GB).</li>
<li><strong>InvMSAFold</strong>: Single core of Intel i9-13905H CPU.</li>
</ul>
</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/luchinoprince/Potts_Inverse_Folding">Potts_Inverse_Folding</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Training and inference code (PyTorch)</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Silva, L. A., Meynard-Piganeau, B., Lucibello, C., &amp; Feinauer, C. (2025). Fast Uncovering of Protein Sequence Diversity from Structure. <em>International Conference on Learning Representations (ICLR)</em>. <a href="https://arxiv.org/abs/2406.11975">https://arxiv.org/abs/2406.11975</a></p>
<p><strong>Publication</strong>: ICLR 2025 (Spotlight)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{silvaFastUncoveringProtein2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Fast Uncovering of Protein Sequence Diversity from Structure}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Silva, Luca Alessandro and {Meynard-Piganeau}, Barthelemy and Lucibello, Carlo and Feinauer, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Statistical Mechanics: Theory and Experiment}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{8}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{084003}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1088/1742-5468/adf0e7}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://openreview.net/forum?id=1iuaxjssVp}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=1iuaxjssVp">OpenReview Page</a></li>
<li><a href="https://github.com/luchinoprince/Potts_Inverse_Folding">GitHub Repository</a></li>
</ul>
]]></content:encoded></item><item><title>InstructMol: Multi-Modal Molecular LLM for Drug Discovery</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/instructmol/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/instructmol/</guid><description>A multi-modal LLM aligning 2D molecular graphs with text via two-stage instruction tuning for drug discovery tasks.</description><content:encoded><![CDATA[<h2 id="instructmol-framework-overview">InstructMol Framework Overview</h2>
<p><strong>Methodological Paper ($\Psi_{\text{Method}}$)</strong></p>
<p>This work proposes <strong>InstructMol</strong>, a novel multi-modal architecture and training paradigm. It focuses on engineering a system that aligns a pre-trained molecular graph encoder with a general-purpose Large Language Model (LLM). The paper&rsquo;s primary contribution is the <strong>Two-Stage Instruction Tuning</strong> strategy (Alignment Pre-training + Task-Specific Tuning) designed to bridge the modality gap between 2D molecular graphs and natural language.</p>
<h2 id="bridging-specialist-and-generalist-models">Bridging Specialist and Generalist Models</h2>
<p>Current AI approaches in drug discovery typically fall into two categories. Specialist models deliver high accuracy on specific tasks (such as property prediction) but require extensive labeled datasets and lack conversational adaptability. Conversely, generalist LLMs offer strong reasoning and dialogue capabilities but struggle to natively interpret complex structural data, often relying on brittle 1D text representations of molecules like <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>.</p>
<p>There is a practical need for a unified &ldquo;Molecular Assistant&rdquo; capable of visually interpreting molecular graphs, reasoning about structure in natural language, and adapting across tasks like synthesis planning and property analysis without training from scratch.</p>
<h2 id="two-stage-modality-alignment">Two-Stage Modality Alignment</h2>
<p>The core novelty lies in the architecture and the <strong>two-stage training pipeline</strong> designed to align differing modalities efficiently:</p>
<ol>
<li><strong>MoleculeSTM Integration</strong>: InstructMol initializes its graph encoder with <strong>MoleculeSTM</strong>, which is already pre-aligned with text via contrastive learning, facilitating easier downstream alignment.</li>
<li><strong>Two-Stage Alignment Strategy</strong>:
<ul>
<li><strong>Stage 1 (Alignment Pre-training)</strong>: Freezes both the LLM and Graph Encoder; trains <em>only</em> a linear projector using a massive dataset of molecule-description pairs to map graph features into the LLM&rsquo;s token space.</li>
<li><strong>Stage 2 (Task-Specific Instruction Tuning)</strong>: Freezes the Graph Encoder; fine-tunes the Projector and the LLM (using <strong>LoRA</strong>) on specific downstream tasks. This allows the model to adapt its reasoning capabilities while preserving the structural understanding gained in Stage 1.</li>
</ul>
</li>
</ol>
<h2 id="task-evaluation-in-drug-discovery">Task Evaluation in Drug Discovery</h2>
<p>The authors evaluated InstructMol across three distinct categories of drug discovery tasks, comparing it against generalist LLMs (Vicuna, LLaMA, <a href="/notes/chemistry/llm-applications/galactica-large-language-model-for-science/">Galactica</a>) and specialist models (<a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a>, MolT5):</p>
<ol>
<li><strong>Property Prediction</strong>:
<ul>
<li><em>Regression</em>: Predicting quantum mechanical properties (HOMO, LUMO, Gap) using the <a href="/notes/chemistry/datasets/qm9/">QM9</a> dataset.</li>
<li><em>Classification</em>: Predicting biological activity (BACE, BBBP, HIV) using <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>.</li>
</ul>
</li>
<li><strong>Molecule Description Generation</strong>: Generating natural language descriptions of molecules using the ChEBI-20 dataset.</li>
<li><strong>Chemical Reaction Analysis</strong>:
<ul>
<li><em>Forward Reaction Prediction</em>: Predicting products from reactants.</li>
<li><em>Reagent Prediction</em>: Identifying necessary reagents.</li>
<li><em><a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">Retrosynthesis</a></em>: Suggesting reactants for a given product.</li>
</ul>
</li>
</ol>
<p><strong>Ablation Studies</strong> tested the impact of the projector type (Linear vs. MLP), LLM scale (7B vs 13B), and the necessity of the two-stage training approach.</p>
<h2 id="core-findings-and-limitations">Core Findings and Limitations</h2>
<ul>
<li><strong>Improvement Over Baseline Generalists</strong>: InstructMol significantly outperformed generalist LLMs (like LLaMA and Galactica) on all tasks, demonstrating the value of incorporating explicit graph modalities.</li>
<li><strong>Reducing the Gap with Specialists</strong>: While InstructMol brings versatile reasoning capabilities, it still trails highly optimized specialist models (such as Uni-Mol and MolT5) on tasks like molecule description generation. This remaining gap likely stems from its reliance on a relatively small alignment pre-training dataset (~264K PubChem pairs) and the information bottleneck of using a simple linear projector, compared to the millions of structures used to train expert foundational models.</li>
<li><strong>Importance of Alignment</strong>: Ablation studies confirmed that skipping Stage 1 (Alignment Pre-training) degraded performance, proving that a dedicated phase for projecting graph features into text space is crucial.</li>
<li><strong>Limitation</strong>: The model struggles with highly imbalanced datasets (e.g., HIV) and complex reaction mixtures where mapping multiple graph tokens to text becomes ambiguous.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training pipeline utilizes distinct datasets for the two stages. <strong>Note:</strong> As of the latest repository update, the finely-processed instruction-tuning datasets (e.g., the filtered ~264K PubChem pairs and instruction-formatted subset pairs) are listed as &ldquo;coming soon&rdquo;, requiring manual recreation for full reproduction.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Stage 1</strong> (Alignment)</td>
          <td style="text-align: left"><strong><a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a></strong></td>
          <td style="text-align: left">~264K pairs</td>
          <td style="text-align: left">Molecule-text pairs. Filtered from 330K for invalid descriptions and overlaps with ChEBI-20 test set.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Stage 2</strong> (Prop. Reg.)</td>
          <td style="text-align: left"><strong>QM9</strong></td>
          <td style="text-align: left">362K samples</td>
          <td style="text-align: left">Quantum mechanics properties (HOMO, LUMO, Gap).</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Stage 2</strong> (Prop. Class.)</td>
          <td style="text-align: left"><strong>MoleculeNet</strong></td>
          <td style="text-align: left">35K samples</td>
          <td style="text-align: left">BACE, BBBP, HIV datasets. Converted to instruction format (Yes/No answer).</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Stage 2</strong> (Generation)</td>
          <td style="text-align: left"><strong>ChEBI-20</strong></td>
          <td style="text-align: left">26.5K samples</td>
          <td style="text-align: left">Molecule description generation.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Stage 2</strong> (Reactions)</td>
          <td style="text-align: left"><strong>USPTO</strong></td>
          <td style="text-align: left">~380K samples</td>
          <td style="text-align: left">Combined datasets for Forward (125K), Retrosynthesis (130K), and Reagent (125K) prediction.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Two-Stage Training</strong>:
<ol>
<li><strong>Alignment Pre-training</strong>: Updates only the Projector. The objective maximizes the probability of generating the target description token sequence $\mathbf{X}_A$ given the molecule input $\mathbf{X}_M$ and instruction $\mathbf{X}_I$:
$$p(\mathbf{X}_A | \mathbf{X}_M, \mathbf{X}_I) = \prod_{i=1}^L p_\theta(x_i | \mathbf{X}_G \parallel \mathbf{X}_S, \mathbf{X}_I, \mathbf{X}_{A,&lt;i})$$</li>
<li><strong>Instruction Tuning</strong>: Updates Projector + LLM (via LoRA) using standard autoregressive language modeling on task-specific instructions. The objective minimizes the negative log-likelihood of generating the target response $R$ of length $L$:
$$\mathcal{L}(\theta) = -\sum_{i=1}^L \log p(R_i | I, M, R_{&lt;i}; \theta)$$
where $I$ represents the instruction and $M$ is the multi-modal molecular input.</li>
</ol>
</li>
<li><strong>LoRA (Low-Rank Adaptation)</strong>: Applied to the LLM in Stage 2. Rank $r=64$, Scaling $\alpha=16$.</li>
<li><strong>Optimization</strong>: AdamW optimizer. Learning rate starts at 2e-3 (Stage 1) and 8e-5 (Stage 2) with cosine decay. Warm-up ratio 0.03.</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Note:</strong> The official repository currently lists the final fine-tuned <strong>InstructMol weights</strong> as &ldquo;coming soon.&rdquo; Consequently, one must fine-tune the components using the provided scripts. Base model weights (Vicuna-7B and MoleculeSTM) are publicly available via Hugging Face.</p>
<ul>
<li><strong>Graph Encoder ($f_g$)</strong>:
<ul>
<li>Architecture: Graph Isomorphism Network (GIN) with 5 layers.</li>
<li>Hidden Dimension: 300.</li>
<li>Initialization: <strong>MoleculeSTM</strong> checkpoint (pre-trained via contrastive learning).</li>
<li>Status: <strong>Frozen</strong> during Stage 2.</li>
</ul>
</li>
<li><strong>LLM</strong>:
<ul>
<li>Base: <strong>Vicuna-v1.3-7B</strong>.</li>
<li>Status: Frozen in Stage 1; LoRA fine-tuned in Stage 2.</li>
</ul>
</li>
<li><strong>Projector</strong>:
<ul>
<li>Architecture: Linear Layer.</li>
<li>Function: Maps node-level graph representation $Z_G \in \mathbb{R}^{N \times d}$ to the LLM&rsquo;s word embedding space dimensions.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metric Libraries</strong>: RDKit for validity/fingerprints, standard NLP libraries for BLEU/ROUGE.</li>
<li><strong>Reaction Metrics</strong>: Fingerprint <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto Similarity</a> (FTS), Exact Match, Levenshtein distance, and validity (via RDKit).</li>
<li><strong>Description Metrics</strong>: BLEU-2, BLEU-4, ROUGE-1, ROUGE-2, ROUGE-L, METEOR.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: 4 x NVIDIA RTX A6000 (48GB VRAM).</li>
<li><strong>Training Time</strong>:
<ul>
<li>Stage 1: 5 epochs.</li>
<li>Stage 2: 20-50 epochs (Description Generation), 10 epochs (Properties/Reactions).</li>
</ul>
</li>
<li><strong>Batch Size</strong>: 128 for both stages.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/IDEA-XL/InstructMol">InstructMol (GitHub)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Apache 2.0 (code), CC BY-NC 4.0 (data)</td>
          <td style="text-align: left">Training/evaluation scripts provided; fine-tuned weights listed as &ldquo;coming soon&rdquo;</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://huggingface.co/lmsys/vicuna-7b-v1.3">Vicuna-7B v1.3</a></td>
          <td style="text-align: left">Model</td>
          <td style="text-align: left">Non-commercial (LLaMA license)</td>
          <td style="text-align: left">Base LLM; must be downloaded separately</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://huggingface.co/chao1224/MoleculeSTM">MoleculeSTM</a></td>
          <td style="text-align: left">Model</td>
          <td style="text-align: left">MIT</td>
          <td style="text-align: left">Pre-trained graph encoder checkpoint</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Cao, H., Liu, Z., Lu, X., Yao, Y., &amp; Li, Y. (2025). InstructMol: Multi-Modal Integration for Building a Versatile and Reliable Molecular Assistant in Drug Discovery. <em>Proceedings of the 31st International Conference on Computational Linguistics</em>, 354-379.</p>
<p><strong>Publication</strong>: COLING 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{caoInstructMolMultiModalIntegration2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{InstructMol}}: {{Multi-Modal Integration}} for {{Building}} a {{Versatile}} and {{Reliable Molecular Assistant}} in {{Drug Discovery}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{InstructMol}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the 31st {{International Conference}} on {{Computational Linguistics}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Cao, He and Liu, Zijing and Lu, Xingyu and Yao, Yuan and Li, Yu}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">editor</span> = <span style="color:#e6db74">{Rambow, Owen and Wanner, Leo and Apidianaki, Marianna and {Al-Khalifa}, Hend and Eugenio, Barbara Di and Schockaert, Steven}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2025</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jan,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{354--379}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://aclanthology.org/2025.coling-main.25/}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Association for Computational Linguistics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">{Abu Dhabi, UAE}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">abstract</span> = <span style="color:#e6db74">{The rapid evolution of artificial intelligence in drug discovery encounters challenges with generalization and extensive training, yet Large Language Models (LLMs) offer promise in reshaping interactions with complex molecular data. Our novel contribution, InstructMol, a multi-modal LLM, effectively aligns molecular structures with natural language via an instruction-tuning approach, utilizing a two-stage training strategy that adeptly combines limited domain-specific data with molecular and textual information. InstructMol showcases substantial performance improvements in drug discovery-related molecular tasks, surpassing leading LLMs and significantly reducing the gap with specialists, thereby establishing a robust foundation for a versatile and dependable drug discovery assistant.}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/IDEA-XL/InstructMol">Official Repository</a></li>
</ul>
]]></content:encoded></item><item><title>DynamicFlow: Integrating Protein Dynamics into Drug Design</title><link>https://hunterheidenreich.com/notes/biology/computational-biology/dynamicflow/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/biology/computational-biology/dynamicflow/</guid><description>Flow matching model that co-generates ligands and flexible protein pockets, addressing rigid-receptor limitations in structure-based drug design.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Methodological Paper</strong> ($\Psi_{\text{Method}}$) with a strong <strong>Resource</strong> ($\Psi_{\text{Resource}}$) component.</p>
<ul>
<li><strong>Method</strong>: It proposes <strong>DynamicFlow</strong>, a novel multiscale architecture combining atom-level SE(3)-equivariant GNNs (SE(3) is the special Euclidean group in 3D: the set of all 3D rotations and translations, and equivariance means predictions transform consistently under those symmetries) and residue-level Transformers within a <a href="/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/">flow matching</a> framework to model the joint distribution of ligand generation and protein conformational change.</li>
<li><strong>Resource</strong>: It curates a significant dataset derived from MISATO, pairing AlphaFold2-predicted apo structures with multiple MD-simulated holo states, specifically filtered for flow matching tasks.</li>
</ul>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>Traditional Structure-Based Drug Design (SBDD) methods typically assume the protein target is rigid, which limits their applicability because proteins are dynamic and undergo conformational changes (induced fit) upon ligand binding.</p>
<ul>
<li><strong>Biological Reality</strong>: Proteins exist as ensembles of states; binding often involves transitions from &ldquo;apo&rdquo; (unbound) to &ldquo;holo&rdquo; (bound) <a href="/posts/geom-conformer-generation-dataset/">conformational changes</a>, sometimes revealing cryptic pockets.</li>
<li><strong>Computational Bottleneck</strong>: <a href="/notes/chemistry/molecular-simulation/">Molecular Dynamics (MD)</a> simulates these changes but incurs high computational costs due to energy barriers.</li>
<li><strong>Gap</strong>: <a href="/notes/machine-learning/generative-models/">Existing generative models</a> for SBDD mostly condition on a fixed pocket structure, ignoring the co-adaptation of the protein and ligand.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>simultaneous modeling of ligand generation and protein conformational dynamics</strong> using a unified flow matching framework.</p>
<ul>
<li><strong>DynamicFlow Architecture</strong>: A multiscale model that treats the protein as both full-atom (for interaction) and residue-level frames (for large-scale dynamics), utilizing separate flow matching objectives for backbone frames, side-chain torsions, and ligand atoms.</li>
<li><strong>Stochastic Flow (SDE)</strong>: Introduction of a <a href="/notes/machine-learning/generative-models/score-based-generative-modeling-sde/">stochastic variant</a> (DynamicFlow-SDE) that improves robustness and diversity compared to the deterministic ODE flow.</li>
<li><strong>Coupled Generation</strong>: The model learns to transport the <em>apo</em> pocket distribution to the <em>holo</em> pocket distribution while simultaneously denoising the ligand, advancing beyond rigid pocket docking methods.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the method on a curated dataset of 5,692 protein-ligand complexes.</p>
<ul>
<li><strong>Baselines</strong>: Compared against rigid-pocket SBDD methods: Pocket2Mol, TargetDiff, and IPDiff (adapted as TargetDiff* and IPDiff* for fair comparison of atom numbers). Also compared against conformation sampling baselines (Str2Str).</li>
<li><strong>Metrics</strong>:
<ul>
<li><strong>Ligand Quality</strong>: Vina Score (binding affinity), QED (drug-likeness), SA (synthesizability), Lipinski&rsquo;s rule of 5.</li>
<li><strong>Pocket Quality</strong>: RMSD between generated and ground-truth holo pockets, Cover Ratio (percentage of holo states successfully retrieved), and Pocket Volume distributions.</li>
<li><strong>Interaction</strong>: Protein-Ligand Interaction Profiler (PLIP) to measure specific non-covalent interactions.</li>
</ul>
</li>
<li><strong>Ablations</strong>: Tested the impact of the interaction loss, residue-level Transformer, and SDE vs. ODE formulations.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Improved Affinity</strong>: DynamicFlow-SDE achieved the best (lowest) Vina scores ($-7.65$) compared to baselines like TargetDiff ($-5.09$) and Pocket2Mol ($-5.50$). Note that Vina scores are a computational proxy and do not directly predict experimental binding affinity. Moreover, Vina score optimization is gameable: molecules can achieve strong computed binding energies while remaining synthetically inaccessible. QED and SA scores, which assess drug-likeness and synthesizability respectively, were reported but were not primary optimization targets in the paper, which limits the strength of this affinity claim.</li>
<li><strong>Realistic Dynamics</strong>: The model successfully generated holo-like pocket conformations with volume distributions and interaction profiles closer to ground-truth MD simulations than the initial apo structures.</li>
<li><strong>Enhancing Rigid Methods</strong>: Holo pockets generated by DynamicFlow served as better inputs for rigid-SBDD baselines (e.g., TargetDiff improved from $-5.09$ to $-9.00$ and IPDiff improved from $-7.55$ to $-11.04$ when using &ldquo;Our Pocket&rdquo;), suggesting the method can act as a &ldquo;pocket refiner&rdquo;.</li>
<li><strong>ODE vs. SDE Trade-off</strong>: The deterministic ODE variant achieves better pocket RMSD, while the stochastic SDE variant achieves better Cover Ratio (diversity of holo states captured) and binding affinity. Neither dominates uniformly.</li>
<li><strong>Conformation Sampling Baseline</strong>: Str2Str, a dedicated conformation sampling baseline, performed worse than simply perturbing the apo structure with noise. One interpretation is that this highlights the difficulty of the apo-to-holo prediction task; another is that Str2Str was not designed specifically for apo-to-holo prediction, making it a limited test of its capabilities.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The dataset is derived from <strong>MISATO</strong>, which contains MD trajectories for PDBbind complexes.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training/Test</strong></td>
          <td>Curated MISATO</td>
          <td>5,692 complexes</td>
          <td>Filtered for valid MD (<a href="/posts/kabsch-algorithm/">RMSD</a> $&lt; 3\text{\AA}$), clustered to remove redundancy. Contains 46,235 holo-ligand conformations total.</td>
      </tr>
      <tr>
          <td><strong>Apo Structures</strong></td>
          <td>AlphaFold2</td>
          <td>N/A</td>
          <td>Apo structures were obtained by mapping PDB IDs to UniProt and retrieving AlphaFold2 predictions, then aligning to MISATO structures.</td>
      </tr>
      <tr>
          <td><strong>Splits</strong></td>
          <td>Standard</td>
          <td>50 test complexes</td>
          <td>50 complexes with no overlap with the training set selected for testing. Note: 50 is a small held-out set; results should be interpreted cautiously.</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li><strong>Clustering</strong>: Holo-ligand conformations clustered with RMSD threshold $1.0\text{\AA}$; top 10 clusters kept per complex.</li>
<li><strong>Pocket Definition</strong>: Residues within $7\text{\AA}$ of the ligand.</li>
<li><strong>Alignment</strong>: AlphaFold predicted structures (apo) aligned to MISATO holo structures using sequence alignment (Smith-Waterman) to identify pocket residues.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Flow Matching Framework</strong>:</p>
<ul>
<li><strong>Continuous Variables</strong> (Pocket translation/rotation/torsions, Ligand positions): Modeled using <strong>Conditional Flow Matching (CFM)</strong>.
<ul>
<li><em>Prior</em>: Apo state for pocket; Normal distribution for ligand positions.</li>
<li><em>Target</em>: Holo state from MD; Ground truth ligand.</li>
<li><em>Interpolant</em>: Linear interpolation for Euclidean variables; Geodesic for rotations ($SO(3)$, the rotation-only subgroup of SE(3) containing all 3D rotations but not translations); Wrapped linear interpolation for torsions (Torus).</li>
</ul>
</li>
<li><strong>Discrete Variables</strong> (Ligand atom/bond types): Modeled using <strong>Discrete Flow Matching</strong> based on Continuous-Time Markov Chains (CTMC).
<ul>
<li><em>Rate Matrix</em>: Interpolates between mask token and data distribution.</li>
</ul>
</li>
<li><strong>Loss Function</strong>: Weighted sum of 7 losses:
<ol>
<li>Translation CFM (Eq 5)</li>
<li>Rotation CFM (Eq 7)</li>
<li>Torsion CFM (Eq 11)</li>
<li>Ligand Position CFM</li>
<li>Ligand Atom Type CTMC (Eq 14)</li>
<li>Ligand Bond Type CTMC</li>
<li><strong>Interaction Loss</strong> (Eq 18): Explicitly penalizes deviations in pairwise distances between protein and ligand atoms for pairs $\leq 3.5\text{\AA}$.</li>
</ol>
</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Architecture</strong>: <strong>DynamicFlow</strong> is a multiscale model with 15.9M parameters.</p>
<ol>
<li><strong>Atom-Level SE(3)-Equivariant GNN</strong>:
<ul>
<li><em>Input</em>: Complex graph (k-NN) and Ligand graph (fully connected).</li>
<li><em>Layers</em>: 6 EGNN blocks modified to maintain node and edge hidden states.</li>
<li><em>Function</em>: Updates ligand positions and predicts ligand atom/bond types.</li>
</ul>
</li>
<li><strong>Residue-Level Transformer</strong>:
<ul>
<li><em>Input</em>: Aggregated atom features from the GNN + Residue frames/torsions.</li>
<li><em>Layers</em>: 4 Transformer blocks with <strong>Invariant Point Attention (IPA)</strong>.</li>
<li><em>Function</em>: Updates protein residue frames (translation/rotation) and predicts side-chain torsions.</li>
</ul>
</li>
</ol>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>Vina Score</strong>: <code>vina_minimize</code> mode used for binding affinity.</li>
<li><strong>RMSD</strong>: Minimum RMSD between generated pocket and ground-truth holo conformations.</li>
<li><strong>Cover Ratio</strong>: % of ground-truth holo conformations covered by at least one generated sample (threshold $1.42\text{\AA}$).</li>
<li><strong>POVME 3</strong>: For pocket volume calculation.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Inference Benchmark</strong>: 1x Tesla V100-SXM2-32GB.</li>
<li><strong>Speed</strong>: Generates 10 ligands in ~35-36 seconds (100 NFE), significantly faster than diffusion baselines like Pocket2Mol (980s) or TargetDiff (156s).</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhou, X., Xiao, Y., Lin, H., He, X., Guan, J., Wang, Y., Liu, Q., Zhou, F., Wang, L., &amp; Ma, J. (2025). Integrating Protein Dynamics into Structure-Based Drug Design via Full-Atom Stochastic Flows. <em>International Conference on Learning Representations (ICLR)</em>. <a href="https://arxiv.org/abs/2503.03989">https://arxiv.org/abs/2503.03989</a></p>
<p><strong>Publication</strong>: ICLR 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{zhouIntegratingProteinDynamics2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Integrating Protein Dynamics into Structure-Based Drug Design via Full-Atom Stochastic Flows}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Zhou, Xiangxin and Xiao, Yi and Lin, Haowei and He, Xinheng and Guan, Jiaqi and Wang, Yang and Liu, Qiang and Zhou, Feng and Wang, Liang and Ma, Jianzhu}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://arxiv.org/abs/2503.03989}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2503.03989">arXiv Page</a></li>
<li>Code: no public repository available at time of writing</li>
</ul>
]]></content:encoded></item><item><title>Unified Framework for Handwritten Chemical Expressions</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/chang-unified-framework-2009/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/chang-unified-framework-2009/</guid><description>A 2009 unified framework for inorganic/organic chemical handwriting recognition using graph search and statistical symbol grouping.</description><content:encoded><![CDATA[<h2 id="addressing-the-complexity-of-handwritten-organic-chemistry">Addressing the Complexity of Handwritten Organic Chemistry</h2>
<p>This is a <strong>Methodological Paper</strong> ($\Psi_{\text{Method}}$) from Microsoft Research Asia that addresses the challenge of recognizing complex 2D organic chemistry structures. By 2009, math expression recognition had seen significant commercial progress, but chemical expression recognition remained less developed.</p>
<p>The specific gap addressed is the geometric complexity of organic formulas. While inorganic formulas typically follow a linear, equation-like structure, organic formulas present complex 2D diagrammatic structures with various bond types and rings. Existing work often relied on strong assumptions (like single-stroke symbols) or failed to handle arbitrary compounds. There was a clear need for a unified solution capable of handling both inorganic and organic domains consistently.</p>
<h2 id="the-chemical-expression-structure-graph-cesg">The Chemical Expression Structure Graph (CESG)</h2>
<p>The core innovation is a unified statistical framework that processes inorganic and organic expressions within the same pipeline. Key technical novelties include:</p>
<ol>
<li><strong>Unified Bond Modeling</strong>: Bonds are treated as special symbols. The framework detects &ldquo;extended bond symbols&rdquo; (multi-stroke bonds) and splits them into single, double, or triple bonds using corner detection for consistent processing.</li>
<li><strong>Chemical Expression Structure Graph (CESG)</strong>: A defined graph representation for generic chemical expressions where nodes represent symbols and edges represent bonds or spatial relations.</li>
<li><strong>Non-Symbol Modeling</strong>: During the symbol grouping phase, the system explicitly models &ldquo;invalid groups&rdquo; to reduce over-grouping errors.</li>
<li><strong>Global Graph Search</strong>: Structure analysis is formulated as finding the optimal CESG by searching over a Weighted Direction Graph ($G_{WD}$).</li>
</ol>
<h2 id="graph-search-and-statistical-validation">Graph Search and Statistical Validation</h2>
<p>The authors validated the framework on a proprietary database of 35,932 handwritten chemical expressions collected from 300 writers.</p>
<ul>
<li><strong>Setup</strong>: The data was split into roughly 26,000 training and 6,400 testing samples.</li>
<li><strong>Metric</strong>: Recognition accuracy was measured strictly by expression (all symbols and the complete structure must be correct).</li>
<li><strong>Ablations</strong>: The team evaluated the performance contribution of symbol grouping, structure analysis, and full semantic verification.</li>
</ul>
<h2 id="recognition-accuracy-and-outcomes">Recognition Accuracy and Outcomes</h2>
<p>The full framework achieved a Top-1 accuracy of 75.4% and a Top-5 accuracy of 83.1%.</p>
<ul>
<li><strong>Component Contribution</strong>: Structure analysis is the primary bottleneck. Adding it drops the theoretical &ldquo;perfect grouping&rdquo; performance from 85.9% to 74.1% (Top-1) due to structural errors.</li>
<li><strong>Semantic Verification</strong>: Checking valence and grammar improved relative accuracy by 1.7%.</li>
</ul>
<p>The unified framework effectively handles the variance in 2D space for chemical expressions, demonstrating that delayed decision-making (keeping top-N candidates) improves robustness.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">N/A</td>
          <td style="text-align: left">N/A</td>
          <td style="text-align: left">N/A</td>
          <td style="text-align: left">No public artifacts (code, data, models) were released by the authors.</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<p>The study used a private Microsoft Research Asia dataset, making direct reproduction difficult.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Total</td>
          <td>Proprietary MSRA DB</td>
          <td>35,932 expressions</td>
          <td>Written by 300 people</td>
      </tr>
      <tr>
          <td>Training</td>
          <td>Subset</td>
          <td>25,934 expressions</td>
          <td></td>
      </tr>
      <tr>
          <td>Testing</td>
          <td>Subset</td>
          <td>6,398 expressions</td>
          <td></td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Content</strong>: 2,000 unique expressions from high school/college textbooks.</li>
<li><strong>Composition</strong>: ~25% of samples are organic expressions.</li>
<li><strong>Vocabulary</strong>: 163 symbol classes (elements, digits, <code>+</code>, <code>↑</code>, <code>%</code>, bonds, etc.).</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Symbol Grouping (Dynamic Programming)</strong></p>
<ul>
<li>Objective: Find the optimal symbol sequence $G_{max}$ maximizing the posterior probability given the ink strokes:
$$ G_{max} = \arg\max_{G} P(G | \text{Ink}) $$</li>
<li><strong>Non-symbol modeling</strong>: Iteratively trained models on &ldquo;incorrect grouping results&rdquo; to learn to reject invalid strokes.</li>
<li><strong>Inter-group modeling</strong>: Uses Gaussian Mixture Models (GMM) to model spatial relations ($R_j$) between groups.</li>
</ul>
<p><strong>2. Bond Processing</strong></p>
<ul>
<li><strong>Extended Bond Symbol</strong>: Recognizes connected strokes (e.g., a messy double bond written in one stroke) as a single &ldquo;extended&rdquo; symbol.</li>
<li><strong>Splitting</strong>: Uses <strong>Curvature Scale Space (CSS)</strong> corner detection to split extended symbols into primitive lines.</li>
<li><strong>Classification</strong>: A Neural Network verifies if the split lines form valid single, double, or triple bonds.</li>
</ul>
<p><strong>3. Structure Analysis (Graph Search)</strong></p>
<ul>
<li><strong>Graph Construction</strong>: Builds a Weighted Direction Graph ($G_{WD}$) where nodes are symbol candidates and edges are potential relationships ($E_{c}, E_{nc}, E_{peer}, E_{sub}$).</li>
<li><strong>Edge Weights</strong>: Calculated as the product of observation, spatial, and contextual probabilities:
$$ W(S, O, R) = P(O|S) \times P(\text{Spatial}|R) \times P(\text{Context}|S, R) $$
<ul>
<li>Spatial probability uses rectangular control regions and distance functions.</li>
<li>Contextual probability uses statistical co-occurrence (e.g., &lsquo;C&rsquo; often appears with &lsquo;H&rsquo;).</li>
</ul>
</li>
<li><strong>Search</strong>: Breadth-first search with pruning to find the top-N optimal CESGs.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Symbol Recognition</strong>: Implementation details not specified, but likely HMM or NN based on the era. Bond verification explicitly uses a <strong>Neural Network</strong>.</li>
<li><strong>Spatial Models</strong>: <strong>Gaussian Mixture Models (GMM)</strong> are used to model the 9 spatial relations (e.g., Left-super, Above, Subscript).</li>
<li><strong>Semantic Model</strong>: A <strong>Context-Free Grammar (CFG)</strong> parser is used for final verification (e.g., ensuring digits aren&rsquo;t reactants).</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Evaluation is performed using &ldquo;Expression-level accuracy&rdquo;.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value (Top-1)</th>
          <th>Value (Top-5)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Full Framework</td>
          <td>75.4%</td>
          <td>83.1%</td>
          <td></td>
      </tr>
      <tr>
          <td>Without Semantics</td>
          <td>74.1%</td>
          <td>83.0%</td>
          <td></td>
      </tr>
      <tr>
          <td>Grouping Only</td>
          <td>85.9%</td>
          <td>95.6%</td>
          <td>Theoretical max if structure analysis was perfect</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chang, M., Han, S., &amp; Zhang, D. (2009). A Unified Framework for Recognizing Handwritten Chemical Expressions. <em>2009 10th International Conference on Document Analysis and Recognition</em>, 1345&ndash;1349. <a href="https://doi.org/10.1109/ICDAR.2009.64">https://doi.org/10.1109/ICDAR.2009.64</a></p>
<p><strong>Publication</strong>: ICDAR 2009</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{changUnifiedFrameworkRecognizing2009,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{A {{Unified Framework}} for {{Recognizing Handwritten Chemical Expressions}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{2009 10th {{International Conference}} on {{Document Analysis}} and {{Recognition}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Chang, Ming and Han, Shi and Zhang, Dongmei}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2009</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{3}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1345--1349}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{IEEE}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">{Barcelona, Spain}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1109/ICDAR.2009.64}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SVM-HMM Online Classifier for Chemical Symbols</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/zhang-svm-hmm-2010/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/zhang-svm-hmm-2010/</guid><description>A dual-stage classifier combining SVM and HMM to recognize online handwritten chemical symbols, introducing a reordering algorithm for organic rings.</description><content:encoded><![CDATA[<h2 id="contribution-double-stage-classification-method">Contribution: Double-Stage Classification Method</h2>
<p><strong>Method</strong>.
This paper is a methodological contribution that proposes a novel &ldquo;double-stage classifier&rdquo; architecture. It fits the taxonomy by introducing a specific algorithmic pipeline (SVM rough classification followed by HMM fine classification) and a novel pre-processing algorithm (Point Sequence Reordering) to solve technical limitations in recognizing organic ring structures. The contribution is validated through ablation studies (comparing SVM kernels and HMM state/Gaussian counts) and performance benchmarks.</p>
<h2 id="motivation-recognizing-complex-organic-ring-structures">Motivation: Recognizing Complex Organic Ring Structures</h2>
<p>The primary motivation is the complexity of recognizing handwritten chemical symbols, specifically the distinction between <strong>Organic Ring Structures (ORS)</strong> and <strong>Non-Ring Structures (NRS)</strong>. Existing single-stage classifiers are unreliable for ORS because these symbols have arbitrary writing styles, variable stroke numbers, and inconsistent stroke orders due to their 2D hexagonal structure. A robust system is needed to handle this uncertainty and achieve high accuracy.</p>
<h2 id="core-innovation-point-sequence-reordering-psr">Core Innovation: Point Sequence Reordering (PSR)</h2>
<p>The authors introduce two main novelties:</p>
<ol>
<li><strong>Double-Stage Architecture</strong>: A hybrid system where an SVM (using RBF kernel) first roughly classifies inputs as either ORS or NRS, followed by specialized HMMs for fine-grained recognition.</li>
<li><strong>Point Sequence Reordering (PSR) Algorithm</strong>: A stroke-order independent algorithm designed specifically for ORS. It reorders the point sequence of a symbol based on a counter-clockwise scan from the centroid, effectively eliminating the uncertainty caused by variations in stroke number and writing order.</li>
</ol>
<h2 id="methodology--experimental-design">Methodology &amp; Experimental Design</h2>
<p>The authors collected a custom dataset and performed sequential optimizations:</p>
<ul>
<li><strong>SVM Optimization</strong>: Compared Polynomial, RBF, and Sigmoid kernels to find the best rough classifier.</li>
<li><strong>HMM Optimization</strong>: Tested multiple combinations of states (4, 6, 8) and Gaussians (3, 4, 6, 8, 9, 12) to maximize fine classification accuracy.</li>
<li><strong>PSR Validation</strong>: Conducted an ablation study comparing HMM accuracy on ORS symbols &ldquo;Before PSR&rdquo; vs &ldquo;After PSR&rdquo; to quantify the algorithm&rsquo;s impact.</li>
</ul>
<h2 id="results--final-conclusions">Results &amp; Final Conclusions</h2>
<ul>
<li><strong>Architecture Performance</strong>: The RBF-based SVM achieved 99.88% accuracy in differentiating ORS from NRS.</li>
<li><strong>HMM Configuration</strong>: The optimal HMM topology was found to be 8-states and 12-Gaussians for both symbol types.</li>
<li><strong>PSR Impact</strong>: The PSR algorithm improved ORS recognition. Top-1 accuracy shifted from <strong>49.84% (Before PSR)</strong> to <strong>98.36% (After PSR)</strong>.</li>
<li><strong>Overall Accuracy</strong>: The final integrated system achieved a Top-1 accuracy of <strong>93.10%</strong> and Top-3 accuracy of <strong>98.08%</strong> on the test set.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study defined 101 chemical symbols split into two categories.</p>
<table>
  <thead>
      <tr>
          <th>Category</th>
          <th>Count</th>
          <th>Content</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>NRS</strong> (Non-Ring)</td>
          <td>63</td>
          <td>Digits 0-9, 44 letters, 9 operators</td>
          <td>Operators include +, -, =, $\rightarrow$, etc.</td>
      </tr>
      <tr>
          <td><strong>ORS</strong> (Organic Ring)</td>
          <td>38</td>
          <td>2D hexagonal structures</td>
          <td>Benzene rings, cyclohexane, etc.</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Collection</strong>: 12,322 total samples (122 per symbol) collected from 20 writers (teachers and students).</li>
<li><strong>Split</strong>: 9,090 training samples and 3,232 test samples.</li>
<li><strong>Constraints</strong>: Three specifications were used: normal, standard, and freestyle.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. SVM Feature Extraction (Rough Classification)</strong>
The input strokes are scaled, and a 58-dimensional feature vector is calculated:</p>
<ul>
<li><strong>Mesh ($4 \times 4$)</strong>: Ratio of points in 16 grids (16 features).</li>
<li><strong>Outline</strong>: Normalized scan distance from 4 edges with 5 scan lines each (20 features).</li>
<li><strong>Projection</strong>: Point density in 5 bins per edge (20 features).</li>
<li><strong>Aspect Ratio</strong>: Height/Width ratios (2 features).</li>
</ul>
<p><strong>2. Point Sequence Reordering (PSR)</strong>
Used strictly for ORS preprocessing:</p>
<ol>
<li>Calculate the centroid $(x_c, y_c)$ of the symbol.</li>
<li>Initialize a scan line at angle $\theta = 0$.</li>
<li>Traverse points; if a point $p_i = (x_i, y_i)$ satisfies the distance threshold to the scan line, add it to the reordered list. Distance $d_i$ is calculated as:
$$ d_i = |(y_i - y_c)\cos(\theta) - (x_i - x_c)\sin(\theta)| $$</li>
<li>Increment $\theta$ by $\Delta\theta$ and repeat until a full circle ($2\pi$) is completed.</li>
</ol>
<h3 id="models">Models</h3>
<ul>
<li><strong>SVM (Stage 1)</strong>: RBF Kernel was selected as optimal with parameters $C=512$ and $\gamma=0.5$.</li>
<li><strong>HMM (Stage 2)</strong>: Left-right continuous HMM trained via Baum-Welch algorithm. The topology is one model per symbol using <strong>8 states and 12 Gaussians</strong>.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics reported are Top-1, Top-2, and Top-3 accuracy on the held-out test set.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>NRS Accuracy</th>
          <th>ORS Accuracy</th>
          <th>Overall Test Accuracy</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Top-1</strong></td>
          <td>91.91%</td>
          <td>97.53%</td>
          <td>93.10%</td>
      </tr>
      <tr>
          <td><strong>Top-3</strong></td>
          <td>99.12%</td>
          <td>99.34%</td>
          <td>98.08%</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Device</strong>: HP Pavilion tx1000 Tablet PC.</li>
<li><strong>Processor</strong>: 2.00GHz CPU.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhang, Y., Shi, G., &amp; Wang, K. (2010). A SVM-HMM Based Online Classifier for Handwritten Chemical Symbols. <em>2010 International Conference on Pattern Recognition</em>, 1888&ndash;1891. <a href="https://doi.org/10.1109/ICPR.2010.465">https://doi.org/10.1109/ICPR.2010.465</a></p>
<p><strong>Publication</strong>: ICPR 2010</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{zhang2010svm,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{A SVM-HMM Based Online Classifier for Handwritten Chemical Symbols}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{2010 International Conference on Pattern Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Zhang, Yang and Shi, Guangshun and Wang, Kai}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2010}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1888--1891}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{IEEE}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1109/ICPR.2010.465}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Recognition of On-line Handwritten Chemical Expressions</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/yang-online-handwritten-2008/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/yang-online-handwritten-2008/</guid><description>A two-level recognition algorithm for on-line handwritten chemical expressions using structural and syntactic features.</description><content:encoded><![CDATA[<h2 id="contribution-on-line-chemical-expression-recognition-framework">Contribution: On-line Chemical Expression Recognition Framework</h2>
<p>This is a <strong>Method</strong> paper. It proposes a novel architectural pipeline (&ldquo;Algorithm Model&rdquo;) for recognizing on-line handwritten chemical expressions. The paper focuses on detailing the specific mechanisms of this pipeline (pre-processing, segmentation, two-level recognition, and HCI) and validates its effectiveness through quantitative comparison against a conventional baseline. The rhetorical structure aligns with the &ldquo;Methodological Basis&rdquo; of the taxonomy, prioritizing the &ldquo;how well does this work?&rdquo; question over theoretical derivation or dataset curation.</p>
<h2 id="motivation-the-hci-gap-in-chemical-drawing">Motivation: The HCI Gap in Chemical Drawing</h2>
<p>The authors identify a gap in existing human-computer interaction (HCI) for chemistry. While mathematical formula recognition had seen progress, chemical expression recognition was under-researched. Existing tools relied on keyboard/mouse input, which was time-consuming and inefficient for the complex, variable nature of chemical structures. Previous attempts were either too slow (vectorization-based) or failed to leverage specific chemical knowledge effectively. There was a practical need for a system that could handle the specific syntactic rules of chemistry in an on-line (real-time) handwriting setting.</p>
<h2 id="novelty-two-level-recognition-architecture">Novelty: Two-Level Recognition Architecture</h2>
<p>The core contribution is a <strong>two-level recognition algorithm</strong> that integrates chemical domain knowledge.</p>
<ul>
<li><strong>Level 1 (Substance Level):</strong> Treats connected strokes as a potential &ldquo;substance unit&rdquo; (e.g., &ldquo;H2O&rdquo;) and matches them against a dictionary using a modified edit distance algorithm.</li>
<li><strong>Level 2 (Character Level):</strong> If the substance match fails, it falls back to segmenting the unit into isolated characters and reconstructing them using syntactic rules.</li>
<li><strong>Hybrid Segmentation:</strong> Combines structural analysis (using bounding box geometry for super/subscript detection) with &ldquo;partial recognition&rdquo; (identifying special symbols like <code>+</code>, <code>=</code>, <code>-&gt;</code> early to split the expression).</li>
</ul>
<h2 id="methodology-custom-dataset-and-baseline-comparisons">Methodology: Custom Dataset and Baseline Comparisons</h2>
<p>The authors conducted a validation experiment in a laboratory environment with 20 participants (chemistry students and teachers).</p>
<ul>
<li><strong>Dataset:</strong> 1,197 total samples (983 from a standard set of 341 expressions, 214 arbitrary expressions written by users).</li>
<li><strong>Baselines:</strong> They compared their &ldquo;Two-Level&rdquo; algorithm against a &ldquo;Conventional&rdquo; algorithm that skips the substance-level check and directly recognizes characters (&ldquo;Recognize Character Directly&rdquo;).</li>
<li><strong>Conditions:</strong> They also tested the impact of their Human-Computer Interaction (HCI) module which allows user corrections.</li>
</ul>
<h2 id="results-high-accuracy-and-hci-corrections">Results: High Accuracy and HCI Corrections</h2>
<ul>
<li><strong>Accuracy:</strong> The proposed two-level algorithm achieved significantly higher accuracy (<strong>96.4%</strong> for expression recognition) compared to the conventional baseline (<strong>91.5%</strong>).</li>
<li><strong>Robustness:</strong> The method performed well even on &ldquo;arbitrary&rdquo; expressions not in the standard set (92.5% accuracy vs 88.2% baseline).</li>
<li><strong>HCI Impact:</strong> Allowing users to modify results via the HCI module pushed final accuracy to high levels (<strong>98.8%</strong>).</li>
<li><strong>Conclusion:</strong> The authors concluded the algorithm is reliable for real applications and flexible enough to be extended to other domains like physics or engineering.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The paper does not use a public benchmark but collected its own data for validation.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Validation</strong></td>
          <td style="text-align: left">Custom Lab Dataset</td>
          <td style="text-align: left">1,197 samples</td>
          <td style="text-align: left">Collected from 20 chemistry students/teachers using Tablet PCs. Includes 341 standard expressions + arbitrary user inputs.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The pipeline consists of four distinct phases with specific algorithmic choices:</p>
<p><strong>1. Pre-processing</strong></p>
<ul>
<li><strong>Smoothing:</strong> Uses a 5-tap Gaussian low-pass filter (Eq. 1) with specific coefficients to smooth stroke data.</li>
<li><strong>Redundancy:</strong> Merges redundant points and removes &ldquo;prickles&rdquo; (isolated noise).</li>
<li><strong>Re-ordering:</strong> Strokes are spatially re-sorted left-to-right, top-to-down to correct for arbitrary writing order.</li>
</ul>
<p><strong>2. Segmentation</strong></p>
<ul>
<li><strong>Structural Analysis:</strong> Distinguishes relationships (Superscript vs. Subscript vs. Horizontal) using a geometric feature vector $(T, B)$ based on bounding box heights ($h$), vertical centers ($C$), and barycenters ($B_{bary}$):
$$
\begin{aligned}
d &amp;= 0.7 \cdot y_{12} - y_{22} + 0.3 \cdot y_{11} \\
T &amp;= 1000 \cdot \frac{d}{h_1} \\
B &amp;= 1000 \cdot \frac{B_{bary1} - B_{bary2}}{h_1}
\end{aligned}
$$</li>
<li><strong>Partial Recognition:</strong> Detects special symbols (<code>+</code>, <code>=</code>, <code>-&gt;</code>) early to break expressions into &ldquo;super-substance units&rdquo; (e.g., separating reactants from products).</li>
</ul>
<p><strong>3. Recognition (Two-Level)</strong></p>
<ul>
<li><strong>Level 1 (Dictionary Match):</strong>
<ul>
<li>Uses a modified <strong>Edit Distance</strong> (Eq. 6) incorporating a specific distance matrix based on chemical syntax.</li>
<li>Similarity $\lambda_{ij}$ is weighted by stroke credibility $\mu_i$ and normalized by string length.</li>
</ul>
</li>
<li><strong>Level 2 (Character Segmentation):</strong>
<ul>
<li>Falls back to this if Level 1 fails.</li>
<li>Segments characters by analyzing pixel density in horizontal/vertical/diagonal directions to find concave/convex points.</li>
<li>Recombines characters using syntactic rules (e.g., valency checks) to verify validity.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Evaluation focused on recognition accuracy at both the character and expression level.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">Value (Proposed)</th>
          <th style="text-align: left">Value (Baseline)</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Expression Accuracy (EA)</strong></td>
          <td style="text-align: left"><strong>96.4%</strong></td>
          <td style="text-align: left">91.5%</td>
          <td style="text-align: left">&ldquo;Standard&rdquo; dataset subset.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Expression Accuracy (EA)</strong></td>
          <td style="text-align: left"><strong>92.5%</strong></td>
          <td style="text-align: left">88.2%</td>
          <td style="text-align: left">&ldquo;Other&rdquo; (arbitrary) dataset subset.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>HCI-Assisted Accuracy</strong></td>
          <td style="text-align: left"><strong>98.8%</strong></td>
          <td style="text-align: left">N/A</td>
          <td style="text-align: left">Accuracy after user correction.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Input Devices:</strong> Tablet PCs were used for data collection and testing.</li>
<li><strong>Compute:</strong> Specific training hardware is not listed, but the algorithm is designed for real-time interaction on standard 2008-era computing devices.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Yang, J., Shi, G., Wang, Q., &amp; Zhang, Y. (2008). Recognition of On-line Handwritten Chemical Expressions. <em>2008 IEEE International Joint Conference on Neural Networks</em>, 2360&ndash;2365. <a href="https://doi.org/10.1109/IJCNN.2008.4634125">https://doi.org/10.1109/IJCNN.2008.4634125</a></p>
<p><strong>Publication</strong>: IJCNN 2008</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{jufengyangRecognitionOnlineHandwritten2008,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Recognition of On-Line Handwritten Chemical Expressions}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{2008 {{IEEE International Joint Conference}} on {{Neural Networks}} ({{IEEE World Congress}} on {{Computational Intelligence}})}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{{Jufeng Yang} and {Guangshun Shi} and {Qingren Wang} and {Yong Zhang}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2008</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jun,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{2360--2365}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{IEEE}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">{Hong Kong, China}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1109/IJCNN.2008.4634125}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-12-17}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">isbn</span> = <span style="color:#e6db74">{978-1-4244-1820-6}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Online Handwritten Chemical Formula Structure Analysis</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/wang-online-handwritten-2009/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/wang-online-handwritten-2009/</guid><description>A hierarchical grammar-based approach for recognizing and analyzing online handwritten chemical formulas in mobile education contexts.</description><content:encoded><![CDATA[<h2 id="hierarchical-grammatical-framework-contribution">Hierarchical Grammatical Framework Contribution</h2>
<p>This is a <strong>Method</strong> paper. It proposes a novel architectural framework for processing chemical formulas by decomposing them into three hierarchical levels (Formula, Molecule, Text). The contribution is defined by a specific set of formal grammatical rules and parsing algorithms used to construct a &ldquo;grammar spanning tree&rdquo; and &ldquo;molecule spanning graph&rdquo; from online handwritten strokes.</p>
<h2 id="motivation-for-online-formula-recognition">Motivation for Online Formula Recognition</h2>
<p>The primary motivation is the application of mobile computing in chemistry education, where precise comprehension of casual, <em>online</em> handwritten formulas is a significant challenge.</p>
<ul>
<li><strong>2D Complexity</strong>: Unlike 1D text, chemical formulas utilize complex 2D spatial relationships that convey specific chemical meaning (e.g., bonds, rings).</li>
<li><strong>Format Limitations</strong>: Existing storage formats like CML (Chemical Markup Language) or MDL MOLFILE do not natively record the layout or abbreviated information necessary for recognizing handwritten input.</li>
<li><strong>Online Gap</strong>: Previous research focused heavily on <em>offline</em> (image-based) recognition, lacking solutions for <em>online</em> (stroke-based) handwritten chemical formulas (OHCF).</li>
</ul>
<h2 id="core-novelty-in-three-level-grammatical-analysis">Core Novelty in Three-Level Grammatical Analysis</h2>
<p>The core novelty is the <strong>Three-Level Grammatical Analysis</strong> approach:</p>
<ol>
<li><strong>Formula Level (1D)</strong>: Treats the reaction equation as a linear sequence of components (Reactants, Products, Separators), parsed via a context-free grammar to build a spanning tree.</li>
<li><strong>Molecule Level (2D)</strong>: Treats molecules as graphs where &ldquo;text groups&rdquo; are vertices and &ldquo;bonds&rdquo; are edges. It introduces specific handling for &ldquo;hidden Carbon dots&rdquo; (intersections of bonds without text).</li>
<li><strong>Text Level (1D)</strong>: Analyzes the internal structure of text groups (atoms, subscripts).</li>
</ol>
<p>Unique to this approach is the <strong>formal definition of the chemical grammar</strong> as a 5-tuple $G=(T,N,P,M,S)$ and the generation of an <strong>Adjacency Matrix</strong> directly from the handwritten sketch to represent chemical connectivity.</p>
<h2 id="experimental-validation-on-handwritten-strokes">Experimental Validation on Handwritten Strokes</h2>
<p>The authors validated their model using a custom dataset of online handwritten formulas.</p>
<ul>
<li><strong>Data Source</strong>: 25 formulas were randomly selected from a larger pool of 1,250 samples.</li>
<li><strong>Scope</strong>: The test set included 484 total symbols, comprising generators, separators, text symbols, rings, and various bond types.</li>
<li><strong>Granular Validation</strong>: The system was tested at multiple distinct stages:
<ul>
<li>Key Symbol Extraction (Formula Level)</li>
<li>Text Localization (Molecule Level)</li>
<li>Bond End Grouping (Molecule Level)</li>
<li>Text Recognition (Text Level)</li>
</ul>
</li>
</ul>
<h2 id="downstream-impact-and-parsing-accuracy">Downstream Impact and Parsing Accuracy</h2>
<p>The system achieved high accuracy across all sub-tasks, demonstrating that the hierarchical grammar approach is effective for both inorganic and organic formulas.</p>
<ul>
<li><strong>Formula Level</strong>: 98.3% accuracy for Key Symbols; 100% for State-assisted symbols.</li>
<li><strong>Molecule Level</strong>: 98.8% accuracy for Bond End Grouping; 100% for Free End-Text connection detection.</li>
<li><strong>Text Recognition</strong>: 98.7% accuracy (Top-3) using HMMs.</li>
<li><strong>Impact</strong>: The method successfully preserves the writer&rsquo;s &ldquo;online information&rdquo; (habits/intentions) while converting the handwritten input into standard formats suitable for visual editing or data retrieval.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>To replicate this work, one would need to implement the specific grammatical production rules and the geometric thresholds defined for bond analysis.</p>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training</strong></td>
          <td>Symbol HMMs</td>
          <td>5,670 samples</td>
          <td>Used to train the text recognition module</td>
      </tr>
      <tr>
          <td><strong>Testing</strong></td>
          <td>Text Recognition</td>
          <td>2,016 samples</td>
          <td>Test set for character HMMs</td>
      </tr>
      <tr>
          <td><strong>Testing</strong></td>
          <td>Formula Analysis</td>
          <td>25 formulas</td>
          <td>Random subset of 1,250 collected samples; contains 484 symbols</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Formula Level Parsing</strong></p>
<ul>
<li><strong>HBL Analysis</strong>: Identify the &ldquo;Horizontal Baseline&rdquo; (HBL) containing the most symbols to locate key operators (e.g., $+$, $\rightarrow$).</li>
<li><strong>Grammar</strong>: Use the productions defined in Figure 4. Example rules include:
<ul>
<li>$Reaction ::= ReactantList \ Generator \ ProductList$</li>
<li>$Reactant ::= BalancingNum \ Molecule \ IonicCharacter$</li>
</ul>
</li>
</ul>
<p><strong>2. Molecule Level Analysis (Bond Grouping)</strong></p>
<ul>
<li><strong>Endpoint Classification</strong>: Points are classified as <em>free ends</em>, <em>junctions</em> (3+ bonds), or <em>connections</em> (2 bonds).</li>
<li><strong>Grouping Equation</strong>: An endpoint $(x_k, y_k)$ belongs to Group A based on distance thresholding:
$$
\begin{aligned}
Include(x_0, y_0) = \begin{cases} 1, &amp; d_0 &lt; t \cdot \max d_x + \partial \\ 0, &amp; \text{else} \end{cases}
\end{aligned}
$$
Where $d_k$ is the Euclidean distance to the group center $(x_a, y_a)$.</li>
</ul>
<p><strong>3. Connection Detection</strong></p>
<ul>
<li><strong>Text-Bond Connection</strong>: A text group is connected to a bond if the free end falls within a bounding box expanded by thresholds $t_W$ and $t_H$:
$$
\begin{aligned}
Con(x,y) = \begin{cases} 1, &amp; \min x - t_W &lt; x &lt; \max x + t_W \text{ AND } \min y - t_H &lt; y &lt; \max y + t_H \\ 0, &amp; \text{else} \end{cases}
\end{aligned}
$$</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Text Recognition</strong>: Hidden Markov Models (HMM) are used for recognizing individual text symbols.</li>
<li><strong>Grammar</strong>: Context-Free Grammar (CFG) designed with ambiguity elimination to ensure a single valid parse tree for any valid formula.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Performance is measured by recognition accuracy at specific processing stages:</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Accuracy</td>
          <td>F1 (Key Symbol Extraction)</td>
          <td>98.3%</td>
          <td>Formula Level</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td>F2 (State-assisted Symbol)</td>
          <td>100%</td>
          <td>Formula Level</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td>M2 (Bond End Grouping)</td>
          <td>98.8%</td>
          <td>Molecule Level</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td>M3 (Free End-Text Conn)</td>
          <td>100%</td>
          <td>Molecule Level</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td>T1 (Text Recognition)</td>
          <td>98.7%</td>
          <td>Top-3 Accuracy</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Wang, X., Shi, G., &amp; Yang, J. (2009). The Understanding and Structure Analyzing for Online Handwritten Chemical Formulas. <em>2009 10th International Conference on Document Analysis and Recognition</em>, 1056&ndash;1060. <a href="https://doi.org/10.1109/ICDAR.2009.70">https://doi.org/10.1109/ICDAR.2009.70</a></p>
<p><strong>Publication</strong>: ICDAR 2009</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{wangUnderstandingStructureAnalyzing2009,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{The {{Understanding}} and {{Structure Analyzing}} for {{Online Handwritten Chemical Formulas}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{2009 10th {{International Conference}} on {{Document Analysis}} and {{Recognition}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Wang, Xin and Shi, Guangshun and Yang, Jufeng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2009}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1056--1060}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{IEEE}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">{Barcelona, Spain}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1109/ICDAR.2009.70}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">isbn</span> = <span style="color:#e6db74">{978-1-4244-4500-4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">langid</span> = <span style="color:#e6db74">{english}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>On-line Handwritten Chemical Expression Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/yang-icpr-2008/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/yang-icpr-2008/</guid><description>Two-level algorithm for recognizing on-line handwritten chemical expressions using structural analysis, ANNs, and string edit distance.</description><content:encoded><![CDATA[<h2 id="a-methodological-approach-to-chemical-recognition">A Methodological Approach to Chemical Recognition</h2>
<p>This is a <strong>Method</strong> paper. It proposes a specific &ldquo;novel two-level algorithm&rdquo; and a &ldquo;System model&rdquo; for recognizing chemical expressions. The paper focuses on the architectural design of the recognition pipeline (segmentation, substance recognition, symbol recognition) and validates it against a &ldquo;conventional algorithm&rdquo; baseline, fitting the standard profile of a methodological contribution.</p>
<h2 id="bridging-the-gap-in-pen-based-chemical-input">Bridging the Gap in Pen-Based Chemical Input</h2>
<p>While pen-based computing has advanced for text and mathematical formulas, inputting chemical expressions remains &ldquo;time-consuming&rdquo;. Existing research often lacks &ldquo;adequate chemical knowledge&rdquo; or relies on algorithms that are too slow (global optimization) or structurally weak (local optimization). The authors aim to bridge this gap by integrating chemical domain knowledge into the recognition process to improve speed and accuracy.</p>
<h2 id="two-level-recognition-strategy-for-formulas">Two-Level Recognition Strategy for Formulas</h2>
<p>The core novelty is a <strong>two-level recognition strategy</strong>:</p>
<ol>
<li><strong>Level 1 (Substance Recognition)</strong>: Uses global structural information to identify entire &ldquo;substance units&rdquo; (e.g., $H_2SO_4$) by matching against a dictionary.</li>
<li><strong>Level 2 (Symbol Recognition)</strong>: If Level 1 fails, the system falls back to segmenting the substance into isolated characters and recognizing them individually.</li>
</ol>
<p>Additionally, the method integrates <strong>syntactic features</strong> (chemical knowledge) such as element conservation to validate and correct results and uses specific geometric features to distinguish superscript/subscript relationships.</p>
<h2 id="dataset-collection-and-baseline-comparisons">Dataset Collection and Baseline Comparisons</h2>
<ul>
<li><strong>Dataset Collection</strong>: The authors collected 1197 handwritten expression samples from 20 chemistry professionals and students. This included 983 &ldquo;standard&rdquo; expressions (from 341 templates) and 214 &ldquo;arbitrary&rdquo; expressions written freely.</li>
<li><strong>Comparison</strong>: They compared their &ldquo;Two-level recognition&rdquo; approach against a &ldquo;conventional algorithm&rdquo; that shields the first level (directly segmenting into characters).</li>
<li><strong>Metrics</strong>: They measured Material Accuracy (MA), Correct Expressions Number (AEN), and Expression Accuracy (EA).</li>
</ul>
<h2 id="high-accuracy-in-formula-recognition">High Accuracy in Formula Recognition</h2>
<ul>
<li><strong>High Accuracy</strong>: The proposed algorithm achieved <strong>96.4% Material Accuracy (MA)</strong> and <strong>95.7% Expression Accuracy (EA)</strong> on the total test set.</li>
<li><strong>Robustness</strong>: The method performed well on both standard (96.3% EA) and arbitrary (92.5% EA) expressions.</li>
<li><strong>Validation</strong>: The authors conclude the algorithm is &ldquo;reliable,&rdquo; &ldquo;flexible,&rdquo; and suitable for real-time applications compared to prior work.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors constructed two distinct datasets for training and evaluation:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Symbol Training</strong></td>
          <td style="text-align: left">ISF Files</td>
          <td style="text-align: left">12,240 files</td>
          <td style="text-align: left">Used to train the ANN classifier. Covers 102 symbol classes (numerals, letters, operators, organic loops).</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Expression Testing</strong></td>
          <td style="text-align: left">Handwritten Expressions</td>
          <td style="text-align: left">1,197 samples</td>
          <td style="text-align: left">983 standard + 214 arbitrary expressions collected from 20 chemistry teachers/students.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Structural Segmentation (Superscript/Subscript)</strong></p>
<p>To distinguish relationships (superscript, subscript, in-line), the authors define geometric parameters based on the bounding boxes of adjacent symbols ($x_{i1}, y_{i1}, x_{i2}, y_{i2}$):</p>
<p>$$d = 0.7 \times y_{12} - y_{22} + 0.3 \times y_{11}$$
$$T = 1000 \times d/h$$
$$B = 1000 \times (B_1 - B_2)/h_1$$</p>
<p>Where $B_1, B_2$ are barycenters and $h$ is height. $(T, B)$ serves as the feature vector for classification.</p>
<p><strong>2. Segmentation Reliability</strong></p>
<p>For segmenting strokes into units, the reliability of a segmentation path is calculated as:</p>
<p>$$Cof(K_{i},n)=\sum_{j=0}^{N}P(k_{j},k_{j+1})+P(S_{K_{i}})+\delta(N)$$</p>
<p>Where $P(k_j, k_{j+1})$ is the reliability of strokes being recognized as symbol $S_{k_j}$.</p>
<p><strong>3. Substance Matching (Level 1)</strong></p>
<p>A modified string edit distance is used to match handwritten input against a dictionary:</p>
<p>$$\lambda_{\overline{u}}=\mu_{i} \times f(Dis(i,j,r)/\sqrt{Max(Len_{i},Len_{j})})$$</p>
<p>Where $\mu_i$ is the recognizer credibility and $Dis(i,j,r)$ is the edit distance.</p>
<h3 id="models">Models</h3>
<ul>
<li><strong>Classifier</strong>: An ANN-based classifier is used for isolated symbol recognition.</li>
<li><strong>Input Features</strong>: A set of ~30 features is extracted from strokes, including writing time, interval time, elastic mesh, and stroke outline.</li>
<li><strong>Performance</strong>: The classifier achieved 92.1% accuracy on a test set of 2,702 isolated symbols.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>The system was evaluated on the 1,197 expression samples.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">Value (Total)</th>
          <th style="text-align: left">Value (Standard)</th>
          <th style="text-align: left">Value (Other)</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Material Accuracy (MA)</strong></td>
          <td style="text-align: left">96.4%</td>
          <td style="text-align: left">97.7%</td>
          <td style="text-align: left">94%</td>
          <td style="text-align: left">Accuracy of substance recognition.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Expression Accuracy (EA)</strong></td>
          <td style="text-align: left">95.7%</td>
          <td style="text-align: left">96.3%</td>
          <td style="text-align: left">92.5%</td>
          <td style="text-align: left">Accuracy of full expression recognition.</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Yang, J., Shi, G., Wang, K., Geng, Q., &amp; Wang, Q. (2008). A Study of On-Line Handwritten Chemical Expressions Recognition. <em>2008 19th International Conference on Pattern Recognition</em>, 1&ndash;4. <a href="https://doi.org/10.1109/ICPR.2008.4761824">https://doi.org/10.1109/ICPR.2008.4761824</a></p>
<p><strong>Publication</strong>: ICPR 2008</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{yangStudyOnlineHandwritten2008,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{A Study of On-Line Handwritten Chemical Expressions Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{2008 19th {{International Conference}} on {{Pattern Recognition}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Yang, Jufeng and Shi, Guangshun and Wang, Kai and Geng, Qian and Wang, Qingren}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2008</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = dec,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1--4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{IEEE}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">{Tampa, FL, USA}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1109/ICPR.2008.4761824}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>HMM-based Online Recognition of Chemical Symbols</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/zhang-hmm-handwriting-2009/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/zhang-hmm-handwriting-2009/</guid><description>Online recognition of handwritten chemical symbols using Hidden Markov Models with 11-dimensional local features, achieving 89.5% top-1 accuracy.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper that proposes a specific algorithmic pipeline for the online recognition of handwritten chemical symbols. The core contribution is the engineering of an 11-dimensional feature vector combined with a Hidden Markov Model (HMM) architecture. The paper validates this method through quantitative experiments on a custom dataset, focusing on recognition accuracy as the primary metric.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>Recognizing chemical symbols is uniquely challenging due to the complex structure of chemical expressions and the nature of pen-based input, which often results in broken or conglutinate strokes. Additionally, variations in writing style and random noise make the task difficult. While online recognition for Western characters and CJK scripts is well-developed, works specifically targeting online chemical symbol recognition are scarce, with most prior research focusing on offline recognition or global optimization.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The primary novelty is the application of continuous HMMs specifically to the domain of <strong>online</strong> chemical symbol recognition, utilizing a specialized set of <strong>11-dimensional local features</strong>. While HMMs have been used for other scripts, this paper tailors the feature extraction (including curliness, linearity, and writing direction) to capture the specific geometric properties of chemical symbols.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors constructed a specific dataset for this task involving 20 participants (college teachers and students).</p>
<ul>
<li><strong>Dataset</strong>: 64 distinct symbols (digits, English letters, Greek letters, operators)</li>
<li><strong>Volume</strong>: 7,808 total samples (122 per symbol), split into 5,670 training samples and 2,016 testing samples</li>
<li><strong>Model Sweeps</strong>: They evaluated the HMM performance by varying the number of states (4, 6, 8) and the number of Gaussians per state (3, 4, 6, 9, 12)</li>
</ul>
<h2 id="what-were-the-outcomes-and-conclusions-drawn">What were the outcomes and conclusions drawn?</h2>
<ul>
<li><strong>Performance</strong>: The best configuration (6 states, 9 Gaussians) achieved a <strong>top-1 accuracy of 89.5%</strong> and a <strong>top-3 accuracy of 98.7%</strong></li>
<li><strong>Scaling</strong>: Results showed that generally, increasing the number of states and Gaussians improved accuracy, though at the cost of computational efficiency</li>
<li><strong>Error Analysis</strong>: The primary sources of error were shape similarities between specific characters (e.g., &lsquo;0&rsquo; vs &lsquo;O&rsquo; vs &lsquo;o&rsquo;, and &lsquo;C&rsquo; vs &lsquo;c&rsquo; vs &lsquo;(&rsquo;)</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Status:</strong> Closed / Very Low Reproducibility. This 2009 study relies on a private, custom-collected dataset and does not provide source code, model weights, or an open-access preprint.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><em>None publicly available</em></td>
          <td style="text-align: left">N/A</td>
          <td style="text-align: left">N/A</td>
          <td style="text-align: left">No open source code, open datasets, or open-access preprints were released with this publication.</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<p>The study utilized a custom dataset collected in a laboratory environment.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Training</strong></td>
          <td style="text-align: left">Custom Chemical Symbol Set</td>
          <td style="text-align: left">5,670 samples</td>
          <td style="text-align: left">90 samples per symbol</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Testing</strong></td>
          <td style="text-align: left">Custom Chemical Symbol Set</td>
          <td style="text-align: left">2,016 samples</td>
          <td style="text-align: left">32 samples per symbol</td>
      </tr>
  </tbody>
</table>
<p><strong>Dataset Composition</strong>: The set includes <strong>64 symbols</strong>: Digits (0-9), Uppercase (A-Z, missing Q), Lowercase (a-z, selected), Greek letters ($\alpha$, $\beta$, $\gamma$, $\pi$), and operators ($+$, $=$, $\rightarrow$, $\uparrow$, $\downarrow$, $($ , $)$).</p>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Preprocessing</strong></p>
<p>The raw tablet data undergoes a 6-step pipeline:</p>
<ol>
<li><strong>Duplicate Point Elimination</strong>: Removing sequential points with identical coordinates</li>
<li><strong>Broken Stroke Connection</strong>: Using Bezier curves to interpolate missing points/connect broken strokes</li>
<li><strong>Hook Elimination</strong>: Removing artifacts at the start/end of strokes characterized by short length and sharp angle changes</li>
<li><strong>Smoothing</strong>: Reducing noise from erratic pen movement</li>
<li><strong>Re-sampling</strong>: Spacing points equidistantly to remove temporal variation</li>
<li><strong>Size Normalization</strong>: Removing variation in writing scale</li>
</ol>
<p><strong>2. Feature Extraction (11 Dimensions)</strong></p>
<p>Features are extracted from a 5-point window centered on $t$ ($t-2$ to $t+2$). The 11 dimensions are:</p>
<ol>
<li><strong>Normalized Vertical Position</strong>: $y(t)$ mapped to $[0,1]$</li>
<li><strong>Normalized First Derivative ($x&rsquo;$)</strong>: Calculated via weighted sum of neighbors</li>
<li><strong>Normalized First Derivative ($y&rsquo;$)</strong>: Calculated via weighted sum of neighbors</li>
<li><strong>Normalized Second Derivative ($x&rsquo;&rsquo;$)</strong>: Computed using $x&rsquo;$ values</li>
<li><strong>Normalized Second Derivative ($y&rsquo;&rsquo;$)</strong>: Computed using $y&rsquo;$ values</li>
<li><strong>Curvature</strong>: $\frac{x&rsquo;y&rsquo;&rsquo; - x&rsquo;&lsquo;y&rsquo;}{(x&rsquo;^2 + y&rsquo;^2)^{3/2}}$</li>
<li><strong>Writing Direction (Cos)</strong>: $\cos \alpha(t)$ based on vector from $t-1$ to $t+1$</li>
<li><strong>Writing Direction (Sin)</strong>: $\sin \alpha(t)$</li>
<li><strong>Aspect Ratio</strong>: Ratio of height to width in the 5-point window</li>
<li><strong>Curliness</strong>: Deviation from the straight line connecting the first and last point of the window</li>
<li><strong>Linearity</strong>: Average squared distance of points in the window to the straight line connecting start/end points</li>
</ol>
<p><strong>3. Feature Normalization</strong></p>
<p>The final feature matrix $V$ is normalized to zero mean and unit standard deviation using the covariance matrix: $o_t = \Sigma^{-1/2}(v_t - \mu)$.</p>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Continuous Hidden Markov Models (HMM)</li>
<li><strong>Topology</strong>: Left-to-right (Bakis model)</li>
<li><strong>Initialization</strong>: Initial distribution $\pi = {1, 0, &hellip;, 0}$; uniform transition matrix $A$; segmental k-means for observation matrix $B$</li>
<li><strong>Training</strong>: Baum-Welch re-estimation</li>
<li><strong>Decision</strong>: Maximum likelihood classification ($\hat{\lambda} = \arg \max P(O|\lambda)$)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">Best Value</th>
          <th style="text-align: left">Configuration</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Top-1 Accuracy</strong></td>
          <td style="text-align: left"><strong>89.5%</strong></td>
          <td style="text-align: left">6 States, 9 Gaussians</td>
          <td style="text-align: left">Highest reported accuracy</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Top-3 Accuracy</strong></td>
          <td style="text-align: left"><strong>98.7%</strong></td>
          <td style="text-align: left">6 States, 9 Gaussians</td>
          <td style="text-align: left">Top-3 candidate accuracy</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhang, Y., Shi, G., &amp; Yang, J. (2009). HMM-Based Online Recognition of Handwritten Chemical Symbols. <em>2009 10th International Conference on Document Analysis and Recognition</em>, 1255&ndash;1259. <a href="https://doi.org/10.1109/ICDAR.2009.99">https://doi.org/10.1109/ICDAR.2009.99</a></p>
<p><strong>Publication</strong>: ICDAR 2009</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{zhang2009hmm,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{HMM-Based Online Recognition of Handwritten Chemical Symbols}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{2009 10th International Conference on Document Analysis and Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Zhang, Yang and Shi, Guangshun and Yang, Jufeng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2009}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{75}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1255--1259}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{IEEE}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1109/ICDAR.2009.99}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Handwritten Chemical Symbol Recognition Using SVMs</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/tang-online-symbol-2013/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/tang-online-symbol-2013/</guid><description>A hybrid SVM and elastic matching approach for recognizing handwritten chemical symbols drawn on touch devices, achieving 89.7% top-1 accuracy.</description><content:encoded><![CDATA[<h2 id="paper-contribution-and-taxonomy">Paper Contribution and Taxonomy</h2>
<p>This is a <strong>Method</strong> paper according to the <a href="/notes/interdisciplinary/research-methods/ai-physical-sciences-paper-taxonomy/">AI for Physical Sciences taxonomy</a>.</p>
<ul>
<li><strong>Dominant Basis</strong>: The authors propose a novel hybrid architecture (SVM-EM) that combines two existing techniques to solve a specific recognition problem.</li>
<li><strong>Rhetorical Indicators</strong>: The paper explicitly defines algorithms (Algorithm 1 &amp; 2), presents a system architecture, and validates the method via ablation studies comparing the hybrid approach against its individual components.</li>
</ul>
<h2 id="motivation-for-pen-based-input">Motivation for Pen-Based Input</h2>
<p>Entering chemical expressions on digital devices is difficult due to their complex 2D spatial structure.</p>
<ul>
<li><strong>The Problem</strong>: While handwriting recognition for text and math is mature, chemical structures involve unique symbols and spatial arrangements that existing tools struggle to process efficiently.</li>
<li><strong>Existing Solutions</strong>: Standard tools (like ChemDraw) rely on point-and-click interactions, which are described as complicated and non-intuitive compared to direct handwriting.</li>
<li><strong>Goal</strong>: To enable fluid handwriting input on pen/touch-based devices (like iPads) by accurately recognizing individual chemical symbols in real-time.</li>
</ul>
<h2 id="novelty-hybrid-svm-and-elastic-matching">Novelty: Hybrid SVM and Elastic Matching</h2>
<p>The core contribution is the <strong>Hybrid SVM-EM</strong> approach, which splits recognition into a coarse classification stage and a fine-grained verification stage.</p>
<ul>
<li><strong>Two-Stage Pipeline</strong>:
<ol>
<li><strong>SVM Recognition</strong>: Uses statistical features (stroke count, turning angles) to generate a short-list of candidate symbols.</li>
<li><strong>Elastic Matching (EM)</strong>: Uses a geometric point-to-point distance metric to re-rank these candidates against a library of stored symbol prototypes.</li>
</ol>
</li>
<li><strong>Online Stroke Partitioning</strong>: A heuristic-based method to group strokes into symbols in real-time based on time adjacency (grouping the last $n$ strokes) and spatial intersection checks, without waiting for the user to finish the entire drawing.</li>
</ul>
<h2 id="experimental-design-and-data-collection">Experimental Design and Data Collection</h2>
<p>The authors conducted a user study to collect data and evaluate the system:</p>
<ul>
<li><strong>Participants</strong>: 10 users were recruited to write chemical symbols on an iPad.</li>
<li><strong>Task</strong>: Each user wrote 78 distinct chemical symbols (digits, alphabets, bonds) 3 times each.</li>
<li><strong>Baselines</strong>: The hybrid method was compared against two baselines:
<ol>
<li>SVM only</li>
<li>Elastic Matching only.</li>
</ol>
</li>
<li><strong>Metrics</strong>: Evaluation focused on <strong>Precision@k</strong> (where $k=1, 3, 5$), measuring how often the correct symbol appeared in the top-$k$ suggestions.</li>
</ul>
<h2 id="recognition-performance-and-outcomes">Recognition Performance and Outcomes</h2>
<p>The hybrid approach demonstrated improved performance compared to using either technique in isolation.</p>
<ul>
<li><strong>Key Results</strong>:
<ul>
<li><strong>Hybrid SVM-EM</strong>: 89.7% Precision@1 (Top-1 accuracy).</li>
<li><strong>SVM Only</strong>: 85.1% Precision@1.</li>
<li><strong>EM Only</strong>: 76.7% Precision@1.</li>
</ul>
</li>
<li><strong>Category Performance</strong>: The system performed best on Operators (91.9%) and Digits (91.3%), with slightly lower performance on Alphabetic characters (88.6%).</li>
<li><strong>Impact</strong>: The system was successfully implemented as a real-time iOS application, allowing users to draw complex structures like $C\#CC(O)$ which are then converted to SMILES strings.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study generated a custom dataset for training and evaluation.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset Stats</th>
          <th>Details</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td>2,340 samples</td>
          <td>Collected from 10 users. Consists of <strong>78 unique symbols</strong>: 10 digits (0-9), 52 letters (A-Z, a-z), and 16 bonds/operators (e.g., $=$, $+$, hash bonds).</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td>Unspecified size</td>
          <td>A &ldquo;Chemical Elastic Symbol Library&rdquo; was created containing samples of all supported symbols to serve as prototypes for the Elastic Matching step.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The pipeline consists of four distinct algorithmic steps:</p>
<p><strong>1. Stroke Partitioning</strong></p>
<ul>
<li><strong>Logic</strong>: Groups the most recently written stroke with up to the last 4 previous strokes.</li>
<li><strong>Filtering</strong>: Invalid groups are removed using &ldquo;Spatial Distance Checking&rdquo; (strokes too far apart) and &ldquo;Stroke Intersection Checking&rdquo; (strokes that don&rsquo;t intersect where expected).</li>
</ul>
<p><strong>2. Preprocessing</strong></p>
<ul>
<li><strong>Size Normalization</strong>: Scales symbol to a standard size based on its bounding box.</li>
<li><strong>Smoothing</strong>: Uses average smoothing (replacing mid-points with the average of neighbors) to remove jitter.</li>
<li><strong>Sampling</strong>: Resamples valid strokes to a fixed number of <strong>50 points</strong>.</li>
</ul>
<p><strong>3. SVM Feature Extraction</strong></p>
<ul>
<li><strong>Horizontal Angle</strong>: Calculated between two consecutive points ($P_1, P_2$). Values are binned into 12 groups ($30^{\circ}$ each).</li>
<li><strong>Turning Angle</strong>: The difference between two consecutive horizontal angles. Values are binned into 18 groups ($10^{\circ}$ each).</li>
<li><strong>Features</strong>: Input vector consists of stroke count, normalized coordinates, and the percentage of angles falling into the histograms described above.</li>
</ul>
<p><strong>4. Elastic Matching (Verification)</strong></p>
<ul>
<li><strong>Distance Function</strong>: Euclidean distance summation between the points of the candidate symbol ($s$) and the partitioned input ($s_p$).
$$
\begin{aligned}
D(s, s_p) = \sum_{j=1}^{n} \sqrt{(x_{s,j} - x_{p,j})^2 + (y_{s,j} - y_{p,j})^2}
\end{aligned}
$$
<em>Note: The paper formula sums the distances; $n$ is the number of points (50).</em></li>
<li><strong>Ranking</strong>: Candidates are re-ranked in ascending order of this elastic distance.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Classifier</strong>: Linear Support Vector Machine (SVM) implemented using <strong>LibSVM</strong>.</li>
<li><strong>Symbol Library</strong>: A &ldquo;Chemical Elastic Symbol Library&rdquo; stores the raw stroke point sequences for all 78 supported symbols to enable the elastic matching comparison.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Performance was measured using precision at different ranks (Top-N accuracy).</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Precision@1</strong></td>
          <td><strong>89.7%</strong></td>
          <td>85.1% (SVM)</td>
          <td>Hybrid model reduces error rate significantly over baselines.</td>
      </tr>
      <tr>
          <td><strong>Precision@3</strong></td>
          <td><strong>94.1%</strong></td>
          <td>N/A</td>
          <td>High recall in top 3 allows users to quickly correct errors via UI selection.</td>
      </tr>
      <tr>
          <td><strong>Precision@5</strong></td>
          <td><strong>94.6%</strong></td>
          <td>N/A</td>
          <td></td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Device</strong>: Apple iPad (iOS platform).</li>
<li><strong>Input</strong>: Touch/Pen-based input recording digital ink (x, y coordinates and pen-up/down events).</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Tang, P., Hui, S. C., &amp; Fu, C. W. (2013). Online Chemical Symbol Recognition for Handwritten Chemical Expression Recognition. <em>2013 IEEE/ACIS 12th International Conference on Computer and Information Science (ICIS)</em>, 535&ndash;540. <a href="https://doi.org/10.1109/ICIS.2013.6607894">https://doi.org/10.1109/ICIS.2013.6607894</a></p>
<p><strong>Publication</strong>: IEEE ICIS 2013</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{tangOnlineChemicalSymbol2013,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Online Chemical Symbol Recognition for Handwritten Chemical Expression Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{2013 IEEE/ACIS 12th International Conference on Computer and Information Science (ICIS)}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Tang, Peng and Hui, Siu Cheung and Fu, Chi-Wing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2013</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{22}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{535--540}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{IEEE}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1109/ICIS.2013.6607894}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Handwritten Chemical Ring Recognition with Neural Networks</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/hewahi-ring-recognition-2008/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/hewahi-ring-recognition-2008/</guid><description>A two-phase Classifier-Recognizer neural network pipeline for recognizing 23 types of handwritten heterocyclic chemical rings, achieving ~94% accuracy.</description><content:encoded><![CDATA[<h2 id="contribution-recognition-architecture-for-heterocyclic-rings">Contribution: Recognition Architecture for Heterocyclic Rings</h2>
<p>This is a <strong>Method</strong> paper ($\Psi_{\text{Method}}$).</p>
<p>It proposes a specific algorithmic architecture (the &ldquo;Classifier-Recognizer Approach&rdquo;) to solve a pattern recognition problem. The rhetorical structure centers on defining three variations of a method, performing ablation-like comparisons between them (Whole Image vs. Lower Part), and demonstrating superior performance metrics (~94% accuracy) for the proposed technique.</p>
<h2 id="motivation-enabling-sketch-based-chemical-search">Motivation: Enabling Sketch-Based Chemical Search</h2>
<p>The authors identify a gap in existing OCR and handwriting recognition research, which typically focuses on alphanumeric characters or whole words.</p>
<ul>
<li><strong>Missing Capability</strong>: Recognition of specific <em>heterocyclic chemical rings</em> (23 types) had not been performed previously.</li>
<li><strong>Practical Utility</strong>: Existing chemical search engines require text-based queries (names); this work enables &ldquo;backward&rdquo; search where a user can draw a ring to find its information.</li>
<li><strong>Educational/Professional Aid</strong>: Useful for chemistry departments and mobile applications where chemists can sketch formulas on screens.</li>
</ul>
<h2 id="innovation-the-classifier-recognizer-pipeline">Innovation: The Classifier-Recognizer Pipeline</h2>
<p>The core novelty is the <strong>two-phase &ldquo;Classifier-Recognizer&rdquo; architecture</strong> designed to handle the visual similarity of heterocyclic rings:</p>
<ol>
<li><strong>Phase 1 (Classifier)</strong>: A neural network classifies the ring into one of four broad categories (S, N, O, Others) based solely on the <em>upper part</em> of the image (40x15 pixels).</li>
<li><strong>Phase 2 (Recognizer)</strong>: A class-specific neural network identifies the exact ring.</li>
<li><strong>Optimization</strong>: The most successful variation (&ldquo;Lower Part Image Recognizer with Half Size Grid&rdquo;) uses only the <em>lower part</em> of the image and <em>odd rows</em> (half-grid) to reduce input dimensionality and computation time while improving accuracy. This effectively subsamples the input grid matrix $M \in \mathbb{R}^{H \times W}$ to a reduced matrix $M_{\text{sub}}$:
$$ M_{\text{sub}} = { m_{i,j} \in M \mid i \text{ is odd} } $$</li>
</ol>
<h2 id="failed-preliminary-approaches">Failed Preliminary Approaches</h2>
<p>Before arriving at the Classifier-Recognizer architecture, the authors tried three simpler methods that all failed:</p>
<ol>
<li><strong>Ordinary NN</strong>: A single neural network with 1600 inputs (40x40 grid), 1600 hidden units, and 23 outputs. This standard approach achieved only 7% accuracy.</li>
<li><strong>Row/Column pixel counts</strong>: Using the number of black pixels per row and per column as features ($N_c + N_r$ inputs), which dramatically reduced dimensionality. This performed even worse, below 1% accuracy.</li>
<li><strong>Midline crossing count</strong>: Drawing a horizontal midline and counting the number of line crossings. This failed because the crossing count varies between writers for the same ring.</li>
</ol>
<p>These failures motivated the two-phase Classifier-Recognizer design.</p>
<h2 id="experimental-setup-and-network-variations">Experimental Setup and Network Variations</h2>
<p>The authors conducted a comparative study of three methodological variations:</p>
<ol>
<li><strong>Whole Image Recognizer</strong>: Uses the full image.</li>
<li><strong>Whole Image (Half Size Grid)</strong>: Uses only odd rows ($20 \times 40$ pixels).</li>
<li><strong>Lower Part (Half Size Grid)</strong>: Uses the lower part of the image with odd rows (the proposed method).</li>
</ol>
<p><strong>Setup</strong>:</p>
<ul>
<li><strong>Dataset</strong>: 23 types of heterocyclic rings.</li>
<li><strong>Training</strong>: 1500 samples (distributed across S, N, O, and Others classes).</li>
<li><strong>Testing</strong>: 1150 samples.</li>
<li><strong>Metric</strong>: Recognition accuracy (Performance %) and Error %.</li>
</ul>
<h2 id="results-high-accuracy-via-dimension-reduction">Results: High Accuracy via Dimension Reduction</h2>
<ul>
<li><strong>Superior Method</strong>: The &ldquo;Lower Part Image Recognizer with Half Size Grid&rdquo; achieved the best performance (~94% overall).</li>
<li><strong>High Classifier Accuracy</strong>: The first phase (classification into S/N/O/Other) achieves 100% accuracy for class S, 98.67% for O, 97.75% for N, and 97.67% for Others (Table 3).</li>
<li><strong>Class &lsquo;Others&rsquo; Difficulty</strong>: The &lsquo;Others&rsquo; class showed lower performance (~90-93%) compared to S/N/O due to the higher complexity and similarity of rings in that category.</li>
<li><strong>Efficiency</strong>: The half-grid approach reduced training time from ~53 hours (Whole Image) to ~35 hours (Lower Part Half Size Grid) while improving accuracy from 87% to 94%.</li>
</ul>
<p><strong>Training/Testing comparison across the three Classifier-Recognizer variations (Table 2)</strong>:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Method</th>
          <th style="text-align: left">Hidden Nodes</th>
          <th style="text-align: left">Iterations</th>
          <th style="text-align: left">Training Time (hrs)</th>
          <th style="text-align: left">Error</th>
          <th style="text-align: left">Performance</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">Whole Image</td>
          <td style="text-align: left">50</td>
          <td style="text-align: left">1000</td>
          <td style="text-align: left">~53</td>
          <td style="text-align: left">13.0%</td>
          <td style="text-align: left">87.0%</td>
      </tr>
      <tr>
          <td style="text-align: left">Whole Image (Half Grid)</td>
          <td style="text-align: left">50</td>
          <td style="text-align: left">1000</td>
          <td style="text-align: left">~41</td>
          <td style="text-align: left">9.0%</td>
          <td style="text-align: left">91.0%</td>
      </tr>
      <tr>
          <td style="text-align: left">Lower Part (Half Grid)</td>
          <td style="text-align: left">50</td>
          <td style="text-align: left">1000</td>
          <td style="text-align: left">~35</td>
          <td style="text-align: left">6.0%</td>
          <td style="text-align: left">94.0%</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The dataset consists of handwritten samples of 23 specific heterocyclic rings.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Training</strong></td>
          <td style="text-align: left">Heterocyclic Rings</td>
          <td style="text-align: left">1500 samples</td>
          <td style="text-align: left">Split: 300 (S), 400 (N), 400 (O), 400 (Others)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Testing</strong></td>
          <td style="text-align: left">Heterocyclic Rings</td>
          <td style="text-align: left">1150 samples</td>
          <td style="text-align: left">Split: 150 (S), 300 (O), 400 (N), 300 (Others)</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing Steps</strong>:</p>
<ol>
<li><strong>Monochrome Conversion</strong>: Convert image to monochrome bitmap.</li>
<li><strong>Grid Scaling</strong>: Convert drawing area (regardless of original size) to a fixed <strong>40x40</strong> grid.</li>
<li><strong>Bounding</strong>: Scale the ring shape itself to fit the 40x40 grid.</li>
</ol>
<h3 id="algorithms">Algorithms</h3>
<p><strong>The &ldquo;Lower Part with Half Size&rdquo; Pipeline</strong>:</p>
<ol>
<li><strong>Cut Point</strong>: A horizontal midline is defined; the algorithm separates the &ldquo;Upper Part&rdquo; and &ldquo;Lower Part&rdquo;.</li>
<li><strong>Phase 1 Input</strong>: The <strong>Upper Part</strong> (rows 0-15 approx, scaled) is fed to the Classifier NN to determine the class (S, N, O, or Others).</li>
<li><strong>Phase 2 Input</strong>:
<ul>
<li>For classes <strong>S, N, O</strong>: The <strong>Lower Part</strong> of the image is used.</li>
<li>For class <strong>Others</strong>: The <strong>Whole Ring</strong> is used.</li>
</ul>
</li>
<li><strong>Dimensionality Reduction</strong>: For the recognizer networks, only <strong>odd rows</strong> are used (effectively a 20x40 input grid) to reduce inputs from 1600 to 800.</li>
</ol>
<h3 id="models">Models</h3>
<p>The system uses multiple distinct Feed-Forward Neural Networks (Backpropagation is implied by &ldquo;training&rdquo; and &ldquo;epochs&rdquo; context, though not explicitly named as the algorithm):</p>
<ul>
<li><strong>Structure</strong>: 1 Classifier NN + 4 Recognizer NNs (one for each class).</li>
<li><strong>Hidden Layers</strong>: The preliminary &ldquo;ordinary method&rdquo; experiment used 1600 hidden units. The Classifier-Recognizer methods all used 50 hidden nodes per Table 2. The paper also notes that the ordinary approach tried various hidden layer sizes.</li>
<li><strong>Input Nodes</strong>:
<ul>
<li>Standard: 1600 (40x40).</li>
<li>Optimized: ~800 (20x40 via half-grid).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Classifier Phase Testing Results (Table 3)</strong>:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Class</th>
          <th style="text-align: left">Samples</th>
          <th style="text-align: left">Correct</th>
          <th style="text-align: left">Accuracy</th>
          <th style="text-align: left">Error</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>S</strong></td>
          <td style="text-align: left">150</td>
          <td style="text-align: left">150</td>
          <td style="text-align: left"><strong>100.00%</strong></td>
          <td style="text-align: left">0.00%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>O</strong></td>
          <td style="text-align: left">300</td>
          <td style="text-align: left">296</td>
          <td style="text-align: left"><strong>98.67%</strong></td>
          <td style="text-align: left">1.33%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>N</strong></td>
          <td style="text-align: left">400</td>
          <td style="text-align: left">391</td>
          <td style="text-align: left"><strong>97.75%</strong></td>
          <td style="text-align: left">2.25%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Others</strong></td>
          <td style="text-align: left">300</td>
          <td style="text-align: left">293</td>
          <td style="text-align: left"><strong>97.67%</strong></td>
          <td style="text-align: left">2.33%</td>
      </tr>
  </tbody>
</table>
<p><strong>Recognizer Phase Testing Results (Lower Part Image Recognizer with Half Size Grid, Table 4)</strong>:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Class</th>
          <th style="text-align: left">Samples</th>
          <th style="text-align: left">Correct</th>
          <th style="text-align: left">Accuracy</th>
          <th style="text-align: left">Error</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>S</strong></td>
          <td style="text-align: left">150</td>
          <td style="text-align: left">147</td>
          <td style="text-align: left"><strong>98.00%</strong></td>
          <td style="text-align: left">2.00%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>O</strong></td>
          <td style="text-align: left">300</td>
          <td style="text-align: left">289</td>
          <td style="text-align: left"><strong>96.33%</strong></td>
          <td style="text-align: left">3.67%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>N</strong></td>
          <td style="text-align: left">400</td>
          <td style="text-align: left">386</td>
          <td style="text-align: left"><strong>96.50%</strong></td>
          <td style="text-align: left">3.50%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Others</strong></td>
          <td style="text-align: left">300</td>
          <td style="text-align: left">279</td>
          <td style="text-align: left"><strong>93.00%</strong></td>
          <td style="text-align: left">7.00%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Overall</strong></td>
          <td style="text-align: left"><strong>1150</strong></td>
          <td style="text-align: left"><strong>-</strong></td>
          <td style="text-align: left"><strong>~94.0%</strong></td>
          <td style="text-align: left"><strong>-</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="reproducibility-assessment">Reproducibility Assessment</h3>
<p>No source code, trained models, or datasets were released with this paper. The handwritten ring samples were collected by the authors, and the software described (a desktop application) is not publicly available. The neural network architecture details (50 hidden nodes, 1000 iterations) and preprocessing pipeline are described in sufficient detail for reimplementation, but reproducing results would require collecting a new handwritten dataset of heterocyclic rings.</p>
<p><strong>Status</strong>: Closed (no public code, data, or models).</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Hewahi, N., Nounou, M. N., Nassar, M. S., Abu-Hamad, M. I., &amp; Abu-Hamad, H. I. (2008). Chemical Ring Handwritten Recognition Based on Neural Networks. <em>Ubiquitous Computing and Communication Journal</em>, 3(3).</p>
<p><strong>Publication</strong>: Ubiquitous Computing and Communication Journal 2008</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{hewahiCHEMICALRINGHANDWRITTEN2008,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{CHEMICAL RING HANDWRITTEN RECOGNITION BASED ON NEURAL NETWORKS}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Hewahi, Nabil and Nounou, Mohamed N and Nassar, Mohamed S and Abu-Hamad, Mohamed I and Abu-Hamad, Husam I}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2008}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Ubiquitous Computing and Communication Journal}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{3}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{3}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Probabilistic OCSR with Markov Logic Networks</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/rule-based/mlocsr/</link><pubDate>Tue, 16 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/rule-based/mlocsr/</guid><description>A probabilistic approach using Markov Logic Networks to recognize chemical structures from images, improving robustness over rule-based systems.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Frasconi, P., Gabbrielli, F., Lippi, M., &amp; Marinai, S. (2014). Markov Logic Networks for Optical Chemical Structure Recognition. <em>Journal of Chemical Information and Modeling</em>, 54(8), 2380-2390. <a href="https://doi.org/10.1021/ci5002197">https://doi.org/10.1021/ci5002197</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling 2014</p>
<h2 id="contribution-probabilistic-method-for-ocsr">Contribution: Probabilistic Method for OCSR</h2>
<p>This is a <strong>Method</strong> paper ($\Psi_{\text{Method}}$).</p>
<p>It proposes a novel algorithmic architecture (<strong>MLOCSR</strong>) that integrates low-level pattern recognition with a high-level probabilistic reasoning engine based on Markov Logic Networks (MLNs). While it contributes to resources by creating a clustered dataset for evaluation, the primary focus is on demonstrating that probabilistic inference offers a superior methodology to the deterministic, rule-based heuristics employed by previous state-of-the-art systems like OSRA and CLiDE.</p>
<h2 id="motivation-overcoming-brittle-rule-based-systems">Motivation: Overcoming Brittle Rule-Based Systems</h2>
<p>Optical Chemical Structure Recognition (OCSR) is critical for converting the vast archive of chemical literature (bitmap images in patents and papers) into machine-readable formats.</p>
<ul>
<li><strong>Limitation of Prior Work</strong>: Existing systems (OSRA, CLiDE, ChemReader) rely on &ldquo;empirical hard-coded geometrical rules&rdquo; to assemble atoms and bonds. These heuristics are brittle, requiring manual tuning of parameters for different image resolutions and failing when images are degraded or noisy.</li>
<li><strong>Gap</strong>: Chemical knowledge is typically used only in post-processing (e.g., to fix valency errors).</li>
<li><strong>Goal</strong>: To create a resolution-independent system that uses probabilistic reasoning to handle noise and ambiguity in graphical primitives.</li>
</ul>
<h2 id="core-innovation-markov-logic-networks-for-diagram-interpretation">Core Innovation: Markov Logic Networks for Diagram Interpretation</h2>
<p>The core novelty is the application of <strong>Markov Logic Networks (MLNs)</strong> to the problem of diagram interpretation.</p>
<ul>
<li><strong>Probabilistic Reasoning</strong>: The system treats extracted visual elements (lines, text boxes) as &ldquo;evidence&rdquo; and uses weighted first-order logic formulas to infer the most likely molecular graph (Maximum A Posteriori inference). The probability of a state $x$ is defined by the MLN log-linear model:
$$ P(X=x) = \frac{1}{Z} \exp\left(\sum_{i} w_i n_i(x)\right) $$
where $w_i$ is the weight of the $i$-th formula and $n_i(x)$ is the number of true groundings in $x$.</li>
<li><strong>Unified Knowledge Representation</strong>: Geometric constraints (e.g., collinearity) and chemical rules (e.g., valency) are encoded in the same logic framework.</li>
<li>Methodology and Experimental Setupe low-level extraction module dynamically estimates character size ($T$) and stroke width ($S$) to normalize parameters, removing the dependence on image DPI metadata.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors evaluated the system on recognition accuracy against the leading open-source baseline, <strong>OSRA (v1.4.0)</strong>.</p>
<ul>
<li><strong>Datasets</strong>:
<ul>
<li><strong>USPTO Clustered</strong>: A non-redundant subset of 937 images derived from a larger set of 5,719 US Patent Office images.</li>
<li><strong>ChemInfty</strong>: 869 images from Japanese patents.</li>
<li><strong>Degraded Images</strong>: The USPTO set was synthetically degraded at three resampling levels (Low, Medium, High degradation) to test robustness.</li>
</ul>
</li>
<li><strong>Metrics</strong>:
<ul>
<li><strong>Geometric</strong>: Precision, Recall, and $F_1$ scores for individual atoms and bonds.</li>
<li><strong>Chemical</strong>: Tanimoto similarity (using path fingerprints) and InChI string matching (basic and full stereochemistry).</li>
</ul>
</li>
</ul>
<h2 id="results-and-conclusions">Results and Conclusions</h2>
<ul>
<li><strong>Superior Robustness</strong>: MLOCSR significantly outperformed OSRA on degraded images. On high-degradation images, MLOCSR achieved an atom $F_1$ of 80.3% compared to OSRA&rsquo;s 76.0%.</li>
<li><strong>Geometric Accuracy</strong>: In clean datasets (USPTO cluster), MLOCSR achieved higher $F_1$ scores for atoms (99.1% vs 97.5%) and bonds (98.8% vs 97.8%).</li>
<li><strong>Chemical Fidelity</strong>: The system achieved comparable Tanimoto similarity scores (0.948 vs 0.940 for OSRA).</li>
<li><strong>Limitation</strong>: OSRA slightly outperformed MLOCSR on &ldquo;Full <a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a>&rdquo; matching (81.4% vs 79.4%), indicating the probabilistic model still needs improvement in handling complex stereochemistry.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study utilized public datasets, with specific preprocessing to ensure non-redundancy.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Evaluation</td>
          <td><strong>USPTO Clustered</strong></td>
          <td>937 images</td>
          <td>Selected via spectral clustering from 5,719 raw images to remove near-duplicates.</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td><strong>ChemInfty</strong></td>
          <td>869 images</td>
          <td>Ground-truthed dataset from Japanese patent applications (2008).</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The pipeline consists of two distinct phases: Low-Level Vectorization and High-Level Inference.</p>
<p><strong>1. Low-Level Extraction (Image Processing)</strong></p>
<ul>
<li><strong>Binarization</strong>: Global thresholding followed by morphological closing.</li>
<li><strong>Text/Stroke Estimation</strong>:
<ul>
<li>Finds text height ($T$) by looking for &ldquo;N&rdquo; or &ldquo;H&rdquo; characters via OCR, or averaging compatible connected components.</li>
<li>Estimates stroke width ($S$) by inspecting pixel density on potential segments identified by Hough transform.</li>
</ul>
</li>
<li><strong>Vectorization</strong>:
<ul>
<li><strong>Canny Edge Detection</strong> + <strong>Hough Transform</strong> to find lines.</li>
<li><strong>Douglas-Peucker algorithm</strong> for polygonal approximation of contours.</li>
<li><strong>Circle Detection</strong>: Finds aromatic rings by checking for circular arrangements of carbon candidates.</li>
</ul>
</li>
</ul>
<p><strong>2. High-Level Inference (Markov Logic)</strong></p>
<ul>
<li><strong>Evidence Generation</strong>: Visual primitives (lines, text boxes, circles) are converted into logical ground atoms (e.g., <code>LineBetweenCpoints(c1, c2)</code>).</li>
<li><strong>Inference Engine</strong>: Uses <strong>MaxWalkSAT</strong> for Maximum A Posteriori (MAP) inference to determine the most probable state of query predicates (e.g., <code>DoubleBond(c1, c2)</code>).</li>
<li><strong>Parameters</strong>: MaxWalkSAT run with 3 tries and 1,000,000 steps per try.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Markov Logic Network (MLN)</strong>:
<ul>
<li>Contains <strong>128 first-order logic formulas</strong>.</li>
<li><strong>Geometric Rules</strong>: Example: <code>VeryCloseCpoints(c1, c2) =&gt; SameCarbon(c1, c2)</code> (weighted rule to merge close nodes).</li>
<li><strong>Chemical Rules</strong>: Example: <code>IsHydroxyl(t) ^ Connected(c,t) =&gt; SingleBond(c,t)</code> (imposes valency constraints).</li>
</ul>
</li>
<li><strong>OCR Engine</strong>: Tesseract is used for character recognition on text connected components.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>The authors introduced a bipartite graph matching method to evaluate geometric accuracy when superatoms (e.g., &ldquo;COOH&rdquo;) are not expanded.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Details</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Atom/Bond $F_1$</strong></td>
          <td>Calculated via minimum-weight bipartite matching between predicted graph and ground truth, weighted by Euclidean distance.</td>
      </tr>
      <tr>
          <td><strong>InChI</strong></td>
          <td>Standard unique identifier string. &ldquo;Basic&rdquo; ignores stereochemistry; &ldquo;Full&rdquo; includes it.</td>
      </tr>
      <tr>
          <td><strong>Tanimoto</strong></td>
          <td>Jaccard index of path fingerprints between predicted and ground truth molecules.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Software</strong>: Logic inference performed using the <strong>Alchemy</strong> software package (University of Washington).</li>
<li><strong>Web Server</strong>: The system was made available at <code>http://mlocsr.dinfo.unifi.it</code> (Note: URL likely inactive).</li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{frasconiMarkovLogicNetworks2014,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Markov {{Logic Networks}} for {{Optical Chemical Structure Recognition}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Frasconi, Paolo and Gabbrielli, Francesco and Lippi, Marco and Marinai, Simone}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2014</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = aug,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{54}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{8}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{2380--2390}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1549-9596, 1549-960X}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1021/ci5002197}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-10-13}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">langid</span> = <span style="color:#e6db74">{english}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemInk: Real-Time Recognition for Chemical Drawings</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/chemink-2011/</link><pubDate>Mon, 15 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/online-recognition/chemink-2011/</guid><description>A sketch recognition framework for chemical diagrams using a joint CRF model to combine multi-level visual features for real-time interpretation.</description><content:encoded><![CDATA[<h2 id="contribution-real-time-sketch-recognition-method">Contribution: Real-Time Sketch Recognition Method</h2>
<p>This is a <strong>Method</strong> paper. It proposes a novel architectural framework for sketch recognition that integrates visual features at three distinct levels (inkpoints, segments, symbols) into a single probabilistic model. The rhetorical structure centers on the proposal of this new architecture, the introduction of a specific &ldquo;trainable corner detector&rdquo; algorithm, and the validation of these methods against existing benchmarks and alternative toolsets (ChemDraw).</p>
<h2 id="motivation-bridging-the-gap-between-sketching-and-cad">Motivation: Bridging the Gap Between Sketching and CAD</h2>
<p>The primary motivation is to bridge the gap between the natural, efficient process of drawing chemical diagrams by hand and the cumbersome &ldquo;point-click-and-drag&rdquo; interactions required by CAD tools like ChemDraw. While chemists prefer sketching for communication, existing digital tools do not offer the same speed or ease of use. The goal is to build an intelligent system that understands freehand sketches in real-time, converting them into structured data suitable for analysis or search.</p>
<h2 id="core-innovation-hierarchical-joint-crf-model">Core Innovation: Hierarchical Joint CRF Model</h2>
<p>The core novelty lies in the <strong>hierarchical joint model</strong>. Unlike previous approaches that might treat stroke segmentation and symbol recognition as separate, isolated steps, ChemInk uses a <strong>Conditional Random Field (CRF)</strong> to jointly model dependencies across three levels:</p>
<ol>
<li><strong>Inkpoints</strong>: Local visual appearance.</li>
<li><strong>Segments</strong>: Stroke fragments separated by corners.</li>
<li><strong>Candidates</strong>: Potential symbol groupings.</li>
</ol>
<p>Additionally, the paper introduces a <strong>trainable corner detector</strong> that learns domain-specific corner definitions from data.</p>
<h2 id="experimental-design-and-baselines">Experimental Design and Baselines</h2>
<p>The authors conducted two primary evaluations:</p>
<ol>
<li><strong>Off-line Accuracy Evaluation</strong>:
<ul>
<li><strong>Dataset</strong>: 12 real-world organic compounds drawn by 10 participants.</li>
<li><strong>Metric</strong>: Recognition accuracy (Recall and Precision).</li>
<li><strong>Baseline</strong>: Comparison against their own previous work (O&amp;D 2009) and ablations (with/without context).</li>
</ul>
</li>
<li><strong>On-line User Study</strong>:
<ul>
<li><strong>Task</strong>: 9 participants (chemistry students) drew 5 diagrams using both ChemInk (Tablet PC) and ChemDraw (Mouse/Keyboard).</li>
<li><strong>Metric</strong>: Time to completion and subjective user ratings (speed/ease of use).</li>
</ul>
</li>
</ol>
<h2 id="results-accuracy-and-user-study-outcomes">Results: Accuracy and User Study Outcomes</h2>
<ul>
<li><strong>Accuracy</strong>: The system achieved <strong>97.4% symbol recognition accuracy</strong>, slightly outperforming the best prior result (97.1%). The trainable corner detector achieved <strong>99.91% recall</strong>.</li>
<li><strong>Speed</strong>: Users were <strong>twice as fast</strong> using ChemInk (avg. 36s) compared to ChemDraw (avg. 79s).</li>
<li><strong>Usability</strong>: Participants rated ChemInk significantly higher for speed (6.3 vs 4.5) and ease of use (6.3 vs 4.7) on a 7-point scale.</li>
<li><strong>Conclusion</strong>: Sketch recognition is a viable, superior alternative to standard CAD tools for authoring chemical diagrams.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Training/Test Data</strong>: 12 real-world organic compounds (e.g., Aspirin, Penicillin) drawn by 10 participants (organic chemistry familiar).</li>
<li><strong>Evaluation Split</strong>: User-independent cross-validation (training on 9 users, testing on 1).</li>
<li><strong>Input</strong>: Raw digital ink (strokes) collected on a Tablet PC.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Corner Detection (Trainable)</strong></p>
<ul>
<li><strong>Method</strong>: Iterative vertex elimination.</li>
<li><strong>Cost Function</strong>: $cost(p_{i}) = \sqrt{mse(s_{i}; p_{i-1}, p_{i+1})} \cdot dist(p_{i}; p_{i-1}, p_{i+1})$</li>
<li><strong>Procedure</strong>: Repeatedly remove the vertex with the lowest cost until the classifier (trained on features like cost, diagonal length, ink density) predicts the remaining vertices are corners.</li>
</ul>
<p><strong>2. Feature Extraction</strong></p>
<ul>
<li><strong>Inkpoints</strong>: Sampled at regular intervals. Features = $10 \times 10$ pixel orientation filters (0, 45, 90, 135 degrees) at two scales ($L/2$, $L$), smoothed and downsampled to $5 \times 5$. Total 400 features.</li>
<li><strong>Segments</strong>: Similar image features centered at segment midpoint, plus geometric features (length, ink density).</li>
<li><strong>Candidates</strong>: 5 feature images ($20 \times 20$) including an &ldquo;endpoint&rdquo; image, stretched to normalize aspect ratio.</li>
<li><strong>Dimensionality Reduction</strong>: PCA used to compress feature images to 256 components.</li>
</ul>
<p><strong>3. Structure Generation</strong></p>
<ul>
<li><strong>Clustering</strong>: Agglomerative clustering with a complete-link metric to connect symbols.</li>
<li><strong>Threshold</strong>: Stop clustering at distance $0.4L$.</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Conditional Random Field (CRF)</strong></p>
<ul>
<li><strong>Structure</strong>: 3-level hierarchy (Inkpoints $V_p$, Segments $V_s$, Candidates $V_c$).</li>
<li><strong>Nodes</strong>:
<ul>
<li>$V_p, V_s$ labels: &ldquo;bond&rdquo;, &ldquo;hash&rdquo;, &ldquo;wedge&rdquo;, &ldquo;text&rdquo;.</li>
<li>$V_c$ labels: specific candidate interpretations.</li>
</ul>
</li>
<li><strong>Edges/Potentials</strong>:
<ul>
<li><strong>Entity-Feature</strong>: $\phi(y, x)$ (Linear classifier).</li>
<li><strong>Consistency</strong>: $\psi(y_i, y_j)$ (Hard constraint: child must match parent label).</li>
<li><strong>Spatial Context</strong>: $\psi_{ss}(y_i, y_j)$ (Pairwise geometric relationships between segments: angle, distance).</li>
<li><strong>Overlap</strong>: Prevents conflicting candidates from sharing segments.</li>
</ul>
</li>
<li><strong>Inference</strong>: Loopy Belief Propagation (up to 100 iterations).</li>
<li><strong>Training</strong>: Maximum Likelihood via gradient ascent (L-BFGS).</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Primary Metric</strong>: Accuracy (Recall/Precision) of symbol detection.</li>
<li><strong>Comparison</strong>: Compared against Ouyang &amp; Davis 2009 (previous SOTA).</li>
<li><strong>Speed Metric</strong>: Wall-clock time for diagram creation (ChemInk vs. ChemDraw).</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Processor</strong>: 3.7 GHz processor (single thread) for base benchmarking (approx. 1 sec/sketch).</li>
<li><strong>Deployment</strong>: Validated on 1.8 GHz Tablet PCs using multi-core parallelization for real-time feedback.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ouyang, T. Y., &amp; Davis, R. (2011). ChemInk: A Natural Real-Time Recognition System for Chemical Drawings. <em>Proceedings of the 16th International Conference on Intelligent User Interfaces</em>, 267&ndash;276. <a href="https://doi.org/10.1145/1943403.1943444">https://doi.org/10.1145/1943403.1943444</a></p>
<p><strong>Publication</strong>: IUI &lsquo;11</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{ouyangChemInkNaturalRealtime2011,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{ChemInk: A Natural Real-Time Recognition System for Chemical Drawings}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{ChemInk}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the 16th International Conference on Intelligent User Interfaces}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Ouyang, Tom Y. and Davis, Randall}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2011}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = feb,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{267--276}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{ACM}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">{Palo Alto, CA, USA}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1145/1943403.1943444}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">isbn</span> = <span style="color:#e6db74">{978-1-4503-0419-1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{http://hdl.handle.net/1721.1/78898}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Tea Party in the House: Legislative Ideology via HIPTM</title><link>https://hunterheidenreich.com/notes/interdisciplinary/social-science/tea-party-hiptm/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/interdisciplinary/social-science/tea-party-hiptm/</guid><description>A hierarchical probabilistic model combining roll call votes, bill text, and legislative speeches to analyze political polarization and framing.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p><strong>Method</strong>.</p>
<p>This paper is primarily a <strong>Methodological</strong> contribution. It proposes a novel probabilistic architecture, the Hierarchical Ideal Point Topic Model (HIPTM), designed to solve the specific limitations of existing political science models that typically rely on either voting data or text data in isolation. The paper validates this method by demonstrating its superior performance in predicting &ldquo;Tea Party&rdquo; membership compared to text-only baselines and its ability to provide interpretable &ldquo;framing&rdquo; analysis.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The primary motivation is to better understand political polarization, specifically the &ldquo;Tea Party&rdquo; phenomenon within the Republican party during the 112th Congress.</p>
<p>An ideal point is a scalar score representing a legislator&rsquo;s ideological position, estimated from voting patterns. Standard &ldquo;Ideal Point&rdquo; models (like DW-NOMINATE) typically project legislators onto a single liberal-conservative dimension using only binary voting data. This is insufficient for capturing complex, multi-dimensional intra-party conflicts where legislators might agree on votes but differ on policy &ldquo;framing&rdquo; or specific sub-issues. Furthermore, existing multi-dimensional models often produce dimensions that are difficult for humans to interpret.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>Hierarchical Ideal Point Topic Model (HIPTM)</strong>. It distinguishes itself from prior work through three main technical innovations:</p>
<ol>
<li><strong>Joint Modeling of Three Data Sources</strong>: It integrates roll call votes, the text of bills, and the floor speeches of legislators into a single probabilistic framework.</li>
<li><strong>Hierarchical Topic Structure</strong>: It models &ldquo;frames&rdquo; as a second level of the topic hierarchy. &ldquo;Issues&rdquo; (level 1) are fixed and non-polarized, while &ldquo;Frames&rdquo; (level 2) are discovered dynamically and carry polarity (ideal point weights). For example, Health Care is an issue; &ldquo;government overreach&rdquo; vs. &ldquo;patient protection&rdquo; are frames legislators use when debating it.</li>
<li><strong>Text-Based Ideal Point Prediction</strong>: HIPTM regresses ideal points on speech text, allowing it to predict the political alignment of legislators based solely on their writing or speeches without requiring voting records for inference.</li>
</ol>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the model using data from the 112th U.S. Congress (Republican legislators only).</p>
<ul>
<li><strong>Prediction Task</strong>: Classifying legislators as members of the &ldquo;Tea Party Caucus&rdquo;.</li>
<li><strong>Baselines</strong>: The model was compared against Support Vector Machines (SVM) trained on:
<ul>
<li>TF-IDF vectors (Text only)</li>
<li>Normalized TF-IDF vectors (Text only)</li>
<li>Binary Vote vectors (Vote only)</li>
</ul>
</li>
<li><strong>Metric</strong>: Area Under the Receiver Operating Characteristic Curve (AUC-ROC) via 5-fold cross-validation.</li>
<li><strong>Qualitative Analysis</strong>: The authors examined the &ldquo;span&rdquo; of ideal points within specific topics (e.g., Macroeconomics, Health) to identify which issues were most polarized between Tea Party and Establishment Republicans.</li>
</ul>
<h2 id="what-were-the-outcomes-and-conclusions-drawn">What were the outcomes and conclusions drawn?</h2>
<ul>
<li><strong>Quantitative Performance</strong>: HIPTM features combined with voting data (HIPTM-VOTE) achieved the highest classification performance (AUC-ROC in the ~0.70-0.75 range, approximate, read from Figure 2). Vote-only features slightly trail HIPTM-VOTE, while text-only baselines (TF-IDF, normalized TF-IDF) fall considerably lower. The one-dimensional Tea Party ideal points correlate with DW-NOMINATE ($\rho = 0.91$). When voting data was withheld (simulating a candidate without a record), HIPTM&rsquo;s text-based features outperformed standard text baselines TF-IDF and normalized TF-IDF (approximate, read from Figure 3).</li>
<li><strong>Political Insight</strong>: The model identified &ldquo;Government Operations,&rdquo; &ldquo;Macroeconomics,&rdquo; and &ldquo;Transportation&rdquo; as the three most polarized topics between Tea Party and establishment Republicans.</li>
<li><strong>Framing Analysis</strong>: The hierarchical topic structure reveals how legislators frame issues differently. For Macroeconomics, frame M3 (most Tea Party-oriented) focuses on criticizing government overspending, while frame M1 (least Tea Party-oriented) focuses on the downsides of a government shutdown. For Health, frame H3 captures Tea Party framing of the Affordable Care Act as an unconstitutional government takeover, while frame H1 frames opposition in terms of implementation costs and health care exchanges.</li>
<li><strong>Framing vs. Voting Taxonomy</strong>: The authors construct a 2x2 taxonomy of disagreement across issues, crossing whether ideal points are polarized with whether issue frames are polarized. Issues like Civil Rights fall in the &ldquo;neither polarized&rdquo; quadrant, where cooperation is expected. Banking/Finance and Transportation fall in the &ldquo;ideal points polarized, frames not&rdquo; quadrant, where Republicans frame the issue similarly but have underlying policy disagreements. Issues like Health and Public Lands fall in the &ldquo;frames polarized, ideal points not&rdquo; quadrant: Republicans voted similarly but framed the issue very differently. Issues like Macroeconomics and Government Operations fall in the &ldquo;both polarized&rdquo; quadrant, posing the greatest challenge for Republican leadership.</li>
<li><strong>Sub-group Identification</strong>: The model identifies legislators whose language marks them as ideologically aligned with the Tea Party even without formal caucus membership. For example, Jeff Flake (R-AZ) received the second-highest ideal point, disagreeing with Freedom Works on only one of 60 key votes, despite not being a Tea Party Caucus member. Justin Amash (R-MI), founder and chairman of the Liberty Caucus, agreed with Freedom Works on every key vote since 2011. Conversely, some self-identified Tea Partiers like Rodney Alexander (R-LA) only agreed with Freedom Works 48% of the time. Alexander and Ander Crenshaw (R-FL, 50% agreement) are categorized as &ldquo;Green Tea&rdquo; by Gervais and Morris (2014): Republican legislators who associate with the Tea Party on their own initiative but lack support from Tea Party organizations.</li>
</ul>
<h3 id="limitations">Limitations</h3>
<ul>
<li>HIPTM does not formally distinguish frames from other kinds of subtopics. For example, the model discovered a strongly Tea Party-oriented frame under &ldquo;Labor, Employment and Immigration&rdquo; that reflected a Boeing labor dispute specific to South Carolina legislators, capturing geographic rather than ideological framing.</li>
<li>The model is validated only on Republican legislators in the 112th Congress. Generalization to other parties, chambers, or time periods is untested.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study focuses on the <strong>112th U.S. Congress</strong> (Jan 2011 - Jan 2013).</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Subjects</strong></td>
          <td>Republican Legislators</td>
          <td>240 Reps</td>
          <td>60 are Tea Party Caucus members.</td>
      </tr>
      <tr>
          <td><strong>Votes</strong></td>
          <td>Roll Call Votes</td>
          <td>13,856 votes</td>
          <td>Agreement/disagreement with Freedom Works on 60 key votes (40 in 2011, 20 in 2012).</td>
      </tr>
      <tr>
          <td><strong>Text</strong></td>
          <td>Floor Speeches</td>
          <td>5,349 word types</td>
          <td>Sourced from GovTrack. Vocabulary size after preprocessing.</td>
      </tr>
      <tr>
          <td><strong>Priors</strong></td>
          <td>Congressional Bills Project</td>
          <td>19 Topics</td>
          <td>Used to set informed priors $\phi^*_k$ for top-level issues.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The model uses a <strong>Stochastic EM</strong> approach for inference.</p>
<ul>
<li><strong>Generative Process</strong>:
<ul>
<li><strong>Speeches</strong>: Modeled as a mixture of $K$ Hierarchical Dirichlet Processes (HDPs). A legislator chooses an issue $z$, then a frame $t$ from a Dirichlet Process, then a word $w$.</li>
<li><strong>Bills</strong>: Modeled using Latent Dirichlet Allocation (LDA). Each bill is a mixture over $K$ issues.</li>
<li><strong>Votes</strong>: Modeled via a probabilistic ideal point function (logistic/inverse-logit). The probability of a &ldquo;Yes&rdquo; vote depends on the bill&rsquo;s polarity $x_b$, popularity $y_b$, and the legislator&rsquo;s issue-specific ideal point $u_{a,k}$.</li>
</ul>
</li>
<li><strong>Optimization Steps</strong>:
<ol>
<li><strong>Sampling</strong>: Issue assignments $z$ and frame assignments $t$ are sampled for tokens in speeches and bills.</li>
<li><strong>Regression</strong>: Frame-specific regression weights $\eta_{k,j}$ are optimized using <strong>L-BFGS</strong>.</li>
<li><strong>Ideal Points</strong>: Legislator ideal points $u_{a,k}$ and bill parameters ($x_b, y_b$) are updated using <strong>Gradient Ascent</strong>.</li>
</ol>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Ideal Point Definition</strong>: A legislator&rsquo;s ideal point on issue $k$ ($u_{a,k}$) is defined as a linear combination of the ideal points of the <em>frames</em> they use ($\eta_{k,j}$), weighted by their usage frequency ($\hat{\psi}_{a,k,j}$).</li>
<li><strong>Topic Hierarchy</strong>:
<ul>
<li><strong>Level 1 (Issues)</strong>: Fixed at $K=19$ (based on Policy Agendas Project major headings). These nodes use informed Dirichlet priors.</li>
<li><strong>Level 2 (Frames)</strong>: Unbounded number of frames per issue, discovered non-parametrically via Dirichlet Process.</li>
</ul>
</li>
<li><strong>Prediction Features</strong>: The model runs for 1,000 iterations total with a 500-iteration burn-in. After burn-in, the sampled state is kept every 50 iterations, and feature values are averaged over the 10 stored models.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Primary Metric</strong>: AUC-ROC (Area Under the Receiver Operating Characteristic Curve).</li>
<li><strong>Classifier</strong>: $\text{SVM}^{\text{light}}$ (Joachims, 1999).</li>
<li><strong>Cross-Validation</strong>: 5-fold stratified sampling.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://www.govtrack.us/">GovTrack Congressional Speeches</a></td>
          <td>Dataset</td>
          <td>Public</td>
          <td>Source of floor speech text</td>
      </tr>
      <tr>
          <td><a href="http://www.congressionalbills.org/">Congressional Bills Project</a></td>
          <td>Dataset</td>
          <td>Public</td>
          <td>Bill text with Policy Agendas Project topic labels</td>
      </tr>
      <tr>
          <td>Freedom Works Key Votes</td>
          <td>Dataset</td>
          <td>Public</td>
          <td>60 key votes used to define Tea Party alignment (freedomworks.org is no longer available)</td>
      </tr>
  </tbody>
</table>
<p>No official code release accompanies this paper. The inference algorithm (Stochastic EM with Gibbs sampling, L-BFGS, and gradient ascent) is described in detail in Section 4 of the paper, but a full reimplementation would be required.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Nguyen, V., Boyd-Graber, J., Resnik, P., &amp; Miler, K. (2015). Tea Party in the House: A Hierarchical Ideal Point Topic Model and Its Application to Republican Legislators in the 112th Congress. <em>Proceedings of the 53rd Annual Meeting of the Association for Computational Linguistics</em>, 1438-1448. <a href="https://doi.org/10.3115/v1/P15-1139">https://doi.org/10.3115/v1/P15-1139</a></p>
<p><strong>Publication</strong>: ACL 2015</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{nguyenTeaPartyHouse2015,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Tea {{Party}} in the {{House}}: {{A Hierarchical Ideal Point Topic Model}} and {{Its Application}} to {{Republican Legislators}} in the 112th {{Congress}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{Tea {{Party}} in the {{House}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the 53rd {{Annual Meeting}} of the {{Association}} for {{Computational Linguistics}} and the 7th {{International Joint Conference}} on {{Natural Language Processing}} ({{Volume}} 1: {{Long Papers}})}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Nguyen, Viet-An and {Boyd-Graber}, Jordan and Resnik, Philip and Miler, Kristina}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2015}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1438--1448}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Association for Computational Linguistics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">{Beijing, China}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.3115/v1/P15-1139}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2023-11-02}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">abstract</span> = <span style="color:#e6db74">{We introduce the Hierarchical Ideal Point Topic Model, which provides a rich picture of policy issues, framing, and voting behavior using a joint model of votes, bill text, and the language that legislators use when debating bills. We use this model to look at the relationship between Tea Party Republicans and ``establishment&#39;&#39; Republicans in the U.S. House of Representatives during the 112th Congress.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">langid</span> = <span style="color:#e6db74">{english}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://aclanthology.org/P15-1139/">ACL Anthology: Tea Party in the House</a></li>
<li>Gervais, B. T., &amp; Morris, I. L. (2012). Reading the tea leaves: Understanding Tea Party Caucus membership in the US House of Representatives. <em>PS: Political Science &amp; Politics</em>, 45(2), 245-250.</li>
<li>Gervais, B. T., &amp; Morris, I. L. (2014). Black Tea, Green Tea, White Tea, and Coffee: Understanding the variation in attachment to the Tea Party among members of Congress. In <em>Annual Meeting of the American Political Science Association</em>. (Source of the &ldquo;Green Tea&rdquo; Republican taxonomy cited in the paper)</li>
</ul>
]]></content:encoded></item><item><title>IMG2SMI: Translating Molecular Structure Images to SMILES</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/img2smi/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/img2smi/</guid><description>Campos &amp; Ji's method for converting 2D molecular images to SMILES strings using Transformers and SELFIES representation.</description><content:encoded><![CDATA[<h2 id="contributions--taxonomy">Contributions &amp; Taxonomy</h2>
<p>This is both a <strong>Method</strong> and <strong>Resource</strong> paper:</p>
<ul>
<li><strong>Method</strong>: It adapts standard image captioning architectures (encoder-decoder) to the domain of Optical Chemical Structure Recognition (OCSR), treating molecule recognition as a translation task.</li>
<li><strong>Resource</strong>: It introduces <strong>MOLCAP</strong>, a large-scale dataset of 81 million molecules aggregated from public chemical databases, addressing the data scarcity that previously hindered deep learning approaches to OCSR.</li>
</ul>
<h2 id="the-bottleneck-in-chemical-literature-translation">The Bottleneck in Chemical Literature Translation</h2>
<p>Chemical literature is &ldquo;full of recipes written in a language computers cannot understand&rdquo; because molecules are depicted as 2D images. This creates a fundamental bottleneck:</p>
<ul>
<li><strong>The Problem</strong>: Chemists must manually redraw molecular structures to search for related compounds or reactions. This is slow, error-prone, and makes large-scale literature mining impossible.</li>
<li><strong>Existing Tools</strong>: Legacy systems like OSRA (Optical Structure Recognition Application) rely on handcrafted rules and often require human correction, making them unfit for unsupervised, high-throughput processing.</li>
<li><strong>The Goal</strong>: An automated system that can translate structure images directly to machine-readable strings (SMILES/SELFIES) without human supervision, enabling large-scale knowledge extraction from decades of chemistry literature and patents.</li>
</ul>
<h2 id="core-innovation-selfies-and-image-captioning">Core Innovation: SELFIES and Image Captioning</h2>
<p>The core novelty is demonstrating that <strong>how you represent the output text is as important as the model architecture itself</strong>. Key contributions:</p>
<ol>
<li>
<p><strong>Image Captioning Framework</strong>: Applies modern encoder-decoder architectures (ResNet-101 + Transformer) to OCSR, treating it as an image-to-text translation problem with a standard cross-entropy loss objective over the generation sequence:
$$ \mathcal{L} = -\sum\limits_{t=1}^{T} \log P(y_t \mid y_1, \ldots, y_{t-1}, x) $$</p>
</li>
<li>
<p><strong>SELFIES as Target Representation</strong>: The key mechanism relies on using <strong>SELFIES</strong> (Self-Referencing Embedded Strings) as the output format. SELFIES is based on a formal grammar where every possible string corresponds to a valid molecule, eliminating the syntactic invalidity problems (unmatched parentheses, invalid characters) that plague SMILES generation.</p>
</li>
<li>
<p><strong>MOLCAP Dataset</strong>: Created a comprehensive dataset of 81 million unique molecules from PubChem, ChEMBL, <a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a>, and other sources. Generated 256x256 pixel images using RDKit for 1 million training samples and 5,000 validation samples.</p>
</li>
<li>
<p><strong>Task-Specific Evaluation</strong>: Demonstrated that traditional NLP metrics (BLEU) are poor indicators of scientific utility. Introduced evaluation based on <strong>molecular fingerprints</strong> (MACCS, RDK, Morgan) and <strong>Tanimoto similarity</strong>:
$$ T(a, b) = \frac{c}{a + b - c} $$
where $c$ is the number of common fingerprint bits, and $a$ and $b$ are the number of set bits in each respective molecule&rsquo;s fingerprint. This formulation reliably measures functional chemical similarity.</p>
</li>
</ol>
<h2 id="experimental-setup-and-ablation-studies">Experimental Setup and Ablation Studies</h2>
<p>The evaluation focused on comparing IMG2SMI to existing systems and identifying which design choices matter most:</p>
<ol>
<li>
<p><strong>Baseline Comparisons</strong>: Benchmarked against OSRA (rule-based system) and DECIMER (first deep learning approach) on the MOLCAP dataset to establish whether modern architectures could surpass traditional methods.</p>
</li>
<li>
<p><strong>Ablation Studies</strong>: Extensive ablations isolating key factors:</p>
<ul>
<li><strong>Decoder Architecture</strong>: Transformer vs. RNN/LSTM decoders</li>
<li><strong>Encoder Fine-tuning</strong>: Fine-tuned vs. frozen pre-trained ResNet weights</li>
<li><strong>Output Representation</strong>: SELFIES vs. character-level SMILES vs. BPE-tokenized SMILES (the most critical ablation)</li>
</ul>
</li>
</ol>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>MACCS FTS</th>
          <th>Valid Captions</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RNN + Fixed Encoder</td>
          <td>0.1526</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>RNN + Fine-tuned Encoder</td>
          <td>0.4180</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>Transformer + Fixed Encoder</td>
          <td>0.7674</td>
          <td>61.1%</td>
      </tr>
      <tr>
          <td>Transformer + Fine-tuned Encoder</td>
          <td>0.9475</td>
          <td>99.4%</td>
      </tr>
      <tr>
          <td>Character-level SMILES (fine-tuned)</td>
          <td>N/A</td>
          <td>2.1%</td>
      </tr>
      <tr>
          <td>BPE SMILES (2000 vocab, fine-tuned)</td>
          <td>N/A</td>
          <td>20.0%</td>
      </tr>
      <tr>
          <td>SELFIES (fine-tuned)</td>
          <td>0.9475</td>
          <td>99.4%</td>
      </tr>
  </tbody>
</table>
<ol start="3">
<li><strong>Metric Analysis</strong>: Systematic comparison of evaluation metrics including BLEU, ROUGE, Levenshtein distance, exact match accuracy, and molecular fingerprint-based similarity measures.</li>
</ol>
<h2 id="results-findings-and-limitations">Results, Findings, and Limitations</h2>
<p><strong>Performance Gains</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>IMG2SMI</th>
          <th>OSRA</th>
          <th>DECIMER</th>
          <th>Random Baseline</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MACCS FTS</td>
          <td>0.9475</td>
          <td>0.3600</td>
          <td>0.0000</td>
          <td>0.3378</td>
      </tr>
      <tr>
          <td>RDK FTS</td>
          <td>0.9020</td>
          <td>0.2790</td>
          <td>0.0000</td>
          <td>0.2229</td>
      </tr>
      <tr>
          <td>Morgan FTS</td>
          <td>0.8707</td>
          <td>0.2677</td>
          <td>0.0000</td>
          <td>0.1081</td>
      </tr>
      <tr>
          <td>ROUGE</td>
          <td>0.6240</td>
          <td>0.0684</td>
          <td>0.0000</td>
          <td>0.0422</td>
      </tr>
      <tr>
          <td>Exact Match</td>
          <td>7.24%</td>
          <td>0.04%</td>
          <td>0.00%</td>
          <td>0.00%</td>
      </tr>
      <tr>
          <td>Valid Captions</td>
          <td>99.4%</td>
          <td>65.2%</td>
          <td>N/A</td>
          <td>N/A</td>
      </tr>
  </tbody>
</table>
<ul>
<li>163% improvement over OSRA on MACCS Tanimoto similarity.</li>
<li>Nearly 10x improvement on ROUGE scores (0.6240 vs. 0.0684).</li>
<li>Average Tanimoto similarity exceeds 0.85 (functionally similar molecules even when not exact matches).</li>
</ul>
<p><strong>Key Findings</strong>:</p>
<ul>
<li><strong>SELFIES is Critical</strong>: Using SELFIES yields <strong>99.4% valid molecules</strong>, compared to only ~2% validity for character-level SMILES.</li>
<li><strong>Architecture Matters</strong>: Transformer decoder significantly outperforms RNN/LSTM approaches. Fine-tuning the ResNet encoder (vs. frozen weights) yields substantial performance gains (e.g., MACCS FTS: 0.7674 to 0.9475).</li>
<li><strong>Metric Insights</strong>: BLEU is a poor metric for this task. Molecular fingerprint-based Tanimoto similarity is most informative because it measures functional chemical similarity.</li>
</ul>
<p><strong>Limitations</strong>:</p>
<ul>
<li><strong>Low Exact Match</strong>: Only <strong>7.24%</strong> exact matches. The model captures the overarching functional groups and structure but misses fine details like exact double bond placement.</li>
<li><strong>Complexity Bias</strong>: Trained on large molecules (average length &gt;40 tokens), so it performs poorly on very simple structures where OSRA still excels.</li>
</ul>
<p><strong>Conclusion</strong>: The work shows that modern encoder-decoder architectures combined with valid-by-construction molecular representations (SELFIES) can outperform traditional rule-based systems by large margins on fingerprint-based similarity metrics. The system is useful for literature mining where functional similarity matters more than exact matches, though 7.24% exact match accuracy and poor performance on simple molecules indicate clear directions for future work.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="models">Models</h3>
<p><strong>Architecture</strong>: Image captioning system based on DETR (Detection Transformer) framework.</p>
<p><strong>Visual Encoder</strong>:</p>
<ul>
<li><strong>Backbone</strong>: ResNet-101 pre-trained on ImageNet</li>
<li><strong>Feature Extraction</strong>: 4th layer extraction (convolutions only)</li>
<li><strong>Output</strong>: 2048-dimensional dense feature vector</li>
</ul>
<p><strong>Caption Decoder</strong>:</p>
<ul>
<li><strong>Type</strong>: Transformer encoder-decoder</li>
<li><strong>Layers</strong>: 3 stacked encoder layers, 3 stacked decoder layers</li>
<li><strong>Attention Heads</strong>: 8</li>
<li><strong>Hidden Dimensions</strong>: 2048 (feed-forward networks)</li>
<li><strong>Dropout</strong>: 0.1</li>
<li><strong>Layer Normalization</strong>: 1e-12</li>
</ul>
<p><strong>Training Configuration</strong>:</p>
<ul>
<li><strong>Optimizer</strong>: AdamW</li>
<li><strong>Learning Rate</strong>: 5e-5 (selected after sweep from 1e-4 to 1e-6)</li>
<li><strong>Weight Decay</strong>: 1e-4</li>
<li><strong>Batch Size</strong>: 32</li>
<li><strong>Epochs</strong>: 5</li>
<li><strong>Codebase</strong>: Built on open-source DETR implementation</li>
</ul>
<h3 id="data">Data</h3>
<p><strong>MOLCAP Dataset</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Property</th>
          <th>Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Total Size</td>
          <td>81,230,291 molecules</td>
          <td>Aggregated from PubChem, ChEMBL, GDB13</td>
      </tr>
      <tr>
          <td>Training Split</td>
          <td>1,000,000 molecules</td>
          <td>Randomly selected unique molecules</td>
      </tr>
      <tr>
          <td>Validation Split</td>
          <td>5,000 molecules</td>
          <td>Randomly selected for evaluation</td>
      </tr>
      <tr>
          <td>Image Resolution</td>
          <td>256x256 pixels</td>
          <td>Generated using RDKit</td>
      </tr>
      <tr>
          <td>Median SELFIES Length</td>
          <td>&gt;45 characters</td>
          <td>More complex than typical benchmarks</td>
      </tr>
      <tr>
          <td>Full Dataset Storage</td>
          <td>~16.24 TB</td>
          <td>Necessitated use of 1M subset</td>
      </tr>
      <tr>
          <td>Augmentation</td>
          <td>None</td>
          <td>No cropping, rotation, or other augmentation</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li>Images generated using RDKit at 256x256 resolution</li>
<li>Molecules converted to canonical representations</li>
<li>SELFIES tokenization for model output</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Primary Metrics</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>IMG2SMI Value</th>
          <th>OSRA Baseline</th>
          <th>Purpose</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MACCS FTS</td>
          <td>0.9475</td>
          <td>0.3600</td>
          <td>Fingerprint Tanimoto Similarity (functional groups)</td>
      </tr>
      <tr>
          <td>RDK FTS</td>
          <td>0.9020</td>
          <td>0.2790</td>
          <td>RDKit fingerprint similarity</td>
      </tr>
      <tr>
          <td>Morgan FTS</td>
          <td>0.8707</td>
          <td>0.2677</td>
          <td>Morgan fingerprint similarity (circular)</td>
      </tr>
      <tr>
          <td>ROUGE</td>
          <td>0.6240</td>
          <td>0.0684</td>
          <td>Text overlap metric</td>
      </tr>
      <tr>
          <td>Exact Match</td>
          <td>7.24%</td>
          <td>0.04%</td>
          <td>Structural identity (strict)</td>
      </tr>
      <tr>
          <td>Valid Captions</td>
          <td>99.4%</td>
          <td>65.2%</td>
          <td>Syntactic validity (with SELFIES)</td>
      </tr>
      <tr>
          <td>Levenshtein Distance</td>
          <td>21.13</td>
          <td>32.76</td>
          <td>String edit distance (lower is better)</td>
      </tr>
  </tbody>
</table>
<p><strong>Secondary Metrics</strong> (shown to be less informative for chemical tasks):</p>
<ul>
<li>BLEU, ROUGE (better suited for natural language)</li>
<li>Levenshtein distance (doesn&rsquo;t capture chemical similarity)</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPU</strong>: Single NVIDIA GeForce RTX 2080 Ti</li>
<li><strong>Training Time</strong>: ~5 hours per epoch, approximately 25 hours total for 5 epochs</li>
<li><strong>Memory</strong>: Sufficient for batch size 32 with ResNet-101 + Transformer architecture</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<p>The paper mentions releasing both code and the MOLCAP dataset, but no public repository or download link has been confirmed as available.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MOLCAP dataset</td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>81M molecules; claimed released but no public URL found</td>
      </tr>
      <tr>
          <td>IMG2SMI code</td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Built on DETR; claimed released but no public URL found</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Campos, D., &amp; Ji, H. (2021). IMG2SMI: Translating Molecular Structure Images to Simplified Molecular-input Line-entry System (No. arXiv:2109.04202). arXiv. <a href="https://doi.org/10.48550/arXiv.2109.04202">https://doi.org/10.48550/arXiv.2109.04202</a></p>
<p><strong>Publication</strong>: arXiv preprint (2021)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://doi.org/10.48550/arXiv.2109.04202">Paper on arXiv</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{campos2021img2smi,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{IMG2SMI: Translating Molecular Structure Images to Simplified Molecular-input Line-entry System}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Campos, Daniel and Ji, Heng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2109.04202}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.48550/arXiv.2109.04202}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Hand-Drawn Chemical Diagram Recognition (AAAI 2007)</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/ouyang-davis-aaai-2007/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/ouyang-davis-aaai-2007/</guid><description>A sketch recognition system for organic chemistry that uses domain knowledge (chemical valence) to correct recognition errors.</description><content:encoded><![CDATA[<h2 id="contribution-and-methodological-approach">Contribution and Methodological Approach</h2>
<p>This is a <strong>Method</strong> paper. It proposes a multi-stage pipeline for interpreting hand-drawn diagrams that integrates a trainable symbol recognizer with a domain-specific verification step. The authors validate the method through an ablation study comparing the full system against a baseline lacking domain knowledge.</p>
<h2 id="motivation-for-sketch-based-interfaces">Motivation for Sketch-Based Interfaces</h2>
<p>Current software for specifying chemical structures (e.g., ChemDraw, IsisDraw) relies on mouse and keyboard interfaces, which lack the speed, ease of use, and naturalness of drawing on paper. The goal is to bridge the gap between natural expression and computer interpretation by building a system that understands freehand chemical sketches.</p>
<h2 id="novel-integration-of-chemical-domain-knowledge">Novel Integration of Chemical Domain Knowledge</h2>
<p>The primary novelty is the integration of <strong>domain knowledge</strong> (specifically chemical valence rules) directly into the interpretation loop to resolve ambiguities and correct errors.</p>
<p>Specific technical contributions include:</p>
<ul>
<li><strong>Hybrid Recognizer</strong>: Combines feature-based SVMs, image-based template matching (modified Tanimoto), and off-the-shelf handwriting recognition to handle the mix of geometry and text.</li>
<li><strong>Domain Verification Loop</strong>: A post-processing step that checks the chemical validity of the structure (e.g., nitrogen must have 3 bonds). If an inconsistency is found, the system searches the space of alternative hypotheses generated during the initial parsing phase to find a valid interpretation.</li>
<li><strong>Contextual Parsing</strong>: Uses a sliding window (up to 7 strokes) and spatial context to parse interspersed symbols.</li>
<li><strong>Implicit Structure Handling</strong>: Supports two common chemistry notations: (1) implicit elements, where carbon and hydrogen atoms are omitted and inferred from bond connectivity and valence rules, and (2) aromatic rings, detected as a circle drawn inside a hexagonal 6-carbon cycle.</li>
</ul>
<h2 id="experimental-design-and-user-study">Experimental Design and User Study</h2>
<p>The authors conducted a user study to evaluate the system&rsquo;s robustness on unconstrained sketches.</p>
<ul>
<li><strong>Participants</strong>: 6 users familiar with organic chemistry.</li>
<li><strong>Task</strong>: Each user drew 12 pre-specified molecular compounds on a Tablet PC.</li>
<li><strong>Conditions</strong>: The system was evaluated in two modes:
<ol>
<li><strong>Domain</strong>: The full system with chemical valence checks.</li>
<li><strong>Baseline</strong>: A simplified version with no knowledge of chemical valence/verification.</li>
</ol>
</li>
<li><strong>Data Split</strong>: Evaluated on collected sketches using a leave-one-out style approach (training on 11 examples from the same users).</li>
</ul>
<h2 id="results-and-error-reduction-analysis">Results and Error Reduction Analysis</h2>
<ul>
<li><strong>Performance</strong>: The full system achieved an overall <strong>F-measure of 0.87</strong> (Precision 0.86, Recall 0.89).</li>
<li><strong>Impact of Domain Knowledge</strong>: Using domain knowledge reduced the overall error rate (measured by recall) by <strong>27%</strong> compared to the baseline. The improvement was statistically significant ($p &lt; .05$).</li>
<li><strong>Error Recovery</strong>: The system successfully recovered from interpretations that were geometrically plausible but chemically impossible (e.g., misinterpreting &ldquo;N&rdquo; as bonds), as illustrated in their qualitative analysis.</li>
<li><strong>Output Integration</strong>: Once interpreted, the resulting structure is expressed in a standard chemical specification format that can be passed to tools such as ChemDraw (for rendering) or SciFinder (for database queries).</li>
<li><strong>Limitations</strong>: The system struggled with &ldquo;messy&rdquo; sketches where users drew single bonds with multiple strokes or over-traced lines, as the current bond recognizer assumes single-stroke straight bonds.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study collected a custom dataset of hand-drawn diagrams.</p>
<ul>
<li><strong>Volume</strong>: 6 participants $\times$ 12 molecules = 72 total sketches (implied).</li>
<li><strong>Preprocessing</strong>:
<ul>
<li><strong>Scale Normalization</strong>: The system estimates scale based on the average length of straight bonds (chosen because they are easy to identify). This normalizes geometric features for the classifier.</li>
<li><strong>Stroke Segmentation</strong>: Poly-line approximation using recursive splitting (minimizing least squared error) to break multi-segment strokes (e.g., connected bonds) into primitives.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Ink Parsing (Sliding Window)</strong></p>
<ul>
<li>Examines all combinations of up to <strong>$n=7$</strong> sequential strokes.</li>
<li>Classifies each group as a valid symbol or invalid garbage.</li>
</ul>
<p><strong>2. Template Matching (Image-based)</strong></p>
<ul>
<li>Used for resolving ambiguities in text/symbols (e.g., &lsquo;H&rsquo; vs &lsquo;N&rsquo;).</li>
<li><strong>Metric</strong>: Modified <strong>Tanimoto coefficient</strong>. Unlike standard Tanimoto (point overlap), this version accounts for relative angle and curvature at each point.</li>
</ul>
<p><strong>3. Domain Verification</strong></p>
<ul>
<li><strong>Trigger</strong>: An element with incorrect valence (e.g., Hydrogen with &gt;1 bond).</li>
<li><strong>Resolution</strong>: Searches stored alternative hypotheses for the affected strokes. It accepts a new hypothesis if it resolves the valence error without introducing new ones.</li>
<li><strong>Constraint</strong>: It keeps an inconsistent structure if the original confidence score is significantly higher than alternatives (assuming user is still drawing or intentionally left it incomplete).</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Symbol Recognizer (Discriminative Classifier)</strong></p>
<ul>
<li><strong>Type</strong>: Support Vector Machine (SVM).</li>
<li><strong>Classes</strong>: Element letters, straight bonds, hash bonds, wedge bonds, invalid groups.</li>
<li><strong>Input Features</strong>:
<ol>
<li>Number of strokes</li>
<li>Bounding-box dimensions (width, height, diagonal)</li>
<li>Ink density (ink length / diagonal length)</li>
<li>Inter-stroke distance (max distance between strokes in group)</li>
<li>Inter-stroke orientation (vector of relative orientations)</li>
</ol>
</li>
</ul>
<p><strong>Text Recognition</strong></p>
<ul>
<li><strong>Microsoft Tablet PC SDK</strong>: Used for recognizing alphanumeric characters (elements and subscripts).</li>
<li>Integrated with the SVM and Template Matcher via a combined scoring mechanism.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value (Overall)</th>
          <th>Baseline Comparison</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Precision</strong></td>
          <td>0.86</td>
          <td>0.81 (Baseline)</td>
          <td>Full system vs. no domain knowledge</td>
      </tr>
      <tr>
          <td><strong>Recall</strong></td>
          <td>0.89</td>
          <td>0.85 (Baseline)</td>
          <td>27% error reduction</td>
      </tr>
      <tr>
          <td><strong>F-Measure</strong></td>
          <td>0.87</td>
          <td>0.83 (Baseline)</td>
          <td>Statistically significant ($p &lt; .05$)</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>True Positive Definition</strong>: Match in both location (stroke grouping) and classification (label).</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Device</strong>: 1.5GHz Tablet PC.</li>
<li><strong>Performance</strong>: Real-time feedback.</li>
</ul>
<h3 id="reproducibility">Reproducibility</h3>
<p>No source code, trained models, or collected sketch data were publicly released. The paper is openly available through the AAAI digital library. The system depends on the Microsoft Tablet PC SDK (a proprietary, now-discontinued component), which would make exact replication difficult even with the algorithm descriptions provided.</p>
<p><strong>Status</strong>: Closed</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ouyang, T. Y., &amp; Davis, R. (2007). Recognition of Hand Drawn Chemical Diagrams. <em>Proceedings of the 22nd National Conference on Artificial Intelligence</em> (AAAI-07), 846-851.</p>
<p><strong>Publication</strong>: AAAI 2007</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{ouyang2007recognition,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Recognition of Hand Drawn Chemical Diagrams}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ouyang, Tom Y and Davis, Randall}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 22nd National Conference on Artificial Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{846--851}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2007}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Distributed Representations: A Foundational Theory</title><link>https://hunterheidenreich.com/notes/machine-learning/model-architectures/distributed-representations/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/model-architectures/distributed-representations/</guid><description>Hinton's 1984 technical report establishing the theoretical efficiency of distributed representations over local encoding in neural networks.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Theory</strong> paper, with strong secondary elements of <strong>Method</strong> and <strong>Position</strong>.</p>
<p>It is a theoretical work because its core contribution is the formal mathematical derivation of the encoding accuracy and error properties of distributed schemes (coarse coding) compared to local schemes. It serves as a position paper by challenging the &ldquo;grandmother cell&rdquo; (local representation) intuition prevalent in AI at the time and advocating for the &ldquo;constructive&rdquo; view of memory.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The motivation is to overcome the inefficiency of <strong>local representations</strong>, where one hardware unit corresponds to exactly one entity, and to challenge traditional metaphors of memory.</p>
<ul>
<li><strong>Inefficiency</strong>: In local representations, high accuracy requires an exponential number of units (accuracy $\propto \sqrt[k]{n}$ for $k$ dimensions).</li>
<li><strong>Brittleness</strong>: Local representations lack natural support for generalization; learning a fact about one concept (e.g., &ldquo;chimps like onions&rdquo;) requires extra machinery to transfer to similar concepts (e.g., &ldquo;gorillas&rdquo;).</li>
<li><strong>Hardware Mismatch</strong>: Massive parallelism is wasted if units are active rarely (1 bit of info per unit active 50% of the time vs. almost 0 for sparse local units).</li>
<li><strong>The &ldquo;Filing Cabinet&rdquo; Metaphor</strong>: The paper challenges the standard view of memory as a storage system of literal copies. It motivates a shift toward understanding memory as a reconstructive inference process.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The paper introduces formal mechanisms that explain <em>why</em> distributed representations are superior:</p>
<ol>
<li><strong>Coarse Coding Efficiency</strong>: Hinton proves that using broad, overlapping receptive fields (&ldquo;coarse coding&rdquo;) yields higher accuracy for a fixed number of units than non-overlapping local fields. For a $k$-dimensional feature space with $n$ units of receptive field radius $r$, accuracy scales as $a \propto n \cdot r^{k-1}$. This is far superior to local encoding, where accuracy scales as $a \propto n^{1/k}$.</li>
<li><strong>Automatic Generalization</strong>: It demonstrates that generalization is an emergent property of vector overlap. Modifying weights for one pattern automatically affects similar patterns (conspiracy effect).</li>
<li><strong>Memory as Reconstruction</strong>: It posits that memory is a reconstructive process where items are created afresh from fragments using plausible inference rules (connection strengths). This blurs the line between veridical recall and confabulation.</li>
<li><strong>Gradual Concept Formation</strong>: Distributed representations allow new concepts to emerge gradually through weight modifications that progressively differentiate existing concepts. This avoids the discrete decisions and spare hardware units required by local representations.</li>
<li><strong>Solution to the Binding Problem</strong>: It proposes that true part/whole hierarchies are formed by fusing the identity of a part with its role to produce a single, new subpattern. The representation of the whole is then the sum of these combined identity/role representations.</li>
</ol>















<figure class="post-figure center ">
    <img src="/img/notes/distributed-representations-binding.svg"
         alt="Diagram showing distributed representations with three pools of units (AGENT, RELATIONSHIP, PATIENT) connected via role/identity bindings"
         title="Diagram showing distributed representations with three pools of units (AGENT, RELATIONSHIP, PATIENT) connected via role/identity bindings"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The binding problem solution: true hierarchies require creating unique subpatterns that fuse an identity with its role, where the whole is represented as the sum of these combined representations.</figcaption>
    
</figure>

<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The paper performs analytical derivations and two specific computer simulations:</p>
<ol>
<li><strong>Arbitrary Mapping Simulation</strong>: A 3-layer network trained to map 20 grapheme strings (e.g., words) to 20 unrelated semantic vectors.</li>
<li><strong>Damage &amp; Recovery Analysis</strong>:
<ul>
<li><strong>Lesioning</strong>: Removing a single word-set unit to observe error patterns. This produced &ldquo;Deep Dyslexia&rdquo;-like semantic errors (e.g., reading &ldquo;PEACH&rdquo; as &ldquo;APRICOT&rdquo;), where the clean-up effect settles on a similar but incorrect meaning.</li>
<li><strong>Noise Injection</strong>: Adding noise to all connections involving word-set units, reducing performance from 99.3% to 64.3%.</li>
<li><strong>Retraining</strong>: Measuring the speed of relearning after noise damage (&ldquo;spontaneous recovery&rdquo;), where unrehearsed items recover alongside rehearsed ones due to shared weights.</li>
</ul>
</li>
</ol>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ol>
<li><strong>Accuracy Scaling</strong>: For a $k$-dimensional feature space, the accuracy $a$ of a distributed representation scales as $a \propto n \cdot r^{k-1}$ (where $r$ is the receptive field radius), vastly outperforming local schemes.</li>
<li><strong>Reliability</strong>: Distributed systems exhibit graceful degradation. Removing units causes slight noise across many items.</li>
<li><strong>Spontaneous Recovery</strong>: When retraining a damaged network on a subset of items, the network &ldquo;spontaneously&rdquo; recovers unrehearsed items due to weight sharing, which is a qualitative signature of distributed representations.</li>
<li><strong>Limitations of Coarse Coding</strong>: The paper identifies that coarse coding requires relatively sparse features. Crowding too many feature-points together causes receptive fields to contain too many features, preventing the activity pattern from discriminating between combinations.</li>
<li><strong>Sequential Processing Constraint</strong>: When constituent structure is represented using identity/role bindings, only one structure can be represented at a time. Hinton argues this matches the empirical observation that people are, to a first approximation, sequential symbol processors.</li>
<li><strong>Learning Problem Deferred</strong>: The paper acknowledges that discovering which sets of items should correspond to single units is a difficult search problem, and defers the learning question to separate work (Hinton, Sejnowski, and Ackley, 1984).</li>
</ol>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>The following details are extracted from Section 5 (&ldquo;Implementing an Arbitrary Mapping&rdquo;) to facilitate reproduction of the &ldquo;Deep Dyslexia&rdquo; and &ldquo;Arbitrary Mapping&rdquo; simulation.</p>
<h3 id="data">Data</h3>
<p>The simulation uses synthetic data representing words and meanings.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>Synthetic Grapheme/Sememe Pairs</td>
          <td>20 pairs</td>
          <td>20 different grapheme strings mapped to random semantic vectors.</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Input (Graphemes)</strong>: 30 total units.
<ul>
<li>Structure: Divided into 3 groups of 10 units each.</li>
<li>Encoding: Each &ldquo;word&rdquo; (3 letters) activates exactly 1 unit in each group (sparse binary).</li>
</ul>
</li>
<li><strong>Output (Sememes)</strong>: 30 total units.
<ul>
<li>Structure: Binary units.</li>
<li>Encoding: Meanings are random vectors where each unit is active with probability $p=0.2$.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Learning Rule</strong>: The paper cites &ldquo;Hinton, Sejnowski &amp; Ackley (1984)&rdquo; (Boltzmann Machines) for the specific learning algorithm used to set weights.</li>
<li><strong>False Positive Analysis</strong>: The probability $f$ that a semantic feature is incorrectly activated is derived as:</li>
</ul>
<p>$$f = (1 - (1-p)^{(w-1)})^u$$</p>
<p>Where:</p>
<ul>
<li>$p$: Probability of a sememe being in a word meaning ($0.2$).</li>
<li>$w$: Number of words in a &ldquo;word-set&rdquo; (cluster).</li>
<li>$u$: Number of active &ldquo;word-set&rdquo; units per word.</li>
</ul>
<h3 id="models">Models</h3>
<p>The simulation uses a specific three-layer architecture.</p>
<table>
  <thead>
      <tr>
          <th>Layer</th>
          <th>Type</th>
          <th>Count</th>
          <th>Connectivity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Input</strong></td>
          <td>Grapheme Units</td>
          <td>30</td>
          <td>Connected to all Intermediate units (no direct link to Output).</td>
      </tr>
      <tr>
          <td><strong>Hidden</strong></td>
          <td>&ldquo;Word-Set&rdquo; Units</td>
          <td>20</td>
          <td>Fully connected to Input and Output.</td>
      </tr>
      <tr>
          <td><strong>Output</strong></td>
          <td>Sememe Units</td>
          <td>30</td>
          <td>Connected to all Intermediate units. Includes lateral inhibition (implied for &ldquo;clean up&rdquo;).</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Weights</strong>: Binary/Integer logic in theoretical analysis, but &ldquo;stochastic&rdquo; weights in the Boltzmann simulation.</li>
<li><strong>Thresholds</strong>: Sememe units have variable thresholds dynamically adjusted to be slightly less than the number of active word-set units.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>The simulation evaluated the robustness of the mapping.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Accuracy (Clean)</strong></td>
          <td>99.9%</td>
          <td>N/A</td>
          <td>Correct pattern produced 99.9% of the time after learning.</td>
      </tr>
      <tr>
          <td><strong>Lesion Error Rate</strong></td>
          <td>1.4%</td>
          <td>N/A</td>
          <td>140 errors in 10,000 tests after removing 1 word-set unit.</td>
      </tr>
      <tr>
          <td><strong>Semantic Errors</strong></td>
          <td>~60% of errors</td>
          <td>N/A</td>
          <td>83 of the 140 lesion errors were &ldquo;Deep Dyslexia&rdquo; errors (producing a valid but wrong semantic pattern).</td>
      </tr>
      <tr>
          <td><strong>Post-Noise Accuracy</strong></td>
          <td>64.3%</td>
          <td>99.3%</td>
          <td>Performance after adding noise to all connections involving word-set units. The 99.3% baseline (reported separately from the 99.9% clean accuracy above) reflects the pre-noise measurement at the time of this specific experiment.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: Minimal. The original simulation ran on 1980s hardware (likely VAX-11 or similar).</li>
<li><strong>Replication</strong>: Reproducible on any modern CPU in milliseconds.</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Hinton, G. E. (1984). Distributed Representations. <em>Technical Report CMU-CS-84-157</em>, Carnegie-Mellon University.</p>
<p><strong>Publication</strong>: CMU Computer Science Department Technical Report, October 1984</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@techreport</span>{hinton1984distributed,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Distributed representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Hinton, Geoffrey E}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{1984}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">institution</span>=<span style="color:#e6db74">{Carnegie-Mellon University}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{CMU-CS-84-157}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SELFIES and the Future of Molecular String Representations</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/selfies-2022/</link><pubDate>Tue, 02 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/selfies-2022/</guid><description>Perspective on SELFIES as a 100% robust SMILES alternative, with 16 future research directions for molecular AI.</description><content:encoded><![CDATA[<h2 id="position-a-roadmap-for-robust-chemical-languages">Position: A Roadmap for Robust Chemical Languages</h2>
<p>This is a <strong>Position</strong> paper (perspective) that proposes a research agenda for molecular representations in AI. It reviews the evolution of chemical notation over 250 years and argues for extending SELFIES-style robust representations beyond traditional organic chemistry into polymers, crystals, reactions, and other complex chemical systems.</p>
<h2 id="the-generative-bottleneck-in-traditional-representations">The Generative Bottleneck in Traditional Representations</h2>
<p>While SMILES has been the standard molecular representation since 1988, its fundamental weakness for machine learning is well-established: randomly generated SMILES strings are often invalid. The motivation is twofold:</p>
<ol>
<li><strong>Current problem</strong>: Traditional representations (SMILES, <a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a>, DeepSMILES) lack 100% robustness; random mutations or generations can produce invalid strings, limiting their use in generative AI models.</li>
<li><strong>Future opportunity</strong>: SELFIES solved this for small organic molecules, but many important chemical domains (polymers, crystals, reactions) still lack robust representations, creating a bottleneck for AI-driven discovery in these areas.</li>
</ol>
<h2 id="16-concrete-research-directions-for-selfies">16 Concrete Research Directions for SELFIES</h2>
<p>The novelty is in the comprehensive research roadmap. The authors propose 16 concrete research projects organized around key themes:</p>
<ul>
<li><strong>Domain extension</strong>: Includes metaSELFIES for learning graph rules directly from data, BigSELFIES for stochastic polymers, and crystal structures via labeled quotient graphs.</li>
<li><strong>Chemical reactions</strong>: Robust reaction representations that enforce conservation laws.</li>
<li><strong>Programming perspective</strong>: Treating molecular representations as programming languages, potentially achieving Turing-completeness.</li>
<li><strong>Benchmarking</strong>: Systematic comparisons across representation formats.</li>
<li><strong>Interpretability</strong>: Understanding how humans and machines actually learn from different representations.</li>
</ul>
<h2 id="evidence-from-generative-case-studies">Evidence from Generative Case Studies</h2>
<p>This perspective paper includes case studies:</p>
<ol>
<li>
<p><strong>Pasithea (Deep Molecular Dreaming)</strong>: A generative model that first learns to predict a chemical property from a one-hot encoded SELFIES, then freezes the network weights and uses gradient descent on the one-hot input encoding to optimize molecular properties (logP). The target property increases or decreases nearly monotonically, demonstrating that the model has learned meaningful structure-property relationships from the SELFIES representation.</p>
</li>
<li>
<p><strong>DECIMER and STOUT</strong>: DECIMER (Deep lEarning for Chemical ImagE Recognition) is an image-to-structure tool, and STOUT (SMILES-TO-IUPAC-name Translator) translates between IUPAC names and molecular string representations. Both show improved performance when using SELFIES as an intermediate representation. STOUT internally converts SMILES to SELFIES before processing and decodes predicted SELFIES back to SMILES. These results suggest SELFIES provides a more learnable internal representation for sequence-to-sequence models.</p>
</li>
</ol>
<h2 id="strategic-outcomes-and-future-vision">Strategic Outcomes and Future Vision</h2>
<p>The paper establishes robust representations as a fundamental bottleneck in computational chemistry and proposes a clear path forward:</p>
<p><strong>Key outcomes</strong>:</p>
<ul>
<li>Identification of 16 concrete research projects spanning domain extension, benchmarking, and interpretability</li>
<li>Evidence that SELFIES enables capabilities (like smooth property optimization) impossible with traditional formats</li>
<li>Framework for thinking about molecular representations as programming languages</li>
</ul>
<p><strong>Strategic impact</strong>: The proposed extensions could enable new applications across drug discovery (efficient exploration beyond small molecules), materials design (systematic crystal structure discovery), synthesis planning (better reaction representations), and fundamental research (new ways to understand chemical behavior).</p>
<p><strong>Future vision</strong>: The authors emphasize that robust representations could become a bridge for bidirectional learning between humans and machines, enabling humans to learn new chemical concepts from AI systems.</p>
<h2 id="the-mechanism-of-robustness">The Mechanism of Robustness</h2>
<p>The key difference between SELFIES and other representations lies in how they handle syntax:</p>
<ul>
<li><strong>SMILES/DeepSMILES</strong>: Rely on non-local markers (opening/closing parentheses or ring numbers) that must be balanced. A mutation or random generation can easily break this balance, producing invalid strings.</li>
<li><strong>SELFIES</strong>: Uses a formal grammar (automaton) where derivation rules are entirely local. The critical innovation is <strong>overloading</strong>: a state-modifying symbol like <code>[Branch1]</code> starts a branch and changes the interpretation of the <em>next</em> symbol to represent a numerical parameter (the branch length).</li>
</ul>
<p>This overloading mechanism ensures that any arbitrary sequence of SELFIES tokens can be parsed into a valid molecular graph. The derivation can never fail because every symbol either adds an atom or modifies how subsequent symbols are interpreted.</p>
<h2 id="the-16-research-projects-technical-details">The 16 Research Projects: Technical Details</h2>
<p>This section provides technical details on the proposed research directions:</p>
<h3 id="extending-to-new-domains">Extending to New Domains</h3>
<p><strong>metaSELFIES (Project 1)</strong>: The authors propose learning graph construction rules automatically from data. This could enable robust representations for any graph-based system, from quantum optics to biological networks, without needing domain-specific expertise.</p>
<p><strong>Token Optimization (Project 2)</strong>: SELFIES uses &ldquo;overloading&rdquo; where a symbol&rsquo;s meaning changes based on context. This project would investigate how this affects machine learning performance and whether the approach can be optimized.</p>
<h3 id="handling-complex-molecular-systems">Handling Complex Molecular Systems</h3>
<p><strong>BigSELFIES (Project 3)</strong>: Current representations struggle with large, often random structures like polymers and biomolecules. BigSELFIES would combine hierarchical notation with stochastic building blocks to handle these complex systems where traditional small-molecule representations break down.</p>
<p><strong>Crystal Structures (Projects 4-5)</strong>: Crystals present unique challenges due to their infinite, periodic arrangements. An infinite net cannot be represented by a finite string directly. The proposed approach uses <strong>labeled quotient graphs (LQGs)</strong>, which are finite graphs that uniquely determine a periodic net. However, current SELFIES cannot represent LQGs because they lack symbols for edge directions and edge labels (vector shifts encoding periodicity). Extending SELFIES to handle these structures could enable AI-driven materials design without relying on predefined crystal structures, opening up systematic exploration of theoretical materials space.</p>
<p><strong>Beyond Organic Chemistry (Project 6)</strong>: Transition metals and main-group compounds feature complex bonding that breaks the simple two-center, two-electron model. The solution: use machine learning on large structural databases to automatically learn these complex bonding rules.</p>
<h3 id="chemical-reactions-and-programming-concepts">Chemical Reactions and Programming Concepts</h3>
<p><strong>Reaction Representations (Project 7)</strong>: Moving beyond static molecules to represent chemical transformations. A robust reaction format would enforce conservation laws and could learn reactivity patterns from large reaction datasets, improving synthesis planning.</p>
<h3 id="developing-a-100-robust-programming-language">Developing a 100% Robust Programming Language</h3>
<p><strong>Programming Language Perspective (Projects 8-9)</strong>: An intriguing reframing views molecular representations as programming languages executed by chemical parsers. This opens possibilities for adding loops, logic, and other programming concepts to efficiently describe complex structures. The ambitious goal is a Turing-complete programming language that is also 100% robust. While fascinating, it is worth critically noting that enforcing 100% syntactical robustness inherently restricts grammar flexibility. Can a purely robust string representation realistically describe highly fuzzy, delocalized electron bonds (like in Project 6) without becoming impractically long or collapsing into specialized sub-languages?</p>
<p><strong>Empirical Comparisons (Projects 10-11)</strong>: With multiple representation options (strings, matrices, images), we need systematic comparisons. The proposed benchmarks would go beyond simple validity metrics to focus on real-world design objectives in drug discovery, catalysis, and materials science.</p>
<p><strong>Human Readability (Project 12)</strong>: While SMILES is often called &ldquo;human-readable,&rdquo; this claim lacks scientific validation. The proposed study would test how well humans actually understand different molecular representations.</p>
<p><strong>Machine Learning Perspectives (Projects 13-16)</strong>: These projects explore how machines interpret molecular representations:</p>
<ul>
<li>Training networks to translate between formats to find universal representations</li>
<li>Comparing learning efficiency across different formats</li>
<li>Investigating latent space smoothness in generative models</li>
<li>Visualizing what models actually learn about molecular structure</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>Since this is a position paper outlining future research directions, standard empirical reproducibility metrics do not apply. However, the foundational tools required to pursue the proposed roadmap are open-source.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/aspuru-guzik-group/selfies">aspuru-guzik-group/selfies</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Core SELFIES Python library, installable via <code>pip install selfies</code></td>
      </tr>
      <tr>
          <td><a href="https://arxiv.org/abs/2204.00056">arXiv:2204.00056</a></td>
          <td>Paper</td>
          <td>N/A</td>
          <td>Open-access preprint of the published Patterns article</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Krenn, M., Ai, Q., Barthel, S., Carson, N., Frei, A., Frey, N. C., Friederich, P., Gaudin, T., Gayle, A. A., Jablonka, K. M., Lameiro, R. F., Lemm, D., Lo, A., Moosavi, S. M., Nápoles-Duarte, J. M., Nigam, A., Pollice, R., Rajan, K., Schatzschneider, U., &hellip; Aspuru-Guzik, A. (2022). SELFIES and the future of molecular string representations. <em>Patterns</em>, <em>3</em>(10). <a href="https://doi.org/10.1016/j.patter.2022.100588">https://doi.org/10.1016/j.patter.2022.100588</a></p>
<p><strong>Publication</strong>: Patterns 2022</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{Krenn2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{SELFIES and the future of molecular string representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{3}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">ISSN</span> = <span style="color:#e6db74">{2666-3899}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{http://dx.doi.org/10.1016/j.patter.2022.100588}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">DOI</span> = <span style="color:#e6db74">{10.1016/j.patter.2022.100588}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{10}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Patterns}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Elsevier BV}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Krenn, Mario and Ai, Qianxiang and Barthel, Senja and Carson, Nessa and Frei, Angelo and Frey, Nathan C. and Friederich, Pascal and Gaudin, Théophile and Gayle, Alberto Alexander and Jablonka, Kevin Maik and Lameiro, Rafael F. and Lemm, Dominik and Lo, Alston and Moosavi, Seyed Mohamad and Nápoles-Duarte, José Manuel and Nigam, AkshatKumar and Pollice, Robert and Rajan, Kohulan and Schatzschneider, Ulrich and Schwaller, Philippe and Skreta, Marta and Smit, Berend and Strieth-Kalthoff, Felix and Sun, Chong and Tom, Gary and von Rudorff, Guido Falk and Wang, Andrew and White, Andrew and Young, Adamo and Yu, Rose and Aspuru-Guzik, Alán}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = oct,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{100588}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="/notes/chemistry/molecular-representations/notations/selfies-original-paper/">Original SELFIES Paper</a></li>
<li><a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES Overview</a></li>
</ul>
]]></content:encoded></item><item><title>Invalid SMILES Benefit Chemical Language Models: A Study</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/invalid-smiles-help/</link><pubDate>Tue, 02 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/invalid-smiles-help/</guid><description>Skinnider (2024) shows that generating invalid SMILES actually improves chemical language model performance through quality filtering.</description><content:encoded><![CDATA[<h2 id="core-contribution-repurposing-invalid-smiles">Core Contribution: Repurposing Invalid SMILES</h2>
<p>This is an <strong>Empirical</strong> paper that challenges a fundamental assumption in the field of chemical language models. Skinnider provides both empirical evidence and mechanistic explanations for why the ability to generate &ldquo;invalid&rdquo; SMILES strings is beneficial for model performance.</p>
<h2 id="the-problem-with-absolute-validity-in-chemical-lms">The Problem with Absolute Validity in Chemical LMs</h2>
<p>Prior research attempted to eliminate invalid generations using constrained representations like SELFIES. This paper demonstrates that invalid outputs serve as low-likelihood samples whose removal acts as an implicit quality filter, improving distribution learning.</p>
<h2 id="invalid-generation-as-an-implicit-quality-filter">Invalid Generation as an Implicit Quality Filter</h2>
<p>The central insight is counterintuitive: <strong>invalid SMILES generation acts as a built-in quality control mechanism</strong>. The key contributions are:</p>
<ol>
<li>
<p><strong>Empirical Evidence</strong>: Direct comparisons showing that SMILES-based models consistently outperform SELFIES-based models across multiple metrics, with performance gains strongly correlated with the proportion of invalid outputs generated.</p>
</li>
<li>
<p><strong>Mechanistic Explanation</strong>: Invalid SMILES are demonstrated to be low-likelihood samples from the model&rsquo;s probability distribution. When these are filtered out, it&rsquo;s equivalent to removing the model&rsquo;s least confident predictions, a form of automatic quality control.</p>
</li>
<li>
<p><strong>Causal Evidence</strong>: By modifying SELFIES to allow invalid generation (through relaxed constraints), the author shows that performance improves when models can generate and discard invalid outputs, directly proving the causal relationship.</p>
</li>
<li>
<p><strong>Bias Analysis</strong>: SELFIES models are shown to introduce systematic structural biases (fewer aromatic rings, more aliphatic rings) due to their validity constraints, limiting their ability to explore chemical space naturally.</p>
</li>
</ol>
<h2 id="experimental-design-and-causal-interventions">Experimental Design and Causal Interventions</h2>
<p>The paper uses a multi-pronged approach to establish both correlation and causation:</p>
<p><strong>Performance Comparisons</strong>: SMILES and SELFIES models were trained on identical datasets and evaluated using distribution-learning metrics like Fréchet ChemNet distance. The comparison was robust across different architectures, training set sizes, and chemical databases.</p>
<p><strong>Loss Analysis</strong>: The relationship between SMILES validity and model confidence was examined by analyzing the sequence loss. For a given SMILES string $S$ composed of tokens $t_1, t_2, &hellip;, t_N$, the negative log-likelihood acts as a proxy for the model&rsquo;s uncertainty:</p>
<p>$$ \text{NLL}(S) = -\sum_{i=1}^N \log P(t_i | t_1, &hellip;, t_{i-1}) $$</p>
<p>Invalid SMILES strings consistently register higher $\text{NLL}$ scores, meaning they represent the model&rsquo;s least confident predictions. Filtering them effectively acts as automatic quality control, providing the mechanistic explanation for why invalid filtering improves performance.</p>
<p><strong>Causal Intervention</strong>: A key experiment involved modifying the SELFIES valency constraints at two levels: first allowing pentavalent carbons (&ldquo;Texas SELFIES&rdquo;), then removing all constraints entirely (&ldquo;unconstrained SELFIES&rdquo;). This allowed direct testing of whether the ability to generate invalid outputs (which are then discarded) causally improves performance.</p>
<p><strong>Structural Bias Analysis</strong>: Generated molecules were analyzed for chemical features like ring types and bond patterns to quantify how validity constraints systematically distort the model&rsquo;s exploration of chemical space.</p>
<p><strong>Generalization Testing</strong>: Models were trained on subsets of chemical databases and tested on their ability to reproduce the broader chemical space, measuring how validity constraints affect generalization.</p>
<p><strong>Practical Application</strong>: The approach was tested on structure elucidation, using models to identify unknown molecules from minimal experimental data like mass spectrometry.</p>
<h2 id="key-findings-on-validity-constraints-and-bias">Key Findings on Validity Constraints and Bias</h2>
<p><strong>Superior Performance Across the Board</strong>: SMILES-based models consistently outperformed SELFIES models on distribution-learning tasks. Using metrics like Fréchet ChemNet distance, SMILES models generated molecules that more closely matched the statistical properties of their training data. This performance advantage was directly correlated with the proportion of invalid SMILES generated. Models that produced more invalid outputs performed better after filtering.</p>
<p><strong>Invalid SMILES Are Low-Confidence Predictions</strong>: The analysis revealed that invalid SMILES consistently have higher loss values than valid ones, meaning they represent the model&rsquo;s least confident predictions. This suggests that validity checking acts as an automatic confidence filter, removing low-quality samples without requiring explicit uncertainty estimation.</p>
<p><strong>Causal Evidence Through Unconstrained SELFIES</strong>: Direct causal evidence came from modifying SELFIES to allow invalid generation. When &ldquo;unconstrained SELFIES&rdquo; models could generate and discard invalid molecules, their performance improved, approaching that of SMILES models. This provides direct causal evidence that the ability to generate invalid outputs is what drives the performance gains.</p>
<p><strong>Validity Constraints Introduce Systematic Bias</strong>: SELFIES models showed clear structural biases compared to both training data and SMILES outputs. They generated fewer aromatic rings and more aliphatic structures, systematic distortions caused by the valency constraints used to ensure validity. These biases limit the model&rsquo;s ability to faithfully represent chemical space.</p>
<p><strong>Reduced Generalization</strong>: When trained on subsets of chemical databases, SMILES models could reproduce a larger portion of the complete chemical space compared to SELFIES models. Although SELFIES generated more valid molecules in absolute terms, their structural biases constrained exploration and limited generalization beyond the training set.</p>
<p><strong>Real-World Application Benefits</strong>: In structure elucidation tasks, identifying unknown molecules from experimental data like mass spectrometry, SMILES-based models significantly outperformed SELFIES models. This demonstrates that the benefits extend beyond academic benchmarks to practical applications.</p>
<p><strong>CASMI 2022 Benchmark</strong>: The language model trained on the LOTUS database was benchmarked against 19 submissions to the CASMI 2022 competition for structure elucidation of unknown compounds. Using only accurate mass as input (no MS/MS data), the model achieved competitive performance, highlighting the practical utility of the sampling-frequency-based approach for de novo structure elucidation.</p>
<p><strong>Computational Efficiency</strong>: Filtering invalid SMILES is computationally trivial. Parsing ten million SMILES strings with RDKit takes approximately 7.5 minutes on a single CPU, making the post-processing overhead negligible compared to model training and inference costs.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="models">Models</h3>
<p><strong>Primary Architecture (LSTM):</strong> The main results rely on a Recurrent Neural Network (RNN) using Long Short-Term Memory (LSTM) units.</p>
<ul>
<li><strong>Structure:</strong> Three-layer LSTM with a hidden layer size of 1,024 dimensions</li>
<li><strong>Embedding:</strong> An embedding layer of 128 dimensions</li>
<li><strong>Decoder:</strong> A linear decoder layer outputs token probabilities</li>
</ul>
<p><strong>Secondary Architecture (Transformer/GPT):</strong> To confirm robustness across architectures, the author also used a Generative Pretrained Transformer (GPT) architecture adapted from MolGPT.</p>
<ul>
<li><strong>Structure:</strong> Eight transformer blocks</li>
<li><strong>Internals:</strong> Each block contains eight masked self-attention heads and a feed-forward network (1,024 dimensions) using GELU activation</li>
<li><strong>Embedding:</strong> 256 dimensions, concatenated with learned positional encodings</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Optimizer:</strong> Adam optimizer for both architectures with $\beta_1=0.9$ and $\beta_2=0.999$.</p>
<p><strong>Learning Rate:</strong></p>
<ul>
<li>LSTM: 0.001</li>
<li>Transformer: 0.0005</li>
</ul>
<p><strong>Batch Size:</strong> 64</p>
<p><strong>Loss Function:</strong> Cross-entropy loss of next-token prediction.</p>
<p><strong>Stopping Criteria:</strong> Early stopping using a validation set (10% of training data) with patience of 50,000 minibatches.</p>
<h3 id="data">Data</h3>
<p><strong>Primary Source:</strong> ChEMBL database (version 28).</p>
<p><strong>Preprocessing Pipeline:</strong></p>
<ul>
<li><strong>Cleaning:</strong> Removal of duplicate SMILES, salts, and solvents (retaining heavy fragments with $\geq 3$ heavy atoms)</li>
<li><strong>Filtering:</strong> Molecules with atoms other than {Br, C, Cl, F, H, I, N, O, P, S} were removed</li>
<li><strong>Normalization:</strong> Charged molecules were neutralized and converted to canonical SMILES</li>
</ul>
<p><strong>Training Subsets:</strong> Models were trained on random samples of 30,000, 100,000, and 300,000 molecules to test scalability.</p>
<p><strong>Generalization Data:</strong> To test generalization, models were also trained on the <a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a> database (enumerating drug-like molecules up to 13 heavy atoms).</p>
<p><strong>Structure Elucidation Data:</strong> For practical application tasks, models were trained on natural products (LOTUS, COCONUT), food compounds (FooDB), and environmental contaminants (NORMAN).</p>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Primary Metric:</strong> Fréchet ChemNet Distance (FCD), measuring chemical similarity between generated molecules and the training set (lower is better).</p>
<p><strong>Secondary Metrics:</strong></p>
<ul>
<li><strong>Validity:</strong> Percentage of outputs parseable by RDKit</li>
<li><strong>Scaffold Similarity:</strong> Jensen-Shannon distances between Murcko scaffold compositions</li>
<li><strong>Physical Properties:</strong> Comparisons of molecular weight, LogP, topological polar surface area (TPSA), and ring counts (aromatic vs. aliphatic)</li>
<li><strong>Structure Elucidation:</strong> &ldquo;Top-k accuracy,&rdquo; the proportion of held-out molecules where the correct structure appeared in the model&rsquo;s top $k$ ranked outputs</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute Nodes:</strong> Dell EMC C4140 GPU compute nodes</li>
<li><strong>GPUs:</strong> NVIDIA Tesla V100</li>
<li><strong>Compute Time:</strong> Parsing 10 million SMILES took ~7.5 minutes on a single CPU; SELFIES models required an average of 0.6 hours longer to train than SMILES models</li>
</ul>
<h3 id="replicability">Replicability</h3>
<p><strong>Code Availability:</strong> Source code and intermediate data are available via <a href="https://doi.org/10.5281/zenodo.10680855">Zenodo</a>. Pre-trained model weights are not provided in the archive, requiring researchers to train models from scratch using the included scripts to fully replicate the study.</p>
<p><strong>Data Availability:</strong> Training datasets and generated molecule samples (10 million from ChEMBL/GDB-13 models, 100 million from LOTUS/COCONUT/FooDB/NORMAN cross-validation folds) are available via <a href="https://doi.org/10.5281/zenodo.8321735">Zenodo</a>.</p>
<p><strong>Software Libraries:</strong></p>
<ul>
<li><strong>PyTorch:</strong> LSTM and Transformer implementations</li>
<li><strong>RDKit:</strong> SMILES parsing, validity checking, and property calculation</li>
<li><strong>SELFIES:</strong> Version 2.1.1 for conversion</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.10680855">Source code (Zenodo)</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Training scripts, analysis code, and intermediate data</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.8321735">Training and generated molecules (Zenodo)</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>Preprocessed training sets and sampled molecules</td>
      </tr>
  </tbody>
</table>
<h2 id="implications-and-takeaways">Implications and Takeaways</h2>
<p>This work reframes how we think about &ldquo;errors&rdquo; in generative models. The key insight is that model outputs appearing incorrect often represent low-likelihood samples whose removal improves overall performance.</p>
<p>The findings suggest that the field&rsquo;s drive toward guaranteed validity leads to systematic biases. Letting models fail informatively and using those failures as quality signals can yield better distribution learning. This is relevant as the field moves toward larger, more capable models where such self-correction mechanisms become increasingly valuable.</p>
<p>For practitioners, the takeaway is to consider the role of invalid outputs before eliminating them. Filtering low-confidence generations provides automatic quality control that improves final results.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Skinnider, M. A. (2024). Invalid SMILES are beneficial rather than detrimental to chemical language models. Nature Machine Intelligence, 6(4), 437-448. <a href="https://doi.org/10.1038/s42256-024-00821-x">https://doi.org/10.1038/s42256-024-00821-x</a></p>
<p><strong>Publication</strong>: Nature Machine Intelligence (2024)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{skinnider2024invalid,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Invalid SMILES are beneficial rather than detrimental to chemical language models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Skinnider, Michael A}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Machine Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{6}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{437--448}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group UK London}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Importance Weighted Autoencoders: Beyond the Standard VAE</title><link>https://hunterheidenreich.com/posts/importance-weighted-autoencoders/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/importance-weighted-autoencoders/</guid><description>The key difference between multi-sample VAEs and IWAEs: how log-of-averages creates a tighter bound on log-likelihood.</description><content:encoded><![CDATA[<p>If you&rsquo;ve worked with Variational Autoencoders (VAEs), you&rsquo;ve almost certainly used the standard $\mathcal{L}_1$ objective, or ELBO. It&rsquo;s trained by taking <em>one</em> sample ($k=1$) from the recognition network to calculate the loss.</p>
<p>A natural question follows: &ldquo;What if I use more samples? Won&rsquo;t that make it better?&rdquo;</p>
<p>Using more samples improves performance when paired with the correct objective function. Averaging the loss over $k$ samples yields minimal gains. Changing the objective itself is where the real gain comes from. This post explores the difference between a &ldquo;multi-sample VAE&rdquo; and the <strong>Importance Weighted Autoencoder (IWAE)</strong>, a model that uses the <em>same architecture</em> as a VAE but is trained with a different objective that optimizes a tighter bound on the log-likelihood.</p>
<p>All ideas here are based on the fantastic paper: <a href="https://arxiv.org/abs/1509.00519">&ldquo;Importance Weighted Autoencoders&rdquo;</a> by Burda, Grosse, and Salakhutdinov.</p>
<h2 id="the-two-ways-to-use-k-samples">The Two Ways to Use $k$ Samples</h2>
<p>Let&rsquo;s say we have our encoder $q(h|x)$ and decoder $p(x,h)$. We decide to use $k=5$ samples instead of $k=1$. We have two main options for how to calculate our loss.</p>
<h3 id="option-1-the-multi-sample-vae-the-naive-way">Option 1: The &ldquo;Multi-Sample VAE&rdquo; (The Naive Way)</h3>
<p>This is the most straightforward idea. For each input $x$ in our batch:</p>
<ol>
<li>Draw 5 samples ($h_1, &hellip;, h_5$) from $q(h|x)$.</li>
<li>Calculate the standard VAE $\mathcal{L}_1$ loss for <em>each</em> sample.</li>
<li>Average these 5 losses together.</li>
</ol>
<p>This is an <strong>average of logs</strong>. As the IWAE paper shows experimentally, this approach gives you a more stable gradient, but the final performance (in terms of log-likelihood) is &ldquo;only slightly&rdquo; better. You&rsquo;re paying a 5x computational cost for a marginal gain because you&rsquo;re still optimizing the same &ldquo;loose&rdquo; $\mathcal{L}_1$ bound.</p>
<h3 id="option-2-the-importance-weighted-autoencoder-iwae-the-right-way">Option 2: The Importance Weighted Autoencoder (IWAE) (The Right Way)</h3>
<p>The IWAE takes a different approach. For each input $x$:</p>
<ol>
<li>Draw 5 samples ($h_1, &hellip;, h_5$) from $q(h|x)$.</li>
<li>Calculate an &ldquo;importance weight&rdquo; $w_i$ for each sample.</li>
<li>Average these 5 <em>weights</em> together.</li>
<li>Take the <em>logarithm</em> of that average.</li>
</ol>
<p>This is a <strong>log of an average</strong>, and the difference matters.</p>
<h2 id="the-math-average-of-logs-vs-log-of-averages">The Math: Average-of-Logs vs. Log-of-Averages</h2>
<p>Let&rsquo;s make this concrete. The standard VAE $\mathcal{L}_1$ objective is:</p>
<p>$$
\mathcal{L}_1(x) = \mathbb{E} _{h\sim q(h|x)} \left[ \log \frac{p(x,h)}{q(h|x)} \right]
$$</p>
<p>A <strong>multi-sample VAE</strong> simply gets a better estimate of this same value:</p>
<p>$$
\mathcal{L} _{\text{VAE}, k}(x) \approx  \frac{1}{k} \sum _{i=1}^{k} \log w_i \quad \text{where} \quad w_i = \frac{p(x,h_i)}{q(h_i|x)}
$$</p>
<p>The <strong>IWAE</strong> objective, $\mathcal{L}_k$, is fundamentally different:</p>
<p>$$
\mathcal{L} _k (x) = \mathbb{E} _{h_1..h_k \sim q(h|x)} \left[ \log \left( \frac{1}{k} \sum _{i=1}^{k} \frac{p(x,h_i)}{q(h_i|x)} \right) \right]
$$</p>
<p>In practice, we estimate this with a single Monte Carlo sample (of $k$ latents):</p>
<p>$$
\mathcal{L} _k (x) \approx \log \left( \frac{1}{k} \sum _{i=1}^{k} w_i \right)
$$</p>
<p>Because the logarithm is a concave function, Jensen&rsquo;s Inequality tells us that the &ldquo;log of an average&rdquo; is <em>always</em> greater than or equal to the &ldquo;average of logs.&rdquo;</p>
<p>$$
\mathcal{L}_k(x) \ge \mathcal{L}_1(x)
$$</p>
<p>This means the IWAE is optimizing a <strong>strictly tighter lower bound</strong> on the true log-likelihood of the data.</p>
<h2 id="why-does-this-log-of-average-matter">Why Does This &ldquo;Log-of-Average&rdquo; Matter?</h2>
<p>This mathematical property provides two practical benefits.</p>
<h3 id="1-better-density-estimation">1. Better Density Estimation</h3>
<p>Because $\mathcal{L}_k$ is a tighter bound on the true $p(x)$, optimizing it pushes the model to learn a much better generative distribution. The paper shows that IWAEs achieve &ldquo;significantly higher log-likelihoods&rdquo; than VAEs.</p>
<h3 id="2-richer-latent-representations">2. Richer Latent Representations</h3>
<p>This is the most interesting part. The standard VAE $\mathcal{L}_1$ objective &ldquo;harshly penalizes&rdquo; the model if its <em>one</em> sample $h$ is a poor explanation for $x$. This pressure forces the recognition network $q(h|x)$ to be &ldquo;overly simplified&rdquo; to avoid bad samples, which can lead to many latent dimensions becoming inactive (the paper reports the number of &ldquo;active units&rdquo; per model).</p>
<p>The IWAE objective is more flexible. It only needs <em>one</em> of the $k$ samples to be good. This &ldquo;increased flexibility&rdquo; allows the model to learn far more complex posterior distributions and &ldquo;richer latent space representations.&rdquo; The paper&rsquo;s experiments confirm this, showing IWAEs learn to use many more &ldquo;active units&rdquo; in their latent space.</p>
<h2 id="what-this-looks-like-in-code-pytorch">What This Looks Like in Code (PyTorch)</h2>
<p>The implementation difference makes this crystal clear.</p>
<p>First, the &ldquo;k-sample&rdquo; trick: for a batch <code>x</code> of shape <code>[B, D]</code> and <code>k=5</code> samples, we repeat <code>x</code> to get <code>x_repeated</code> of shape <code>[B*k, D]</code>. We do all our forward passes on this large tensor.</p>
<h3 id="vae-multi-sample-k--1-loss">VAE (Multi-Sample, k &gt; 1) Loss</h3>
<p>Here, we can still use the analytical KL divergence, which is a big simplification.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># x_repeated has shape [B*k, 784]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># mu, logvar have shape [B*k, latent_dim]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># recon_x has shape [B*k, 784]</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># recon_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span>recon_loss_all <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>binary_cross_entropy(recon_x, x_repeated, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># kl_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># We use the simple, analytical KL term!</span>
</span></span><span style="display:flex;"><span>kl_loss_all <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(<span style="color:#ae81ff">1</span> <span style="color:#f92672">+</span> logvar <span style="color:#f92672">-</span> mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> logvar<span style="color:#f92672">.</span>exp(), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># total_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span>total_loss_all <span style="color:#f92672">=</span> recon_loss_all <span style="color:#f92672">+</span> kl_loss_all
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- The Key Step ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Just average all B*k losses. This is the &#34;average of logs&#34;.</span>
</span></span><span style="display:flex;"><span>loss <span style="color:#f92672">=</span> total_loss_all<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="iwae-k--1-loss">IWAE (k &gt; 1) Loss</h3>
<p>Here, we must compute the exact log-probabilities of the <em>specific samples</em> we drew.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Helper function to compute log-prob of a sample from a Gaussian</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">log_prob_gaussian</span>(sample, mu, logvar):
</span></span><span style="display:flex;"><span>    const <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sample<span style="color:#f92672">.</span>shape[<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>] <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>log(<span style="color:#ae81ff">2</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>tensor(math<span style="color:#f92672">.</span>pi))
</span></span><span style="display:flex;"><span>    log_det <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(logvar, dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    log_exp <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum((sample <span style="color:#f92672">-</span> mu)<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">/</span> torch<span style="color:#f92672">.</span>exp(logvar), dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> const <span style="color:#f92672">+</span> log_det <span style="color:#f92672">+</span> log_exp
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Get the 3 log-prob components ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># x_repeated, recon_x, z_samples, mu_repeated, logvar_repeated</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># all have a first dimension of [B*k]</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 1. log p(x|h_i): Log-Reconstruction Probability</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_p_x_given_h shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_p_x_given_h <span style="color:#f92672">=</span> <span style="color:#f92672">-</span>F<span style="color:#f92672">.</span>binary_cross_entropy(recon_x, x_repeated, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 2. log p(h_i): Log-Prior Probability (under N(0, I))</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_p_h shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_p_h <span style="color:#f92672">=</span> log_prob_gaussian(z_samples, <span style="color:#ae81ff">0.0</span>, <span style="color:#ae81ff">0.0</span>) <span style="color:#75715e"># mu=0, logvar=0</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 3. log q(h_i|x): Log-Encoder Probability</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_q_h_given_x shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_q_h_given_x <span style="color:#f92672">=</span> log_prob_gaussian(z_samples, mu_repeated, logvar_repeated)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- The Key Step ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Combine to get the log-importance-weight</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_w shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_w <span style="color:#f92672">=</span> log_p_x_given_h <span style="color:#f92672">+</span> log_p_h <span style="color:#f92672">-</span> log_q_h_given_x
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Reshape to [B, k] to group samples by their original input</span>
</span></span><span style="display:flex;"><span>log_w_matrix <span style="color:#f92672">=</span> log_w<span style="color:#f92672">.</span>view(B, k) <span style="color:#75715e"># B is original batch size</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Apply the IWAE Objective (Log-Sum-Exp Trick) ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># This is the &#34;log of the average&#34;</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log( (1/k) * sum(exp(log_w)) ) = logsumexp(log_w) - log(k)</span>
</span></span><span style="display:flex;"><span>log_iwae_bound_per_x <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>logsumexp(log_w_matrix, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> math<span style="color:#f92672">.</span>log(k)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># The objective is to MAXIMIZE this bound, so the loss is its negative</span>
</span></span><span style="display:flex;"><span>loss <span style="color:#f92672">=</span> <span style="color:#f92672">-</span>log_iwae_bound_per_x<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="the-critical-implementation-detail">The Critical Implementation Detail</h3>
<p>Notice the key difference in the final step:</p>
<ul>
<li><strong>VAE</strong>: <code>loss = total_loss_all.mean()</code> average of individual losses</li>
<li><strong>IWAE</strong>: <code>loss = -torch.logsumexp(log_w_matrix, dim=1).mean()</code> log of averaged weights</li>
</ul>
<p>This seemingly small change implements the fundamental mathematical difference between optimizing an &ldquo;average of logs&rdquo; versus a &ldquo;log of averages.&rdquo;</p>
<h2 id="when-to-use-each-approach">When to Use Each Approach</h2>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>When to Use</th>
          <th>Key Benefit</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>VAE ($k=1$)</strong></td>
          <td>Your <strong>default baseline</strong>. It&rsquo;s fast, simple, and often &ldquo;good enough&rdquo; for many tasks.</td>
          <td>Speed and simplicity.</td>
      </tr>
      <tr>
          <td><strong>Multi-Sample VAE ($k&gt;1$)</strong></td>
          <td>When you want slightly more stable gradients but aren&rsquo;t ready for the full IWAE complexity.</td>
          <td>Marginal improvement with minimal code changes.</td>
      </tr>
      <tr>
          <td><strong>IWAE ($k&gt;1$)</strong></td>
          <td>When your baseline VAE is <strong>insufficient</strong>. Specifically, if you need:<br>1. The best possible log-likelihood.<br>2. To activate more latent dimensions or learn richer representations.</td>
          <td>Better performance and richer latents, at the cost of compute (scales linearly with $k$).</td>
      </tr>
  </tbody>
</table>
<h2 id="the-computational-trade-off">The Computational Trade-off</h2>
<p>Both approaches scale linearly with $k$. If you use $k=5$ samples, you&rsquo;re doing roughly 5x the computation. The question is whether you get 5x the benefit.</p>
<p>For multi-sample VAEs, the answer is usually &ldquo;no&rdquo;. You get more stable gradients but only marginal performance improvements.</p>
<p>For IWAEs, the answer is often &ldquo;yes&rdquo;. You get meaningfully better log-likelihoods and richer latent representations that can be worth the computational cost.</p>
<h2 id="conclusion">Conclusion</h2>
<p>The next time you use more samples with your VAE, switch to the IWAE objective to get the full benefit of the computational cost of $k &gt; 1$.</p>
<p>The mathematical insight is simple but powerful: Jensen&rsquo;s Inequality tells us that the &ldquo;log of an average&rdquo; is always greater than or equal to the &ldquo;average of logs.&rdquo; By optimizing this tighter bound, IWAEs achieve better density estimation and learn richer latent representations than standard VAEs.</p>
<p>The implementation requires computing exact log-probabilities to evaluate the specific samples. The result is a fundamentally more powerful model using the exact same architecture.</p>
<p><strong>Want to dive deeper?</strong> Check out the <a href="https://arxiv.org/abs/1509.00519">original IWAE paper</a> for experimental results and theoretical analysis, or explore my <a href="/posts/modern-variational-autoencoder-in-pytorch/">VAE tutorial</a> for hands-on implementation details.</p>
]]></content:encoded></item><item><title>Importance Weighted Autoencoders (IWAE) for Tighter Bounds</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/importance-weighted-autoencoders/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/importance-weighted-autoencoders/</guid><description>Summary of Burda, Grosse &amp; Salakhutdinov's ICLR 2016 paper introducing Importance Weighted Autoencoders for tighter variational bounds</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper that introduces the <strong>Importance Weighted Autoencoder (IWAE)</strong>, a generative model that shares the same architecture as the Variational Autoencoder (VAE) but uses a different, tighter objective function. The key innovation is using importance weighting to derive a strictly tighter log-likelihood lower bound than the standard VAE objective.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The standard VAE has several limitations that motivated this work:</p>
<ul>
<li><strong>Strong assumptions</strong>: VAEs typically assume the posterior distribution is simple (e.g., approximately factorial) and that its parameters can be easily approximated from observations.</li>
<li><strong>Simplified representations</strong>: The VAE objective can force models to learn overly simplified representations that underutilize the network&rsquo;s full modeling capacity.</li>
<li><strong>Harsh penalization</strong>: The VAE objective harshly penalizes approximate posterior samples that are poor explanations for the data, which can be overly restrictive.</li>
<li><strong>Inactive units</strong>: VAEs tend to learn latent spaces with effective dimensions far below their capacity, where many latent units are ignored (a phenomenon later termed <strong>posterior collapse</strong>, where the approximate posterior collapses to the prior and conveys no information). The authors wanted to investigate whether a new objective could address this issue.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>IWAE objective function</strong>, denoted as $\mathcal{L}_{k}$.</p>
<ul>
<li>
<p><strong>VAE ($\mathcal{L}_{1}$ Bound)</strong>: The standard VAE maximizes $\mathcal{L}(x)=\mathbb{E}_{q(h|x)}[\log\frac{p(x,h)}{q(h|x)}]$. This is equivalent to the new bound when $k=1$.</p>
</li>
<li>
<p><strong>IWAE ($\mathcal{L}_{k}$ Bound)</strong>: The IWAE maximizes a tighter bound that uses $k$ samples drawn from the recognition model $q(h|x)$:</p>
</li>
</ul>
<p>$$\mathcal{L}_{k}(x)=\mathbb{E}_{h_{1},&hellip;,h_{k}\sim q(h|x)}\left[\log\frac{1}{k}\sum_{i=1}^{k}\frac{p(x,h_{i})}{q(h_{i}|x)}\right]$$</p>
<ul>
<li>
<p><strong>Tighter Bound</strong>: The authors prove that this bound is always tighter than or equal to the VAE bound ($\mathcal{L}_{k+1} \geq \mathcal{L}_{k}$) and that as $k$ approaches infinity, $\mathcal{L}_{k}$ approaches the true log-likelihood $\log p(x)$.</p>
</li>
<li>
<p><strong>Increased Flexibility</strong>: Using multiple samples gives the IWAE additional flexibility to learn generative models whose posterior distributions are complex and violate the VAE&rsquo;s simplifying assumptions.</p>
</li>
</ul>
<h3 id="key-concept-averaging-inside-vs-outside-the-log">Key Concept: Averaging Inside vs. Outside the Log</h3>
<p>A crucial distinction exists between how VAE and IWAE utilize $k$ samples. Understanding this difference explains why increasing $k$ in IWAE improves the bound. In VAE, it reduces variance.</p>















<figure class="post-figure center ">
    <img src="/img/notes/variational-autoencoder-vae-vs-importance-weighted-autoencoder-iwae.webp"
         alt="Flowchart comparing VAE and IWAE computation: VAE takes the log of each weight then averages (average of logs). IWAE averages the weights first then takes the log (log of average)"
         title="Flowchart comparing VAE and IWAE computation: VAE takes the log of each weight then averages (average of logs). IWAE averages the weights first then takes the log (log of average)"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">VAE vs IWAE computation flow. The key difference is where the log operation occurs: VAE computes log(w_i) for each sample then averages. IWAE averages the weights first then applies log to the result.</figcaption>
    
</figure>

<p><strong>VAE (Average of Logs):</strong></p>
<p>For a VAE, using $k$ samples approximates:</p>
<p>$$\mathbb{E}\left[ \frac{1}{k} \sum_{i=1}^k \log w_i \right] \approx \text{ELBO}$$</p>
<p>where $w_i = p(x, h_i) / q(h_i | x)$. Increasing $k$ here only reduces the variance of the gradient estimator; the model still targets the same ELBO bound, so performance gains saturate quickly.</p>
<p><strong>IWAE (Log of Average):</strong></p>
<p>IWAE performs the averaging <em>inside</em> the logarithm:</p>
<p>$$\mathbb{E}\left[ \log \left( \frac{1}{k} \sum_{i=1}^k w_i \right) \right] = \mathcal{L}_k$$</p>
<p>By Jensen&rsquo;s Inequality ($\log(\mathbb{E}[X]) \geq \mathbb{E}[\log(X)]$ for concave functions), this bound is mathematically guaranteed to be at least as tight as the VAE bound. Each increase in $k$ defines a new, strictly tighter lower bound on the log-likelihood.</p>
<p><strong>Why This Matters for Gradients:</strong></p>
<p>In IWAE, the gradient weights are normalized importance weights $\tilde{w}_i = w_i / \sum_j w_j$. This means &ldquo;bad&rdquo; samples (those with low $w_i$) contribute very little to the gradient update since they vanish from the weighted sum. VAE uses unweighted samples, so a single sample with extremely low probability produces a massive negative log value that can dominate the loss and harshly penalize the model. IWAE&rsquo;s formulation allows the model to focus learning on the samples that explain the data well.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors compared VAE and IWAE on density estimation tasks using the MNIST and Omniglot datasets. They evaluated two main network architectures: one with a single stochastic layer and another with two stochastic layers. The models were trained with varying numbers of importance samples ($k \in {1, 5, 50}$) to observe the effect on performance and latent space utilization. The primary metrics for evaluation were the test log-likelihood (estimated using 5000 samples) and the number of &ldquo;active&rdquo; latent units, which quantifies the richness of the learned representations.</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>















<figure class="post-figure center ">
    <img src="/img/notes/iwae-vs-vae-active-latent-units-comparison.webp"
         alt="Bar chart comparing active latent units between VAE and IWAE across different k values on MNIST and Omniglot datasets"
         title="Bar chart comparing active latent units between VAE and IWAE across different k values on MNIST and Omniglot datasets"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Active latent units for VAE vs IWAE (1 stochastic layer). VAE active units remain flat. IWAE increases with k. Data from Table 1 of Burda et al. (2016).</figcaption>
    
</figure>

<ul>
<li>
<p><strong>Better Performance</strong>: IWAE achieved higher log-likelihoods than VAEs across all configurations. On MNIST with two stochastic layers and $k=50$, IWAE reached $-82.90$ nats compared to $-84.78$ for VAE. On Omniglot, the best IWAE achieved $-103.38$ nats versus $-106.30$ for VAE. IWAE performance improved consistently with increasing $k$, while VAE performance benefited only slightly from using more samples ($k&gt;1$).</p>
</li>
<li>
<p><strong>Richer Representations</strong>: In all experiments with $k&gt;1$, IWAE learned more active latent dimensions than VAE, suggesting richer latent representations.</p>
</li>
<li>
<p><strong>Objective Drives Representation</strong>: The authors found that latent dimension inactivation is driven by the objective function. They demonstrated this through an &ldquo;objective swap&rdquo; experiment:</p>
</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/notes/iwae-objective-swap-experiment.webp"
         alt="Bar charts showing the objective swap experiment results with active units and NLL changes when switching between VAE and IWAE objectives"
         title="Bar charts showing the objective swap experiment results with active units and NLL changes when switching between VAE and IWAE objectives"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Objective swap experiment on MNIST (1 stochastic layer). Switching a trained VAE to the IWAE objective improves both metrics. Switching IWAE to VAE degrades them. Data from Table 2 of Burda et al. (2016).</figcaption>
    
</figure>

<p>This experiment provides evidence that the objective function itself influences latent utilization:</p>
<ul>
<li><strong>VAE → IWAE</strong>: A converged VAE model, when fine-tuned with the IWAE objective ($k=50$), gained 3 active units (19 → 22) and improved test NLL from 86.76 to 84.88.</li>
<li><strong>IWAE → VAE</strong>: A converged IWAE model fine-tuned with the VAE objective lost 2 active units (25 → 23) and worsened test NLL from 84.78 to 86.02.</li>
</ul>
<p>These results strongly suggest that inactivation of latent dimensions is driven by the objective function rather than by optimization dynamics, initialization, or architecture. The authors note that optimization also plays a role, as the swap results do not exactly match training from scratch.</p>
<ul>
<li>
<p><strong>Comparison to Other Models</strong>: On MNIST, the best IWAE ($-82.90$ nats) outperformed deep belief networks ($-84.55$ nats) and deep autoregressive networks ($-84.13$ nats), though DRAW ($-80.97$ nats), which exploits spatial structure, achieved better results. On Omniglot, the best IWAE ($-103.38$ nats) fell slightly behind RBMs trained with persistent contrastive divergence ($-100.46$ nats).</p>
</li>
<li>
<p><strong>Conclusion</strong>: IWAEs learn richer latent representations and achieve better generative performance than VAEs with equivalent architectures and training time.</p>
</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>MNIST</strong>: $28 \times 28$ binarized handwritten digits (60,000 training / 10,000 test).</li>
<li><strong>Omniglot</strong>: $28 \times 28$ binarized handwritten characters from various alphabets (24,345 training / 8,070 test).</li>
<li><strong>Binarization</strong>: Dynamic sampling where binary values are sampled with expectations equal to the real pixel intensities (following Salakhutdinov &amp; Murray, 2008).</li>
<li><strong>Fixed Binarization</strong>: Results on a fixed binarization of MNIST (Larochelle, 2011) confirm that IWAE outperforms VAE across preprocessing methods. It exhibits notably more overfitting compared to dynamic sampling.</li>
</ul>
<h3 id="models">Models</h3>
<p>Two main network architectures were tested:</p>
<ol>
<li>One stochastic layer (50 units) with two deterministic layers (200 units each).</li>
<li>Two stochastic layers (100 and 50 units). Between x and h1 were two deterministic layers with 200 units each. Between h1 and h2 were two deterministic layers with 100 units each.</li>
</ol>
<ul>
<li><strong>Activations</strong>: <code>tanh</code> for deterministic layers; <code>exp</code> applied to variance predictions to ensure positivity.</li>
<li><strong>Distributions</strong>: Gaussian latent layers with diagonal covariance; Bernoulli observation layer.</li>
<li><strong>Initialization</strong>: Glorot &amp; Bengio (2010) heuristic.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Optimizer</strong>: Adam ($\beta_1=0.9$, $\beta_2=0.999$, $\epsilon=10^{-4}$).</li>
<li><strong>Batch Size</strong>: 20.</li>
<li><strong>Learning Rate Schedule</strong>: Annealed rate of $0.001 \cdot 10^{-i/7}$ for $3^i$ epochs (where $i=0&hellip;7$), totaling 3,280 passes over the data.</li>
<li><strong>Variance Control</strong>: A common concern with importance sampling is high variance. The authors prove that the Mean Absolute Deviation of their estimator is bounded by $2 + 2\delta$, where $\delta$ is the gap between the bound and true log-likelihood. As the bound tightens, variance remains controlled.</li>
<li><strong>Computational trick</strong>: In the basic IWAE implementation, both forward and backward passes must be done independently for each of the $k$ samples, so the cost scales linearly with $k$. However, the authors describe an optional optimization: stochastically approximate the gradient sum by sampling a single $\epsilon_i$ proportional to its normalized weight $\tilde{w}_i$, then computing only that one backward pass. This reduces the cost to $k$ forward passes and one backward pass. Since the backward pass costs roughly twice the forward pass, this yields approximately a 3x speedup for large $k$ at the cost of increased gradient variance.</li>
</ul>
<p><strong>Relationship to Reweighted Wake-Sleep (RWS):</strong> Both IWAE and Reweighted Wake-Sleep (Bornschein &amp; Bengio, 2015) use importance-weighted samples and have closely related generative model updates. The key difference is that IWAE derives a single unified lower bound $\mathcal{L}_k$ and uses the reparameterization trick to train the recognition network jointly. RWS instead uses separate wake and sleep phases for the recognition network, which are not derived from $\mathcal{L}_k$.</p>
<h3 id="evaluation">Evaluation</h3>
<ol>
<li><strong>Test Log-Likelihood</strong>: Primary measure of generative performance, estimated as the mean of $\mathcal{L}_{5000}$ (5000 samples) on the test set.</li>
<li><strong>Active Units</strong>: To quantify latent space richness, the authors measured &ldquo;active&rdquo; latent dimensions. A unit $u$ was defined as active if its activity statistic $A_{u}=\text{Cov}_{x}(\mathbb{E}_{u\sim q(u|x)}[u])$ exceeded $10^{-2}$. The $10^{-2}$ threshold is justified by a bimodal distribution of the log activity statistic, showing clear separation between active and inactive units.</li>
</ol>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Hardware</strong>: GPU-based implementation using mini-batch replication to parallelize the $k$ samples. Specific GPU type and training times are not reported.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/yburda/iwae">yburda/iwae</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Official Theano implementation for MNIST and Omniglot</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Burda, Y., Grosse, R., &amp; Salakhutdinov, R. (2016). Importance Weighted Autoencoders. <em>International Conference on Learning Representations (ICLR) 2016</em>. <a href="https://arxiv.org/abs/1509.00519">https://arxiv.org/abs/1509.00519</a></p>
<p><strong>Publication</strong>: ICLR 2016</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{burda2016importance,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Importance Weighted Autoencoders}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Yuri Burda and Roger Grosse and Ruslan Salakhutdinov}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2016}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/1509.00519}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/1509.00519">ArXiv</a></li>
</ul>
]]></content:encoded></item><item><title>Auto-Encoding Variational Bayes: VAE Paper Summary</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/autoencoding-variational-bayes/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/autoencoding-variational-bayes/</guid><description>Summary of Kingma &amp; Welling's 2013 VAE paper introducing the reparameterization trick and variational autoencoders.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper that introduces a generative mechanism (the VAE) and an optimization technique (the reparameterization trick), with formal theoretical derivation. The method, called the Auto-Encoding VB (AEVB) algorithm, leads to what we now know as the <strong>variational auto-encoder (VAE)</strong> when neural networks are used as the recognition model.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The authors address two central intractabilities in directed probabilistic models with continuous latent variables:</p>















<figure class="post-figure center ">
    <img src="/img/notes/autoencoding-variational-bayes-figure-1-model-diagram.webp"
         alt="VAE graphical model showing latent variable z, observed variable x, and parameters phi and theta"
         title="VAE graphical model showing latent variable z, observed variable x, and parameters phi and theta"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Figure 1 from the paper: The directed graphical model. Solid lines denote the generative model $p_\theta(z)p_\theta(x|z)$, dashed lines denote the variational approximation $q_\phi(z|x)$. The variational parameters $\phi$ are learned jointly with the generative parameters $\theta$.</figcaption>
    
</figure>

<ol>
<li>
<p><strong>Intractable Posteriors</strong>: In models with continuous latent variables (like those with non-linear hidden layers), the true posterior $p_{\theta}(z|x)$ cannot be calculated analytically, preventing the use of standard EM algorithms.</p>
</li>
<li>
<p><strong>Large Datasets</strong>: Sampling-based solutions like Monte Carlo EM (MCEM) require expensive sampling loops per datapoint. This makes them too slow for large datasets where batch optimization is too costly and efficient minibatch updates are required.</p>
</li>
</ol>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<h3 id="the-reparameterization-trick-sgvb-estimator">The Reparameterization Trick (SGVB Estimator)</h3>
<p>The core innovation is the <strong>Stochastic Gradient Variational Bayes (SGVB)</strong> estimator. The authors solve the high variance of standard gradient estimation by &ldquo;reparameterizing&rdquo; the random variable $\tilde{z}$.</p>
<p>They express $z$ as a deterministic function of the input $x$ and an auxiliary noise variable $\epsilon$:</p>
<p>$$\tilde{z} = g_{\phi}(\epsilon, x) \quad \text{with} \quad \epsilon \sim p(\epsilon)$$</p>















<figure class="post-figure center ">
    <img src="/img/notes/variational-autoencoder-reparameterization-trick.webp"
         alt="Comparison of standard stochastic node vs reparameterization trick showing gradient flow"
         title="Comparison of standard stochastic node vs reparameterization trick showing gradient flow"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The reparameterization trick. (A) Standard stochastic nodes block gradient flow during backpropagation. (B) By expressing $z = \mu + \sigma \odot \epsilon$ with external noise $\epsilon \sim \mathcal{N}(0,1)$, gradients can flow through the deterministic path to the parameters $\phi$.</figcaption>
    
</figure>

<ul>
<li><strong>Mechanism</strong>: For a Gaussian posterior, $z = \mu + \sigma \odot \epsilon$ where $\epsilon \sim \mathcal{N}(0, I)$.</li>
<li><strong>Impact</strong>: This makes the Monte Carlo estimate differentiable with respect to the variational parameters $\phi$, allowing the variational lower bound to be optimized via standard stochastic gradient ascent (like SGD or Adagrad).</li>
</ul>
<h3 id="the-aevb-algorithm-the-vae">The AEVB Algorithm (The VAE)</h3>
<p>The <strong>Auto-Encoding VB (AEVB)</strong> algorithm amortizes inference by learning a global recognition model (encoder) $q_{\phi}(z|x)$ jointly with the generative model (decoder) $p_{\theta}(x|z)$.</p>
<p><strong>Objective Function</strong>: Maximize the variational lower bound $\mathcal{L}(\theta, \phi; x^{(i)})$:</p>
<p>$$\mathcal{L} \simeq -D_{KL}(q_\phi(z|x^{(i)}) | p_\theta(z)) + \frac{1}{L} \sum_{l=1}^L \log p_\theta(x^{(i)}|z^{(i,l)})$$</p>
<ul>
<li><strong>First Term (Regularizer)</strong>: Forces the approximate posterior to match the prior (integrated analytically for Gaussians).</li>
<li><strong>Second Term (Reconstruction Error)</strong>: The expected negative reconstruction error (estimated via sampling).</li>
</ul>
<p>This mirrors the standard auto-encoder objective, adding a variational regularizer.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The method was benchmarked against the <strong>Wake-Sleep</strong> algorithm and <strong>Monte Carlo EM (MCEM)</strong> using the <strong>MNIST</strong> (digits) and <strong>Frey Face</strong> (continuous faces) datasets.</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li>
<p><strong>Efficiency</strong>: AEVB converged faster and reached a better lower bound than Wake-Sleep (Figure 2). It scaled efficiently to the full MNIST dataset. MCEM&rsquo;s per-datapoint sampling cost made it impractical at full dataset scale, so comparisons were limited to small subsets (Figure 3).</p>
</li>
<li>
<p><strong>Regularization</strong>: The KL-divergence term provided a regularizing effect, preventing overfitting while increasing latent dimensions ($N_z$).</p>
</li>
<li>
<p><strong>Manifold Learning</strong>: The model successfully learned smooth 2D latent manifolds (visualized in Appendix A), grouping similar digits/faces together.</p>
</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Evaluation Data</strong>: For the marginal likelihood comparison (Figure 3), the paper used MNIST with $N_{\text{train}} = 100$ and $N_{\text{train}} = 5000$ to compare data efficiency (marginal log-likelihood vs. training samples seen) across algorithms. A smaller network (100 hidden units, 3 latent variables) was used for this comparison because the marginal likelihood estimator only works reliably in low-dimensional latent spaces.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Algorithm</strong>: Stochastic gradient ascent with <strong>Adagrad</strong> (global stepsizes chosen from ${0.01, 0.02, 0.1}$).</li>
<li><strong>Regularization</strong>: The objective included a weight decay term corresponding to a prior $p(\theta)=\mathcal{N}(0,I)$.</li>
<li><strong>Minibatches</strong>: Size $M=100$ with $L=1$ sample per datapoint.</li>
<li><strong>Initialization</strong>: Parameters sampled from $\mathcal{N}(0, 0.01)$.</li>
</ul>
<h3 id="models">Models</h3>
<p>The original VAE used simple Multi-Layered Perceptrons (MLPs):</p>
<ul>
<li><strong>Symmetry</strong>: The encoder and decoder were symmetric, having an equal number of hidden units.</li>
<li><strong>Hidden Units</strong>: 500 units for MNIST, 200 for Frey Face (to prevent overfitting on the smaller dataset).</li>
<li><strong>Activations</strong>: <strong>Tanh</strong> activation functions for the hidden layers.</li>
<li><strong>Latent Space</strong>: Experimented with $N_z$ ranging from 2 to 200.</li>
<li><strong>Outputs</strong>:
<ul>
<li><em>MNIST</em>: <strong>Bernoulli</strong> MLP (sigmoid output).</li>
<li><em>Frey Face</em>: <strong>Gaussian</strong> MLP, with means constrained to $(0,1)$ via sigmoid.</li>
</ul>
</li>
<li><strong>Encoder Architecture</strong>: For the Gaussian encoder, the mean $\mu$ and log-variance $\log(\sigma^2)$ are linear outputs from the shared hidden layer (they share the hidden layer weights and have separate output weights).</li>
<li><strong>Log-Variance</strong>: The encoder predicted $\log(\sigma^2)$ for numerical stability.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>The paper distinguishes between two metrics:</p>
<ul>
<li><strong>Variational Lower Bound</strong>: Used as the training objective (what the model optimizes).</li>
<li><strong>Marginal Likelihood</strong>: Used for final evaluation (Figure 3). The true marginal likelihood $p_\theta(x)$ was estimated using an Importance Sampling estimator constructed from samples drawn via Hybrid Monte Carlo (HMC), as detailed in Appendix D. This estimator uses: $p_{\theta}(x^{(i)}) \simeq (\frac{1}{L}\sum \frac{q(z)}{p(z)p(x|z)})^{-1}$.</li>
</ul>
<p>This distinction is critical: the training metric (lower bound) differs from the evaluation metric (estimated marginal likelihood).</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Hardware</strong>: Trained on a standard Intel Xeon CPU (approx. 40 GFLOPS); no GPUs were used.</li>
<li><strong>Training Time</strong>: Approximately 20-40 minutes per million training samples.</li>
</ul>
<h3 id="key-implementation-details-from-appendices">Key Implementation Details from Appendices</h3>
<ul>
<li><strong>Appendix A</strong>: Visualizations of 2D latent manifolds learned for MNIST and Frey Face datasets.</li>
<li><strong>Appendix B</strong>: Closed-form solution for the KL divergence of two Gaussians, essential for implementing the efficient version of the estimator (Equation 10).</li>
<li><strong>Appendix C</strong>: Exact MLP equations, including the use of tanh hidden layers and specific output layers for Bernoulli vs. Gaussian data. Includes specifications for <strong>Bernoulli MLPs</strong> (binary data) and <strong>Gaussian MLPs</strong> (real-valued data).</li>
<li><strong>Appendix D</strong>: Marginal likelihood estimation protocol using Hybrid Monte Carlo (HMC) and importance sampling for evaluation (Figure 3).</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Diederik P. Kingma and Max Welling. &ldquo;Auto-Encoding Variational Bayes.&rdquo; arXiv:1312.6114 [stat.ML], 2013. <a href="https://doi.org/10.48550/arXiv.1312.6114">https://doi.org/10.48550/arXiv.1312.6114</a></p>
<p><strong>Publication</strong>: ICLR 2014 (arXiv preprint December 2013)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{kingma2022autoencodingvariationalbayes,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Auto-Encoding Variational Bayes}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Diederik P Kingma and Max Welling}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2013}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{1312.6114}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">archivePrefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">primaryClass</span>=<span style="color:#e6db74">{stat.ML}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/1312.6114}</span>,
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://en.wikipedia.org/wiki/Variational_autoencoder">Wikipedia: Variational Autoencoder</a> - General overview</li>
<li><a href="https://openreview.net/forum?id=33X9fd2-9FyZd">OpenReview</a> - Original peer review with author responses</li>
<li><a href="/posts/modern-variational-autoencoder-in-pytorch/">Modern VAE in PyTorch</a> - Implementation tutorial on this site</li>
</ul>
]]></content:encoded></item><item><title>SELFIES: The Original Paper on Robust Molecular Strings</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/selfies-original-paper/</link><pubDate>Sun, 12 Oct 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/selfies-original-paper/</guid><description>The 2020 paper introducing SELFIES, the 100% robust molecular representation that solves SMILES validity problems in ML applications.</description><content:encoded><![CDATA[<h2 id="contribution-a-100-robust-representation-for-ml">Contribution: A 100% Robust Representation for ML</h2>
<p>This is a <strong>Method</strong> paper that introduces a new molecular string representation designed specifically for machine learning applications.</p>
<h2 id="motivation-the-invalidity-bottleneck">Motivation: The Invalidity Bottleneck</h2>
<p>When neural networks generate molecules using <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES notation</a>, a huge fraction of output strings are invalid: either syntax errors or chemically impossible structures. This was a fundamental bottleneck: if your generative model produces a large fraction of invalid molecules, you are wasting computational effort and severely limiting chemical space exploration.</p>
<h2 id="novelty-a-formal-grammar-approach">Novelty: A Formal Grammar Approach</h2>
<p>The authors&rsquo; key insight was using a <strong>formal grammar approach</strong> (specifically, a Chomsky type-2, context-free grammar with self-referencing functions) where each symbol is interpreted based on chemical context. The &ldquo;state of the derivation&rdquo; tracks available valence bonds, preventing impossible structures like a carbon with five single bonds.</p>
<p>For example, generating 2-Fluoroethenimine (<code>FC=C=N</code>) follows a state derivation where each step restricts the available valency for the next element:</p>
<p>$$
\mathbf{X}_0 \xrightarrow{[F]} \text{F } \mathbf{X}_1 \xrightarrow{[=C]} \text{FC } \mathbf{X}_3 \xrightarrow{[=C]} \text{FC=C } \mathbf{X}_2 \xrightarrow{[\#N]} \text{FC=C=N}
$$</p>
<p>This approach guarantees 100% validity: every SELFIES string corresponds to a valid molecule, and every valid molecule can be represented.</p>
<h2 id="methodology--experiments-validating-robustness">Methodology &amp; Experiments: Validating Robustness</h2>
<p>The authors ran several experiments to demonstrate SELFIES&rsquo; robustness:</p>
<h3 id="random-mutation-test">Random Mutation Test</h3>
<p>They took the SELFIES and <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> representations of MDMA and introduced random changes:</p>
<ul>
<li><strong>SMILES</strong>: After just one random mutation, only 9.9% of strings remained valid (dropping to 1.1% after three mutations).</li>
<li><strong>SELFIES</strong>: 100% of mutated strings still represented valid molecules (though different from the original).</li>
</ul>
<p>This empirical difference demonstrates why SELFIES is well suited for evolutionary algorithms and genetic programming approaches to molecular design, where random mutations of strings are a core operation.</p>
<h3 id="generative-model-performance">Generative Model Performance</h3>
<p>The real test came with actual machine learning models. The authors trained Variational Autoencoders (VAEs) and Generative Adversarial Networks (GANs) on both representations:</p>
<p><strong>VAE Results:</strong></p>
<ul>
<li>SMILES-based VAE: Large invalid regions scattered throughout the latent space</li>
<li>SELFIES-based VAE: Every point in the continuous latent space mapped to a valid molecule</li>
<li>The SELFIES model encoded <strong>over 100 times more diverse molecules</strong></li>
</ul>
<p><strong>GAN Results:</strong></p>
<ul>
<li>Best SMILES GAN: 18.6% diverse, valid molecules</li>
<li>Best SELFIES GAN: 78.9% diverse, valid molecules</li>
</ul>
<p><strong>Evaluation Metrics:</strong></p>
<ul>
<li><strong>Validity</strong>: Percentage of generated strings representing valid molecular structures</li>
<li><strong>Diversity</strong>: Number of unique valid molecules produced</li>
<li><strong>Reconstruction Accuracy</strong>: How well the autoencoder reproduced input molecules</li>
</ul>
<h3 id="scalability-test">Scalability Test</h3>
<p>The authors showed SELFIES works beyond toy molecules by successfully encoding and decoding all <strong>72 million molecules</strong> from the PubChem database (with fewer than 500 SMILES characters per molecule), demonstrating practical applicability to real chemical databases.</p>
<h2 id="results--conclusions-chemical-space-exploration">Results &amp; Conclusions: Chemical Space Exploration</h2>
<p><strong>Key Findings:</strong></p>
<ul>
<li>SELFIES achieves 100% validity guarantee: every string represents a valid molecule</li>
<li>SELFIES-based VAEs encode over 100x more diverse molecules than SMILES-based models</li>
<li>SELFIES-based GANs produce 78.9% diverse valid molecules vs. 18.6% for SMILES GANs</li>
<li>Successfully validated on all 72 million PubChem molecules</li>
</ul>
<p><strong>Limitations Acknowledged:</strong></p>
<ul>
<li>No standardization or canonicalization method at time of publication</li>
<li>The initial grammar covered only small biomolecules; extensions for stereochemistry, ions, polyvalency, and full periodic table coverage were planned</li>
<li>Requires community testing and adoption</li>
</ul>
<p><strong>Impact:</strong></p>
<p>This work demonstrated that designing ML-native molecular representations could enable new approaches in drug discovery and materials science. SELFIES was subsequently evaluated as an alternative input representation to SMILES in <a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a>, a transformer pretrained on molecular strings for property prediction, where it performed comparably to SMILES on the Tox21 benchmark, though the comparison was limited to a single task.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The machine learning experiments used two distinct datasets:</p>
<ul>
<li><strong><a href="/notes/chemistry/datasets/qm9/">QM9</a></strong> (134k molecules): Primary training dataset for VAE and GAN models</li>
<li><strong>PubChem</strong> (72M molecules): Used only to test representation coverage and scalability; not used for model training</li>
</ul>
<h3 id="models">Models</h3>
<p>The VAE implementation included:</p>
<ul>
<li><strong>Latent space</strong>: 241-dimensional with Gaussian distributions</li>
<li><strong>Input encoding</strong>: One-hot encoding of SELFIES/SMILES strings</li>
<li>Full architectural details (encoder/decoder structures, layer types) provided in Supplementary Information</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p>The authors found GAN performance was highly sensitive to hyperparameter selection:</p>
<ul>
<li>Searched <strong>200 different hyperparameter configurations</strong> to achieve the reported 78.9% diversity</li>
<li>Specific optimizers, learning rates, and training duration detailed in Supplementary Information</li>
<li>Full rule generation algorithm provided in Table 2</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>All models evaluated on:</p>
<ul>
<li><strong>Validity rate</strong>: Percentage of syntactically and chemically valid outputs</li>
<li><strong>Diversity</strong>: Count of unique valid molecules generated</li>
<li><strong>Reconstruction accuracy</strong>: Fidelity of autoencoder reconstruction (VAEs only)</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Training performed on the SciNet supercomputing infrastructure.</li>
<li>The paper does not specify GPU types or training times.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/aspuru-guzik-group/selfies">SELFIES GitHub Repository</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official implementation; has evolved significantly since the original paper</td>
      </tr>
  </tbody>
</table>
<h3 id="replication-resources">Replication Resources</h3>
<p>Complete technical replication is highly accessible due to the paper being published open-access in <em>Machine Learning: Science and Technology</em>. It primarily requires:</p>
<ul>
<li>The full rule generation algorithm (Table 2 in paper)</li>
<li>Code: <a href="https://github.com/aspuru-guzik-group/selfies">https://github.com/aspuru-guzik-group/selfies</a></li>
<li>Supplementary Information for complete architectural and hyperparameter specifications</li>
</ul>
<p><strong>Note</strong>: The <a href="/notes/chemistry/molecular-representations/notations/selfies/">modern SELFIES library</a> has evolved significantly since this foundational paper, addressing many of the implementation challenges identified by the authors.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Krenn, M., Häse, F., Nigam, A., Friederich, P., &amp; Aspuru-Guzik, A. (2020). Self-referencing embedded strings (SELFIES): A 100% robust molecular string representation. <em>Machine Learning: Science and Technology</em>, <em>1</em>(4), 045024. <a href="https://doi.org/10.1088/2632-2153/aba947">https://doi.org/10.1088/2632-2153/aba947</a></p>
<p><strong>Publication</strong>: Machine Learning: Science and Technology, 2020</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{Krenn_2020,
</span></span><span style="display:flex;"><span>	<span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1088/2632-2153/aba947}</span>,
</span></span><span style="display:flex;"><span>	<span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://doi.org/10.1088%2F2632-2153%2Faba947}</span>,
</span></span><span style="display:flex;"><span>	<span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2020</span>,
</span></span><span style="display:flex;"><span>	<span style="color:#a6e22e">month</span> = <span style="color:#e6db74">{aug}</span>,
</span></span><span style="display:flex;"><span>	<span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{{IOP} Publishing}</span>,
</span></span><span style="display:flex;"><span>	<span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>	<span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>	<span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{045024}</span>,
</span></span><span style="display:flex;"><span>	<span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Mario Krenn and Florian H{\&#34;{a}}se and AkshatKumar Nigam and Pascal Friederich and Alan Aspuru-Guzik}</span>,
</span></span><span style="display:flex;"><span>	<span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Self-referencing embedded strings ({SELFIES}): A 100{\%} robust molecular string representation}</span>,
</span></span><span style="display:flex;"><span>	<span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Machine Learning: Science and Technology}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/aspuru-guzik-group/selfies">GitHub Repository</a></li>
<li><a href="/notes/chemistry/molecular-representations/notations/selfies/">Modern SELFIES Documentation</a></li>
</ul>
]]></content:encoded></item><item><title>Recent Advances in the SELFIES Library: 2023 Update</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/selfies-2023/</link><pubDate>Sun, 12 Oct 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/selfies-2023/</guid><description>Major updates to the SELFIES library, improved performance, expanded chemistry support, and new customization features.</description><content:encoded><![CDATA[<h2 id="overview">Overview</h2>
<p>This software update paper documents major improvements to the SELFIES Python library (version 2.1.1), covering its history, underlying algorithms, design, and performance.</p>
<h2 id="limitations-in-the-original-selfies-implementation">Limitations in the Original SELFIES Implementation</h2>
<p>While the <a href="/notes/chemistry/molecular-representations/notations/selfies-original-paper/">original SELFIES concept</a> was promising, the initial 2019 implementation had critical limitations that prevented widespread adoption:</p>
<ol>
<li><strong>Performance</strong>: Too slow for production ML workflows</li>
<li><strong>Limited chemistry</strong>: Couldn&rsquo;t represent aromatic molecules, stereochemistry, or many other important chemical features</li>
<li><strong>Poor usability</strong>: Lacked user-friendly APIs for common tasks</li>
</ol>
<p>These barriers meant that despite SELFIES&rsquo; theoretical advantages (100% validity guarantee), researchers couldn&rsquo;t practically use it for real-world applications like drug discovery or materials science.</p>
<h2 id="architectural-refactoring-and-new-ml-integrations">Architectural Refactoring and New ML Integrations</h2>
<p>The 2023 update refactors the underlying SELFIES engine with improvements to design, efficiency, and supported features. The key updates include:</p>
<ol>
<li>
<p><strong>Streamlined Grammar</strong>: The underlying context-free grammar has been generalized and streamlined, improving execution speed and extensibility while maintaining the 100% validity guarantee.</p>
</li>
<li>
<p><strong>Expanded Chemical Support</strong>: Adds support for aromatic systems (via internal kekulization), stereochemistry (chirality, cis/trans), charged species, and isotopic data, covering nearly all features supported by SMILES while preserving the validity guarantee.</p>
</li>
<li>
<p><strong>Semantic Constraint API</strong>: Introduces the <code>set_semantic_constraints()</code> function, allowing specification of custom valence definitions useful for theoretical studies or hypervalent states.</p>
</li>
<li>
<p><strong>ML Utility Functions</strong>: Provides tokenization (<code>split_selfies</code>), length estimation (<code>len_selfies</code>), label/one-hot encoding (<code>selfies_to_encoding</code>), vocabulary extraction, and attribution tracking for integration with neural network pipelines.</p>
</li>
</ol>
<h2 id="performance-benchmarks--validity-testing">Performance Benchmarks &amp; Validity Testing</h2>
<p>The authors validated the library through several benchmarks:</p>
<p><strong>Performance testing</strong>: Roundtrip conversion (SMILES to SELFIES to SMILES) on the DTP open compound collection (slightly over 300K molecules) completed in 252 seconds total (136s encoding, 116s decoding), using pure Python with no external dependencies.</p>
<p><strong>Random SELFIES generation</strong>: Demonstrated that random SELFIES strings of varying lengths always decode to valid molecules, with the size distribution of generated molecules controllable by filtering the sampling alphabet (e.g., removing multi-bond and low-valence atom symbols shifts the distribution toward larger molecules).</p>
<p><strong>Validity guarantee</strong>: By construction, every SELFIES string decodes to a valid molecule. The grammar&rsquo;s bond demotion and deferred ring closure mechanisms make it impossible to generate chemically invalid structures.</p>
<p><strong>Attribution system</strong>: Showed both encoder and decoder can track which input symbols produce which output symbols, useful for property alignment.</p>
<h2 id="future-trajectories-for-general-chemical-representations">Future Trajectories for General Chemical Representations</h2>
<p>The 2023 update successfully addresses the main adoption barriers:</p>
<ol>
<li><strong>Fast enough</strong> for large-scale ML applications (300K molecules in ~4 minutes)</li>
<li><strong>Chemically comprehensive</strong> enough for drug discovery and materials science</li>
<li><strong>User-friendly</strong> enough for straightforward integration into existing workflows</li>
</ol>
<p>The validity guarantee, SELFIES&rsquo; core advantage, is now practically accessible for real-world research. The roadmap includes future extensions for polymers, crystals, chemical reactions, and non-covalent interactions, which would expand SELFIES&rsquo; applicability beyond small-molecule chemistry.</p>
<p><strong>Limitations acknowledged</strong>: The paper focuses on implementation improvements. Some advanced chemical systems (polymers, crystals) still need future work.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/aspuru-guzik-group/selfies">selfies</a></td>
          <td>Code</td>
          <td>Apache 2.0</td>
          <td>Official Python library, installable via <code>pip install selfies</code></td>
      </tr>
  </tbody>
</table>
<h3 id="code">Code</h3>
<p>The <code>selfies</code> library is completely open-source and written in pure Python. It requires no extra dependencies and is available on GitHub, installable via <code>pip install selfies</code>. The repository includes testing suites (<code>tox</code>) and example benchmarking scripts to reproduce the translation speeds reported in the paper.</p>
<h3 id="hardware">Hardware</h3>
<p>Performance benchmarks (e.g., the 252-second roundtrip conversion on 300K molecules) were executed on Google Colaboratory using two 2.20GHz Intel Xeon CPUs.</p>
<h3 id="algorithms">Algorithms</h3>
<h4 id="technical-specification-the-grammar">Technical Specification: The Grammar</h4>
<p>The core innovation of SELFIES is a <strong>Context-Free Grammar (CFG) augmented with state-machine logic</strong> to ensure that every derived string represents a valid molecule. While the software features are important, understanding the underlying derivation rules is essential for replication or extension of the system.</p>
<p><strong>1. Derivation Rules: The Atom State Machine</strong></p>
<p>The fundamental mechanism that guarantees validity is a <strong>state machine</strong> that tracks the remaining valence of the most recently added atom:</p>
<ul>
<li><strong>State Tracking</strong>: The derivation maintains a non-terminal state $X_l$, where $l$ represents the current atom&rsquo;s remaining valence (number of bonds it can still form)</li>
<li><strong>Standard Derivation</strong>: An atom symbol $[\beta \alpha]$ (bond order + atom type) transitions the state from $S$ (start) to $X_l$, where $l$ is calculated from the atom&rsquo;s standard valence minus the incoming bond order</li>
<li><strong>Bond Demotion (The Key Rule)</strong>: When deriving atom symbol $[\beta \alpha]$ in state $X_i$, the actual bond order used is $d_0 = \min(\ell, i, d(\beta))$, where $\ell$ is the new atom&rsquo;s valence, $i$ is the previous atom&rsquo;s remaining capacity, and $d(\beta)$ is the requested bond order. This automatic downward adjustment is the mathematical core of the validity guarantee.</li>
</ul>
<p>This state machine ensures that no atom ever exceeds its allowed valence, making it impossible to generate chemically invalid structures.</p>
<p><strong>2. Control Symbols: Branches and Rings</strong></p>
<p>Branch length calculation: SELFIES uses a <strong>hexadecimal encoding</strong> to determine branch lengths. A branch symbol <code>[Branch l]</code> consumes the next $\ell$ symbols from the queue and converts them to integer indices $c_1, \dots, c_\ell$ via a fixed mapping (Table III in the paper). The number of symbols $N$ to include in the branch is then:</p>
<p>$$
N = 1 + \sum_{k=1}^{\ell} 16^{\ell - k} , c_k
$$</p>
<p>This formula interprets the indices as hexadecimal digits, allowing compact specification of branches up to hundreds of symbols long.</p>
<p>Ring closure queue system: Ring formation uses a <strong>deferred evaluation</strong> strategy to maintain validity. Ring symbols don&rsquo;t create bonds immediately; instead, they push closure candidates into a queue $R$. These candidates are resolved after the main derivation completes. A ring closure candidate is <strong>rejected</strong> if either ring atom has no remaining valence ($m_1 = 0$ or $m_2 = 0$), or if the left and right ring atoms are not distinct (to avoid self-loops). If a prior bond already exists between the two atoms, the bond order is incremented rather than duplicated. This deferred validation prevents invalid ring structures while keeping the grammar context-free during the main derivation.</p>
<p><strong>3. Symbol Structure and Standardization</strong></p>
<p>SELFIES enforces a strict, standardized format for atom symbols to eliminate ambiguity:</p>
<ul>
<li><strong>Canonical Format</strong>: Atom symbols follow the structure <code>[Bond, Isotope, Element, Chirality, H-count, Charge]</code></li>
<li><strong>No Variation</strong>: There is only one way to write each symbol (e.g., <code>[Fe++]</code> and <code>[Fe+2]</code> are standardized to a single form)</li>
<li><strong>Order Matters</strong>: The components must appear in the specified order</li>
</ul>
<p><strong>4. Default Semantic Constraints</strong></p>
<p>By default, the library enforces standard organic chemistry valence rules:</p>
<ul>
<li><strong>Charge-Dependent Valences</strong>: Default constraints specify maximum bonds per charge state (e.g., C: 4/5/3 for neutral/+1/-1; S: 6/7/5). Unlisted atom types default to 8 maximum bonds as a catch-all.</li>
<li><strong>Preset Options</strong>: Three preset constraint sets are available: <code>default</code>, <code>octet_rule</code>, and <code>hypervalent</code>.</li>
<li><strong>Customizable</strong>: Constraints can be modified via <code>set_semantic_constraints()</code> for specialized applications (hypervalent compounds, theoretical studies, etc.)</li>
</ul>
<p>The combination of these grammar rules with the state machine ensures that <strong>every valid SELFIES string decodes to a chemically valid molecule</strong>, regardless of how the string was generated (random, ML model output, manual construction, etc.).</p>
<h3 id="data">Data</h3>
<p><strong>Benchmark dataset</strong>: DTP (Developmental Therapeutics Program) open compound collection with slightly over 300K SMILES strings, a set of molecules tested experimentally for potential treatment against cancer and AIDS.</p>
<p><strong>Random generation testing</strong>: Random SELFIES strings of varying lengths (10, 100, 250 symbols) generated from both basic and filtered alphabets to test decoding validity and molecule size distributions.</p>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Performance metric</strong>: Roundtrip conversion time (SMILES to SELFIES to SMILES) is 252 seconds for 300K+ molecules (136s encoding, 116s decoding). Times averaged over 3 replicate trials on Google Colaboratory.</p>
<p><strong>Validity testing</strong>: Random SELFIES strings of lengths 10, 100, and 250 all decode to valid molecules. Decoding 1000 random strings of length 250 from the basic alphabet takes 0.341s; from the filtered alphabet, 1.633s.</p>
<p><strong>Attribution system</strong>: Both <code>encoder()</code> and <code>decoder()</code> support an <code>attribute</code> flag that returns <code>AttributionMap</code> objects, tracing which input symbols produce which output symbols for property alignment.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Lo, A., Pollice, R., Nigam, A., White, A. D., Krenn, M., &amp; Aspuru-Guzik, A. (2023). Recent advances in the self-referencing embedded strings (SELFIES) library. <em>Digital Discovery</em>, <em>2</em>(4), 897-908. <a href="https://doi.org/10.1039/D3DD00044C">https://doi.org/10.1039/D3DD00044C</a></p>
<p><strong>Publication</strong>: Digital Discovery 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{lo2023recent,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Recent advances in the self-referencing embedded strings (SELFIES) library}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Lo, Alston and Pollice, Robert and Nigam, AkshatKumar and White, Andrew D and Krenn, Mario and Aspuru-Guzik, Al{\&#39;a}n}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Digital Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{897--908}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/D3DD00044C}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/aspuru-guzik-group/selfies">SELFIES GitHub Repository</a></li>
<li><a href="/notes/chemistry/molecular-representations/notations/selfies-original-paper/">Original SELFIES Paper (2020)</a></li>
<li><a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES Format Overview</a></li>
</ul>
]]></content:encoded></item><item><title>αExtractor: Chemical Info from Biomedical Literature</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/alpha-extractor/</link><pubDate>Sat, 11 Oct 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/alpha-extractor/</guid><description>αExtractor uses ResNet-Transformer to extract chemical structures from literature images, including noisy and hand-drawn molecules.</description><content:encoded><![CDATA[<h2 id="methodological-contribution-a-robust-optical-recognition-system">Methodological Contribution: A Robust Optical Recognition System</h2>
<p>This is primarily a <strong>Method</strong> ($\Psi_{\text{Method}}$) paper with a significant secondary <strong>Resource</strong> ($\Psi_{\text{Resource}}$) contribution (see the <a href="/notes/interdisciplinary/research-methods/ai-physical-sciences-paper-taxonomy/">AI and Physical Sciences paper taxonomy</a> for more on these categories).</p>
<p>The dominant methodological contribution is the ResNet-Transformer recognition architecture that outperforms existing OCSR tools across multiple benchmarks through robustness engineering. It specifically focuses on training on 20 million synthetic images with aggressive augmentation to handle degraded image conditions. The work answers the core methodological question &ldquo;How well does this work?&rdquo; through extensive benchmarking against existing OCSR tools and ablation studies validating architectural choices.</p>
<p>The secondary resource contribution comes from releasing αExtractor as a freely available web service, correcting labeling errors in standard benchmarks (CLEF, UOB, JPO), and providing an end-to-end document processing pipeline for biomedical literature mining.</p>
<h2 id="motivation-extracting-visual-chemical-knowledge-from-biomedical-literature">Motivation: Extracting Visual Chemical Knowledge from Biomedical Literature</h2>
<p>The motivation addresses a familiar pain point in chemical informatics within a biomedical context. Vast amounts of chemical knowledge in biomedical literature exist only as images, such as molecular structures embedded in figures, chemical synthesis schemes, and compound diagrams. This visual knowledge remains effectively invisible to computational methods, which creates a massive bottleneck for drug discovery research, systematic reviews, and large-scale chemical database construction.</p>
<p>Existing OCSR tools face two critical problems when applied to biomedical literature:</p>
<ol>
<li>
<p><strong>Real-world image quality</strong>: Biomedical papers often contain low-resolution figures, images with complex backgrounds, noise from scanning/digitization, and inconsistent drawing styles across different journals and decades of publications.</p>
</li>
<li>
<p><strong>End-to-end extraction</strong>: Most OCSR systems assume the presence of clean, cropped molecular images. In practice, you need to first find the molecular structures within multi-panel figures, reaction schemes, and dense document layouts before you can recognize them.</p>
</li>
</ol>
<p>The authors argue that a practical literature mining system needs to solve both problems simultaneously via robust recognition under noisy conditions and automated detection of molecular images within complex documents.</p>
<h2 id="core-innovation-robust-resnet-transformer-architecture">Core Innovation: Robust ResNet-Transformer Architecture</h2>
<p>The core innovation lies in combining a competition-winning recognition architecture with extensive robustness engineering and end-to-end document processing. The key contributions include:</p>
<ol>
<li>
<p><strong>ResNet-Transformer Recognition Model</strong>: The core recognition system uses a <strong>Residual Neural Network (ResNet)</strong> encoder paired with a <strong>Transformer decoder</strong> in an image-captioning framework. This architecture won first place in a Kaggle molecular translation competition, which provided a strong foundation for the recognition task. Let the input image be $I$. The model maximizes the joint likelihood of the SMILES tokens $T$ and coordinate sequences $X, Y$:
$$
\begin{aligned}
\mathcal{L}_{\text{total}} = - \sum_{i=1}^{L} \log P(T_i \mid I, T_{&lt;i}) - \lambda \sum_{i=1}^{L} \big(\log P(X_i \mid I, X_{&lt;i}) + \log P(Y_i \mid I, Y_{&lt;i})\big)
\end{aligned}
$$
Here, continuous $X$ and $Y$ atom coordinates are mapped strictly to 200 discrete bins to formulate the coordinate prediction as a standard classification task alongside SMILES generation.</p>
</li>
<li>
<p><strong>Enhanced Molecular Representation</strong>: The model produces an augmented representation that encompasses:</p>
<ul>
<li>Standard molecular connectivity information</li>
<li><strong>Bond type tokens</strong> (solid wedge bonds, dashed bonds, etc.) that preserve 3D stereochemical information</li>
<li><strong>Atom coordinate predictions</strong> that allow reconstruction of the exact molecular pose from the original image</li>
</ul>
<p>This dual prediction of discrete structure and continuous coordinates makes the output strictly faithful to the source material and enables better quality assessment.</p>
</li>
<li>
<p><strong>Massive Synthetic Training Dataset</strong>: The model was trained on approximately <strong>20 million synthetic molecular images</strong> generated from PubChem SMILES with aggressive data augmentation. The augmentation strategy randomized visual styles, image quality, and rendering parameters to create maximum diversity, ensuring the network rarely saw the same molecular depiction twice. This forces the model to learn robust, style-invariant features.</p>
</li>
<li>
<p><strong>End-to-End Document Processing Pipeline</strong>: αExtractor integrates <strong>object detection</strong> and <strong>structure recognition</strong> into a complete document mining system:</p>
<ul>
<li>An object detection model automatically locates molecular images within PDF documents</li>
<li>The recognition model converts detected images to structured representations</li>
<li>A web service interface makes the entire pipeline accessible to researchers without machine learning expertise</li>
</ul>
</li>
<li>
<p><strong>Robustness-First Design</strong>: The system was explicitly designed to handle degraded image conditions that break traditional OCSR tools, including low resolution, background interference, color variations, and scanning artifacts commonly found in legacy biomedical literature.</p>
</li>
</ol>
<h2 id="experimental-methodology-stress-testing-under-real-world-conditions">Experimental Methodology: Stress Testing under Real-World Conditions</h2>
<p>The evaluation focused on demonstrating robust performance across diverse image conditions, from pristine benchmarks to challenging real-world scenarios:</p>
<ol>
<li>
<p><strong>Benchmark Dataset Evaluation</strong>: αExtractor was tested on four standard OCSR benchmarks:</p>
<ul>
<li><strong>CLEF</strong>: Chemical structure recognition challenge dataset</li>
<li><strong>UOB</strong>: University of Birmingham patent images</li>
<li><strong>JPO</strong>: Japan Patent Office molecular diagrams</li>
<li><strong>USPTO</strong>: US Patent and Trademark Office structures</li>
</ul>
<p>Performance was measured using exact SMILES match accuracy.</p>
</li>
<li>
<p><strong>Error Analysis and Dataset Correction</strong>: During evaluation, the researchers discovered numerous labeling errors in the original benchmark datasets. They systematically identified and corrected these errors, then re-evaluated all methods on the cleaned datasets to get more accurate performance measurements.</p>
</li>
<li>
<p><strong>Robustness Stress Testing</strong>: The system was evaluated on two challenging datasets specifically designed to test robustness:</p>
<ul>
<li><strong>Color background images</strong> (200 samples): Molecular structures on complex, colorful backgrounds that simulate real figure conditions</li>
<li><strong>Low-quality images</strong> (200 samples): Degraded images with noise, blur, and artifacts typical of scanned documents</li>
</ul>
<p>These tests compared αExtractor against three open-source tools (OSRA, Molvel, and Imago) under realistic degradation conditions.</p>
</li>
<li>
<p><strong>Generalization Testing</strong>: In the most challenging experiment, αExtractor was tested on the <strong>DECIMER hand-drawn molecule images dataset</strong> (Brinkhaus et al., 2022), representing a completely different visual domain not represented in the training data. This tested whether the learned features could generalize beyond digital rendering styles to human-drawn chemistry.</p>
</li>
<li>
<p><strong>End-to-End Document Extraction</strong>: The complete pipeline was evaluated on 50 PDF files containing 2,336 molecular images. This tested both the object detection component (finding molecules in complex documents) and the recognition component (converting them to SMILES) in a realistic literature mining scenario.</p>
</li>
<li>
<p><strong>Speed Benchmarking</strong>: Inference time was measured to demonstrate the practical efficiency needed for large-scale document processing.</p>
</li>
</ol>
<h2 id="results--conclusions-strong-performance-on-degraded-images">Results &amp; Conclusions: Strong Performance on Degraded Images</h2>
<ul>
<li>
<p><strong>Substantial Accuracy Gains</strong>: On the four benchmark datasets, αExtractor achieved accuracies of 91.83% (CLEF), 98.47% (UOB), 88.67% (JPO), and 93.64% (USPTO), compared to previous best results of 84.6%, 90.0%, 72.2%, and 89.9% respectively. After correcting dataset labeling errors, the true accuracies were even higher, reaching <strong>95.77% on CLEF, 99.86% on UOB, and 92.44% on JPO</strong>.</p>
</li>
<li>
<p><strong>Robustness on Degraded Images</strong>: Open-source competitors struggled on degraded images (achieving 5.5% accuracy at best). αExtractor maintained <strong>over 90% accuracy</strong> on both color background and low-quality image datasets, demonstrating the effectiveness of the synthetic training strategy.</p>
</li>
<li>
<p><strong>Generalization to Hand-Drawn Molecules</strong>: On hand-drawn molecules, a domain completely absent from training data, αExtractor achieved <strong>61.4% accuracy</strong> while other tools scored between 0.69% and 2.93%. This suggests the model learned genuinely chemical features rather than style-specific patterns.</p>
</li>
<li>
<p><strong>Practical End-to-End Performance</strong>: In the complete document processing evaluation, αExtractor detected <strong>95.1% of molecular images</strong> (2,221 out of 2,336) and correctly recognized <strong>94.5% of detected structures</strong> (2,098 correct predictions). This demonstrates the system&rsquo;s readiness for real-world literature mining applications.</p>
</li>
<li>
<p><strong>Ablation Results</strong>: Ablation experiments confirmed that each architectural component (ResNet backbone, Transformer encoder, Transformer decoder) contributes to performance, with the Transformer decoder having the largest impact. Replacing the Transformer decoder with an LSTM decoder substantially reduced accuracy (Table S6 in the paper).</p>
</li>
<li>
<p><strong>Dataset Quality Issues</strong>: The systematic discovery of labeling errors in standard benchmarks highlights a broader problem in OCSR evaluation. The corrected datasets provide more reliable baselines for future method development.</p>
</li>
<li>
<p><strong>Spatial Layout Limitation</strong>: αExtractor correctly identifies molecular connectivity, but the re-rendered structures may have different spatial layouts than the originals. This could complicate visual verification for complex molecules, even if the chemical information remains accurate.</p>
</li>
<li>
<p><strong>Non-Standard Depiction Handling</strong>: For images with non-standard bond depictions or atomic valences, αExtractor correctly identifies and normalizes them to standard representations. While chemically accurate, this means the re-rendered structure may visually differ from the original image.</p>
</li>
</ul>
<p>Overall, αExtractor combines accurate recognition (over 90% on degraded images), end-to-end document processing, and strong generalization across image conditions. It targets large-scale literature mining tasks where previous tools struggled with degraded inputs. The focus on real-world robustness over benchmark optimization reflects a practical approach to deploying machine learning in scientific workflows.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>This paper is <strong>Partially Reproducible</strong>. While the authors detail the model architectures and training techniques, the source code, training dataset (20M synthetic images), and pre-trained weights remain closed-source and proprietary. The authors released a sample of their test data and host an online web server for running inference.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/CLEF_corrected">Corrected CLEF Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Authors&rsquo; corrected version of the CLEF benchmark.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/UOB_corrected">Corrected UOB Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Authors&rsquo; corrected version of the UOB benchmark.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/JPO_corrected">Corrected JPO Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Authors&rsquo; corrected version of the JPO benchmark.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/Colored_Background">Color Background Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">200 samples of molecular structures on complex, colorful backgrounds.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/Low_Quality">Low Quality Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">200 samples of degraded images with noise, blur, and artifacts.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/PDF">PDF Test Set</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Sample PDF files for end-to-end document extraction evaluation.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://extractor.alphama.com.cn/csr">αExtractor Web Server</a></td>
          <td style="text-align: left">Other</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Online service for running inference using the proprietary system.</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<p><strong>Image Recognition Model:</strong></p>
<ul>
<li><strong>Backbone:</strong> ResNet50 producing output of shape $2048 \times 19 \times 19$, projected to 512 channels via a feed-forward layer</li>
<li><strong>Transformer Architecture:</strong> 3 encoder layers and 3 decoder layers with hidden dimension of 512</li>
<li><strong>Output Format:</strong> Generates SMILES tokens plus two auxiliary coordinate sequences (X-axis and Y-axis) that are length-aligned with the SMILES tokens via padding</li>
</ul>
<p><strong>Object Detection Model:</strong></p>
<ul>
<li><strong>Architecture:</strong> DETR (Detection Transformer) with ResNet101 backbone</li>
<li><strong>Transformer Architecture:</strong> 6 encoder layers and 6 decoder layers with hidden dimension of 256</li>
<li><strong>Purpose:</strong> Locates molecular images within PDF pages before recognition</li>
</ul>
<p><strong>Coordinate Prediction:</strong></p>
<ul>
<li>Continuous X/Y coordinates are discretized into <strong>200 discrete bins</strong></li>
<li>Padding tokens added to coordinate sequences to align perfectly with SMILES token sequence, enabling simultaneous structure and pose prediction</li>
</ul>
<h3 id="data">Data</h3>
<p><strong>Training Data:</strong></p>
<ul>
<li><strong>Synthetic Generation:</strong> Python script rendering PubChem SMILES into 2D images</li>
<li><strong>Dataset Size:</strong> Approximately 20.3 million synthetic molecular images from PubChem</li>
<li><strong>Superatom Handling:</strong> 50% of molecules had functional groups replaced with superatoms (e.g., &ldquo;COOH&rdquo;) or generic labels (R1, X1) to match literature drawing conventions</li>
<li><strong>Rendering Augmentation:</strong> Randomized bond thickness, bond spacing, font size, font color, and padding size</li>
</ul>
<p><strong>Geometric Augmentation:</strong></p>
<ul>
<li>Shear along x-axis: $\pm 15^\circ$</li>
<li>Rotation: $\pm 15^\circ$</li>
<li>Piecewise affine scaling</li>
</ul>
<p><strong>Noise Injection:</strong></p>
<ul>
<li>Pepper noise: 0-2%</li>
<li>Salt noise: 0-40%</li>
<li>Gaussian noise: scale 0-0.16</li>
</ul>
<p><strong>Destructive Augmentation:</strong></p>
<ul>
<li>JPEG compression: severity levels 2-5</li>
<li>Random masking</li>
</ul>
<p><strong>Evaluation Datasets:</strong></p>
<ul>
<li><strong>CLEF</strong>: Chemical structure recognition challenge dataset</li>
<li><strong>UOB</strong>: University of Birmingham patent images</li>
<li><strong>JPO</strong>: Japan Patent Office molecular diagrams</li>
<li><strong>USPTO</strong>: US Patent and Trademark Office structures</li>
<li><strong>Color background images</strong>: 200 samples</li>
<li><strong>Low-quality images</strong>: 200 samples</li>
<li><strong>Hand-drawn structures</strong>: Test set for generalization</li>
<li><strong>End-to-end document extraction</strong>: 50 PDFs (567 pages, 2,336 molecular images)</li>
</ul>
<h3 id="training">Training</h3>
<p><strong>Image Recognition Model:</strong></p>
<ul>
<li><strong>Optimizer:</strong> Adam with learning rate of 1e-4</li>
<li><strong>Batch Size:</strong> 100</li>
<li><strong>Epochs:</strong> 5</li>
<li><strong>Loss Function:</strong> Cross-entropy loss for both SMILES prediction and coordinate prediction</li>
</ul>
<p><strong>Object Detection Model:</strong></p>
<ul>
<li><strong>Optimizer:</strong> Adam with learning rate of 1e-4</li>
<li><strong>Batch Size:</strong> 24</li>
<li><strong>Training Strategy:</strong> Pre-trained on synthetic &ldquo;Lower Quality&rdquo; data for 5 epochs, then fine-tuned on annotated real &ldquo;High Quality&rdquo; data for 30 epochs</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics:</strong></p>
<ul>
<li><strong>Recognition</strong>: SMILES accuracy (exact match)</li>
<li><strong>End-to-End Pipeline</strong>:
<ul>
<li><strong>Recall</strong>: 95.1% for detection</li>
<li><strong>Accuracy</strong>: 94.5% for recognition</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p><strong>Inference Hardware:</strong></p>
<ul>
<li>Cloud CPU server (8 CPUs, 64 GB RAM)</li>
<li><strong>Throughput:</strong> Processed 50 PDFs (567 pages) in 40 minutes</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Xiong, J., Liu, X., Li, Z., Xiao, H., Wang, G., Niu, Z., Fei, C., Zhong, F., Wang, G., Zhang, W., Fu, Z., Liu, Z., Chen, K., Jiang, H., &amp; Zheng, M. (2023). αExtractor: a system for automatic extraction of chemical information from biomedical literature. <em>Science China Life Sciences</em>, 67(3), 618-621. <a href="https://doi.org/10.1007/s11427-023-2388-x">https://doi.org/10.1007/s11427-023-2388-x</a></p>
<p><strong>Publication</strong>: Science China Life Sciences (2023)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://doi.org/10.1007/s11427-023-2388-x">Paper on Springer</a></li>
</ul>
]]></content:encoded></item><item><title>What is Optical Chemical Structure Recognition (OCSR)?</title><link>https://hunterheidenreich.com/posts/what-is-ocsr/</link><pubDate>Sat, 11 Oct 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/what-is-ocsr/</guid><description>A micro-review of Optical Chemical Structure Recognition (OCSR), covering rule-based systems to modern deep learning models.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>Decades of chemical research, breakthroughs in medicine, and novel materials are archived in journals, patents, and textbooks.
A huge portion of this knowledge is stored as images, a format inaccessible to standard computational tools.
This imposes challenges for both data retrieval and leveraging modern computational tools to analyze and predict chemical properties, inefficiencies that compound across the literature: knowledge locked in image form is invisible to search, mining, and downstream model training.</p>
<p>This is the central challenge that <strong>Optical Chemical Structure Recognition (OCSR)</strong> aims to solve. At its heart, OCSR is to chemistry what OCR (Optical Character Recognition) is to text: a technology that teaches computers to extract chemical information directly from 2D diagrams of molecules. It&rsquo;s the bridge between a picture of a molecule and a machine-readable format like <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> (Simplified Molecular Input Line Entry System) that can be stored, searched, and used to power new discoveries.</p>















<figure class="post-figure center ">
    <img src="/img/ocsr/img2smiles.webp"
         alt="The transformation from a 2D chemical structure image to a SMILES representation."
         title="The transformation from a 2D chemical structure image to a SMILES representation."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The transformation from a 2D chemical structure image to a SMILES representation.</figcaption>
    
</figure>

<p>Teaching a computer to read a chemical structure requires specialized techniques.</p>
<h2 id="the-complexity-of-chemical-graphs">The Complexity of Chemical Graphs</h2>
<p>Recognizing a molecule requires specialized techniques that extend standard Optical Character Recognition (OCR). A molecule is a <em>graph</em>: a collection of atoms (nodes) connected by bonds (edges).</p>
<blockquote>
<p>(While this simplified view excludes complex structures like coordination compounds and polymers, it provides a highly effective starting point for this discussion.)</p></blockquote>
<p>An OCSR system must overcome several hurdles:</p>
<ul>
<li><strong>Varying Styles:</strong> Chemical drawings vary widely across publications. Bond lengths, angles, and fonts can differ dramatically from one document to another.</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/ocsr/acs.orglett.2c02187_1.webp"
         alt="An example from the Colored Background OSCR Benchmark, showing a complex and colorful chemical structure."
         title="An example from the Colored Background OSCR Benchmark, showing a complex and colorful chemical structure."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">An example from the <a href="https://huggingface.co/datasets/hheiden/Colored_Background_OCSR_benchmark">Colored Background OSCR Benchmark</a>, showing a complex and colorful chemical structure.</figcaption>
    
</figure>

<ul>
<li><strong>Image Quality:</strong> Older documents might be scanned at low resolutions, containing noise, blur, or other artifacts that make interpretation difficult.</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/ocsr/2008239616_449_chem.webp"
         alt="A challenging chemical structure image from the JPO benchmark, difficult due to its low quality."
         title="A challenging chemical structure image from the JPO benchmark, difficult due to its low quality."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A challenging chemical structure image from the <a href="https://huggingface.co/datasets/hheiden/JPO_OCSR_benchmark">JPO benchmark</a>, difficult due to its low quality.</figcaption>
    
</figure>

<ul>
<li><strong>Structural Complexity:</strong> From simple rings to sprawling polymers and complex <strong>Markush structures</strong> (common in patents to represent a whole family of related compounds), the variety is immense.</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/ocsr/markush.webp"
         alt="An example of a Markush structure, illustrating the complexity and variety of chemical compounds."
         title="An example of a Markush structure, illustrating the complexity and variety of chemical compounds."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">An example of a Markush structure, illustrating the complexity and variety of chemical compounds.</figcaption>
    
</figure>

<h2 id="the-evolution-of-ocsr">The Evolution of OCSR</h2>
<p>The quest to automate this process has evolved significantly, moving from brittle, hand-coded systems to sophisticated AI that can learn from data.</p>
<h3 id="act-1-the-rule-based-pioneers-ocr-10">Act 1: The Rule-Based Pioneers (OCR-1.0)</h3>
<p>The first OCSR systems, developed in the early 1990s, represent what we can now call the <strong>&ldquo;OCR-1.0&rdquo; era</strong>. Tools like <a href="https://pubs.acs.org/doi/10.1021/ci00008a018">Kekulé</a>, and later open-source solutions like <a href="/notes/chemistry/optical-structure-recognition/rule-based/osra/">OSRA</a> and <a href="https://github.com/ncats/molvec">MolVec</a>, operated like meticulous draftsmen. Their approach was methodical:</p>
<ol>
<li><strong>Vectorize the Image:</strong> Convert the pixel-based image into a collection of lines and shapes</li>
<li><strong>Identify Components:</strong> Use a set of hard-coded rules to classify these components. &ldquo;This thick line is a wedge bond.&rdquo; &ldquo;This group of pixels is the letter &lsquo;O&rsquo;.&rdquo;</li>
<li><strong>Reconstruct the Graph:</strong> Piece together the identified atoms and bonds into a coherent molecular graph</li>
</ol>
<p>This rule-based approach was a real first step but brittle. It struggled with the messiness of real-world documents and was expensive to maintain because each new style or error required new rules.</p>
<p>Additionally, they were designed as interactive tools to assist human experts in digitizing chemical structures.
There was always the assumption that a human would review and correct the output.</p>
<p>As a concrete case-study, consider the (reproduced) results from <a href="https://arxiv.org/abs/2411.11098">MolParser</a>:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Method</th>
          <th style="text-align: center">USPTO</th>
          <th style="text-align: center">UoB</th>
          <th style="text-align: center">CLEF</th>
          <th style="text-align: center">JPO</th>
          <th style="text-align: center">ColoredBG</th>
          <th style="text-align: center">USPTO-10K</th>
          <th style="text-align: center">WildMol-10K</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Rule-based methods</strong></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
      </tr>
      <tr>
          <td style="text-align: left">OSRA 2.1 *</td>
          <td style="text-align: center">89.3</td>
          <td style="text-align: center">86.3</td>
          <td style="text-align: center"><strong>93.4</strong></td>
          <td style="text-align: center">56.3</td>
          <td style="text-align: center">5.5</td>
          <td style="text-align: center">89.7</td>
          <td style="text-align: center">26.3</td>
      </tr>
      <tr>
          <td style="text-align: left">MolVec 0.9.7 *</td>
          <td style="text-align: center">91.6</td>
          <td style="text-align: center">79.7</td>
          <td style="text-align: center">81.2</td>
          <td style="text-align: center">66.8</td>
          <td style="text-align: center">8.0</td>
          <td style="text-align: center">92.4</td>
          <td style="text-align: center">26.4</td>
      </tr>
      <tr>
          <td style="text-align: left">Imago 2.0 *</td>
          <td style="text-align: center">89.4</td>
          <td style="text-align: center">63.9</td>
          <td style="text-align: center">68.2</td>
          <td style="text-align: center">41.0</td>
          <td style="text-align: center">2.0</td>
          <td style="text-align: center">89.9</td>
          <td style="text-align: center">6.9</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Only synthetic training</strong></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
      </tr>
      <tr>
          <td style="text-align: left">Img2Mol *</td>
          <td style="text-align: center">30.0</td>
          <td style="text-align: center">68.1</td>
          <td style="text-align: center">17.9</td>
          <td style="text-align: center">16.1</td>
          <td style="text-align: center">3.5</td>
          <td style="text-align: center">33.7</td>
          <td style="text-align: center">24.4</td>
      </tr>
      <tr>
          <td style="text-align: left">MolGrapher †*</td>
          <td style="text-align: center">91.5</td>
          <td style="text-align: center"><strong>94.9</strong></td>
          <td style="text-align: center">90.5</td>
          <td style="text-align: center">67.5</td>
          <td style="text-align: center">7.5</td>
          <td style="text-align: center">93.3</td>
          <td style="text-align: center">45.5</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Real data finetuning</strong></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
      </tr>
      <tr>
          <td style="text-align: left">DECIMER 2.7 *</td>
          <td style="text-align: center">59.9</td>
          <td style="text-align: center">88.3</td>
          <td style="text-align: center">72.0</td>
          <td style="text-align: center">64.0</td>
          <td style="text-align: center">14.5</td>
          <td style="text-align: center">82.4</td>
          <td style="text-align: center">56.0</td>
      </tr>
      <tr>
          <td style="text-align: left">MolScribe *</td>
          <td style="text-align: center"><u>93.1</u></td>
          <td style="text-align: center">87.4</td>
          <td style="text-align: center">88.9</td>
          <td style="text-align: center">76.2</td>
          <td style="text-align: center">21.0</td>
          <td style="text-align: center"><strong>96.0</strong></td>
          <td style="text-align: center">66.4</td>
      </tr>
      <tr>
          <td style="text-align: left">MolParser-Tiny (Ours)</td>
          <td style="text-align: center">93.0</td>
          <td style="text-align: center">91.6</td>
          <td style="text-align: center"><u>91.0</u></td>
          <td style="text-align: center">75.6</td>
          <td style="text-align: center"><strong>58.5</strong></td>
          <td style="text-align: center">89.5</td>
          <td style="text-align: center">73.1</td>
      </tr>
      <tr>
          <td style="text-align: left">MolParser-Small (Ours)</td>
          <td style="text-align: center"><strong>93.1</strong></td>
          <td style="text-align: center">91.1</td>
          <td style="text-align: center">90.8</td>
          <td style="text-align: center">76.2</td>
          <td style="text-align: center">57.0</td>
          <td style="text-align: center"><u>94.8</u></td>
          <td style="text-align: center">76.3</td>
      </tr>
      <tr>
          <td style="text-align: left">MolParser-Base (Ours)</td>
          <td style="text-align: center">93.0</td>
          <td style="text-align: center"><u>91.8</u></td>
          <td style="text-align: center">90.7</td>
          <td style="text-align: center"><strong>78.9</strong></td>
          <td style="text-align: center">57.0</td>
          <td style="text-align: center">94.5</td>
          <td style="text-align: center"><strong>76.9</strong></td>
      </tr>
  </tbody>
</table>
<blockquote>
<p><strong>Table 2. Comparison of our method with existing OCSR models.</strong> We report the accuracy. We use <strong>bold</strong> to indicate the best performance and <u>underline</u> to denote the second-best performance. *: re-implemented results. †: results from original publications.</p></blockquote>
<p>In this table, we see that the rule-based methods (OSRA, MolVec, Imago) perform reasonably well on cleaner datasets like USPTO and UoB but falter on more challenging ones like JPO and ColoredBG. Modern AI-based methods (MolGrapher, DECIMER, MolScribe, MolParser) improve most on the hardest benchmarks (like JPO and ColoredBG), especially when fine-tuned on real data, while the rule-based tools still do reasonably well on cleaner sets like USPTO and UoB.</p>
<h3 id="act-2-the-ai-fork-in-the-road-2010s-2020s">Act 2: The AI Fork in the Road (2010s-2020s)</h3>
<p>The rise of deep learning in the 2010s brought new paradigms that could learn from data. Here, the field split into two distinct paths.</p>
<h4 id="path-a-the-rise-of-the-specialists-graph-based-ai">Path A: The Rise of the Specialists (Graph-Based AI)</h4>
<p>Some models replaced the hard-coded rules with AI components. Systems like <a href="https://github.com/DS4SD/MolGrapher">MolGrapher</a> and <a href="https://github.com/thomas0809/MolScribe">MolScribe</a> use a two-stage process:</p>
<ul>
<li><strong>Atom Detection:</strong> A neural network first identifies all the atoms in the image</li>
<li><strong>Bond Prediction:</strong> A second process then predicts the connections (bonds) between those atoms to form the final graph</li>
</ul>
<p>These are highly specialized tools, trained specifically for the task of building a molecular graph.</p>
<h4 id="path-b-the-rise-of-the-generalists-lvlms">Path B: The Rise of the Generalists (LVLMs)</h4>
<p>Another, more direct method treats OCSR as an image captioning task. This approach aligns with the broader trend of <strong>Large Vision-Language Models (LVLMs)</strong>: massive, general-purpose AIs like GPT-4V. Models like <a href="https://github.com/Kohulan/DECIMER-Image_Transformer">DECIMER</a> and <a href="/notes/chemistry/optical-structure-recognition/vision-language/mol-parser/">MolParser</a> look at a molecular image and directly generate its textual representation, most commonly a <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES string</a>. This direct, end-to-end approach is powerful, though it requires enormous datasets to train effectively.</p>
<h2 id="the-next-frontier-the-ocr-20-vision-2024">The Next Frontier: The OCR-2.0 Vision (2024+)</h2>
<p>Recently, a proposal has emerged that charts a third path forward: <strong>OCR-2.0</strong>. This vision, proposed by <a href="https://arxiv.org/abs/2409.01704">Wei et al.</a> in 2024, argues for a new class of models that combine the best of both worlds. An OCR-2.0 model should be:</p>
<ol>
<li><strong>End-to-End:</strong> A single, unified model that simplifies maintenance</li>
<li><strong>Efficient &amp; Low-Cost:</strong> A specialized, highly efficient perception engine. The paper argues that using a giant LVLM for a pure recognition task is often inefficient</li>
<li><strong>Versatile:</strong> Capable of handling diverse artificial optical signals</li>
</ol>
<p>The flagship model for this theory is <a href="https://huggingface.co/stepfun-ai/GOT-OCR2_0">GOT (General OCR Theory)</a>. It&rsquo;s a single, unified model that can read an image and output structured text for a wide variety of inputs. It can translate a molecular diagram into a SMILES string, transcribe sheet music into musical notation, parse a bar chart into a data table, and describe a geometric shape using code.</p>
<p>This demonstrates that OCSR can be integrated into broader systems for processing human visual information. The same OCR-2.0 philosophy extends beyond chemistry: <a href="/research/gutenocr-grounded-vision-language-frontend/">GutenOCR</a>, for instance, applies grounded vision-language modeling to general document OCR, producing both text transcriptions and bounding-box outputs from a single model.</p>
<h2 id="pushing-the-boundaries-of-recognition">Pushing the Boundaries of Recognition</h2>
<p>OCR-2.0 models like GOT push for <em>breadth</em>, and other state-of-the-art research deepens the <em>depth</em> of understanding for the uniquely complex task of chemical recognition.</p>
<h3 id="deepening-reasoning-with-a-visual-chain-of-thought">Deepening Reasoning with a &ldquo;Visual Chain of Thought&rdquo;</h3>
<p>The <a href="https://arxiv.org/abs/2506.07553">GTR-Mol-VLM</a> model makes recognition more intelligent by mimicking how a person might analyze a complex diagram. The model traverses the molecule step-by-step, predicting an atom, then its bond, then the next atom, and so on. This &ldquo;Visual Chain of Thought&rdquo; improves accuracy, especially for complex molecules. It also faithfully recognizes abbreviations like &ldquo;Ph&rdquo; as single units, better representing the source image.</p>
<h3 id="deepening-application-with-visual-fingerprinting">Deepening Application with &ldquo;Visual Fingerprinting&rdquo;</h3>
<p><a href="https://link.springer.com/article/10.1186/s13321-025-01091-4">Subgrapher</a> rethinks the end goal. Many applications (like searching a patent database) require only the identification of specific molecular features. Subgrapher detects key functional groups and backbones directly from the image and creates a visual fingerprint. This approach mirrors identifying a person by key features (&ldquo;has glasses, a mustache&rdquo;), making it well-suited to finding matches in a large set.</p>
<h2 id="why-it-matters">Why It Matters</h2>
<p>The evolution of OCSR directly enables practical scientific advancements. This technology is a critical enabler for the future of science.</p>
<h3 id="searching-past-knowledge">Searching Past Knowledge</h3>
<p>OCSR digitizes decades of research from patents and journals, making it searchable and accessible for data mining. Imagine being able to search through every molecule ever published with a simple query. Or consider the practical impact: pharmaceutical companies can now automatically scan thousands of patent documents to ensure their new drug candidates don&rsquo;t infringe existing intellectual property, a process that previously required substantial manual review by patent analysts.</p>
<h3 id="accelerating-drug-discovery">Accelerating Drug Discovery</h3>
<p>By extracting vast datasets of molecules, scientists can train AI models to predict drug efficacy and toxicity, speeding up the discovery pipeline. The more molecular data we can digitize, the better our predictive models become.</p>
<h3 id="building-universal-document-intelligence">Building Universal Document Intelligence</h3>
<p>OCSR contributes to building AI systems capable of processing complex human documents. A scientific paper is a mix of text, equations, charts, tables, and molecular diagrams. Unified OCR-2.0 models are the key to making all of this knowledge searchable holistically.</p>
<h2 id="looking-forward">Looking Forward</h2>
<p>The goal is a loop where scientific knowledge, regardless of how it is stored, can be fed back into systems that read, search, and reason over it.</p>
<p>From the rule-based systems of the 1990s to today&rsquo;s models that read many printed diagrams reliably (though hard cases like low-quality scans and Markush structures remain open), OCSR has improved a great deal. As accuracy, efficiency, and breadth improve, more of the chemical literature becomes machine-readable.</p>
<p>This entire process begins with teaching a computer how to read a picture.</p>
]]></content:encoded></item><item><title>MolParser-7M &amp; WildMol: Large-Scale OCSR Datasets</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/molparser_7m-wildmol/</link><pubDate>Fri, 03 Oct 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/molparser_7m-wildmol/</guid><description>MolParser-7M is the largest open-source OCSR dataset with 7.7M image-SMILES pairs including 400k real-world annotated samples.</description><content:encoded><![CDATA[<h2 id="dataset-examples">Dataset Examples</h2>















<figure class="post-figure center ">
    <img src="/img/molparser-markush-example.webp"
         alt="Example of a complex Markush structure"
         title="Example of a complex Markush structure"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">An example of a complex Markush structure that can be represented by the E-SMILES format but not by standard SMILES or FG-SMILES.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/molparser-low-quality-example.webp"
         alt="Sample from the WildMol benchmark"
         title="Sample from the WildMol benchmark"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A sample from the WildMol benchmark, showing a low-quality, noisy molecular image cropped from real-world literature that challenges OCSR systems.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/molparser-colored-example.webp"
         alt="Colored molecule with annotations"
         title="Colored molecule with annotations"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A colored molecule with annotations, representing the diverse drawing styles found in scientific papers that OCSR models must handle.</figcaption>
    
</figure>

<h2 id="dataset-subsets">Dataset Subsets</h2>
<table>
  <thead>
      <tr>
          <th>Subset</th>
          <th>Count</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>MolParser-7M (Training Set)</strong></td>
          <td>7,740,871</td>
          <td>A large-scale dataset for training OCSR models, split into pre-training and fine-tuning stages.</td>
      </tr>
      <tr>
          <td><strong>WildMol (Test Set)</strong></td>
          <td>20,000</td>
          <td>A benchmark of 20,000 human-annotated samples cropped from real PDF files to evaluate OCSR models in &lsquo;in-the-wild&rsquo; scenarios. Comprises WildMol-10k (10k ordinary molecules) and WildMol-10k-M (10k Markush structures).</td>
      </tr>
  </tbody>
</table>
<h2 id="benchmarks">Benchmarks</h2>

<div class="benchmarks-content">
  <div class="benchmark-section">
    <h3 id="wildmol-10k-accuracy">WildMol-10K Accuracy<a hidden class="anchor" aria-hidden="true" href="#wildmol-10k-accuracy">#</a></h3>
    <p class="benchmark-description">Evaluation of OCSR models on 10,000 real-world molecular images cropped from scientific literature and patents</p>
    <table class="benchmark-table">
      <thead>
        <tr>
          <th>Rank</th>
          <th>Model</th>
          <th>Accuracy (%)</th>
        </tr>
      </thead>
      <tbody>
        <tr class="top-result">
          <td>🥇 1</td>
          <td>
            <strong>MolParser-Base</strong><br><small>End-to-end visual recognition trained on MolParser-7M</small>
          </td>
          <td>76.9</td>
        </tr>
        <tr class="top-result">
          <td>🥈 2</td>
          <td>
            <strong>MolScribe</strong><br><small>Transformer-based OCSR system</small>
          </td>
          <td>66.4</td>
        </tr>
        <tr class="top-result">
          <td>🥉 3</td>
          <td>
            <strong>DECIMER 2.7</strong><br><small>Deep learning for chemical image recognition</small>
          </td>
          <td>56</td>
        </tr>
        <tr>
          <td>4</td>
          <td>
            <strong>MolGrapher</strong><br><small>Graph-based molecular structure recognition</small>
          </td>
          <td>45.5</td>
        </tr>
        <tr>
          <td>5</td>
          <td>
            <strong>MolVec 0.9.7</strong><br><small>Vector-based structure recognition</small>
          </td>
          <td>26.4</td>
        </tr>
        <tr>
          <td>6</td>
          <td>
            <strong>OSRA 2.1</strong><br><small>Optical Structure Recognition Application</small>
          </td>
          <td>26.3</td>
        </tr>
        <tr>
          <td>7</td>
          <td>
            <strong>Img2Mol</strong><br><small>Image-to-molecule translation</small>
          </td>
          <td>24.4</td>
        </tr>
        <tr>
          <td>8</td>
          <td>
            <strong>Imago 2.0</strong><br><small>Chemical structure recognition toolkit</small>
          </td>
          <td>6.9</td>
        </tr>
      </tbody>
    </table>
  </div>
</div>

<h2 id="key-contribution">Key Contribution</h2>
<p>Introduces MolParser-7M, the largest open-source Optical Chemical Structure Recognition (OCSR) dataset, uniquely combining diverse synthetic data with a large volume of manually-annotated, &ldquo;in-the-wild&rdquo; images from real scientific documents to improve model robustness. Also introduces WildMol, a new challenging benchmark for evaluating OCSR performance on real-world data, including Markush structures.</p>
<h2 id="overview">Overview</h2>
<p>The MolParser project addresses the challenge of recognizing molecular structures from images found in real-world scientific documents. Unlike existing OCSR datasets that rely primarily on synthetically generated images, MolParser-7M incorporates 400,000 manually annotated images cropped from actual patents and scientific papers, making it the first large-scale dataset to bridge the gap between synthetic training data and real-world deployment scenarios.</p>
<h2 id="strengths">Strengths</h2>
<ul>
<li>Largest open-source OCSR dataset with over 7.7 million pairs</li>
<li>The only large-scale OCSR training set that includes a significant amount (400k) of &ldquo;in-the-wild&rdquo; data cropped from real patents and literature</li>
<li>High diversity of molecular structures from numerous sources (PubChem, ChEMBL, polymers, etc.)</li>
<li>Introduces the WildMol benchmark for evaluating performance on challenging, real-world data, including Markush structures</li>
<li>The &ldquo;in-the-wild&rdquo; fine-tuning data (MolParser-SFT-400k) was curated via an efficient active learning data engine with human-in-the-loop validation</li>
</ul>
<h2 id="limitations">Limitations</h2>
<ul>
<li>The E-SMILES format cannot represent certain complex cases, such as coordination bonds, dashed abstract rings, Markush structures depicted with special patterns, and replication of long structural segments on the skeleton</li>
<li>The model and data do not yet fully exploit molecular chirality, which is critical for chemical properties</li>
<li>Performance could be further improved by scaling up the amount of real annotated training data</li>
</ul>
<h2 id="technical-notes">Technical Notes</h2>
<h3 id="synthetic-data-generation">Synthetic Data Generation</h3>
<p>To ensure diversity, molecular structures were collected from databases like ChEMBL, PubChem, and Kaggle BMS. A significant number of Markush, polymer, and fused-ring structures were also randomly generated. Images were rendered using RDKit and epam.indigo with randomized parameters (e.g., bond width, font size, rotation) to increase visual diversity. The pretraining dataset is composed of the following subsets:</p>
<table>
  <thead>
      <tr>
          <th>Subset</th>
          <th>Ratio</th>
          <th>Source</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Markush-3M</td>
          <td>40%</td>
          <td>Random groups replacement from PubChem</td>
      </tr>
      <tr>
          <td>ChEMBL-2M</td>
          <td>27%</td>
          <td>Molecules selected from ChEMBL</td>
      </tr>
      <tr>
          <td>Polymer-1M</td>
          <td>14%</td>
          <td>Randomly generated polymer molecules</td>
      </tr>
      <tr>
          <td>PAH-600k</td>
          <td>8%</td>
          <td>Randomly generated fused-ring molecules</td>
      </tr>
      <tr>
          <td>BMS-360k</td>
          <td>5%</td>
          <td>Molecules with long carbon chains from BMS</td>
      </tr>
      <tr>
          <td>MolGrapher-300K</td>
          <td>4%</td>
          <td>Training data from MolGrapher</td>
      </tr>
      <tr>
          <td>Pauling-100k</td>
          <td>2%</td>
          <td>Pauling-style images drawn using epam.indigo</td>
      </tr>
  </tbody>
</table>
<h3 id="in-the-wild-data-engine-molparser-sft-400k">In-the-Wild Data Engine (MolParser-SFT-400k)</h3>
<p>A YOLO11 object detection model (MolDet) located and cropped over 20 million molecule images from 1.22 million real PDFs (patents and papers). After de-duplication via p-hash similarity, 4 million unique images remained.</p>
<p>An active learning algorithm was used to select the most informative samples for annotation, targeting images where an ensemble of 5-fold models showed moderate confidence (0.6-0.9 Tanimoto similarity), indicating they were challenging but learnable.</p>
<p>This active learning approach with model pre-annotations reduced manual annotation time per molecule to 30 seconds, approximately 90% savings compared to annotating from scratch. In the final fine-tuning dataset, 56.04% of annotations directly utilized raw model pre-annotations, 20.97% passed review after a single manual correction, 13.87% were accepted after a second round of annotation, and 9.13% required three or more rounds.</p>
<p>The fine-tuning dataset is composed of:</p>
<table>
  <thead>
      <tr>
          <th>Subset</th>
          <th>Ratio</th>
          <th>Source</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MolParser-SFT-400k</td>
          <td>66%</td>
          <td>Manually annotated data obtained via data engine</td>
      </tr>
      <tr>
          <td>MolParser-Gen-200k</td>
          <td>32%</td>
          <td>Synthetic data selected from pretraining stage</td>
      </tr>
      <tr>
          <td>Handwrite-5k</td>
          <td>1%</td>
          <td>Handwritten molecules selected from Img2Mol</td>
      </tr>
  </tbody>
</table>
<h3 id="e-smiles-specification">E-SMILES Specification</h3>
<p>To accommodate complex patent structures that standard SMILES cannot support, the authors introduced an Extended SMILES format (<code>SMILES&lt;sep&gt;EXTENSION</code>). The <code>EXTENSION</code> component uses XML-like tokens to manage complexities:</p>
<ul>
<li><code>&lt;a&gt;...&lt;/a&gt;</code> encapsulates Markush R-groups and abbreviation groups.</li>
<li><code>&lt;r&gt;...&lt;/r&gt;</code> denotes ring attachments with uncertainty positions.</li>
<li><code>&lt;c&gt;...&lt;/c&gt;</code> defines abstract rings.</li>
<li><code>&lt;dum&gt;</code> identifies a connection point.</li>
</ul>
<p>This format enables Markush-molecule matching and LLM integration, while retaining RDKit compatibility for the standard SMILES portion.</p>
<h2 id="reproducibility">Reproducibility</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://huggingface.co/datasets/UniParser/MolParser-7M">MolParser-7M</a></td>
          <td>Dataset</td>
          <td>CC-BY-NC-SA-4.0</td>
          <td>Training and test data on HuggingFace. SFT subset is partially released.</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/UniParser/MolDet">MolDet (YOLO11)</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>Molecule detection model on HuggingFace</td>
      </tr>
      <tr>
          <td><a href="https://ocsr.dp.tech/">MolParser Demo</a></td>
          <td>Other</td>
          <td>N/A</td>
          <td>Online OCSR demo using MolParser-Base</td>
      </tr>
  </tbody>
</table>
<p>The dataset is publicly available on HuggingFace under a CC-BY-NC-SA-4.0 (non-commercial) license. The MolParser-SFT-400k subset is only partially released. The YOLO11-based MolDet detection model is also available on HuggingFace. No public code repository is provided for the MolParser recognition model itself. All experiments were conducted on 8 NVIDIA RTX 4090D GPUs, and throughput benchmarks were measured on a single RTX 4090D GPU.</p>
]]></content:encoded></item><item><title>SELFIES: A Robust Molecular String Representation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/selfies/</link><pubDate>Fri, 12 Sep 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/selfies/</guid><description>SELFIES is a robust molecular string representation for ML where every string decodes to a valid molecule, implemented in the selfies Python library.</description><content:encoded><![CDATA[<h2 id="overview">Overview</h2>
<p><strong>SELFIES (SELF-referencIng Embedded Strings)</strong> is a string-based molecular representation where every possible string, even one generated randomly, corresponds to a syntactically and semantically valid molecule. This property addresses a major limitation of <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>, where a large fraction of strings produced by machine learning models represent invalid chemical structures.</p>
<p>The format is implemented in an open-source Python library called <code>selfies</code>. Since the <a href="/notes/chemistry/molecular-representations/notations/selfies-original-paper/">original publication</a>, the library has undergone significant architectural changes, most notably replacing the original string-manipulation engine with a graph-based internal representation that improved both performance and extensibility (see <a href="#recent-developments">Recent Developments</a>).</p>
<h3 id="key-characteristics">Key Characteristics</h3>
<ul>
<li><strong>Guaranteed Validity</strong>: Every possible SELFIES string can be decoded into a valid molecular graph that obeys chemical valence rules. This is its fundamental advantage over SMILES.</li>
<li><strong>Machine Learning Friendly</strong>: Can be used directly in any machine learning model (like VAEs or GANs) without adaptation, guaranteeing that all generated outputs are valid molecules.</li>
<li><strong>Customizable Constraints</strong>: The underlying chemical rules, such as maximum valence for different atoms, can be customized by the user. The library provides presets (e.g., for hypervalent species) and allows users to define their own rule sets.</li>
<li><strong>Human-readable</strong>: With some familiarity, SELFIES strings are human-readable, allowing interpretation of functional groups and connectivity.</li>
<li><strong>Local Operations</strong>: SELFIES encodes branch length and ring size as adjacent symbols in the string (rather than requiring matched delimiters or repeated digits at distant positions, as SMILES does), preventing common syntactical errors like unmatched parentheses or mismatched ring-closure digits.</li>
<li><strong>Broad Support</strong>: The current <code>selfies</code> library supports aromatic molecules (via kekulization), isotopes, charges, radicals, and stereochemistry. It also includes a dot symbol (<code>.</code>) for representing disconnected molecular fragments.</li>
</ul>
<h2 id="basic-syntax">Basic Syntax</h2>
<p>SELFIES uses symbols enclosed in square brackets (e.g., <code>[C]</code>, <code>[O]</code>, <code>[#N]</code>). The interpretation of each symbol depends on the current <strong>state of the derivation</strong> (described below), which ensures chemical valence rules are strictly obeyed. The syntax is formally defined by a Chomsky type-2 context-free grammar.</p>
<h3 id="derivation-rules">Derivation Rules</h3>
<p>SELFIES are constructed using a table of derivation rules. The process starts in an initial state (e.g., $X_0$) and reads the SELFIES string symbol by symbol. Each symbol, combined with the current state, determines the resulting atom/bond and the next state. The derivation state $X_n$ intuitively tracks that the previously added atom can form a maximum of $n$ additional bonds.</p>
<p>For example, the string <code>[F][=C][=C][#N]</code> is derived as follows, where $X_n$ indicates the atom can form up to $n$ additional bonds. Notice how bond demotion occurs: the first <code>[=C]</code> requests a double bond, but only a single bond is formed because state $X_1$ limits the connection to one bond.</p>
<p>$$
\begin{aligned}
\text{State } X_0 + \text{[F]} &amp;\rightarrow \text{F} + \text{State } X_1 \\
\text{State } X_1 + \text{[=C]} &amp;\rightarrow \text{F-C} + \text{State } X_3 \\
\text{State } X_3 + \text{[=C]} &amp;\rightarrow \text{F-C=C} + \text{State } X_2 \\
\text{State } X_2 + [\#\text{N}] &amp;\rightarrow \text{F-C=C=N} + \text{Final}
\end{aligned}
$$</p>
<h3 id="structural-features">Structural Features</h3>
<ul>
<li><strong>Branches</strong>: Represented by a <code>[Branch]</code> symbol. The symbols immediately following it are interpreted as an index that specifies the number of SELFIES symbols belonging to that branch. This structure prevents errors like unmatched parentheses in SMILES.</li>
<li><strong>Rings</strong>: Represented by a <code>[Ring]</code> symbol. Similar to branches, subsequent symbols specify an index that indicates which previous atom to connect to, forming a ring closure. To avoid violating valence constraints, ring bond creation is postponed to a final post-processing step, where it is only completed if the target atom has available bonds.</li>
</ul>
<h2 id="examples">Examples</h2>
<p>To see how these derivation rules work in practice, here are SELFIES representations for common molecules of increasing complexity:</p>















<figure class="post-figure center ">
    <img src="/img/selfies/ethanol.webp"
         alt="Ethanol molecule from SELFIES"
         title="Ethanol molecule from SELFIES"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Ethanol: <code>[C][C][O]</code></figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/selfies/benzene.webp"
         alt="Benzene molecule from SELFIES"
         title="Benzene molecule from SELFIES"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Benzene: <code>[C][=C][C][=C][C][=C][Ring1][=Branch1]</code></figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/selfies/aspirin.webp"
         alt="Aspirin molecule from SELFIES"
         title="Aspirin molecule from SELFIES"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Aspirin: <code>[C][C][Branch1][C][=O][O][C][=C][C][=C][C][=C][Ring1][=Branch1][C][Branch1][C][=O][O]</code></figcaption>
    
</figure>

<h2 id="the-selfies-python-library">The <code>selfies</code> Python Library</h2>
<p>The <code>selfies</code> library provides a dependency-free Python implementation. Here are the core operations:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> selfies <span style="color:#66d9ef">as</span> sf
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># SMILES -&gt; SELFIES</span>
</span></span><span style="display:flex;"><span>smiles <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;c1ccc(C(=O)O)cc1&#34;</span>  <span style="color:#75715e"># benzoic acid</span>
</span></span><span style="display:flex;"><span>encoded <span style="color:#f92672">=</span> sf<span style="color:#f92672">.</span>encoder(smiles)
</span></span><span style="display:flex;"><span>print(encoded)
</span></span><span style="display:flex;"><span><span style="color:#75715e"># -&gt; [C][=C][C][=C][C][Branch1][C][=O][O][=C][Ring1][=Branch1]</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># SELFIES -&gt; SMILES</span>
</span></span><span style="display:flex;"><span>decoded <span style="color:#f92672">=</span> sf<span style="color:#f92672">.</span>decoder(encoded)
</span></span><span style="display:flex;"><span>print(decoded)
</span></span><span style="display:flex;"><span><span style="color:#75715e"># -&gt; C1=CC=CC(=C1)C(=O)O</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Robustness: random strings always decode to valid molecules</span>
</span></span><span style="display:flex;"><span>random_selfies <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;[C][F][Ring1][O][=N][Branch1][C][S]&#34;</span>
</span></span><span style="display:flex;"><span>print(sf<span style="color:#f92672">.</span>decoder(random_selfies))
</span></span><span style="display:flex;"><span><span style="color:#75715e"># -&gt; always returns a valid molecule</span>
</span></span></code></pre></div><h3 id="tokenization-and-encoding">Tokenization and Encoding</h3>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> selfies <span style="color:#66d9ef">as</span> sf
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>selfies_str <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;[C][=C][C][=C][C][Branch1][C][=O][O][=C][Ring1][=Branch1]&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Tokenize into individual symbols</span>
</span></span><span style="display:flex;"><span>tokens <span style="color:#f92672">=</span> list(sf<span style="color:#f92672">.</span>split_selfies(selfies_str))
</span></span><span style="display:flex;"><span>print(tokens)
</span></span><span style="display:flex;"><span><span style="color:#75715e"># -&gt; [&#39;[C]&#39;, &#39;[=C]&#39;, &#39;[C]&#39;, &#39;[=C]&#39;, &#39;[C]&#39;, &#39;[Branch1]&#39;, &#39;[C]&#39;,</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e">#     &#39;[=O]&#39;, &#39;[O]&#39;, &#39;[=C]&#39;, &#39;[Ring1]&#39;, &#39;[=Branch1]&#39;]</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Get the alphabet (unique token set) from a dataset</span>
</span></span><span style="display:flex;"><span>dataset <span style="color:#f92672">=</span> [<span style="color:#e6db74">&#34;[C][C][O]&#34;</span>, <span style="color:#e6db74">&#34;[C][=C][C][=C][C][=C][Ring1][=Branch1]&#34;</span>]
</span></span><span style="display:flex;"><span>alphabet <span style="color:#f92672">=</span> sf<span style="color:#f92672">.</span>get_alphabet_from_selfies(dataset)
</span></span><span style="display:flex;"><span>print(sorted(alphabet))
</span></span><span style="display:flex;"><span><span style="color:#75715e"># -&gt; [&#39;[=Branch1]&#39;, &#39;[=C]&#39;, &#39;[C]&#39;, &#39;[O]&#39;, &#39;[Ring1]&#39;]</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Convert to integer encoding for ML pipelines</span>
</span></span><span style="display:flex;"><span>encoding, _ <span style="color:#f92672">=</span> sf<span style="color:#f92672">.</span>selfies_to_encoding(
</span></span><span style="display:flex;"><span>    selfies<span style="color:#f92672">=</span>selfies_str,
</span></span><span style="display:flex;"><span>    vocab_stoi<span style="color:#f92672">=</span>{s: i <span style="color:#66d9ef">for</span> i, s <span style="color:#f92672">in</span> enumerate(sorted(alphabet))},
</span></span><span style="display:flex;"><span>    pad_to_len<span style="color:#f92672">=</span><span style="color:#ae81ff">20</span>,
</span></span><span style="display:flex;"><span>    enc_type<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;label&#34;</span>,
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><h3 id="customizing-valence-constraints">Customizing Valence Constraints</h3>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> selfies <span style="color:#66d9ef">as</span> sf
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># View current constraints</span>
</span></span><span style="display:flex;"><span>print(sf<span style="color:#f92672">.</span>get_semantic_constraints())
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Allow hypervalent sulfur (e.g., SF6)</span>
</span></span><span style="display:flex;"><span>sf<span style="color:#f92672">.</span>set_semantic_constraints(<span style="color:#e6db74">&#34;hypervalent&#34;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Or define custom constraints</span>
</span></span><span style="display:flex;"><span>sf<span style="color:#f92672">.</span>set_semantic_constraints({
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;S&#34;</span>: <span style="color:#ae81ff">6</span>,  <span style="color:#75715e"># allow hexavalent sulfur</span>
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;P&#34;</span>: <span style="color:#ae81ff">5</span>,  <span style="color:#75715e"># allow pentavalent phosphorus</span>
</span></span><span style="display:flex;"><span>})
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Reset to defaults</span>
</span></span><span style="display:flex;"><span>sf<span style="color:#f92672">.</span>set_semantic_constraints(<span style="color:#e6db74">&#34;default&#34;</span>)
</span></span></code></pre></div><h2 id="selfies-in-machine-learning">SELFIES in Machine Learning</h2>
<h3 id="molecular-generation">Molecular Generation</h3>
<p>SELFIES is particularly advantageous for generative models in computational chemistry. When used in a VAE, the entire continuous latent space decodes to valid molecules, unlike SMILES where large regions of the latent space are invalid. The <a href="/notes/chemistry/molecular-representations/notations/selfies-original-paper/">original SELFIES paper</a> demonstrated this concretely: a VAE trained with SELFIES stored two orders of magnitude more diverse molecules than a SMILES-based VAE, and a GAN produced 78.9% diverse valid molecules compared to 18.6% for SMILES (Krenn et al., 2020).</p>
<p>Several generation approaches build directly on SELFIES:</p>
<ul>
<li><strong>Latent space optimization</strong>: <a href="/notes/chemistry/molecular-design/generation/latent-space/limo-latent-inceptionism/">LIMO</a> uses a SELFIES-based VAE with gradient-based optimization to generate molecules with nanomolar binding affinities, achieving 6-8x speedup over RL baselines (Eckmann et al., 2022).</li>
<li><strong>Training-free generation</strong>: <a href="/notes/chemistry/molecular-design/generation/search-based/stoned-selfies-chemical-space-exploration/">STONED</a> demonstrates that simple character-level mutations in SELFIES (replacement, deletion, insertion) produce valid molecules by construction, eliminating the need for neural networks entirely. STONED achieved a GuacaMol score of 14.70, competitive with deep generative models (Nigam et al., 2021).</li>
<li><strong>Gradient-based dreaming</strong>: <a href="/notes/chemistry/molecular-design/generation/latent-space/deep-molecular-dreaming-pasithea/">PASITHEA</a> computes gradients with respect to one-hot encoded SELFIES inputs to steer molecules toward target property values. Because SELFIES&rsquo; surjective mapping guarantees every intermediate representation is a valid molecule, this continuous optimization over the input space is feasible. PASITHEA generated molecules with properties outside the training data range (logP up to 4.24 vs. a training max of 3.08), with 97.2% novelty (Shen et al., 2021).</li>
<li><strong>Large-scale pre-training</strong>: <a href="/notes/chemistry/molecular-design/generation/autoregressive/molgen-molecular-generation-chemical-feedback/">MolGen</a> is a BART-based model pre-trained on 100M+ SELFIES molecules. It achieves 100% validity and an FCD of 0.0015 on MOSES (vs. 0.0061 for Chemformer), and introduces chemical feedback to align outputs with preference rankings (Fang et al., 2024).</li>
</ul>
<p>In benchmarks, SELFIES performs well for optimization-oriented tasks. In the <a href="/notes/chemistry/molecular-design/generation/evaluation/pmo-sample-efficient-molecular-optimization/">PMO benchmark</a> of 25 methods, SELFIES-REINVENT ranked 3rd and STONED ranked 5th. SELFIES-based genetic algorithms outperformed SMILES-based GAs, likely because SELFIES provides more intuitive mutation operations (Gao et al., 2022). The <a href="/notes/chemistry/molecular-design/generation/evaluation/tartarus-inverse-molecular-design/">Tartarus benchmark</a> corroborates this across more diverse real-world objectives (organic emitters, protein ligands, reaction substrates): SELFIES-VAE consistently outperforms SMILES-VAE, and the representation matters most where validity is a bottleneck (Nigam et al., 2022).</p>
<p>SELFIES mutations provide a simple but effective way to explore chemical space:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> selfies <span style="color:#66d9ef">as</span> sf
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> random
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">mutate_selfies</span>(selfies_str, mutation_type<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;replace&#34;</span>):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Mutate a SELFIES string. Every output is a valid molecule.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    tokens <span style="color:#f92672">=</span> list(sf<span style="color:#f92672">.</span>split_selfies(selfies_str))
</span></span><span style="display:flex;"><span>    alphabet <span style="color:#f92672">=</span> list(sf<span style="color:#f92672">.</span>get_semantic_robust_alphabet())
</span></span><span style="display:flex;"><span>    idx <span style="color:#f92672">=</span> random<span style="color:#f92672">.</span>randint(<span style="color:#ae81ff">0</span>, len(tokens) <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> mutation_type <span style="color:#f92672">==</span> <span style="color:#e6db74">&#34;replace&#34;</span>:
</span></span><span style="display:flex;"><span>        tokens[idx] <span style="color:#f92672">=</span> random<span style="color:#f92672">.</span>choice(alphabet)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">elif</span> mutation_type <span style="color:#f92672">==</span> <span style="color:#e6db74">&#34;insert&#34;</span>:
</span></span><span style="display:flex;"><span>        tokens<span style="color:#f92672">.</span>insert(idx, random<span style="color:#f92672">.</span>choice(alphabet))
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">elif</span> mutation_type <span style="color:#f92672">==</span> <span style="color:#e6db74">&#34;delete&#34;</span> <span style="color:#f92672">and</span> len(tokens) <span style="color:#f92672">&gt;</span> <span style="color:#ae81ff">1</span>:
</span></span><span style="display:flex;"><span>        tokens<span style="color:#f92672">.</span>pop(idx)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> <span style="color:#e6db74">&#34;&#34;</span><span style="color:#f92672">.</span>join(tokens)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Every mutation produces a valid molecule</span>
</span></span><span style="display:flex;"><span>original <span style="color:#f92672">=</span> sf<span style="color:#f92672">.</span>encoder(<span style="color:#e6db74">&#34;c1ccccc1&#34;</span>)  <span style="color:#75715e"># benzene</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> _ <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">5</span>):
</span></span><span style="display:flex;"><span>    mutant <span style="color:#f92672">=</span> mutate_selfies(original)
</span></span><span style="display:flex;"><span>    print(sf<span style="color:#f92672">.</span>decoder(mutant))  <span style="color:#75715e"># always valid</span>
</span></span></code></pre></div><h3 id="property-prediction-and-pretraining">Property Prediction and Pretraining</h3>
<p><a href="/notes/chemistry/molecular-representations/encoders/selformer/">SELFormer</a> is a RoBERTa-based chemical language model pretrained on 2M ChEMBL compounds using SELFIES as input. Because every masked token prediction corresponds to a valid molecular fragment, the model never wastes capacity learning invalid chemistry. SELFormer outperformed <a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a> by approximately 12% on average across BACE, BBBP, and HIV classification benchmarks (Yüksel et al., 2023). <a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a> also evaluated SELFIES as an input representation, finding comparable performance to SMILES on the Tox21 task (Chithrananda et al., 2020).</p>
<p>The <a href="/notes/chemistry/molecular-design/property-prediction/regression-transformer/">Regression Transformer</a> demonstrated that SELFIES achieves ~100% validity vs. ~40% for SMILES in conditional molecular generation, while performing comparably for property prediction. This dual prediction-generation capability is enabled by interleaving numerical property tokens with SELFIES molecular tokens in a single sequence (Born &amp; Manica, 2023).</p>
<p>At larger scales, <a href="/notes/chemistry/molecular-representations/encoders/neural-scaling-of-deep-chemical-models/">ChemGPT</a> (up to 1B parameters) uses a GPT-Neo backbone with SELFIES tokenization for autoregressive molecular generation, demonstrating that SELFIES follows the same power-law neural scaling behavior observed in NLP (Frey et al., 2023).</p>
<h3 id="optical-chemical-structure-recognition">Optical Chemical Structure Recognition</h3>
<p>In image-to-text chemical structure recognition, <a href="/notes/chemistry/optical-structure-recognition/benchmarks/rajan-string-representations-2022/">Rajan et al. (2022)</a> compared SMILES, DeepSMILES, SELFIES, and InChI as output formats using the same transformer architecture. SELFIES achieved 100% structural validity (every prediction could be decoded), while SMILES predictions occasionally contained syntax errors. The trade-off: SMILES achieved higher exact match accuracy (88.62%) partly because SELFIES strings are longer, producing more tokens for the decoder to predict.</p>
<h3 id="chemical-name-translation">Chemical Name Translation</h3>
<p><a href="/notes/chemistry/molecular-representations/name-translation/stout/">STOUT</a> uses SELFIES as its internal representation for translating between chemical line notations and IUPAC names. All SMILES are converted to SELFIES before processing, and the model achieves a BLEU score of 0.94 for IUPAC-to-SELFIES translation and 0.98 Tanimoto similarity on valid outputs. The authors found SELFIES&rsquo; syntactic robustness particularly valuable for this sequence-to-sequence task, where the decoder must produce a chemically valid output string (Rajan et al., 2021).</p>
<h3 id="tokenization">Tokenization</h3>
<p>Converting SELFIES strings into tokens for neural models is more straightforward than SMILES tokenization. Each bracket-enclosed symbol (<code>[C]</code>, <code>[=C]</code>, <code>[Branch1]</code>) is a natural token boundary. <a href="/notes/chemistry/molecular-representations/notations/smiles-selfies-tokenization-chemical-lm/">Atom Pair Encoding (APE)</a> extends byte pair encoding with chemistry-aware constraints for both SMILES and SELFIES. For SELFIES specifically, APE preserves atomic identity during subword merging, and SELFIES models showed strong inter-tokenizer agreement: all true positives from SELFIES-BPE were captured by SELFIES-APE (Leon et al., 2024).</p>
<h2 id="limitations-and-trade-offs">Limitations and Trade-offs</h2>
<h3 id="validity-constraints-can-introduce-bias">Validity Constraints Can Introduce Bias</h3>
<p>The guarantee that every string decodes to a valid molecule is SELFIES&rsquo; core advantage, but recent work has shown this comes with trade-offs. <a href="/notes/chemistry/molecular-representations/notations/invalid-smiles-help/">Skinnider (2024)</a> demonstrated that SMILES-based models consistently outperform SELFIES-based models on distribution-learning tasks. The mechanism: invalid SMILES represent a model&rsquo;s least confident predictions, and filtering them out acts as implicit quality control. SELFIES models, by construction, cannot discard low-confidence outputs this way. Furthermore, SELFIES validity constraints introduce systematic structural biases, generating fewer aromatic rings and more aliphatic structures compared to training data. When SELFIES constraints were relaxed to allow invalid generation (&ldquo;unconstrained SELFIES&rdquo;), performance improved, providing causal evidence that the ability to generate and discard invalid outputs benefits distribution learning.</p>
<p>This finding reframes the SMILES vs. SELFIES choice as context-dependent. As Grisoni (2023) summarizes in a <a href="/notes/chemistry/molecular-design/generation/evaluation/clms-de-novo-drug-design-review/">review of chemical language models</a>: &ldquo;SMILES offer a richer, more interpretable language with well-studied augmentation strategies, while SELFIES guarantee validity at the cost of chemical realism and edit interpretability.&rdquo;</p>
<p>The <a href="/notes/chemistry/molecular-design/generation/evaluation/pmo-sample-efficient-molecular-optimization/">PMO benchmark</a> provides further nuance: SELFIES-based variants of language model methods (REINVENT, LSTM HC, VAE) generally do not outperform their SMILES counterparts, because modern language models learn SMILES grammar well enough that syntactic invalidity is no longer a practical bottleneck. The exception is genetic algorithms, where SELFIES mutations are naturally well-suited.</p>
<p>A study on <a href="/notes/chemistry/molecular-design/property-prediction/lm-complex-molecular-distributions/">complex molecular distributions</a> paints a consistent picture: SELFIES-trained RNNs achieve better standard metrics (validity, uniqueness, novelty), while SMILES-trained RNNs achieve better distributional fidelity as measured by Wasserstein distance (Flam-Shepherd et al., 2022). Taken together, these findings suggest that SELFIES and SMILES have genuinely complementary strengths, and the best choice depends on whether the task prioritizes validity/novelty or distributional faithfulness.</p>
<h3 id="degenerate-outputs">Degenerate Outputs</h3>
<p>Although every SELFIES string decodes to a valid molecule, the decoded molecule may not always be chemically meaningful in context. The <a href="/notes/chemistry/molecular-design/property-prediction/regression-transformer/">Regression Transformer</a> reported ~1.9% defective generations where the output molecule had fewer than 50% of the seed molecule&rsquo;s atoms (Born &amp; Manica, 2023). This highlights a distinction between syntactic validity (which SELFIES guarantees) and semantic appropriateness (which it does not).</p>
<h3 id="other-limitations">Other Limitations</h3>
<ul>
<li><strong>Indirect Canonicalization</strong>: A canonical SELFIES string is currently generated by first creating a canonical SMILES string and then converting it to SELFIES. Direct canonicalization is a goal for future development.</li>
<li><strong>String Length</strong>: SELFIES strings are generally longer than their corresponding SMILES strings, which can impact storage, processing times, and sequence modeling difficulty for very large datasets.</li>
<li><strong>Ongoing Standardization</strong>: While the library now supports most major features found in SMILES, work is ongoing to extend the format to more complex systems like polymers, crystals, and reactions.</li>
</ul>
<h2 id="variants-and-extensions">Variants and Extensions</h2>
<h3 id="group-selfies">Group SELFIES</h3>
<p><a href="/notes/chemistry/molecular-representations/notations/group-selfies-fragment-molecular-representation/">Group SELFIES</a> extends the representation with group tokens that represent functional groups or entire substructures (e.g., a benzene ring or carboxyl group) as single units. Each group token has labeled attachment points with specified valency, allowing the decoder to continue tracking available bonds. Group SELFIES maintains the validity guarantee while producing shorter, more human-readable strings. On MOSES VAE benchmarks, Group SELFIES achieved an FCD of 0.1787 versus 0.6351 for standard SELFIES, indicating substantially better distribution learning (Cheng et al., 2023).</p>
<h3 id="stoned-algorithms">STONED Algorithms</h3>
<p><a href="/notes/chemistry/molecular-design/generation/search-based/stoned-selfies-chemical-space-exploration/">STONED</a> (Superfast Traversal, Optimization, Novelty, Exploration and Discovery) is a suite of algorithms that exploit SELFIES&rsquo; validity guarantee for training-free molecular design through point mutations, interpolation, and optimization (Nigam et al., 2021). See <a href="#molecular-generation">Molecular Generation</a> above for benchmark results.</p>
<h2 id="recent-developments">Recent Developments</h2>
<p>The <a href="/notes/chemistry/molecular-representations/notations/selfies-2023/">2023 library update</a> replaced the original string-manipulation engine with a graph-based internal representation. This change resolved several long-standing limitations: the original approach could not handle aromatics (requiring kekulization), stereochemistry, or charged species. The graph-based engine now supports all of these, and processes 300K+ molecules in approximately 4 minutes in pure Python. The library has been validated on all 72 million molecules from PubChem.</p>
<p>Looking forward, researchers have outlined <a href="/notes/chemistry/molecular-representations/notations/selfies-2022/">16 future research directions</a> for extending robust representations to complex systems like polymers, crystals, and chemical reactions.</p>
<h2 id="further-reading">Further Reading</h2>
<ul>
<li><a href="/posts/visualizing-smiles-and-selfies-strings/"><strong>Converting SELFIES Strings to 2D Molecular Images</strong></a>: Hands-on tutorial demonstrating SELFIES robustness and building visualization tools</li>
</ul>
<h2 id="references">References</h2>
<ul>
<li>Krenn, M., Häse, F., Nigam, A., Friederich, P., &amp; Aspuru-Guzik, A. (2020). Self-referencing embedded strings (SELFIES): A 100% robust molecular string representation. <a href="https://doi.org/10.1088/2632-2153/aba947"><em>Machine Learning: Science and Technology</em>, <em>1</em>(4), 045024.</a></li>
<li>Krenn, M., Ai, Q., Barthel, S., Carson, N., Frei, A., Frey, N. C., &hellip; &amp; Aspuru-Guzik, A. (2022). SELFIES and the future of molecular string representations. <a href="https://doi.org/10.1016/j.patter.2022.100588"><em>Patterns</em>, <em>3</em>(10), 100588.</a></li>
<li>Lo, A., Pollice, R., Nigam, A., White, A. D., Krenn, M., &amp; Aspuru-Guzik, A. (2023). Recent advances in the self-referencing embedded strings (SELFIES) library. <a href="https://doi.org/10.1039/d3dd00044c"><em>Digital Discovery</em>, <em>2</em>, 897-908.</a></li>
<li>Skinnider, M. A. (2024). Invalid SMILES are beneficial rather than detrimental to chemical language models. <a href="https://doi.org/10.1038/s42256-024-00821-x"><em>Nature Machine Intelligence</em>, <em>6</em>, 437-448.</a></li>
<li>Shen, C., Krenn, M., Eppel, S., &amp; Aspuru-Guzik, A. (2021). Deep molecular dreaming: inverse machine learning for de-novo molecular design and interpretability with surjective representations. <a href="https://doi.org/10.1088/2632-2153/ac09d6"><em>Machine Learning: Science and Technology</em>, <em>2</em>(3), 03LT02.</a></li>
<li>Fang, Y., et al. (2024). Domain-agnostic molecular generation with chemical feedback. <a href="https://openreview.net/forum?id=9rnerQyXlh"><em>ICLR 2024</em>.</a></li>
<li>Born, J., &amp; Manica, M. (2023). Regression Transformer enables concurrent sequence regression and generation for molecular language modelling. <a href="https://doi.org/10.1038/s42256-023-00639-z"><em>Nature Machine Intelligence</em>, <em>5</em>, 432-444.</a></li>
<li>Frey, N. C., Soklaski, R., Axelrod, S., Samsi, S., Gómez-Bombarelli, R., Coley, C. W., &amp; Gadepally, V. (2023). Neural scaling of deep chemical models. <a href="https://doi.org/10.1038/s42256-023-00740-3"><em>Nature Machine Intelligence</em>, <em>5</em>, 1297-1305.</a></li>
<li>Rajan, K., Zielesny, A., &amp; Steinbeck, C. (2021). STOUT: SMILES to IUPAC names using neural machine translation. <a href="https://doi.org/10.1186/s13321-021-00512-4"><em>Journal of Cheminformatics</em>, <em>13</em>, 34.</a></li>
<li>Nigam, A., Pollice, R., &amp; Aspuru-Guzik, A. (2022). Tartarus: A benchmarking platform for realistic and practical inverse molecular design. <a href="https://openreview.net/forum?id=sLFDE2MHzHO"><em>NeurIPS 2022 Datasets and Benchmarks</em>.</a></li>
<li><a href="https://github.com/aspuru-guzik-group/selfies">SELFIES GitHub Repository</a></li>
</ul>
]]></content:encoded></item><item><title>Communication in the Presence of Noise: Shannon's 1949 Paper</title><link>https://hunterheidenreich.com/notes/machine-learning/model-architectures/communication-in-the-presence-of-noise/</link><pubDate>Mon, 08 Sep 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/model-architectures/communication-in-the-presence-of-noise/</guid><description>Shannon's 1949 foundational paper establishing information theory, channel capacity, and the sampling theorem for communication systems.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a foundational <strong>Theory</strong> paper. It establishes the mathematical framework for modern information theory and defines the ultimate physical limits of communication for an entire system, from the information source to the final destination.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The central motivation was to develop a general theory of communication that could quantify information and determine the maximum rate at which it can be transmitted reliably over a noisy channel. Prior to this work, communication system design was largely empirical. Shannon sought to create a mathematical foundation to understand the trade-offs between key parameters like bandwidth, power, and noise, independent of any specific hardware or modulation scheme. To frame this, he conceptualized a general communication system as consisting of five essential elements: an information source, a transmitter, a channel, a receiver, and a destination.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The novelty is a complete, end-to-end mathematical theory of communication built upon several key concepts and theorems:</p>
<ol>
<li><strong>Geometric Representation of Signals</strong>: Shannon introduced the idea of representing signals as points in a high-dimensional vector space. A signal of duration $T$ and bandwidth $W$ is uniquely specified by $2TW$ numbers (its samples), which are treated as coordinates in a $2TW$-dimensional space. This transformed problems in communication into problems of high-dimensional geometry. In this representation, signal energy corresponds to squared distance from the origin, and noise introduces a &ldquo;sphere of uncertainty&rdquo; around each transmitted point.</li>
</ol>















<figure class="post-figure center ">
    <img src="/img/notes/geometric-interpretation-of-signals-as-spheres.webp"
         alt="Sphere packing illustration showing the geometric interpretation of channel capacity. A large dashed circle represents the total signal space with radius proportional to the square root of P&#43;N. Inside are multiple smaller blue circles (uncertainty spheres) with radius proportional to the square root of N, each centered on a distinct message point."
         title="Sphere packing illustration showing the geometric interpretation of channel capacity. A large dashed circle represents the total signal space with radius proportional to the square root of P&#43;N. Inside are multiple smaller blue circles (uncertainty spheres) with radius proportional to the square root of N, each centered on a distinct message point."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>Sphere Packing and Channel Capacity</strong>: Each transmitted message corresponds to a point in high-dimensional signal space. Noise creates an &lsquo;uncertainty sphere&rsquo; of radius $\sqrt{N}$ around each point. The channel capacity equals how many non-overlapping uncertainty spheres can be packed into the total signal sphere of radius $\sqrt{P+N}$.</figcaption>
    
</figure>

<ol start="2">
<li>
<p><strong>Theorem 1 (The Sampling Theorem)</strong>: The paper provides an explicit statement and proof that a signal containing no frequencies higher than $W$ is perfectly determined by its samples taken at a rate of $2W$ samples per second (i.e., spaced $1/2W$ seconds apart). Shannon credits Nyquist for pointing out the fundamental importance of the time interval $1/2W$ seconds in connection with telegraphy, and names this the &ldquo;Nyquist interval&rdquo; corresponding to the band $W$. This theorem is the theoretical bedrock of all modern digital signal processing.</p>
</li>
<li>
<p><strong>Theorem 2 (Channel Capacity for AWGN)</strong>: This is the paper&rsquo;s most celebrated result, now known as the <strong>Shannon-Hartley theorem</strong> (a name assigned retrospectively, not used in the paper itself). It provides an exact formula for the capacity $C$ (the maximum rate of error-free communication) of a channel with bandwidth $W$, signal power $P$, and additive white Gaussian noise of power $N$:
$$ C = W \log_2 \left(1 + \frac{P}{N}\right) $$
It proves that for any transmission rate below $C$, a coding scheme exists that can achieve an arbitrarily low error frequency.</p>
<p><strong>Random Coding Proof Technique</strong>: Shannon&rsquo;s proof employs a <strong>random coding argument</strong>: he proved that if you choose signal points at random from the sphere of radius $\sqrt{2TWP}$, the average error frequency vanishes for any transmission rate below capacity. This non-constructive proof (meaning it establishes that good codes must exist without constructing any specific one) established that &ldquo;good&rdquo; codes exist almost everywhere in the signal space, even if we don&rsquo;t know how to build them efficiently. The random coding argument became a fundamental tool in information theory, shifting the focus from building specific codes to proving existence and understanding fundamental limits.</p>
</li>
</ol>















<figure class="post-figure center ">
    <img src="/img/notes/shannons-ideal-capacity-curve-and-sampling-theorem.webp"
         alt="Two plots illustrating Shannon&#39;s key theorems: (Left) The ideal capacity curve showing bits per cycle vs SNR in dB, with an example operating point at 10dB. (Right) The sampling theorem demonstrating how a continuous signal is perfectly captured by samples taken at the Nyquist rate of 2W."
         title="Two plots illustrating Shannon&#39;s key theorems: (Left) The ideal capacity curve showing bits per cycle vs SNR in dB, with an example operating point at 10dB. (Right) The sampling theorem demonstrating how a continuous signal is perfectly captured by samples taken at the Nyquist rate of 2W."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>Left</strong>: Shannon&rsquo;s ideal capacity curve, showing how channel capacity (in bits per cycle) increases logarithmically with signal-to-noise ratio. <strong>Right</strong>: The sampling theorem in action, where a band-limited continuous signal is fully determined by discrete samples taken at twice its maximum frequency.</figcaption>
    
</figure>

<ol start="4">
<li>
<p><strong>Theorem 3 (Channel Capacity for Arbitrary Noise)</strong>: Shannon generalized the capacity concept to channels with any type of noise. Entropy power is defined as $N_1 = \frac{1}{2\pi e} e^{2h(X)}$, where $h(X)$ is the differential entropy of the noise distribution (the continuous analog of discrete entropy $H$: where $H$ counts the average bits per symbol from a discrete source, $h(X)$ measures the same unpredictability for continuous-valued random variables); it quantifies how spread out a distribution is in an information-theoretic sense, with Gaussian noise having the highest entropy power for a given variance. He showed that the capacity for a channel with arbitrary noise of power $N$ is bounded by the noise&rsquo;s <strong>entropy power</strong> $N_1$. Shannon proved that <strong>white Gaussian noise is the worst possible type of noise</strong> for any given noise power. Because the Gaussian distribution maximizes entropy for a given variance, the entropy power $N_1$ of any noise with power $N$ satisfies $N_1 \leq N$, with equality only for the Gaussian case. Since channel capacity decreases as entropy power increases, Gaussian noise achieves the highest $N_1$ (equal to $N$) and therefore imposes the lowest capacity bound. This means a system designed to handle white Gaussian noise will perform at least as well against any other noise type of the same power.</p>
<p><strong>Arbitrary Gaussian Noise and the Water-Filling Principle</strong>: Shannon extended his analysis to Gaussian noise with a non-flat power spectrum $N(f)$, using the calculus of variations (a technique for optimizing over functions rather than fixed variables) to find the power allocation $P(f)$ that maximizes capacity. He proved that optimal capacity is achieved when the sum $P(f) + N(f)$ is constant across the utilized frequency band. This leads to what is now known as the &ldquo;water-filling&rdquo; principle: allocate more signal power to quieter frequency bands, and allocate zero power to any band where noise exceeds the constant threshold. This provides the foundation for modern adaptive power allocation across frequency bands.</p>
</li>
</ol>















<figure class="post-figure center ">
    <img src="/img/notes/the-water-filling-principle.webp"
         alt="The water-filling principle for optimal power allocation. Gray area shows noise power N(f) varying across frequencies, orange area shows signal power P(f) allocated to fill up to a constant level lambda, such that P(f) &#43; N(f) equals a constant."
         title="The water-filling principle for optimal power allocation. Gray area shows noise power N(f) varying across frequencies, orange area shows signal power P(f) allocated to fill up to a constant level lambda, such that P(f) &#43; N(f) equals a constant."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>The Water-Filling Principle</strong>: The condition $P(f) + N(f) = \lambda$ is Shannon&rsquo;s derivation; &lsquo;water-filling&rsquo; is the modern retrospective label for it. When noise power varies across frequencies, optimal capacity is achieved by allocating more signal power to &lsquo;quieter&rsquo; frequency bands. Like filling a container with water, power is poured in until the total (signal + noise) reaches a constant level $\lambda$. Frequencies with noise above this threshold receive no power at all.</figcaption>
    
</figure>

<ol start="5">
<li>
<p><strong>Theorem 4 (now known as the Source Coding Theorem)</strong>: This theorem addresses the information source itself. It proves that it&rsquo;s possible to encode messages from a discrete source into binary digits such that the average number of bits per source symbol approaches the source&rsquo;s <strong>entropy</strong>, $H$. This establishes entropy as the fundamental limit of data compression.</p>
</li>
<li>
<p><strong>Theorem 5 (Information Rate for Continuous Sources)</strong>: For continuous (analog) signals, Shannon introduced a concept foundational to rate-distortion theory. He defined the rate $R$ at which a continuous source generates information relative to a specific fidelity criterion (i.e., a tolerable amount of error, $N_1$, in the reproduction). This provides an early theoretical foundation for what later became rate-distortion theory.</p>
</li>
</ol>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The paper is primarily theoretical, with &ldquo;experiments&rdquo; consisting of rigorous <strong>mathematical derivations and proofs</strong>. The channel capacity theorem, for instance, is proven using a geometric sphere-packing argument in the high-dimensional signal space.</p>
<p>However, Shannon does include a quantitative <strong>theoretical benchmark against existing 1949 technology</strong>. He plots his theoretical &ldquo;Ideal Curve&rdquo; against calculated limits of Pulse Code Modulation (PCM) and Pulse Position Modulation (PPM) systems in Figure 6. The PCM points were calculated from formulas in another paper, and the PPM points were from unpublished calculations by B. McMillan. This comparison reveals that the entire series of plotted points for these contemporary systems operated approximately <strong>8 dB</strong> below the ideal power limit over most of the practical range. Interestingly, PPM systems approached to within <strong>3 dB</strong> of the ideal curve specifically at very small $P/N$ ratios, highlighting that different modulation schemes are optimal for different regimes (PCM for high SNR, PPM for power-limited scenarios).</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<p>The primary outcome was a complete, unified theory that quantifies both information itself (entropy) and the ability of a channel to transmit it (capacity).</p>
<ul>
<li>
<p><strong>Decoupling of Source and Channel</strong>: A key conclusion is that the problem of communication can be split into two distinct parts: encoding sequences of message symbols into sequences of binary digits (where the average digits per symbol approaches the entropy $H$), and then mapping these binary digits into a particular signal function of long duration to combat noise. A source can be transmitted reliably if and only if its rate $R$ (or entropy $H$) is less than the channel capacity $C$.</p>
</li>
<li>
<p><strong>The Limit is on Rate</strong>: A central conclusion is that noise in a channel imposes a maximum <strong>rate</strong> of transmission. Below this rate, error-free communication is theoretically possible.</p>
</li>
<li>
<p><strong>The Threshold Effect and Topological Necessity</strong>: To approach capacity, one must map a lower-dimensional message space into the high-dimensional signal space efficiently, winding through the available signal sphere to fill its volume (as illustrated with the efficient mapping in Fig. 4 of the paper). This complex mapping creates a sharp <strong>threshold effect</strong>: below a certain noise level, recovery is essentially perfect; above it, the system fails catastrophically because the &ldquo;uncertainty spheres&rdquo; around signal points begin to overlap. Shannon provides a topological explanation for why this threshold is unavoidable: it is not possible to map a region of higher dimensionality into a region of lower dimensionality continuously. To compress bandwidth (reducing the number of dimensions in signal space), the mapping from message space to signal space must necessarily be discontinuous. This required discontinuity creates vulnerable points where a small noise perturbation can cause the received signal to &ldquo;jump&rdquo; to an entirely different interpretation. The threshold is an inevitable consequence of dimensional reduction.</p>
</li>
<li>
<p><strong>The Exchange Relation</strong>: Shannon explicitly states that the key parameters $T$ (time), $W$ (bandwidth), $P$ (power), and $N$ (noise) can be &ldquo;altered at will&rdquo; without changing the total information transmitted, provided $TW \log(1 + P/N)$ is held constant. This exchangeability enables trade-offs such as using more bandwidth to compensate for lower power.</p>
</li>
<li>
<p><strong>Characteristics of an Ideal System</strong>: The theory implies that to approach the channel capacity limit, one must use very complex and long codes. An ideal system exhibits five key properties: (1) the transmission rate approaches $C$, (2) the error probability approaches zero, (3) the transmitted signal&rsquo;s statistical properties approach those of white noise, (4) the threshold effect becomes very sharp (errors increase rapidly if noise exceeds the designed value), and (5) <strong>the required delay increases indefinitely</strong>. This final constraint is a crucial practical limitation: achieving near-capacity performance requires encoding over increasingly long message blocks, introducing latency that may be unacceptable for real-time applications.</p>
</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="algorithms">Algorithms</h3>
<p>The paper introduces the theoretical foundation for the <strong>water-filling algorithm</strong> for optimal power allocation across frequency bands with varying noise levels. The mathematical condition derived is that $P(f) + N(f)$ must be constant across the utilized frequency band.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Shannon, C. E. (1949). Communication in the Presence of Noise. <em>Proceedings of the IRE</em>, 37(1), 10-21. <a href="https://doi.org/10.1109/JRPROC.1949.232969">https://doi.org/10.1109/JRPROC.1949.232969</a></p>
<p><strong>Publication</strong>: Proceedings of the IRE, 1949</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{shannon1949communication,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Shannon, C. E.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Proceedings of the IRE}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Communication in the Presence of Noise}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{1949}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{37}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{10-21}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1109/JRPROC.1949.232969}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Implementing the Müller-Brown Potential in PyTorch</title><link>https://hunterheidenreich.com/posts/muller-brown-in-pytorch/</link><pubDate>Wed, 27 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/muller-brown-in-pytorch/</guid><description>Guide to implementing the Müller-Brown potential in PyTorch, comparing analytical vs automatic differentiation with performance analysis.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>The Müller-Brown potential reads, in hindsight, like an adversarial example for optimization algorithms.</p>
<p>Designed in 1979 to break naive path-finding methods, this deceptively simple 2D surface features deep minima, high barriers, and tricky saddle points. For nearly five decades, it has served as a ground-truth benchmark for computational chemistry.</p>
<p>Today, it finds new life as a testbed for machine learning. In the 1970s, chemists struggled to find transition states, saddle points where standard gradient descent fails catastrophically. Modern machine learning engineers face a strikingly similar challenge: escaping saddle points in high-dimensional loss landscapes. The Müller-Brown potential was the original stress test for these algorithms, and it remains a perfect, low-cost sandbox for benchmarking modern optimizers.</p>
<p>Whether you&rsquo;re training neural network potentials or benchmarking reinforcement learning agents for exploration, the Müller-Brown potential offers a fast, noise-free, and mathematically exact environment.</p>
<p>In this guide, we&rsquo;ll implement it in PyTorch, work through the engineering trade-off between <strong>analytical derivatives</strong> and <strong>Autograd</strong>, and measure the roughly <strong>4x</strong> force-evaluation speedup of the compiled analytical kernel over the Autograd reference.</p>
<h2 id="the-problem-finding-saddle-points-in-the-1970s">The Problem: Finding Saddle Points in the 1970s</h2>
<p>In the 1970s, finding energy minima was straightforward: follow the gradient downhill. But finding transition states proved much more challenging. These saddle points are maxima along the reaction coordinate but minima in all other directions, like standing at the top of a mountain pass.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-saddle.webp"
         alt="Diagram showing gradient descent getting stuck at a saddle point, where the surface curves up in one direction and down in another"
         title="Diagram showing gradient descent getting stuck at a saddle point, where the surface curves up in one direction and down in another"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The saddle point problem: standard gradient descent sees a minimum in one direction but a maximum in another. Modern optimizers like SGD and Adam face this same challenge in high-dimensional loss landscapes.</figcaption>
    
</figure>

<p>Standard first-order optimizers, whether 1970s simplex methods or modern SGD and Adam, are designed to minimize loss functions blindly. Point them at a saddle point, and they slide into the nearest valley. The gradients vanish or point in misleading directions. Specialized algorithms were needed to navigate this mixed landscape of ups and downs.</p>
<p>The computational reality made this worse. Early quantum chemistry programs like ATMOL and Gaussian made energy calculations possible, but each computation was expensive. Gradients required even more resources, and second derivatives were rarely computed.</p>
<p>This created a catch-22: sophisticated algorithms were needed to find saddle points, but researchers couldn&rsquo;t afford to test them on real molecular systems. Every calculation represented a major investment of time and computational resources.</p>
<h2 id="müller-and-browns-solution">Müller and Brown&rsquo;s Solution</h2>
<p>Müller and Brown&rsquo;s insight<sup id="fnref:1"><a href="#fn:1" class="footnote-ref" role="doc-noteref">1</a></sup> was to create a simple analytical test function that captured the essential difficulties of real chemical systems without the computational cost. Their potential offered three key advantages:</p>
<ul>
<li><strong>Negligible computational cost</strong> - Evaluate millions of points instantly</li>
<li><strong>Analytical derivatives</strong> - Exact gradients and Hessians available immediately</li>
<li><strong>Realistic challenges</strong> - Multiple minima, saddle points, and curved pathways</li>
</ul>
<p>The clever part was the deliberate design to break naive approaches. Early methods often assumed linear paths between reactants and products. The Müller-Brown potential has a curved minimum energy path that punishes this assumption. Try to take shortcuts, and algorithms climb over high-energy barriers.</p>
<h2 id="the-mathematical-foundation">The Mathematical Foundation</h2>
<p>The Müller-Brown potential combines four two-dimensional Gaussian functions:</p>
<p>$$V(x,y) = \sum_{k=1}^{4} A_k \exp\left[a_k(x-x_k^0)^2 + b_k(x-x_k^0)(y-y_k^0) + c_k(y-y_k^0)^2\right]$$</p>
<p>Each Gaussian contributes a different &ldquo;bump&rdquo; or &ldquo;well&rdquo; to the landscape. The parameters control amplitude ($A_k$), width, orientation, and center position.</p>
<h3 id="the-standard-parameters">The Standard Parameters</h3>
<p>The specific parameter values that define the canonical Müller-Brown surface are:</p>
<table>
  <thead>
      <tr>
          <th>k</th>
          <th>$A_k$</th>
          <th>$a_k$</th>
          <th>$b_k$</th>
          <th>$c_k$</th>
          <th>$x_k^0$</th>
          <th>$y_k^0$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1</td>
          <td>-200</td>
          <td>-1</td>
          <td>0</td>
          <td>-10</td>
          <td>1</td>
          <td>0</td>
      </tr>
      <tr>
          <td>2</td>
          <td>-100</td>
          <td>-1</td>
          <td>0</td>
          <td>-10</td>
          <td>0</td>
          <td>0.5</td>
      </tr>
      <tr>
          <td>3</td>
          <td>-170</td>
          <td>-6.5</td>
          <td>11</td>
          <td>-6.5</td>
          <td>-0.5</td>
          <td>1.5</td>
      </tr>
      <tr>
          <td>4</td>
          <td>15</td>
          <td>0.7</td>
          <td>0.6</td>
          <td>0.7</td>
          <td>-1</td>
          <td>1</td>
      </tr>
  </tbody>
</table>
<p>Notice that the first three terms have negative amplitudes (creating energy wells), while the fourth has a positive amplitude (creating a barrier). The cross-term $b_k$ in the third Gaussian creates the tilted orientation that gives the surface its characteristic curved pathways.</p>
<p><a href="/muller-brown-optimized"><strong>View Interactive Müller-Brown Potential Energy Surface →</strong></a></p>
<h3 id="the-resulting-landscape">The Resulting Landscape</h3>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-potential-surface.webp"
         alt="Müller-Brown Potential Energy Surface showing the three minima (dark blue regions) and two saddle points"
         title="Müller-Brown Potential Energy Surface showing the three minima (dark blue regions) and two saddle points"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Müller-Brown potential energy surface showing the three minima (dark blue regions) and two saddle points.</figcaption>
    
</figure>

<p>This simple formula creates a surprisingly rich topography with exactly the features needed to challenge optimization algorithms:</p>
<table>
  <thead>
      <tr>
          <th><strong>Stationary Point</strong></th>
          <th><strong>Coordinates</strong></th>
          <th><strong>Energy</strong></th>
          <th><strong>Type</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MA (Reactant)</td>
          <td>(-0.558, 1.442)</td>
          <td>-146.70</td>
          <td>Deep minimum</td>
      </tr>
      <tr>
          <td>MC (Intermediate)</td>
          <td>(-0.050, 0.467)</td>
          <td>-80.77</td>
          <td>Shallow minimum</td>
      </tr>
      <tr>
          <td>MB (Product)</td>
          <td>(0.623, 0.028)</td>
          <td>-108.17</td>
          <td>Medium minimum</td>
      </tr>
      <tr>
          <td>S1</td>
          <td>(-0.822, 0.624)</td>
          <td>-40.66</td>
          <td>First saddle point</td>
      </tr>
      <tr>
          <td>S2</td>
          <td>(0.212, 0.293)</td>
          <td>-72.25</td>
          <td>Second saddle point</td>
      </tr>
  </tbody>
</table>
<h3 id="the-key-challenge-curved-pathways">The Key Challenge: Curved Pathways</h3>
<p>The path from the deep reactant minimum (MA) to the product minimum (MB) doesn&rsquo;t go directly over a single barrier. Instead, it follows a curved route:</p>
<ol>
<li><strong>MA → S1 → MC</strong>: First transition over the higher, rate-limiting barrier (S1) into an intermediate basin</li>
<li><strong>MC → S2 → MB</strong>: Second transition over a much lower barrier (S2) to the product</li>
</ol>
<p>This two-step pathway breaks linear interpolation methods. Algorithms that draw a straight line from reactant to product miss both the intermediate minimum and the correct transition states, climbing over much higher energy regions instead.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/naive-versus-minimum-path.webp"
         alt="Two-panel comparison showing naive linear interpolation versus minimum energy path. Left panel shows the contour map with both paths overlaid. Right panel shows the energy profile along each path, revealing the naive path hits a far higher barrier."
         title="Two-panel comparison showing naive linear interpolation versus minimum energy path. Left panel shows the contour map with both paths overlaid. Right panel shows the energy profile along each path, revealing the naive path hits a far higher barrier."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Why naive optimization fails: The left panel shows a straight-line path (red dashed) versus the true minimum energy path (green solid) on the potential surface. The right panel reveals the energetic cost. The naive path climbs a barrier roughly 53 reduced units higher that the curved path avoids entirely. This is the &lsquo;adversarial&rsquo; nature of the Müller-Brown surface.</figcaption>
    
</figure>

<p>The energy profile comparison makes the failure mode concrete. A naive optimizer following the red dashed path would encounter a barrier roughly <strong>53 reduced units higher</strong> than necessary (the naive summit sits at about +13 while the true path tops out near -41, the S1 saddle). The green minimum energy path navigates through the valleys, passing through the intermediate basin MC and crossing only the low-lying saddle points S1 and S2.</p>
<h2 id="why-it-works-as-a-benchmark">Why It Works as a Benchmark</h2>
<p>The Müller-Brown potential has served as a computational chemistry benchmark for over four decades because of four key characteristics:</p>
<p><strong>Low dimensionality</strong>: As a 2D surface, you can visualize the entire landscape and see exactly why algorithms succeed or fail.</p>
<p><strong>Analytical form</strong>: Energy and gradient calculations cost virtually nothing, enabling exhaustive testing impossible with quantum mechanical surfaces.</p>
<p><strong>Non-trivial topology</strong>: The curved minimum energy path and shallow intermediate minimum challenge sophisticated methods while remaining manageable.</p>
<p><strong>Known ground truth</strong>: All minima and saddle points are precisely known, providing unambiguous success metrics.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basins-of-attraction.webp"
         alt="Basins of attraction map showing which regions of the Müller-Brown surface lead to each minimum under gradient descent. Blue region flows to MA, green to MC, and yellow to MB."
         title="Basins of attraction map showing which regions of the Müller-Brown surface lead to each minimum under gradient descent. Blue region flows to MA, green to MC, and yellow to MB."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Basins of attraction: the &lsquo;optimization map&rsquo; of the Müller-Brown surface. Each color indicates which minimum a gradient descent optimizer will reach from that starting point. Saddle points (red x) sit precisely at the basin boundaries, the unstable equilibria that algorithms must navigate to find reaction paths.</figcaption>
    
</figure>

<p>This basin map reveals why the Müller-Brown potential is such an effective benchmark. Standard gradient descent from any point in the blue region inevitably falls into the deep MA minimum; from yellow, into MB. The saddle points S1 and S2 lie exactly on the boundaries between basins, infinitesimally perturbing an optimizer at these points sends it tumbling into different valleys. Finding these saddle points requires algorithms that identify and stabilize at these boundary regions.</p>
<p>What makes this particularly valuable is the contrast with other classic potentials. While the Lennard-Jones potential<sup id="fnref:2"><a href="#fn:2" class="footnote-ref" role="doc-noteref">2</a></sup> serves as the benchmark for equilibrium properties with its single energy minimum, Müller-Brown explicitly models reactive landscapes. Its multiple minima and connecting barriers make it the testing ground for algorithms that find reaction paths, the methods that reveal how chemistry actually happens.</p>
<h3 id="applications-across-decades">Applications Across Decades</h3>
<p>The potential has evolved with the field&rsquo;s changing focus:</p>
<p><strong>1980s-1990s</strong>: Testing path-finding methods like Nudged Elastic Band (NEB)<sup id="fnref:3"><a href="#fn:3" class="footnote-ref" role="doc-noteref">3</a></sup>, which creates discrete representations of reaction pathways and optimizes them to find minimum energy paths.</p>
<p><strong>2000s-2010s</strong>: Validating Transition Path Sampling (TPS) methods<sup id="fnref:4"><a href="#fn:4" class="footnote-ref" role="doc-noteref">4</a></sup> that harvest statistical ensembles of reactive trajectories.</p>
<p><strong>2020s</strong>: Benchmarking machine learning models and generative approaches that learn to sample transition paths or approximate potential energy surfaces.</p>
<h2 id="modern-applications-in-machine-learning">Modern Applications in Machine Learning</h2>
<p>The rise of machine learning has given the Müller-Brown potential renewed purpose. Modern <strong>Machine Learning Interatomic Potentials (MLIPs)</strong><sup id="fnref:5"><a href="#fn:5" class="footnote-ref" role="doc-noteref">5</a></sup><sup id="fnref:6"><a href="#fn:6" class="footnote-ref" role="doc-noteref">6</a></sup> aim to bridge the gap between quantum mechanical accuracy and classical force field efficiency by training flexible models on expensive quantum chemistry data.</p>
<p>This creates a benchmarking challenge: with countless ML architectures available, how do you objectively compare them? The Müller-Brown potential provides an ideal solution, an exactly known potential energy surface that can generate unlimited, noise-free training data.</p>
<p>This enables researchers to ask fundamental questions:</p>
<ul>
<li>How well does a given architecture learn complex, curved surfaces?</li>
<li>How many training points are needed for acceptable accuracy?</li>
<li>How does the model behave when extrapolating beyond training data?</li>
<li>Can it correctly identify minima and saddle points?</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-ml-benchmark.webp"
         alt="Three-panel comparison showing the analytical Müller-Brown surface (left), a neural network at epoch 50 with high error (middle), and a converged neural network at epoch 1000 (right)"
         title="Three-panel comparison showing the analytical Müller-Brown surface (left), a neural network at epoch 50 with high error (middle), and a converged neural network at epoch 1000 (right)"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Visualizing the ML benchmark: A comparison of the analytical ground truth (left) versus a neural network potential at early (middle) and late (right) stages of training. Notice how the model in early training quickly grasps the deep minima regions but struggles significantly with the complex topography of the saddle points and energy barriers: the curved pathways are smoothed into simpler shapes. This illustrates why the Müller-Brown surface remains a challenging test case for modern architectures.</figcaption>
    
</figure>

<p>The potential has evolved from a simple model system into a <strong>reference benchmark</strong>: a fixed, exactly-known surface against which AI learning capacity is measured. Any prediction error is due to model limitations, not data quality.</p>
<p>Beyond static benchmarking, a PyTorch implementation enables <strong>differentiable simulation</strong>. Because the potential, forces, and integrator are all differentiable tensor operations, gradients can be backpropagated <em>through time</em> (via the trajectory) to optimize force field parameters or control policies directly. This capability connects classical molecular simulation to modern gradient-based machine learning in a single computational graph.</p>
<p>These benchmarking principles extend beyond abstract test cases. Real molecular dynamics simulations, such as those studying <a href="/posts/adatom-cu-diffusion/">adatom diffusion on metal surfaces</a>, face similar challenges in understanding energy landscapes and transition pathways. The Müller-Brown potential provides a controlled environment for developing methods that eventually tackle these complex realistic systems.</p>
<h2 id="extension-to-higher-dimensions">Extension to Higher Dimensions</h2>
<p>The canonical Müller-Brown potential can be extended beyond two dimensions to create more challenging test cases that better reflect real molecular systems. This extensibility demonstrates why it remains such an effective template for computational method development.</p>
<h3 id="why-higher-dimensions-matter">Why Higher Dimensions Matter</h3>
<p>Real molecules have dozens or hundreds of degrees of freedom. Understanding how algorithms scale with dimensionality is crucial for practical applications. Higher-dimensional extensions allow researchers to systematically test:</p>
<ul>
<li><strong>Algorithm scaling</strong> - Does performance degrade gracefully as dimensions increase?</li>
<li><strong>Model robustness</strong> - Do machine learning approaches maintain accuracy in high-dimensional spaces?</li>
<li><strong>Parallel efficiency</strong> - Can massively parallel methods exploit additional dimensions effectively?</li>
</ul>
<h3 id="extension-approaches">Extension Approaches</h3>
<p><strong>Harmonic constraints</strong>: Add quadratic wells in orthogonal dimensions while preserving the complex 2D landscape<sup id="fnref:7"><a href="#fn:7" class="footnote-ref" role="doc-noteref">7</a></sup>:</p>
<p>$$V_{5D}(x_1, x_2, x_3, x_4, x_5) = V(x_1, x_3) + \kappa(x_2^2 + x_4^2 + x_5^2)$$</p>
<p>The parameter $\kappa$ controls constraint strength: small values create nearly flat directions that test algorithmic efficiency.</p>
<p><strong>Collective variables</strong>: Define new coordinates that mix multiple dimensions<sup id="fnref:8"><a href="#fn:8" class="footnote-ref" role="doc-noteref">8</a></sup>:</p>
<p>$$\tilde{x} = \sqrt{x_1^2 + x_2^2 + \epsilon x_5^2},\quad \tilde{y} = \sqrt{x_3^2 + x_4^2}$$</p>
<p>where $\epsilon \ll 1$. The 5D potential becomes $V_{5D}(\tilde{x}, \tilde{y}) = V(\tilde{x}, \tilde{y})$, embedding the original surface in a higher-dimensional space.</p>
<h3 id="value-for-algorithm-development">Value for Algorithm Development</h3>
<p>This extensibility makes the Müller-Brown potential ideal for systematic testing:</p>
<ul>
<li><strong>Progressive complexity</strong>: Debug on 2D, then scale to higher dimensions</li>
<li><strong>Ground truth preservation</strong>: Known minima and saddle points remain in the active subspace</li>
<li><strong>Realistic challenges</strong>: Captures the &ldquo;needle in a haystack&rdquo; problem of transition state finding while maintaining analytical tractability</li>
</ul>
<p>These extensions transform a simple 2D benchmark into a scalable testbed for modern computational methods, probing specific challenges in high-dimensional optimization.</p>
<h2 id="implementation-in-pytorch">Implementation in PyTorch</h2>
<p>Now let&rsquo;s implement the Müller-Brown potential in PyTorch. A practical implementation needs to handle batch processing, support both analytical and automatic differentiation, and be optimized for performance.</p>
<h3 id="core-implementation">Core Implementation</h3>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> torch <span style="color:#f92672">import</span> Tensor
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">MuellerBrownPotential</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Müller-Brown potential with a torch.compile-accelerated force kernel.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        device: str <span style="color:#f92672">|</span> torch<span style="color:#f92672">.</span>device <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;cpu&#34;</span>,
</span></span><span style="display:flex;"><span>        dtype: torch<span style="color:#f92672">.</span>dtype <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>float64,
</span></span><span style="display:flex;"><span>        use_autograd: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>
</span></span><span style="display:flex;"><span>    ):
</span></span><span style="display:flex;"><span>        super()<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>use_autograd <span style="color:#f92672">=</span> use_autograd
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Standard Müller-Brown parameters</span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;A&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#f92672">-</span><span style="color:#ae81ff">200.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">100.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">170.0</span>, <span style="color:#ae81ff">15.0</span>],
</span></span><span style="display:flex;"><span>                             device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;a&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#f92672">-</span><span style="color:#ae81ff">1.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">1.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">6.5</span>, <span style="color:#ae81ff">0.7</span>],
</span></span><span style="display:flex;"><span>                             device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;b&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#ae81ff">0.0</span>, <span style="color:#ae81ff">0.0</span>, <span style="color:#ae81ff">11.0</span>, <span style="color:#ae81ff">0.6</span>],
</span></span><span style="display:flex;"><span>                             device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;c&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#f92672">-</span><span style="color:#ae81ff">10.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">10.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">6.5</span>, <span style="color:#ae81ff">0.7</span>],
</span></span><span style="display:flex;"><span>                             device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;x_centers&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#ae81ff">1.0</span>, <span style="color:#ae81ff">0.0</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span>, <span style="color:#f92672">-</span><span style="color:#ae81ff">1.0</span>],
</span></span><span style="display:flex;"><span>                                    device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>register_buffer(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">&#34;y_centers&#34;</span>, torch<span style="color:#f92672">.</span>tensor([<span style="color:#ae81ff">0.0</span>, <span style="color:#ae81ff">0.5</span>, <span style="color:#ae81ff">1.5</span>, <span style="color:#ae81ff">1.0</span>],
</span></span><span style="display:flex;"><span>                                    device<span style="color:#f92672">=</span>device, dtype<span style="color:#f92672">=</span>dtype)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, coordinates: Tensor) <span style="color:#f92672">-&gt;</span> Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute potential energy.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> _calculate_potential(
</span></span><span style="display:flex;"><span>            coordinates, self<span style="color:#f92672">.</span>A, self<span style="color:#f92672">.</span>a, self<span style="color:#f92672">.</span>b, self<span style="color:#f92672">.</span>c,
</span></span><span style="display:flex;"><span>            self<span style="color:#f92672">.</span>x_centers, self<span style="color:#f92672">.</span>y_centers
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">force</span>(self, coordinates: Tensor) <span style="color:#f92672">-&gt;</span> Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute forces (negative gradient).&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>use_autograd:
</span></span><span style="display:flex;"><span>            coordinates <span style="color:#f92672">=</span> coordinates<span style="color:#f92672">.</span>requires_grad_(<span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>            potential <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>forward(coordinates)
</span></span><span style="display:flex;"><span>            grad <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>autograd<span style="color:#f92672">.</span>grad(potential<span style="color:#f92672">.</span>sum(), coordinates)[<span style="color:#ae81ff">0</span>]
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> <span style="color:#f92672">-</span>grad
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> _calculate_force(
</span></span><span style="display:flex;"><span>                coordinates, self<span style="color:#f92672">.</span>A, self<span style="color:#f92672">.</span>a, self<span style="color:#f92672">.</span>b, self<span style="color:#f92672">.</span>c,
</span></span><span style="display:flex;"><span>                self<span style="color:#f92672">.</span>x_centers, self<span style="color:#f92672">.</span>y_centers
</span></span><span style="display:flex;"><span>            )
</span></span></code></pre></div><p>The implementation uses <code>register_buffer</code> to store parameters, a subtle but important PyTorch best practice. This ensures that the potential parameters are automatically moved to the GPU along with the model when calling <code>.to(device)</code>, a common pitfall that leads to frustrating device mismatch errors. Beyond device placement, <code>register_buffer</code> also ensures these parameters are correctly handled during <strong>DistributedDataParallel (DDP)</strong> broadcasting, preventing silent failures when scaling training to multi-GPU clusters. The <code>use_autograd</code> flag switches between analytical and automatic differentiation.</p>
<h3 id="compiling-the-force-kernel">Compiling the Force Kernel</h3>
<p>Two decisions make the force path fast while keeping the rest flexible:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_calculate_potential</span>(coordinates: Tensor, A: Tensor, a: Tensor,
</span></span><span style="display:flex;"><span>                        b: Tensor, c: Tensor, x_centers: Tensor,
</span></span><span style="display:flex;"><span>                        y_centers: Tensor) <span style="color:#f92672">-&gt;</span> Tensor:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Energy. Left eager (uncompiled) so autograd second derivatives,
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    e.g. the Hessian, keep working: torch.compile does not support
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    double-backward, and energy is computed per save, not per step.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    coords <span style="color:#f92672">=</span> coordinates<span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)
</span></span><span style="display:flex;"><span>    x, y <span style="color:#f92672">=</span> coords[:, <span style="color:#ae81ff">0</span>], coords[:, <span style="color:#ae81ff">1</span>]
</span></span><span style="display:flex;"><span>    dx <span style="color:#f92672">=</span> x<span style="color:#f92672">.</span>unsqueeze(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> x_centers
</span></span><span style="display:flex;"><span>    dy <span style="color:#f92672">=</span> y<span style="color:#f92672">.</span>unsqueeze(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> y_centers
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    potential <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>sum(
</span></span><span style="display:flex;"><span>        A <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>exp(a <span style="color:#f92672">*</span> dx<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">+</span> b <span style="color:#f92672">*</span> dx <span style="color:#f92672">*</span> dy <span style="color:#f92672">+</span> c <span style="color:#f92672">*</span> dy<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>),
</span></span><span style="display:flex;"><span>        dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>    )
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> potential<span style="color:#f92672">.</span>view(coordinates<span style="color:#f92672">.</span>shape[:<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@torch.compile</span>(dynamic<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_calculate_force</span>(coordinates: Tensor, A: Tensor, a: Tensor,
</span></span><span style="display:flex;"><span>                    b: Tensor, c: Tensor, x_centers: Tensor,
</span></span><span style="display:flex;"><span>                    y_centers: Tensor) <span style="color:#f92672">-&gt;</span> Tensor:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Forces (negative gradient). Compiled: this is the hot path,
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    called once per simulation step.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    coords <span style="color:#f92672">=</span> coordinates<span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)
</span></span><span style="display:flex;"><span>    x, y <span style="color:#f92672">=</span> coords[:, <span style="color:#ae81ff">0</span>], coords[:, <span style="color:#ae81ff">1</span>]
</span></span><span style="display:flex;"><span>    dx <span style="color:#f92672">=</span> x<span style="color:#f92672">.</span>unsqueeze(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> x_centers
</span></span><span style="display:flex;"><span>    dy <span style="color:#f92672">=</span> y<span style="color:#f92672">.</span>unsqueeze(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> y_centers
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    exp_terms <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>exp(a <span style="color:#f92672">*</span> dx<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">+</span> b <span style="color:#f92672">*</span> dx <span style="color:#f92672">*</span> dy <span style="color:#f92672">+</span> c <span style="color:#f92672">*</span> dy<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>)
</span></span><span style="display:flex;"><span>    A_exp <span style="color:#f92672">=</span> A <span style="color:#f92672">*</span> exp_terms
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    grad_x <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>sum(A_exp <span style="color:#f92672">*</span> (<span style="color:#ae81ff">2</span> <span style="color:#f92672">*</span> a <span style="color:#f92672">*</span> dx <span style="color:#f92672">+</span> b <span style="color:#f92672">*</span> dy), dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    grad_y <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>sum(A_exp <span style="color:#f92672">*</span> (b <span style="color:#f92672">*</span> dx <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span> <span style="color:#f92672">*</span> c <span style="color:#f92672">*</span> dy), dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    forces <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>stack([<span style="color:#f92672">-</span>grad_x, <span style="color:#f92672">-</span>grad_y], dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    new_shape <span style="color:#f92672">=</span> list(coordinates<span style="color:#f92672">.</span>shape[:<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>]) <span style="color:#f92672">+</span> [<span style="color:#ae81ff">2</span>]
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> forces<span style="color:#f92672">.</span>view(new_shape)
</span></span></code></pre></div><p>The force kernel is decorated with <code>@torch.compile(dynamic=True)</code>, which traces it once and hands it to TorchInductor to fuse the pointwise operations (the exponentials and polynomial terms) and cut Python&rsquo;s dispatch overhead in the inner loop that runs millions of times per simulation. The <code>dynamic=True</code> flag keeps a single compiled trace valid across particle counts, so changing the batch size does not trigger a recompile. The energy is left eager on purpose: <code>torch.compile</code> does not support double-backward, so leaving <code>forward</code> uncompiled keeps autograd second derivatives (the Hessian) available, and since the energy is only evaluated when an observable is saved, it is not on the hot path anyway.</p>
<h3 id="performance-analytical-vs-automatic-differentiation">Performance: Analytical vs. Automatic Differentiation</h3>
<p>A key design decision is whether to use analytical derivatives or automatic differentiation. I benchmarked both on an Apple M1 Max (CPU, PyTorch 2.x). To get a stable measurement, each configuration runs 100 warm-up iterations, then the median wall-clock time over 5 runs of 1000 iterations, which filters out operating-system jitter.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-throughput-analysis.webp"
         alt="Throughput comparison showing the analytical force kernel outperforming autograd across batch sizes"
         title="Throughput comparison showing the analytical force kernel outperforming autograd across batch sizes"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Force-evaluation throughput, analytical vs autograd, across batch sizes. The analytical kernel is about 4x faster (3-7x depending on batch size).</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-time-per-particle.webp"
         alt="Per-particle computation time showing analytical derivatives maintain sub-microsecond performance for large systems"
         title="Per-particle computation time showing analytical derivatives maintain sub-microsecond performance for large systems"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Per-particle computation time. Analytical derivatives maintain sub-microsecond performance for large systems.</figcaption>
    
</figure>

<p>The speedup comes from bypassing the computational graph that PyTorch&rsquo;s Autograd engine builds. By deriving the analytical Jacobian, we skip that machinery entirely. Every <code>autograd.grad()</code> call must:</p>
<ol>
<li><strong>Build a tape</strong> of operations during the forward pass</li>
<li><strong>Traverse the graph</strong> backward to compute gradients</li>
<li><strong>Allocate intermediate tensors</strong> for each node</li>
</ol>
<p>For iterative workloads like molecular dynamics (millions of force evaluations per trajectory), this overhead elimination is critical. The analytical kernel computes forces directly in a single fused operation, no graph, no tape, no intermediate allocations.</p>
<p><strong>When to use each approach:</strong></p>
<ul>
<li><strong>Analytical</strong>: Best for production molecular dynamics where forces are computed millions of times. The speedup directly reduces wall-clock simulation time.</li>
<li><strong>Autograd</strong>: Better for prototyping, machine learning training loops, or when implementing new potentials where correctness verification is paramount. The convenience and guaranteed accuracy often outweigh performance costs during development.</li>
</ul>
<h3 id="molecular-dynamics-simulations">Molecular Dynamics Simulations</h3>
<p>To demonstrate the PyTorch implementation in action, I performed Langevin dynamics simulations in different energy basins. These simulations reveal how particles behave when confined to different regions of the potential energy surface.</p>
<h4 id="simulation-parameters">Simulation Parameters</h4>
<p>I ran 3600 time steps with a 0.01 time unit step size, using a friction coefficient of 1.0 and temperature of 25.0 in reduced units. The simulations started from equilibrium positions within each basin and show the characteristic thermal fluctuations around local minima.</p>
<h4 id="basin-ma-deep-reactant-minimum">Basin MA: Deep Reactant Minimum</h4>
<p>The deepest energy well (-146.70 in reduced units) shows highly constrained motion due to the steep energy barriers surrounding it.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-ma-position-distributions.webp"
         alt="Position distributions in Basin MA showing tight confinement"
         title="Position distributions in Basin MA showing tight confinement"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Position distributions in Basin MA. The particle remains tightly confined around (-0.558, 1.442) due to the deep potential well.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-ma-time-series.webp"
         alt="Time evolution of coordinates in Basin MA"
         title="Time evolution of coordinates in Basin MA"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Time series showing small-amplitude oscillations around the equilibrium position. The deep well severely restricts thermal motion.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-ma-trajectory.webp"
         alt="Trajectory overlaid on potential surface for Basin MA"
         title="Trajectory overlaid on potential surface for Basin MA"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Trajectory visualization showing the particle&rsquo;s motion confined to a small region around the minimum. High energy barriers prevent escape on this time scale.</figcaption>
    
</figure>

<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/woVM90qXUQs?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<h4 id="basin-mb-product-minimum">Basin MB: Product Minimum</h4>
<p>The product minimum (-108.17 in reduced units) shows intermediate behavior between the deep reactant well and shallow intermediate basin.</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-mb-position-distributions.webp"
         alt="Position distributions in Basin MB showing moderate confinement"
         title="Position distributions in Basin MB showing moderate confinement"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Position distributions in Basin MB. The particle shows moderate thermal motion around (0.623, 0.028), with confinement between the deep MA basin and shallow MC basin.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-mb-time-series.webp"
         alt="Time evolution showing moderate amplitude fluctuations in Basin MB"
         title="Time evolution showing moderate amplitude fluctuations in Basin MB"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Time series demonstrating moderate amplitude fluctuations. The particle explores a region larger than MA but more constrained than the shallow MC basin.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-basin-mb-trajectory.webp"
         alt="Basin MB trajectory showing balanced exploration"
         title="Basin MB trajectory showing balanced exploration"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The trajectory shows balanced thermal exploration within the product basin. The moderate well depth allows reasonable sampling while maintaining basin confinement.</figcaption>
    
</figure>

<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/gdAHme07bGs?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<h4 id="transition-example">Transition Example</h4>
<p>Running a longer simulation demonstrates transitions between basins:</p>















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-transition-trajectory.webp"
         alt="Transition trajectory between basins"
         title="Transition trajectory between basins"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The trajectory illustrates the particle&rsquo;s movement between the different basins, highlighting the energy barriers and pathways involved.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-transition-time-series.webp"
         alt="Time evolution of positions during transition"
         title="Time evolution of positions during transition"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Time series showing the evolution of the particle&rsquo;s position during the transition between basins.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/muller-brown/muller-brown-transition-position-distributions.webp"
         alt="Transition trajectory on potential surface"
         title="Transition trajectory on potential surface"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The trajectory on the potential surface highlights the energy landscape the particle navigates during the transition.</figcaption>
    
</figure>

<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/dVFe_4KZbps?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<h3 id="integration-with-modern-workflows">Integration with Modern Workflows</h3>
<p>The PyTorch implementation integrates naturally with machine learning workflows. It can serve as a component in larger computational graphs, enable gradient-based optimization, and bridge classical molecular simulation with deep learning approaches.</p>
<p>For readers interested in broader molecular ML applications, this implementation pairs well with other molecular representation methods. The <a href="/posts/molecular-descriptor-coulomb-matrix/">Coulomb matrix approach</a> offers complementary perspectives on encoding molecular structure, while the <a href="/posts/kabsch-algorithm/#the-math">Kabsch algorithm</a> provides essential tools for structural alignment.</p>
<p>The complete implementation is available on <a href="https://github.com/hunter-heidenreich/Muller-Brown-Potential">GitHub</a><sup id="fnref:9"><a href="#fn:9" class="footnote-ref" role="doc-noteref">9</a></sup><sup id="fnref:10"><a href="#fn:10" class="footnote-ref" role="doc-noteref">10</a></sup><sup id="fnref:11"><a href="#fn:11" class="footnote-ref" role="doc-noteref">11</a></sup>, including benchmarking scripts, visualization tools, the test suite, and examples for optimization and molecular dynamics.</p>
<h2 id="conclusion">Conclusion</h2>
<p>The Müller-Brown potential exemplifies how a well-designed benchmark can evolve with a field. Born from 1970s computational constraints, it provided a simple way to test algorithms when quantum chemistry calculations were expensive. Its clever design, simple enough to compute instantly, complex enough to break naive approaches, made it invaluable for algorithm development.</p>
<p>Today, it serves new purposes in the machine learning era. This PyTorch implementation pairs a hand-derived analytical force kernel (compiled with <code>torch.compile</code>) with an autograd reference, a BAOAB Langevin sampler, and a test suite that checks the sampler against the canonical distribution. The analytical kernel&rsquo;s roughly 4x advantage matters for intensive simulations, while PyTorch&rsquo;s flexibility enables integration with neural network potentials and enhanced sampling methods.</p>
<p>The potential&rsquo;s evolution from practical necessity to pedagogical tool to machine learning benchmark demonstrates the value of foundational test cases. As computational chemistry continues evolving, reliable standards like the Müller-Brown potential become even more important for rigorous method development and comparison.</p>
<p>For the complete implementation with benchmarking scripts, the BAOAB Langevin simulator, and visualization tools, see the <a href="https://github.com/hunter-heidenreich/Muller-Brown-Potential">GitHub repository</a>. The full project with architecture details and performance results is at the <a href="/projects/muller-brown-pytorch/">Müller-Brown Potential: A PyTorch ML Testbed project page</a>.</p>
<div class="footnotes" role="doc-endnotes">
<hr>
<ol>
<li id="fn:1">
<p>Müller, K., &amp; Brown, L. D. (1979). Location of saddle points and minimum energy paths by a constrained simplex optimization procedure. <em>Theoretica Chimica Acta</em>, 53, 75-93. <a href="https://link.springer.com/article/10.1007/BF00547608">https://link.springer.com/article/10.1007/BF00547608</a>&#160;<a href="#fnref:1" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:2">
<p>Lennard-Jones, J. E. (1931). Cohesion. <em>Proceedings of the Physical Society</em>, 43(5), 461-482. <a href="https://doi.org/10.1088/0959-5309/43/5/301">https://doi.org/10.1088/0959-5309/43/5/301</a>&#160;<a href="#fnref:2" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:3">
<p>Henkelman, G., &amp; Jónsson, H. (2000). Improved tangent estimate in the nudged elastic band method for finding minimum energy paths and saddle points. <em>Journal of Chemical Physics</em>, 113(22), 9901-9904. <a href="https://doi.org/10.1063/1.1329672">https://doi.org/10.1063/1.1329672</a>&#160;<a href="#fnref:3" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:4">
<p>Dellago, C., Bolhuis, P. G., Csajka, F. S., &amp; Chandler, D. (1998). Transition path sampling and the calculation of rate constants. <em>Journal of Chemical Physics</em>, 108(5), 1964-1977. <a href="https://doi.org/10.1063/1.475562">https://doi.org/10.1063/1.475562</a>&#160;<a href="#fnref:4" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:5">
<p>Behler, J., &amp; Parrinello, M. (2007). Generalized neural-network representation of high-dimensional potential-energy surfaces. <em>Physical Review Letters</em>, 98(14), 146401. <a href="https://doi.org/10.1103/PhysRevLett.98.146401">https://doi.org/10.1103/PhysRevLett.98.146401</a>&#160;<a href="#fnref:5" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:6">
<p>Smith, J. S., Isayev, O., &amp; Roitberg, A. E. (2017). ANI-1: an extensible neural network potential with DFT accuracy at force field computational cost. <em>Chemical Science</em>, 8(4), 3192-3203. <a href="https://doi.org/10.1039/C6SC05720A">https://doi.org/10.1039/C6SC05720A</a>&#160;<a href="#fnref:6" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:7">
<p>Sipka, M., Dietschreit, J. C. B., Grajciar, L., &amp; Gómez-Bombarelli, R. (2023). Differentiable simulations for enhanced sampling of rare events. In <em>International Conference on Machine Learning</em> (pp. 31990-32007). PMLR.&#160;<a href="#fnref:7" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:8">
<p>Sun, L., Vandermause, J., Batzner, S., Xie, Y., Clark, D., Chen, W., &amp; Kozinsky, B. (2022). Multitask machine learning of collective variables for enhanced sampling of rare events. <em>Journal of Chemical Theory and Computation</em>, 18(4), 2341-2353.&#160;<a href="#fnref:8" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:9">
<p>LED-Molecular Repository - Original implementation of the Müller-Brown potential. <a href="https://github.com/cselab/LED-Molecular">https://github.com/cselab/LED-Molecular</a>&#160;<a href="#fnref:9" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:10">
<p>Vlachas, P. R., Zavadlav, J., Praprotnik, M., &amp; Koumoutsakos, P. (2022). Accelerated simulations of molecular systems through learning of effective dynamics. <em>Journal of Chemical Theory and Computation</em>, 18(1), 538-549. <a href="https://pubs.acs.org/doi/10.1021/acs.jctc.1c00809">https://pubs.acs.org/doi/10.1021/acs.jctc.1c00809</a>&#160;<a href="#fnref:10" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
<li id="fn:11">
<p>Vlachas, P. R., Arampatzis, G., Uhler, C., &amp; Koumoutsakos, P. (2022). Multiscale simulations of complex systems by learning their effective dynamics. <em>Nature Machine Intelligence</em>. <a href="https://www.nature.com/articles/s42256-022-00464-w">https://www.nature.com/articles/s42256-022-00464-w</a>&#160;<a href="#fnref:11" class="footnote-backref" role="doc-backlink">&#x21a9;&#xfe0e;</a></p>
</li>
</ol>
</div>
]]></content:encoded></item><item><title>DenoiseVAE: Adaptive Noise for Molecular Pre-training</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/denoise-vae/</link><pubDate>Sun, 24 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/denoise-vae/</guid><description>Liu et al.'s ICLR 2025 paper introducing DenoiseVAE, which learns adaptive, atom-specific noise distributions for better molecular force fields.</description><content:encoded><![CDATA[<h2 id="paper-contribution-type">Paper Contribution Type</h2>
<p>This is a <strong>method paper</strong> with a supporting theoretical component. It introduces a new pre-training framework, DenoiseVAE, that challenges the standard practice of using fixed, hand-crafted noise distributions in denoising-based molecular representation learning.</p>
<h2 id="motivation-the-inter--and-intra-molecular-variations-problem">Motivation: The Inter- and Intra-molecular Variations Problem</h2>
<p>The motivation is to create a more physically principled denoising pre-training task for 3D molecules. The core idea of denoising is to learn molecular force fields by corrupting an equilibrium conformation with noise and then learning to recover it. However, existing methods use a single, hand-crafted noise strategy (e.g., Gaussian noise of a fixed scale) for all atoms across all molecules. This is physically unrealistic for two main reasons:</p>
<ol>
<li><strong>Inter-molecular differences</strong>: Different molecules have unique Potential Energy Surfaces (PES), meaning the space of low-energy (i.e., physically plausible) conformations is highly molecule-specific.</li>
<li><strong>Intra-molecular differences (Anisotropy)</strong>: Within a single molecule, different atoms have different degrees of freedom. For instance, an atom in a rigid functional group can move much less than one connected by a single, rotatable bond.</li>
</ol>
<p>The authors argue that this &ldquo;one-size-fits-all&rdquo; noise approach leads to inaccurate force field learning because it samples many physically improbable conformations.</p>
<h2 id="novelty-a-learnable-atom-specific-noise-generator">Novelty: A Learnable, Atom-Specific Noise Generator</h2>
<p>The core novelty is a framework that learns to generate noise tailored to each specific molecule and atom. This is achieved through three key innovations:</p>
<ol>
<li><strong>Learnable Noise Generator</strong>: The authors introduce a Noise Generator module (a 4-layer Equivariant Graph Neural Network) that takes a molecule&rsquo;s equilibrium conformation $X$ as input and outputs a unique, atom-specific Gaussian noise distribution (i.e., a different variance $\sigma_i^2$ for each atom $i$). This directly addresses the issues of PES specificity and force field anisotropy.</li>
<li><strong>Variational Autoencoder (VAE) Framework</strong>: The Noise Generator (encoder) and a Denoising Module (a 7-layer EGNN decoder) are trained jointly within a VAE paradigm. The noisy conformation is sampled using the reparameterization trick:
$$
\begin{aligned}
\tilde{x}_i &amp;= x_i + \epsilon \sigma_i
\end{aligned}
$$</li>
<li><strong>Principled Optimization Objective</strong>: The training loss balances two competing goals:
$$
\begin{aligned}
\mathcal{L}_{DenoiseVAE} &amp;= \mathcal{L}_{Denoise} + \lambda \mathcal{L}_{KL}
\end{aligned}
$$
<ul>
<li>A denoising reconstruction loss ($\mathcal{L}_{Denoise}$) encourages the Noise Generator to produce physically plausible perturbations from which the original conformation can be recovered. This implicitly constrains the noise to respect the molecule&rsquo;s underlying force fields.</li>
<li>A KL divergence regularization term ($\mathcal{L}_{KL}$) pushes the generated noise distributions towards a predefined prior. This prevents the trivial solution of generating zero noise and encourages the model to explore a diverse set of low-energy conformations.</li>
</ul>
</li>
</ol>
<p>The authors also provide a theoretical analysis showing that optimizing their objective is equivalent to maximizing the Evidence Lower Bound (ELBO) on the log-likelihood of observing physically realistic conformations.</p>
<h2 id="methodology--experimental-baselines">Methodology &amp; Experimental Baselines</h2>
<p>The model was pretrained on the PCQM4Mv2 dataset (approximately 3.4 million organic molecules) and then evaluated on a comprehensive suite of downstream tasks to test the quality of the learned representations:</p>
<ol>
<li><strong>Molecular Property Prediction (<a href="/notes/chemistry/datasets/qm9/">QM9</a>)</strong>: The model was evaluated on 12 quantum chemical property prediction tasks for small molecules (134k molecules; 100k train, 18k val, 13k test split). DenoiseVAE achieved state-of-the-art or second-best performance on 11 of the 12 tasks, with particularly significant gains on $C_v$ (heat capacity), indicating better capture of vibrational modes.</li>
<li><strong>Force Prediction (MD17)</strong>: The task was to predict atomic forces from molecular dynamics trajectories for 8 different small molecules (9,500 train, 500 val split). DenoiseVAE was the top performer on 5 of the 8 molecules (Aspirin, Benzene, Ethanol, Naphthalene, Toluene), though it underperformed Frad on Malonaldehyde, Salicylic Acid, and Uracil by significant margins.</li>
<li><strong>Ligand Binding Affinity (PDBBind v2019)</strong>: On the PDBBind dataset with 30% and 60% protein sequence identity splits, the model showed strong generalization, outperforming baselines like Uni-Mol particularly on the more stringent 30% split across RMSE, Pearson correlation, and Spearman correlation.</li>
<li><strong>PCQM4Mv2 Validation</strong>: DenoiseVAE achieved a validation MAE of 0.0777 on the PCQM4Mv2 HOMO-LUMO gap prediction task with only 1.44M parameters, competitive with models 10-40x larger (e.g., GPS++ at 44.3M params achieves 0.0778).</li>
<li><strong>Ablation Studies</strong>: The authors analyzed the sensitivity to key hyperparameters, namely the prior&rsquo;s standard deviation ($\sigma$) and the KL-divergence weight ($\lambda$), confirming that $\lambda=1$ and $\sigma=0.1$ are optimal. Removing the KL term leads to trivial solutions (near-zero noise). An additional ablation on the Noise Generator depth found 4 EGNN layers optimal over 2 layers. A comparison of independent (diagonal) versus non-independent (full covariance) noise sampling showed comparable results, suggesting the EGNN already captures inter-atomic dependencies implicitly.</li>
<li><strong>Case Studies</strong>: Visualizations of the learned noise variances for different molecules confirmed that the model learns chemically intuitive noise patterns. For example, it applies smaller perturbations to atoms in a rigid bicyclic norcamphor derivative and larger ones to atoms in flexible functional groups of a cyclopropane derivative. Even identical functional groups (e.g., hydroxyl) receive different noise scales in different molecular contexts.</li>
</ol>
<h2 id="key-findings-on-force-field-learning">Key Findings on Force Field Learning</h2>
<ul>
<li><strong>Primary Conclusion</strong>: Learning a <strong>molecule-adaptive and atom-specific</strong> noise distribution is a superior strategy for denoising-based pre-training compared to using fixed, hand-crafted heuristics. This more physically-grounded approach leads to representations that better capture molecular force fields.</li>
<li><strong>Strong Benchmark Performance</strong>: DenoiseVAE achieves best or second-best results on 11 of 12 QM9 tasks, 5 of 8 MD17 molecules, and leads on the stringent 30% LBA split. Performance is mixed on some MD17 molecules (Malonaldehyde, Salicylic Acid, Uracil), where it trails Frad.</li>
<li><strong>Effective Framework</strong>: The proposed VAE-based framework, which jointly trains a Noise Generator and a Denoising Module, is an effective and theoretically sound method for implementing this adaptive noise strategy. The interplay between the reconstruction loss and the KL-divergence regularization is key to its success.</li>
<li><strong>Limitation and Future Direction</strong>: The method is based on classical force field assumptions. The authors note that integrating more accurate force fields represents a promising direction for future work.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Serendipity-r/DenoiseVAE">Serendipity-r/DenoiseVAE</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<h3 id="reproducibility-status">Reproducibility Status</h3>
<ul>
<li><strong>Source Code</strong>: The authors have released their code at <a href="https://github.com/Serendipity-r/DenoiseVAE">Serendipity-r/DenoiseVAE</a> on GitHub. No license is specified in the repository.</li>
<li><strong>Implementation</strong>: Hyperparameters and architectures are detailed in the paper&rsquo;s appendix (A.14), and the repository provides reference implementations.</li>
</ul>
<h3 id="data">Data</h3>
<ul>
<li><strong>Pre-training Dataset</strong>: <a href="https://ogb.stanford.edu/docs/lsc/pcqm4mv2/">PCQM4Mv2</a> (approximately 3.4 million organic molecules)</li>
<li><strong>Property Prediction</strong>: <a href="https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.QM9.html">QM9 dataset</a> (134k molecules; 100k train, 18k val, 13k test split) for 12 quantum chemical properties</li>
<li><strong>Force Prediction</strong>: <a href="http://www.sgdml.org/#datasets">MD17 dataset</a> (9,500 train, 500 val split) for 8 different small molecules</li>
<li><strong>Ligand Binding Affinity</strong>: PDBBind v2019 (4,463 protein-ligand complexes) with 30% and 60% sequence identity splits</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Noise Generator</strong>: 4-layer Equivariant Graph Neural Network (EGNN) that outputs atom-specific Gaussian noise distributions</li>
<li><strong>Denoising Module</strong>: 7-layer EGNN decoder</li>
<li><strong>Training Objective</strong>: $\mathcal{L}_{DenoiseVAE} = \mathcal{L}_{Denoise} + \lambda \mathcal{L}_{KL}$ with $\lambda=1$</li>
<li><strong>Noise Sampling</strong>: Reparameterization trick with $\tilde{x}_i = x_i + \epsilon \sigma_i$</li>
<li><strong>Prior Distribution</strong>: Standard deviation $\sigma=0.1$</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Model Size</strong>: 1.44M parameters total</li>
<li><strong>Fine-tuning Protocol</strong>: Noise Generator discarded after pre-training; only the pre-trained Denoising Module (7-layer EGNN) is retained for downstream fine-tuning</li>
<li><strong>Optimizer</strong>: AdamW with cosine learning rate decay (max LR of 0.0005)</li>
<li><strong>Batch Size</strong>: 128</li>
<li><strong>System Training</strong>: Fine-tuned end-to-end for specific tasks; force prediction involves computing the gradient of the predicted energy</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Ablation Studies</strong>: Sensitivity analysis confirmed $\lambda=1$ and $\sigma=0.1$ as optimal hyperparameters; removing the KL term leads to trivial solutions (near-zero noise)</li>
<li><strong>Noise Generator Depth</strong>: 4 EGNN layers outperformed 2 layers across both QM9 and MD17 benchmarks</li>
<li><strong>Covariance Structure</strong>: Full covariance matrix (non-independent noise sampling) yielded comparable results to diagonal variance (independent sampling), likely because the EGNN already integrates neighboring atom information</li>
<li><strong>O(3) Invariance</strong>: The method satisfies O(3) probabilistic invariance, meaning the noise distribution is unchanged under rotations and reflections</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPU Configuration</strong>: Experiments conducted on a single RTX A3090 GPU; 6 GPUs with 144GB total memory sufficient for full reproduction</li>
<li><strong>CPU</strong>: Intel Xeon Gold 5318Y @ 2.10GHz</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Liu, Y., Chen, J., Jiao, R., Li, J., Huang, W., &amp; Su, B. (2025). DenoiseVAE: Learning Molecule-Adaptive Noise Distributions for Denoising-based 3D Molecular Pre-training. <em>The Thirteenth International Conference on Learning Representations (ICLR)</em>.</p>
<p><strong>Publication</strong>: ICLR 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{liu2025denoisevae,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{DenoiseVAE: Learning Molecule-Adaptive Noise Distributions for Denoising-based 3D Molecular Pre-training}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Yurou Liu and Jiahao Chen and Rui Jiao and Jiangmeng Li and Wenbing Huang and Bing Su}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{The Thirteenth International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://openreview.net/forum?id=ym7pr83XQr}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://iclr.cc/virtual/2025/poster/27701">ICLR 2025 poster page</a></li>
<li><a href="https://openreview.net/forum?id=ym7pr83XQr">OpenReview forum</a></li>
<li><a href="https://openreview.net/pdf?id=ym7pr83XQr">PDF on OpenReview</a></li>
</ul>
]]></content:encoded></item><item><title>eSEN: Smooth Interatomic Potentials (ICML Spotlight)</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/learning-smooth-interatomic-potentials/</link><pubDate>Sat, 23 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/learning-smooth-interatomic-potentials/</guid><description>Fu et al. propose energy conservation as a key MLIP diagnostic and introduce eSEN, bridging test accuracy and real performance.</description><content:encoded><![CDATA[<h2 id="paper-overview">Paper Overview</h2>
<p>This is a <strong>method paper</strong>. It addresses a critical disconnect in the evaluation of Machine Learning Interatomic Potentials (MLIPs) and introduces a novel architecture, <strong>eSEN</strong>, designed based on insights from this analysis. The paper proposes a new standard for evaluating MLIPs beyond simple test-set errors.</p>
<h2 id="the-energy-conservation-gap-in-mlip-evaluation">The Energy Conservation Gap in MLIP Evaluation</h2>
<p>The motivation addresses a well-known but under-addressed problem in the field: improvements in standard MLIP metrics (lower energy/force MAE on static test sets) do not reliably translate to better performance on complex downstream tasks like molecular dynamics (MD) simulations, materials stability prediction, or phonon calculations. The authors seek to understand why this gap exists and how to design models that are both accurate on test sets and physically reliable in practical scientific workflows.</p>
<h2 id="the-esen-architecture-and-continuous-representation">The eSEN Architecture and Continuous Representation</h2>
<p>The novelty is twofold, spanning both a conceptual framework for evaluation and a new model architecture:</p>
<ol>
<li>
<p><strong>Energy Conservation as a Diagnostic Test</strong>: The core conceptual contribution is using an MLIP&rsquo;s ability to conserve energy in out-of-distribution MD simulations as a crucial diagnostic test. The authors demonstrate that for models passing this test, a strong correlation between test-set error and downstream task performance is restored.</p>
</li>
<li>
<p><strong>The eSEN Architecture</strong>: The paper introduces the <strong>equivariant Smooth Energy Network (eSEN)</strong>, designed with specific choices to ensure a smooth and well-behaved Potential Energy Surface (PES):</p>
<ul>
<li><strong>Strictly Conservative Forces</strong>: Forces are computed exclusively as the negative gradient of energy ($F = -\nabla E$), using conservative force prediction instead of faster direct-force prediction heads.</li>
<li><strong>Continuous Representations</strong>: Maintains strict equivariance and smoothness by using equivariant gated non-linearities instead of discretizing spherical harmonic representations during nodewise processing.</li>
<li><strong>Smooth PES Construction</strong>: Critical design choices include using distance cutoffs, polynomial envelope functions ensuring derivatives go to zero at cutoffs, and limited radial basis functions to avoid overly sensitive PES.</li>
</ul>
</li>
<li>
<p><strong>Efficient Training Strategy</strong>: A two-stage training regimen with fast pre-training using a non-conservative direct-force model, followed by fine-tuning to enforce energy conservation. This captures the efficiency of direct-force training while ensuring physical robustness.</p>
</li>
</ol>
<h2 id="evaluating-ood-energy-conservation-and-physical-properties">Evaluating OOD Energy Conservation and Physical Properties</h2>
<p>The paper presents a comprehensive experimental validation:</p>
<ol>
<li>
<p><strong>Ablation Studies on Energy Conservation</strong>: MD simulations on out-of-distribution systems (TM23 and MD22 datasets) systematically tested key design choices (direct-force vs. conservative, representation discretization, neighbor limits, envelope functions). This empirically demonstrated which choices lead to energy drift despite negligible impact on test-set MAE.</p>
</li>
<li>
<p><strong>Physical Property Prediction Benchmarks</strong>: The eSEN model was evaluated on challenging downstream tasks:</p>
<ul>
<li><strong>Matbench-Discovery</strong>: Materials stability and thermal conductivity prediction, where eSEN achieved the highest F1 score among compliant models and excelled at both metrics simultaneously.</li>
<li><strong>MDR Phonon Benchmark</strong>: Predicting phonon properties that test accurate second and third-order derivatives of the PES. eSEN achieved state-of-the-art results, particularly outperforming direct-force models.</li>
<li><strong>SPICE-MACE-OFF</strong>: Standard energy and force prediction on organic molecules, demonstrating that physical plausibility design choices enhanced raw accuracy.</li>
</ul>
</li>
<li>
<p><strong>Correlation Analysis</strong>: Explicit plots of test-set energy MAE versus performance on downstream benchmarks showed weak overall correlation that becomes strong and predictive when restricted to models passing the energy conservation test.</p>
</li>
</ol>
<h2 id="outcomes-and-conclusions">Outcomes and Conclusions</h2>
<ul>
<li>
<p><strong>Primary Conclusion</strong>: Energy conservation is a critical, practical property for MLIPs. Using it as a filter re-establishes test-set error as a reliable proxy for model development, dramatically accelerating the innovation cycle. Models that are not conservative, even with low test error, are unreliable for many critical scientific applications.</p>
</li>
<li>
<p><strong>Model Performance</strong>: The eSEN architecture outperforms base models across diverse tasks, from energy/force prediction to geometry optimization, phonon calculations, and thermal conductivity prediction.</p>
</li>
<li>
<p><strong>Actionable Design Principles</strong>: The paper provides experimentally-validated architectural choices that promote physical plausibility. Seemingly minor details, like how atomic neighbors are selected, can have profound impacts on a model&rsquo;s utility in simulations.</p>
</li>
<li>
<p><strong>Efficient Path to Robust Models</strong>: The direct-force pre-training plus conservative fine-tuning strategy offers a practical method for developing physically robust models without incurring the full computational cost of conservative training from scratch.</p>
</li>
</ul>
<hr>
<h2 id="reproducibility">Reproducibility</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/facebookresearch/fairchem">fairchem (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation within FAIR Chemistry framework</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/facebook/OMAT24">OMAT24 (Hugging Face)</a></td>
          <td>Model</td>
          <td>FAIR Acceptable Use Policy</td>
          <td>Pre-trained eSEN-30M-MP and eSEN-30M-OAM checkpoints</td>
      </tr>
      <tr>
          <td><a href="https://openreview.net/forum?id=R0PBjxIbgm">OpenReview</a></td>
          <td>Paper</td>
          <td>CC BY 4.0</td>
          <td>ICML 2025 camera-ready paper</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<p>The eSEN architecture builds on components from <strong>eSCN</strong> (Equivariant Spherical Channel Network) and <strong>Equiformer</strong>, combining them with design choices that prioritize smoothness and energy conservation. The implementation integrates into the standard <code>fairchem</code> Open Catalyst experimental framework.</p>
<h4 id="layer-structure">Layer Structure</h4>
<ul>
<li><strong>Edgewise Convolution</strong>: Uses <code>SO2</code> convolution layers (from eSCN) with an envelope function applied. Source and target embeddings are concatenated before convolution.</li>
<li><strong>Nodewise Feed-Forward</strong>: Two equivariant linear layers with an intermediate <strong>SiLU-based gated non-linearity</strong> (from Equiformer).</li>
<li><strong>Normalization</strong>: Equivariant Layer Normalization (from Equiformer).</li>
</ul>
<h4 id="smoothness-design-choices">Smoothness Design Choices</h4>
<p>Several architectural decisions distinguish eSEN from prior work:</p>
<ul>
<li><strong>No Grid Projection</strong>: eSEN performs operations directly in the spherical harmonic space to maintain equivariance and energy conservation, bypassing the projection of spherical harmonics to spatial grids for non-linearity.</li>
<li><strong>Distance Cutoff for Graph Construction</strong>: Uses a strict distance cutoff (6 Å for MPTrj models, 5 Å for SPICE models). Neighbor limits introduce discontinuities that break energy conservation.</li>
<li><strong>Polynomial Envelope Functions</strong>: Ensures derivatives go to zero smoothly at the cutoff radius.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<h4 id="two-stage-training-esen-30m-mp">Two-Stage Training (eSEN-30M-MP)</h4>
<ol>
<li><strong>Direct-Force Pre-training</strong> (60 epochs): Uses <strong>DeNS</strong> (Denoising Non-equilibrium Structures) to reduce overfitting. This stage is fast because it does not require backpropagation through energy gradients.</li>
<li><strong>Conservative Fine-tuning</strong> (40 epochs): The direct-force head is removed, and forces are calculated via gradients ($F = -\nabla E$). This enforces energy conservation.</li>
</ol>
<p><strong>Important</strong>: DeNS is used exclusively during the direct-force pre-training stage, with a noising probability of 0.5, a standard deviation of 0.1 Å for the added Gaussian noise, and a DeNS loss coefficient of 10. The fine-tuning strategy reduces the wall-clock time for model training by 40%.</p>
<h4 id="optimization">Optimization</h4>
<ul>
<li><strong>Optimizer</strong>: AdamW with cosine learning rate scheduler</li>
<li><strong>Max Learning Rate</strong>: $4 \times 10^{-4}$</li>
<li><strong>Batch Size</strong>: 512 (for MPTrj models)</li>
<li><strong>Weight Decay</strong>: $1 \times 10^{-3}$</li>
<li><strong>Gradient Clipping</strong>: Norm of 100</li>
<li><strong>Warmup</strong>: 0.1 epochs with a factor of 0.2</li>
</ul>
<h4 id="loss-function">Loss Function</h4>
<p>A composite loss combining per-atom energy MAE, force $L_2$ loss, and stress MAE:</p>
<p>$$
\begin{aligned}
\mathcal{L} = \lambda_{\text{e}} \frac{1}{N} \sum_{i=1}^N \lvert E_{i} - \hat{E}_{i} \rvert + \lambda_{\text{f}} \frac{1}{3N} \sum_{i=1}^N \lVert \mathbf{F}_{i} - \hat{\mathbf{F}}_{i} \rVert_2^2 + \lambda_{\text{s}} \lVert \mathbf{S} - \hat{\mathbf{S}} \rVert_1
\end{aligned}
$$</p>
<p>For MPTrj-30M, the weighting coefficients are set to $\lambda_{\text{e}} = 20$, $\lambda_{\text{f}} = 20$, and $\lambda_{\text{s}} = 5$.</p>
<h3 id="data">Data</h3>
<h4 id="training-data">Training Data</h4>
<ul>
<li><strong>Inorganic</strong>: MPTrj (Materials Project Trajectory) dataset</li>
<li><strong>Organic</strong>: SPICE-MACE-OFF dataset</li>
</ul>
<h4 id="test-data-construction">Test Data Construction</h4>
<ul>
<li><strong>MPTrj Testing</strong>: Since MPTrj lacks an official test split, the authors created a test set using 5,000 random samples from the <strong>subsampled Alexandria (sAlex)</strong> dataset to ensure fair comparison.</li>
<li><strong>Out-of-Distribution Conservation Testing</strong>:
<ul>
<li><em>Inorganic</em>: <strong>TM23</strong> dataset (transition metal defects). Simulation: 100 ps, 5 fs timestep.</li>
<li><em>Organic</em>: <strong>MD22</strong> dataset (large molecules). Simulation: 100 ps, 1 fs timestep.</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Compute for training operations predominantly utilizes <strong>80GB NVIDIA A100 GPUs</strong>.</p>
<h4 id="inference-efficiency">Inference Efficiency</h4>
<p>For a periodic system of <strong>216 atoms</strong> on a single A100 (PyTorch 2.4.0, CUDA 12.1, no compile/torchscript), the 2-layer eSEN models achieve approximately <strong>0.4 million steps per day</strong> (3.2M parameters) and <strong>0.8 million steps per day</strong> (6.5M parameters), comparable to MACE-OFF-L at 0.7 million steps per day.</p>
<h3 id="evaluation">Evaluation</h3>
<p>The paper evaluated eSEN across three major benchmark tasks. Key evaluation metrics included energy MAE (meV/atom), force MAE (meV/Å), stress MAE (meV/Å/atom), F1 score for stability prediction, $\kappa_{\text{SRME}}$ for thermal conductivity, and phonon frequency accuracy.</p>
<h4 id="ablation-test-set-mae-table-1">Ablation Test-Set MAE (Table 1)</h4>
<p>Design choices that dramatically affect energy conservation have negligible impact on static test-set MAE, which is precisely why test-set error alone is misleading. All models are 2-layer with 3.2M parameters, $L_{\text{max}} = 2$, $M_{\text{max}} = 2$:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Energy MAE</th>
          <th>Force MAE</th>
          <th>Stress MAE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>eSEN (default)</td>
          <td>17.02</td>
          <td>43.96</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>eSEN, direct-force</td>
          <td>18.66</td>
          <td>43.62</td>
          <td>0.16</td>
      </tr>
      <tr>
          <td>eSEN, neighbor limit</td>
          <td>17.30</td>
          <td>44.11</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>eSEN, no envelope</td>
          <td>17.60</td>
          <td>44.69</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>eSEN, $N_{\text{basis}} = 512$</td>
          <td>19.87</td>
          <td>48.29</td>
          <td>0.15</td>
      </tr>
      <tr>
          <td>eSEN, Bessel</td>
          <td>17.65</td>
          <td>44.83</td>
          <td>0.15</td>
      </tr>
      <tr>
          <td>eSEN, discrete, res=6</td>
          <td>17.05</td>
          <td>43.10</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>eSEN, discrete, res=10</td>
          <td>17.11</td>
          <td>43.13</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>eSEN, discrete, res=14</td>
          <td>17.12</td>
          <td>43.09</td>
          <td>0.14</td>
      </tr>
  </tbody>
</table>
<p>Energy MAE in meV/atom. Force MAE in meV/Å. Stress MAE in meV/Å/atom.</p>
<h4 id="matbench-discovery-tables-2-and-3">Matbench-Discovery (Tables 2 and 3)</h4>
<p><strong>Compliant models</strong> (trained only on MPTrj or its subset), unique prototype split:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>F1</th>
          <th>DAF</th>
          <th>$\kappa_{\text{SRME}}$</th>
          <th>RMSD</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>eSEN-30M-MP</strong></td>
          <td><strong>0.831</strong></td>
          <td><strong>5.260</strong></td>
          <td><strong>0.340</strong></td>
          <td><strong>0.0752</strong></td>
      </tr>
      <tr>
          <td>eqV2-S-DeNS</td>
          <td>0.815</td>
          <td>5.042</td>
          <td>1.676</td>
          <td>0.0757</td>
      </tr>
      <tr>
          <td>MatRIS-MP</td>
          <td>0.809</td>
          <td>5.049</td>
          <td>0.861</td>
          <td>0.0773</td>
      </tr>
      <tr>
          <td>AlphaNet-MP</td>
          <td>0.799</td>
          <td>4.863</td>
          <td>1.31</td>
          <td>0.1067</td>
      </tr>
      <tr>
          <td>DPA3-v2-MP</td>
          <td>0.786</td>
          <td>4.822</td>
          <td>0.959</td>
          <td>0.0823</td>
      </tr>
      <tr>
          <td>ORB v2 MPtrj</td>
          <td>0.765</td>
          <td>4.702</td>
          <td>1.725</td>
          <td>0.1007</td>
      </tr>
      <tr>
          <td>SevenNet-13i5</td>
          <td>0.760</td>
          <td>4.629</td>
          <td>0.550</td>
          <td>0.0847</td>
      </tr>
      <tr>
          <td>GRACE-2L-MPtrj</td>
          <td>0.691</td>
          <td>4.163</td>
          <td>0.525</td>
          <td>0.0897</td>
      </tr>
      <tr>
          <td>MACE-MP-0</td>
          <td>0.669</td>
          <td>3.777</td>
          <td>0.647</td>
          <td>0.0915</td>
      </tr>
      <tr>
          <td>CHGNet</td>
          <td>0.613</td>
          <td>3.361</td>
          <td>1.717</td>
          <td>0.0949</td>
      </tr>
      <tr>
          <td>M3GNet</td>
          <td>0.569</td>
          <td>2.882</td>
          <td>1.412</td>
          <td>0.1117</td>
      </tr>
  </tbody>
</table>
<p>eSEN-30M-MP excels at both F1 and $\kappa_{\text{SRME}}$ simultaneously, while all previous models only achieve SOTA on one or the other.</p>
<p><strong>Non-compliant models</strong> (trained on additional datasets):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>F1</th>
          <th>$\kappa_{\text{SRME}}$</th>
          <th>RMSD</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>eSEN-30M-OAM</strong></td>
          <td><strong>0.925</strong></td>
          <td><strong>0.170</strong></td>
          <td><strong>0.0608</strong></td>
      </tr>
      <tr>
          <td>eqV2-M-OAM</td>
          <td>0.917</td>
          <td>1.771</td>
          <td>0.0691</td>
      </tr>
      <tr>
          <td>ORB v3</td>
          <td>0.905</td>
          <td>0.210</td>
          <td>0.0750</td>
      </tr>
      <tr>
          <td>SevenNet-MF-ompa</td>
          <td>0.901</td>
          <td>0.317</td>
          <td>0.0639</td>
      </tr>
      <tr>
          <td>DPA3-v2-OpenLAM</td>
          <td>0.890</td>
          <td>0.687</td>
          <td>0.0679</td>
      </tr>
      <tr>
          <td>GRACE-2L-OAM</td>
          <td>0.880</td>
          <td>0.294</td>
          <td>0.0666</td>
      </tr>
      <tr>
          <td>MatterSim-v1-5M</td>
          <td>0.862</td>
          <td>0.574</td>
          <td>0.0733</td>
      </tr>
      <tr>
          <td>MACE-MPA-0</td>
          <td>0.852</td>
          <td>0.412</td>
          <td>0.0731</td>
      </tr>
  </tbody>
</table>
<p>The eSEN-30M-OAM model is pre-trained on the OMat24 dataset, then fine-tuned on the subsampled Alexandria (sAlex) dataset and MPTrj dataset.</p>
<h4 id="mdr-phonon-benchmark-table-4">MDR Phonon Benchmark (Table 4)</h4>
<p>Metrics: maximum phonon frequency MAE($\omega_{\text{max}}$) in K, vibrational entropy MAE($S$) in J/K/mol, Helmholtz free energy MAE($F$) in kJ/mol, heat capacity MAE($C_V$) in J/K/mol.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>MAE($\omega_{\text{max}}$)</th>
          <th>MAE($S$)</th>
          <th>MAE($F$)</th>
          <th>MAE($C_V$)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>eSEN-30M-MP</strong></td>
          <td><strong>21</strong></td>
          <td><strong>13</strong></td>
          <td><strong>5</strong></td>
          <td><strong>4</strong></td>
      </tr>
      <tr>
          <td>SevenNet-13i5</td>
          <td>26</td>
          <td>28</td>
          <td>10</td>
          <td>5</td>
      </tr>
      <tr>
          <td>GRACE-2L (r6)</td>
          <td>40</td>
          <td>25</td>
          <td>9</td>
          <td>5</td>
      </tr>
      <tr>
          <td>SevenNet-0</td>
          <td>40</td>
          <td>48</td>
          <td>19</td>
          <td>9</td>
      </tr>
      <tr>
          <td>MACE</td>
          <td>61</td>
          <td>60</td>
          <td>24</td>
          <td>13</td>
      </tr>
      <tr>
          <td>CHGNet</td>
          <td>89</td>
          <td>114</td>
          <td>45</td>
          <td>21</td>
      </tr>
      <tr>
          <td>M3GNet</td>
          <td>98</td>
          <td>150</td>
          <td>56</td>
          <td>22</td>
      </tr>
  </tbody>
</table>
<p>Direct-force models show dramatically worse performance at the standard 0.01 Å displacement (e.g., eqV2-S-DeNS: 280/224/54/94) but improve at larger displacements (0.2 Å: 58/26/8/8), revealing that their PES is rough near energy minima.</p>
<h4 id="spice-mace-off-table-5">SPICE-MACE-OFF (Table 5)</h4>
<p>Test set MAE for organic molecule energy/force prediction. Energy MAE in meV/atom, force MAE in meV/Å:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>MACE-4.7M (E/F)</th>
          <th>EscAIP-45M* (E/F)</th>
          <th>eSEN-3.2M (E/F)</th>
          <th>eSEN-6.5M (E/F)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PubChem</td>
          <td>0.88 / 14.75</td>
          <td>0.53 / 5.86</td>
          <td>0.22 / 6.10</td>
          <td><strong>0.15</strong> / <strong>4.21</strong></td>
      </tr>
      <tr>
          <td>DES370K M.</td>
          <td>0.59 / 6.58</td>
          <td>0.41 / 3.48</td>
          <td>0.17 / 1.85</td>
          <td><strong>0.13</strong> / <strong>1.24</strong></td>
      </tr>
      <tr>
          <td>DES370K D.</td>
          <td>0.54 / 6.62</td>
          <td>0.38 / 2.18</td>
          <td>0.20 / 2.77</td>
          <td><strong>0.15</strong> / <strong>2.12</strong></td>
      </tr>
      <tr>
          <td>Dipeptides</td>
          <td>0.42 / 10.19</td>
          <td>0.31 / 5.21</td>
          <td>0.10 / 3.04</td>
          <td><strong>0.07</strong> / <strong>2.00</strong></td>
      </tr>
      <tr>
          <td>Sol. AA</td>
          <td>0.98 / 19.43</td>
          <td>0.61 / 11.52</td>
          <td>0.30 / 5.76</td>
          <td><strong>0.25</strong> / <strong>3.68</strong></td>
      </tr>
      <tr>
          <td>Water</td>
          <td>0.83 / 13.57</td>
          <td>0.72 / 10.31</td>
          <td>0.24 / 3.88</td>
          <td><strong>0.15</strong> / <strong>2.50</strong></td>
      </tr>
      <tr>
          <td>QMugs</td>
          <td>0.45 / 16.93</td>
          <td>0.41 / 8.74</td>
          <td>0.16 / 5.70</td>
          <td><strong>0.12</strong> / <strong>3.78</strong></td>
      </tr>
  </tbody>
</table>
<p>*EscAIP-45M is a direct-force model. eSEN-6.5M outperforms MACE-OFF-L and EscAIP on all test splits. The smaller eSEN-3.2M has inference efficiency comparable to MACE-4.7M while achieving lower MAE.</p>
<hr>
<h2 id="why-these-design-choices-matter">Why These Design Choices Matter</h2>
<h3 id="bounded-energy-derivatives-and-the-verlet-integrator">Bounded Energy Derivatives and the Verlet Integrator</h3>
<p>The theoretical foundation for why smoothness matters comes from Theorem 5.1 of Hairer et al. (2003). For the Verlet integrator (the standard NVE integrator), the total energy drift satisfies:</p>
<p>$$
|E(\mathbf{r}_T, \mathbf{a}) - E(\mathbf{r}_0, \mathbf{a})| \leq C \Delta t^2 + C_N \Delta t^N T
$$</p>
<p>where $T$ is the total simulation time ($T \leq \Delta t^{-N}$), $N$ is the highest order for which the $N$th derivative of $E$ is continuously differentiable with bounded derivative, and $C$, $C_N$ are constants independent of $T$ and $\Delta t$. The first term is a time-independent fluctuation of $O(\Delta t^2)$; the second term governs long-term conservation. This means the PES must be continuously differentiable to high order, with bounded derivatives, for energy conservation in long-time simulations.</p>
<h3 id="architectural-choices-that-break-conservation">Architectural Choices That Break Conservation</h3>
<p>The authors provide theoretical justification for why specific architectural choices break energy conservation:</p>
<ul>
<li><strong>Max Neighbor Limit (KNN)</strong>: Introduces discontinuity in the PES. If a neighbor at distance $r$ moves to $r + \epsilon$ and drops out of the top-$K$, the energy changes discontinuously.</li>
<li><strong>Grid Discretization</strong>: Projecting spherical harmonics to a spatial grid introduces discretization errors in energy gradients that break conservation. This can be mitigated with higher-resolution grids but not eliminated.</li>
<li><strong>Direct-Force Prediction</strong>: Imposes no mathematical constraint that forces must be the gradient of an energy scalar field. In other words, $\nabla \times \mathbf{F} \neq 0$ is permitted, violating the requirement for a conservative force field.</li>
</ul>
<h3 id="displacement-sensitivity-in-phonon-calculations">Displacement Sensitivity in Phonon Calculations</h3>
<p>An important empirical finding concerns how displacement values affect phonon predictions. Conservative models (eSEN, MACE) show convergent phonon band structures as displacement decreases toward zero. In contrast, direct-force models (eqV2-S-DeNS) fail to converge, exhibiting missing acoustic branches and spurious imaginary frequencies at small displacements. While direct-force models achieve competitive thermodynamic property accuracy at large displacements (0.2 Å), this is deceptive: the underlying phonon band structures remain inaccurate, and the apparent accuracy comes from Boltzmann-weighted integrals smoothing over errors.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Fu, X., Wood, B. M., Barroso-Luque, L., Levine, D. S., Gao, M., Dzamba, M., &amp; Zitnick, C. L. (2025). Learning Smooth and Expressive Interatomic Potentials for Physical Property Prediction. <em>Proceedings of the 42nd International Conference on Machine Learning (ICML)</em>, PMLR 267:17875–17893.</p>
<p><strong>Publication</strong>: ICML 2025 (Spotlight)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{fu2025learning,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Learning Smooth and Expressive Interatomic Potentials for Physical Property Prediction}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Fu, Xiang and Wood, Brandon M. and Barroso-Luque, Luis and Levine, Daniel S. and Gao, Meng and Dzamba, Misko and Zitnick, C. Lawrence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 42nd International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">series</span>=<span style="color:#e6db74">{Proceedings of Machine Learning Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{267}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{17875--17893}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{PMLR}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://icml.cc/virtual/2025/poster/45302">ICML 2025 poster page</a></li>
<li><a href="https://openreview.net/forum?id=R0PBjxIbgm">OpenReview forum</a></li>
<li><a href="https://openreview.net/pdf?id=R0PBjxIbgm">PDF on OpenReview</a></li>
<li><a href="https://huggingface.co/facebook/OMAT24">OMAT24 model on Hugging Face</a></li>
<li><a href="https://github.com/facebookresearch/fairchem">Code on GitHub (fairchem)</a></li>
</ul>
]]></content:encoded></item><item><title>Efficient DFT Hamiltonian Prediction via Adaptive Sparsity</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/efficient-dft-hamiltonian-predicton-sphnet/</link><pubDate>Sat, 23 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/efficient-dft-hamiltonian-predicton-sphnet/</guid><description>Luo et al. introduce SPHNet, using adaptive sparsity to achieve up to 7x speedup in SE(3)-equivariant Hamiltonian prediction.</description><content:encoded><![CDATA[<h2 id="core-innovation-adaptive-sparsity-in-se3-networks">Core Innovation: Adaptive Sparsity in SE(3) Networks</h2>
<p>This is a <strong>methodological paper</strong> introducing a novel architecture and training curriculum to solve efficiency bottlenecks in Geometric Deep Learning. It directly tackles the primary computational bottleneck in modern SE(3)-equivariant graph neural networks (the tensor product operation) and proposes a generalizable solution through adaptive network sparsification.</p>
<h2 id="the-computational-bottleneck-in-dft-hamiltonian-prediction">The Computational Bottleneck in DFT Hamiltonian Prediction</h2>
<p>SE(3)-equivariant networks are accurate but unscalable for DFT Hamiltonian prediction due to two key bottlenecks:</p>
<ul>
<li><strong>Atom Scaling</strong>: Tensor Product (TP) operations grow quadratically with atoms ($N^2$).</li>
<li><strong>Basis Set Scaling</strong>: Computational complexity grows with the sixth power of the angular momentum order ($L^6$). Larger basis sets (e.g., def2-TZVP) require higher orders ($L=6$), making them prohibitively slow.</li>
</ul>
<p>Existing SE(3)-equivariant models cannot handle large molecules (40-100 atoms) with high-quality basis sets, limiting their practical applicability in computational chemistry.</p>
<h2 id="sphnet-architecture-and-the-three-phase-sparsity-scheduler">SPHNet Architecture and the Three-Phase Sparsity Scheduler</h2>
<p><strong>SPHNet</strong> introduces <strong>Adaptive Sparsity</strong> to prune redundant computations at two levels:</p>
<ol>
<li><strong>Sparse Pair Gate</strong>: Learns which atom pairs to include in message passing, adapting the interaction graph based on importance.</li>
<li><strong>Sparse TP Gate</strong>: Filters which spherical harmonic triplets $(l_1, l_2, l_3)$ are computed in tensor product operations, pruning higher-order combinations that contribute less to accuracy.</li>
<li><strong>Three-Phase Sparsity Scheduler</strong>: A training curriculum (Random → Adaptive → Fixed) that enables stable convergence to high-performing sparse subnetworks.</li>
</ol>
<p>Key insight: The Sparse Pair Gate learns to preserve long-range interactions (16-25 Angstrom) at higher rates than short-range ones. Short-range pairs are abundant and easier to learn, while rare long-range interactions require more samples for accurate representation, making them more critical to retain.</p>
<h2 id="benchmarks-and-ablation-studies">Benchmarks and Ablation Studies</h2>
<p>The authors evaluated SPHNet on three datasets (MD17, QH9, and PubChemQH) with varying molecule sizes and basis set complexities. Baselines include SchNOrb, PhiSNet, QHNet, and WANet. SchNOrb and PhiSNet results are limited to MD17, as those models are designed for trajectory datasets. WANet was not open-sourced, so only partial metrics from its paper are reported.</p>
<h3 id="evaluation-metrics">Evaluation Metrics</h3>
<ul>
<li><strong>Hamiltonian MAE ($H$)</strong>: Mean absolute error between predicted and DFT-computed Hamiltonian matrices, in Hartrees ($E_h$)</li>
<li><strong>Occupied Orbital Energy MAE ($\epsilon$)</strong>: Mean absolute error of all occupied molecular orbital energies derived from the predicted Hamiltonian</li>
<li><strong>Orbital Coefficient Similarity ($\psi$)</strong>: Cosine similarity of occupied molecular orbital coefficients between predicted and reference wavefunctions</li>
</ul>
<h3 id="ablation-studies">Ablation Studies</h3>
<p><strong>Sparse Gates</strong> (on PubChemQH):</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>$H$ [$10^{-6} E_h$] $\downarrow$</th>
          <th>Memory [GB] $\downarrow$</th>
          <th>Speedup $\uparrow$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Both gates</td>
          <td>97.31</td>
          <td>5.62</td>
          <td>7.09x</td>
      </tr>
      <tr>
          <td>Pair Gate only</td>
          <td>87.70</td>
          <td>6.98</td>
          <td>2.73x</td>
      </tr>
      <tr>
          <td>TP Gate only</td>
          <td>94.31</td>
          <td>8.04</td>
          <td>3.98x</td>
      </tr>
      <tr>
          <td>Neither gate</td>
          <td>86.35</td>
          <td>10.91</td>
          <td>1.73x</td>
      </tr>
  </tbody>
</table>
<p>The Sparse Pair Gate contributes a 78% speedup with 30% memory reduction. The Sparse TP Gate (pruning 70% of combinations) yields a 160% speedup. Both gates together achieve the highest speedup, though accuracy slightly decreases compared to no gating.</p>
<p><strong>Three-Phase Scheduler</strong>: Removing the random phase causes convergence to local optima ($112.68 \pm 10.75$ vs $97.31 \pm 0.52$). Removing the adaptive phase increases variance and lowers accuracy ($122.79 \pm 19.02$). Removing the fixed phase has minimal accuracy impact but reduces speedup from 7.09x to 5.45x due to dynamic graph overhead.</p>
<p><strong>Sparsity Rate</strong>: The critical sparsity threshold scales with system complexity: 30% for MD17 (small molecules), 40% for QH9 (medium), and 70% for PubChemQH (large). Beyond the threshold, MAE increases sharply. Computational cost decreases approximately linearly with sparsity rate.</p>
<h3 id="transferability-to-other-models">Transferability to Other Models</h3>
<p>To demonstrate the speedup is architecture-agnostic, the authors applied the Sparse Pair Gate and Sparse TP Gate to the QHNet baseline on PubChemQH:</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>$H$ [$10^{-6} E_h$] $\downarrow$</th>
          <th>Memory [GB] $\downarrow$</th>
          <th>Speedup $\uparrow$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>QHNet baseline</td>
          <td>123.74</td>
          <td>22.50</td>
          <td>1.00x</td>
      </tr>
      <tr>
          <td>+ TP Gate</td>
          <td>128.16</td>
          <td>12.68</td>
          <td>2.04x</td>
      </tr>
      <tr>
          <td>+ Pair Gate</td>
          <td>126.27</td>
          <td>10.07</td>
          <td>1.66x</td>
      </tr>
      <tr>
          <td>+ Both gates</td>
          <td>128.89</td>
          <td>8.46</td>
          <td>3.30x</td>
      </tr>
  </tbody>
</table>
<p>The gates reduced QHNet&rsquo;s memory by 62% and improved speed by 3.3x with modest accuracy trade-off, confirming the gates are portable modules applicable to other SE(3)-equivariant architectures.</p>
<h2 id="performance-results">Performance Results</h2>
<h3 id="qh9-134k-molecules-leq-20-atoms">QH9 (134k molecules, $\leq$ 20 atoms)</h3>
<p>SPHNet achieves 3.3x to 4.0x speedup over QHNet across all four QH9 splits, with improved Hamiltonian MAE and orbital energy MAE. Memory drops to 0.23 GB/sample (33% of QHNet&rsquo;s 0.70 GB). On the stable-iid split, Hamiltonian MAE improves from 76.31 to 45.48 ($10^{-6} E_h$).</p>
<h3 id="pubchemqh-50k-molecules-40-100-atoms">PubChemQH (50k molecules, 40-100 atoms)</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>$H$ [$10^{-6} E_h$] $\downarrow$</th>
          <th>$\epsilon$ [$E_h$] $\downarrow$</th>
          <th>$\psi$ [$10^{-2}$] $\uparrow$</th>
          <th>Memory [GB] $\downarrow$</th>
          <th>Speedup $\uparrow$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>QHNet</td>
          <td>123.74</td>
          <td>3.33</td>
          <td>2.32</td>
          <td>22.5</td>
          <td>1.0x</td>
      </tr>
      <tr>
          <td>WANet</td>
          <td>99.98</td>
          <td><strong>1.17</strong></td>
          <td><strong>3.13</strong></td>
          <td>15.0</td>
          <td>2.4x</td>
      </tr>
      <tr>
          <td>SPHNet</td>
          <td><strong>97.31</strong></td>
          <td>2.16</td>
          <td>2.97</td>
          <td><strong>5.62</strong></td>
          <td><strong>7.1x</strong></td>
      </tr>
  </tbody>
</table>
<p>SPHNet achieves the best Hamiltonian MAE and efficiency, though WANet outperforms on orbital energy MAE and coefficient similarity. The higher speedup on PubChemQH (vs QH9) reflects greater computational redundancy in larger systems with higher-order basis sets ($L_{max} = 6$ for def2-TZVP vs $L_{max} = 4$ for def2-SVP).</p>
<h3 id="md17-small-molecule-trajectories">MD17 (Small Molecule Trajectories)</h3>
<p>SPHNet achieves accuracy comparable to QHNet and PhiSNet on four MD17 molecules (water, ethanol, malondialdehyde, uracil; 3-12 atoms). MD17 represents a simpler task where baseline models already perform well, leaving limited room for improvement. For water (3 atoms), the number of interaction combinations is inherently small, limiting the benefit of adaptive sparsification.</p>
<h3 id="scaling-limit">Scaling Limit</h3>
<p>SPHNet can train on systems with approximately 3000 atomic orbitals on a single A6000 GPU; the QHNet baseline runs out of memory at approximately 1800 orbitals. Memory consumption scales more favorably as molecule size increases.</p>
<h3 id="key-findings">Key Findings</h3>
<ul>
<li><strong>Adaptive sparsity scales with system complexity</strong>: The method is most effective for large systems where redundancy is high. For small molecules (e.g., water with only 3 atoms), every interaction is critical, so pruning hurts accuracy and yields negligible speedup.</li>
<li><strong>Long-range pair preservation</strong>: The Sparse Pair Gate selects long-range pairs (16-25 Angstrom) at higher rates than short-range ones. Short-range pairs are numerous and easier to learn, while rare long-range interactions are harder to represent and thus more critical to retain.</li>
<li><strong>Generalizable components</strong>: The sparsification techniques are portable modules, demonstrated by successful integration into QHNet with 3.3x speedup.</li>
<li><strong>Architecture ablation</strong>: Removing one Vectorial Node Interaction block or Spherical Node Interaction block significantly hurts accuracy, confirming the importance of the progressive order-increase design. Removing one Pair Construction block has less impact, suggesting room for further speedup.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/microsoft/SPHNet">SPHNet (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation; archived by Microsoft (Dec 2025), read-only</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/datasets/EperLuo/PubChemQH">PubChemQH (Hugging Face)</a></td>
          <td>Dataset</td>
          <td>MIT</td>
          <td>50k molecules, 40-100 atoms, def2-TZVP basis</td>
      </tr>
  </tbody>
</table>
<p>No pre-trained model weights are provided. MD17 and QH9 are publicly available community datasets. Training requires 4x NVIDIA A100 (80GB) GPUs; benchmarking uses a single NVIDIA RTX A6000 (46GB).</p>
<h3 id="data">Data</h3>
<p>The experiments evaluated SPHNet on three datasets with different molecular sizes and basis set complexities. All datasets use DFT calculations as ground truth, with MD17 using the PBE exchange-correlation functional and QH9/PubChemQH using B3LYP.</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Molecules</th>
          <th>Molecule Size</th>
          <th>Basis Set</th>
          <th>$L_{max}$</th>
          <th>Functional</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MD17</td>
          <td>4 systems</td>
          <td>3-12 atoms (water, ethanol, malondialdehyde, uracil)</td>
          <td>def2-SVP</td>
          <td>4</td>
          <td>PBE</td>
      </tr>
      <tr>
          <td>QH9</td>
          <td>134k</td>
          <td>$\leq$ 20 atoms (Stable/Dynamic splits)</td>
          <td>def2-SVP</td>
          <td>4</td>
          <td>B3LYP</td>
      </tr>
      <tr>
          <td>PubChemQH</td>
          <td>50k</td>
          <td>40-100 atoms</td>
          <td>def2-TZVP</td>
          <td>6</td>
          <td>B3LYP</td>
      </tr>
  </tbody>
</table>
<p><strong>Data Availability</strong>:</p>
<ul>
<li><strong>MD17 &amp; QH9</strong>: Publicly available</li>
<li><strong>PubChemQH</strong>: Publicly available on Hugging Face (<a href="https://huggingface.co/datasets/EperLuo/PubChemQH">EperLuo/PubChemQH</a>)</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Loss Function</strong>:</p>
<p>The model learns the <strong>residual</strong> $\Delta H$:</p>
<p>$$
\begin{aligned}
\Delta H &amp;= H_{\text{ref}} - H_{\text{init}} \\
\mathcal{L} &amp;= \text{MAE}(H_{\text{ref}}, H_{\text{pred}}) + \text{MSE}(H_{\text{ref}}, H_{\text{pred}})
\end{aligned}
$$</p>
<p>where $H_{\text{init}}$ is a computationally inexpensive initial guess computed via PySCF.</p>
<p><strong>Hyperparameters</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Parameter</th>
          <th>PubChemQH</th>
          <th>QH9</th>
          <th>MD17</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Batch Size</td>
          <td>8</td>
          <td>32</td>
          <td>10 (uracil: 5)</td>
      </tr>
      <tr>
          <td>Training Steps</td>
          <td>300k</td>
          <td>260k</td>
          <td>200k</td>
      </tr>
      <tr>
          <td>Warmup Steps</td>
          <td>1k</td>
          <td>1k</td>
          <td>1k</td>
      </tr>
      <tr>
          <td>Learning Rate</td>
          <td>1e-3</td>
          <td>1e-3</td>
          <td>5e-4</td>
      </tr>
      <tr>
          <td>Sparsity Rate</td>
          <td>0.7</td>
          <td>0.4</td>
          <td>0.1-0.3</td>
      </tr>
      <tr>
          <td>TSS Epoch $t$</td>
          <td>3</td>
          <td>3</td>
          <td>3</td>
      </tr>
  </tbody>
</table>
<p><strong>Sparse Pair Gate</strong>: Adapts the interaction graph. It concatenates zero-order features and inner products of atom pairs, then passes them through a linear layer $F_p$ with sigmoid activation to learn a weight $W_p^{ij}$ for every pair. Pairs are kept only if selected by the scheduler ($U_p^{TSS}$). The overhead comes primarily from the linear layer $F_p$.</p>
<p><strong>Sparse TP Gate</strong>: Filters triplets $(l_1, l_2, l_3)$ inside the TP operation. Higher-order combinations are more likely to be pruned. Complexity: $\mathcal{O}(L^3)$.</p>
<p><strong>Three-Phase Sparsity Scheduler</strong>: Training curriculum designed to optimize the sparse gates effectively:</p>
<ul>
<li><strong>Phase 1 (Random)</strong>: Random selection ($1-k$ probability) to ensure unbiased weight updates. Complexity: $\mathcal{O}(|U|)$.</li>
<li><strong>Phase 2 (Adaptive)</strong>: Selects top $(1-k)$ percent based on learned magnitude. Complexity: $\mathcal{O}(|U|\log|U|)$.</li>
<li><strong>Phase 3 (Fixed)</strong>: Freezes the connectivity mask for maximum inference speed. No overhead.</li>
</ul>
<p><strong>Weight Initialization</strong>: Learnable sparsity weights ($W$) initialized as all-ones vector.</p>
<h3 id="models">Models</h3>
<p>The model predicts the Hamiltonian matrix $H$ from atomic numbers $Z$ and coordinates $r$.</p>
<p><strong>Inputs</strong>: Atomic numbers ($Z$) and 3D coordinates.</p>
<p><strong>Backbone Structure</strong>:</p>
<ol>
<li><strong>Vectorial Node Interaction (x4)</strong>: Uses long-short range message passing. Extracts vectorial representations ($l=1$) without high-order TPs to save cost.</li>
<li><strong>Spherical Node Interaction (x2)</strong>: Projects features to high-order spherical harmonics (up to $L_{max}$). The first block increases the maximum order from 0 to $L_{max}$ without the Sparse Pair Gate; the second block applies the <strong>Sparse Pair Gate</strong> to filter node pairs.</li>
<li><strong>Pair Construction Block (x2)</strong>: Splits into <strong>Diagonal</strong> (self-interaction) and <strong>Non-Diagonal</strong> (cross-interaction) blocks. Both use the <strong>Sparse TP Gate</strong> to prune cross-order combinations $(l_1, l_2, l_3)$. The Non-Diagonal blocks also use the <strong>Sparse Pair Gate</strong> to filter atom pairs. The two Pair Construction blocks receive representations from the two Spherical Node Interaction blocks respectively, and their outputs are summed.</li>
<li><strong>Expansion Block</strong>: Reconstructs the full Hamiltonian matrix from the sparse irreducible representations, exploiting symmetry ($H_{ji} = H_{ij}^T$) to halve computations.</li>
</ol>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training</strong>: 4x NVIDIA A100 (80GB)</li>
<li><strong>Benchmarking</strong>: Single NVIDIA RTX A6000 (46GB)</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Luo, E., Wei, X., Huang, L., Li, Y., Yang, H., Xia, Z., Wang, Z., Liu, C., Shao, B., &amp; Zhang, J. (2025). Efficient and Scalable Density Functional Theory Hamiltonian Prediction through Adaptive Sparsity. <em>Proceedings of the 42nd International Conference on Machine Learning</em>, PMLR 267:41368&ndash;41390.</p>
<p><strong>Publication</strong>: ICML 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{luo2025efficient,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Efficient and Scalable Density Functional Theory Hamiltonian Prediction through Adaptive Sparsity}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Luo, Erpai and Wei, Xinran and Huang, Lin and Li, Yunyang and Yang, Han and Xia, Zaishuo and Wang, Zun and Liu, Chang and Shao, Bin and Zhang, Jia}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 42nd International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{41368--41390}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{267}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">series</span>=<span style="color:#e6db74">{Proceedings of Machine Learning Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{PMLR}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://icml.cc/virtual/2025/poster/45656">ICML 2025 poster page</a></li>
<li><a href="https://openreview.net/forum?id=K3lykWhXON">OpenReview forum</a></li>
<li><a href="https://openreview.net/pdf?id=K3lykWhXON">PDF on OpenReview</a></li>
<li><a href="https://github.com/microsoft/SPHNet">GitHub Repository</a> <em>(Note: The official repository was archived by Microsoft in December 2025. It is available for reference but no longer actively maintained.)</em></li>
</ul>
]]></content:encoded></item><item><title>Dark Side of Forces: Non-Conservative ML Force Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/dark-side-of-forces/</link><pubDate>Sat, 23 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/dark-side-of-forces/</guid><description>Bigi et al. critique non-conservative force models in ML potentials, showing their simulation failures and proposing hybrid solutions.</description><content:encoded><![CDATA[<h2 id="contribution-systematic-assessment-of-non-conservative-ml-force-models">Contribution: Systematic Assessment of Non-Conservative ML Force Models</h2>
<p>This is a <strong>Systematization</strong> paper. It systematically catalogs the exact failure modes of existing non-conservative force approaches, quantifies them with a new diagnostic metric, and proposes a hybrid Multiple Time-Stepping solution combining the speed benefits of direct force prediction with the physical correctness of conservative models.</p>
<h2 id="motivation-the-speed-accuracy-trade-off-in-ml-force-fields">Motivation: The Speed-Accuracy Trade-off in ML Force Fields</h2>
<p>Many recent machine learning interatomic potential (MLIP) architectures predict forces directly ($F_\theta(r)$). This &ldquo;non-conservative&rdquo; approach avoids the computational overhead of automatic differentiation, yielding faster inference (typically 2-3x speedup) and faster training (up to 3x). However, it sacrifices energy conservation and rotational constraints, potentially destabilizing molecular dynamics simulations. The field lacks rigorous quantification of when this trade-off breaks down and how to mitigate the failures.</p>
<h2 id="novelty-jacobian-asymmetry-and-hybrid-architectures">Novelty: Jacobian Asymmetry and Hybrid Architectures</h2>
<p>Four key contributions:</p>
<ol>
<li>
<p><strong>Jacobian Asymmetry Metric ($\lambda$):</strong> A quantitative diagnostic for non-conservation. Since conservative forces derive from a scalar field, their Jacobian (the Hessian of energy) must be symmetric. The normalized norm of the antisymmetric part quantifies the degree of violation:
$$ \lambda = \frac{|| \mathbf{J}_{\text{anti}} ||_F}{|| \mathbf{J} ||_F} $$
where $\mathbf{J}_{\text{anti}} = (\mathbf{J} - \mathbf{J}^\top)/2$. Measured values range from $\lambda \approx 0.004$ (PET-NC) to $\lambda \approx 0.032$ (SOAP-BPNN-NC), with ORB at 0.015 and EquiformerV2 at 0.017. Notably, the pairwise $\lambda_{ij}$ approaches 1 at large interatomic distances, meaning non-conservative artifacts disproportionately affect long-range and collective interactions.</p>
</li>
<li>
<p><strong>Systematic Failure Mode Catalog:</strong> First comprehensive demonstration that non-conservative models cause runaway heating in NVE ensembles (temperature drifts of ~7,000 billion K/s for PET-NC and ~10x larger for ORB) and equipartition violations in NVT ensembles where different atom types equilibrate to different temperatures, a physical impossibility.</p>
</li>
<li>
<p><strong>Theoretical Analysis of Force vs. Energy Training:</strong> Force-only training overemphasizes high-frequency vibrational modes because force labels carry per-atom gradients that are dominated by stiff, short-range interactions. Energy labels provide a more balanced representation across the frequency spectrum. Additionally, conservative models benefit from backpropagation extending the effective receptive field to approximately 2x the interaction cutoff, while direct-force models are limited to the nominal cutoff radius.</p>
</li>
<li>
<p><strong>Hybrid Training and Inference Protocol:</strong> A practical workflow that combines fast direct-force prediction with conservative corrections:</p>
<ul>
<li><strong>Training:</strong> Pre-train on direct forces, then fine-tune on energy gradients (2-4x faster than training conservative models from scratch)</li>
<li><strong>Inference:</strong> Multiple Time-Stepping (MTS) where fast non-conservative forces are periodically corrected by slower conservative forces</li>
</ul>
</li>
</ol>
<h2 id="methodology-systematic-failure-mode-analysis">Methodology: Systematic Failure Mode Analysis</h2>
<p>The evaluation systematically tests multiple state-of-the-art models across diverse simulation scenarios:</p>
<p><strong>Models tested:</strong></p>
<ul>
<li><strong>PET-C/PET-NC</strong> (Point Edge Transformer, conservative and non-conservative variants)</li>
<li><strong>PET-M</strong> (hybrid variant jointly predicting both conservative and non-conservative forces)</li>
<li><strong>ORB-v2</strong> (non-conservative, trained on Alexandria/MPtrj)</li>
<li><strong>EquiformerV2</strong> (non-conservative equivariant Transformer)</li>
<li><strong>MACE-MP-0</strong> (conservative message-passing)</li>
<li><strong>SevenNet</strong> (conservative message-passing)</li>
<li><strong>SOAP-BPNN-C/SOAP-BPNN-NC</strong> (descriptor-based baseline, both conservative and non-conservative variants)</li>
</ul>
<p><strong>Test scenarios:</strong></p>
<ol>
<li><strong>NVE stability tests</strong> on bulk liquid water, graphene, amorphous carbon, and FCC aluminum</li>
<li><strong>Thermostat artifact analysis</strong> with Langevin and GLE thermostats</li>
<li><strong>Geometry optimization</strong> on water snapshots and <a href="/notes/chemistry/datasets/qm9/">QM9</a> molecules using FIRE and L-BFGS</li>
<li><strong>MTS validation</strong> on OC20 catalysis dataset</li>
<li><strong>Species-resolved temperature measurements</strong> for equipartition testing</li>
</ol>
<p><strong>Key metrics:</strong></p>
<ul>
<li>Jacobian asymmetry ($\lambda$)</li>
<li>Kinetic temperature drift in NVE</li>
<li>Velocity-velocity correlations</li>
<li>Radial distribution functions</li>
<li>Species-resolved temperatures</li>
<li>Inference speed benchmarks</li>
</ul>
<h2 id="results-simulation-instability-and-hybrid-solutions">Results: Simulation Instability and Hybrid Solutions</h2>
<p>Purely non-conservative models are <strong>unsuitable for production simulations</strong> due to uncontrollable unphysical artifacts that no thermostat can correct. Key findings:</p>
<p><strong>Performance failures:</strong></p>
<ul>
<li>Non-conservative models exhibited catastrophic temperature drift in NVE simulations: ~7,000 billion K/s for PET-NC and ~70,000 billion K/s for ORB, with EquiformerV2 comparable to PET-NC</li>
<li>Strong Langevin thermostats ($\tau=10$ fs) damped diffusion by ~5x, negating the speed benefits of non-conservative models</li>
<li>Advanced GLE thermostats also failed to control non-conservative drift (ORB reached 1181 K vs. 300 K target)</li>
<li>Equipartition violations: under stochastic velocity rescaling, O and H atoms equilibrated at different temperatures. For ORB, H atoms reached 336 K and O atoms 230 K against a 300 K target. For PET-NC, deviations were smaller but still significant (H at 296 K, O at 310 K).</li>
<li>Geometry optimization was more fragile with non-conservative forces: inaccurate NC models (SOAP-BPNN-NC) failed catastrophically, while more accurate ones (PET-NC) could converge with FIRE but showed large force fluctuations with L-BFGS. Non-conservative models consistently had lower success rates across water and QM9 benchmarks.</li>
</ul>
<p><strong>Hybrid solution success:</strong></p>
<ul>
<li>MTS with non-conservative forces corrected every 8 steps ($M=8$) achieved conservative stability with only ~20% overhead compared to a purely non-conservative trajectory. Results were essentially indistinguishable from fully conservative simulations. Higher stride values ($M=16$) became unstable due to resonances between fast degrees of freedom and integration errors.</li>
<li>Conservative fine-tuning achieved the accuracy of from-scratch training in about 1/3 the total training time (2-4x resource reduction)</li>
<li>Validated on OC20 catalysis benchmark</li>
</ul>
<p><strong>Scaling caveat:</strong> The authors note that as training datasets grow and models become more expressive, non-conservative artifacts should diminish because accurate models naturally exhibit less non-conservative behavior. However, they argue the best path forward is hybrid approaches rather than waiting for scale to solve the problem.</p>
<p><strong>Recommendation:</strong> The optimal production path is hybrid architectures using direct forces for acceleration (via MTS and pre-training) while anchoring models in conservative energy surfaces. This captures computational benefits without sacrificing physical reliability.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Primary training/evaluation:</strong></p>
<ul>
<li><strong>Bulk Liquid Water</strong> (Cheng et al., 2019): revPBE0-D3 calculations with over 250,000 force/energy targets, chosen for rigorous thermodynamic testing</li>
</ul>
<p><strong>Generalization tests:</strong></p>
<ul>
<li>Graphene, amorphous carbon, FCC aluminum (tested with general-purpose foundation models)</li>
</ul>
<p><strong>Benchmarks:</strong></p>
<ul>
<li><strong>QM9</strong>: Geometry optimization tests</li>
<li><strong>OC20</strong> (Open Catalyst): Oxygen on alloy surfaces for MTS validation</li>
</ul>
<p>All datasets publicly available through cited sources.</p>
<h3 id="models">Models</h3>
<p><strong>Point Edge Transformer (PET)</strong> variants:</p>
<ul>
<li><strong>PET-C (Conservative)</strong>: Forces via energy backpropagation</li>
<li><strong>PET-NC (Non-Conservative)</strong>: Direct force prediction head, slightly higher parameter count</li>
<li><strong>PET-M (Hybrid)</strong>: Jointly predicts both conservative and non-conservative forces, accuracy within ~10% of the best single-task models</li>
</ul>
<p><strong>Baseline comparisons:</strong></p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Type</th>
          <th>Training Data</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ORB-v2</td>
          <td>Non-conservative</td>
          <td>Alexandria/MPtrj</td>
          <td>Rotationally unconstrained</td>
      </tr>
      <tr>
          <td>EquiformerV2</td>
          <td>Non-conservative</td>
          <td>Alexandria/MPtrj</td>
          <td>Equivariant Transformer</td>
      </tr>
      <tr>
          <td>MACE-MP-0</td>
          <td>Conservative</td>
          <td>MPtrj</td>
          <td>Equivariant message-passing</td>
      </tr>
      <tr>
          <td>SevenNet</td>
          <td>Conservative</td>
          <td>MPtrj</td>
          <td>Equivariant message-passing</td>
      </tr>
      <tr>
          <td>SOAP-BPNN-C</td>
          <td>Conservative</td>
          <td>Bulk water</td>
          <td>Descriptor-based baseline</td>
      </tr>
      <tr>
          <td>SOAP-BPNN-NC</td>
          <td>Non-conservative</td>
          <td>Bulk water</td>
          <td>Descriptor-based baseline</td>
      </tr>
  </tbody>
</table>
<p><strong>Training details:</strong></p>
<ul>
<li><strong>Loss functions</strong>: PET-C uses joint Energy + Force $L^2$ loss; PET-NC uses Force-only $L^2$ loss</li>
<li><strong>Fine-tuning protocol</strong>: PET-NC converted to conservative via energy head fine-tuning</li>
<li><strong>MTS configuration</strong>: Non-conservative forces with conservative corrections every 8 steps ($M=8$)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics &amp; Software:</strong>
Molecular dynamics evaluations were performed using <strong>i-PI</strong>, while geometry optimizations used <strong>ASE (Atomic Simulation Environment)</strong>. Note that primary code reproducibility is provided via an archived Zenodo snapshot; the authors did not link a live, public GitHub repository.</p>
<ol>
<li><strong>Jacobian asymmetry</strong> ($\lambda$): Quantifies non-conservation via antisymmetric component</li>
<li><strong>Temperature drift</strong>: NVE ensemble stability</li>
<li><strong>Velocity-velocity correlation</strong> ($\hat{c}_{vv}(\omega)$): Thermostat artifact detection</li>
<li><strong>Radial distribution functions</strong> ($g(r)$): Structural accuracy</li>
<li><strong>Species-resolved temperature</strong>: Equipartition testing</li>
<li><strong>Inference speed</strong>: Wall-clock time per MD step</li>
</ol>
<p><strong>Key results:</strong></p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Speed (ms/step)</th>
          <th>NVE Stability</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PET-NC</td>
          <td>8.58</td>
          <td>Failed</td>
          <td>~7,000 billion K/s drift</td>
      </tr>
      <tr>
          <td>PET-C</td>
          <td>19.4</td>
          <td>Stable</td>
          <td>2.3x slower than PET-NC</td>
      </tr>
      <tr>
          <td>SevenNet</td>
          <td>52.8</td>
          <td>Stable</td>
          <td>Conservative baseline</td>
      </tr>
      <tr>
          <td><strong>PET Hybrid (MTS)</strong></td>
          <td><strong>~10.3</strong></td>
          <td><strong>Stable</strong></td>
          <td><strong>~20% overhead vs. pure NC</strong></td>
      </tr>
  </tbody>
</table>
<p><strong>Thermostat artifacts:</strong></p>
<ul>
<li>Langevin ($\tau=10$ fs) dampened diffusion by ~5x (weaker coupling at $\tau=100$ fs reduced diffusion by ~1.5x)</li>
<li>GLE thermostats also failed to control non-conservative drift</li>
<li>Equipartition violations under SVR: ORB showed H at 336 K and O at 230 K (target 300 K); PET-NC showed smaller but significant species-resolved deviations</li>
</ul>
<p><strong>Optimization failures:</strong></p>
<ul>
<li>Non-conservative models showed lower geometry optimization success rates across water and QM9 benchmarks, with inaccurate NC models failing catastrophically</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p><strong>Compute resources:</strong></p>
<ul>
<li><strong>Training</strong>: From-scratch baseline models were trained using 4x Nvidia H100 GPUs (over a duration of around two days).</li>
<li><strong>Fine-Tuning</strong>: Conservative fine-tuning was performed using a single (1x) Nvidia H100 GPU for a duration of one day.</li>
<li>This hybrid fine-tuning approach achieved a 2-4x reduction in computational resources compared to training conservative models from scratch.</li>
</ul>
<p><strong>Reproduction resources:</strong></p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://zenodo.org/records/14778891">Zenodo repository</a></td>
          <td>Code/Data</td>
          <td>Unknown</td>
          <td>Code and data to reproduce all results</td>
      </tr>
      <tr>
          <td><a href="https://atomistic-cookbook.org/examples/pet-mad-nc/pet-mad-nc.html">MTS inference tutorial</a></td>
          <td>Other</td>
          <td>Unknown</td>
          <td>Multiple time-stepping dynamics tutorial</td>
      </tr>
      <tr>
          <td><a href="https://atomistic-cookbook.org/examples/pet-finetuning/pet-ft-nc.html">Conservative fine-tuning tutorial</a></td>
          <td>Other</td>
          <td>Unknown</td>
          <td>Fine-tuning workflow tutorial</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Bigi, F., Langer, M. F., &amp; Ceriotti, M. (2025). The dark side of the forces: assessing non-conservative force models for atomistic machine learning. <em>Proceedings of the 42nd International Conference on Machine Learning</em>, PMLR 267.</p>
<p><strong>Publication</strong>: ICML 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{bigi2025dark,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{The dark side of the forces: assessing non-conservative force models for atomistic machine learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Bigi, Filippo and Langer, Marcel F and Ceriotti, Michele}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 42nd International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">series</span>=<span style="color:#e6db74">{Proceedings of Machine Learning Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{267}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span>=<span style="color:#e6db74">{Vancouver, Canada}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://icml.cc/virtual/2025/poster/45458">ICML 2025 poster page</a></li>
<li><a href="https://openreview.net/pdf?id=OEl3L8osas">PDF on OpenReview</a></li>
<li><a href="https://zenodo.org/records/14778891">Zenodo repository</a></li>
<li><a href="https://atomistic-cookbook.org/examples/pet-mad-nc/pet-mad-nc.html">MTS Inference Tutorial</a></li>
<li><a href="https://atomistic-cookbook.org/examples/pet-finetuning/pet-ft-nc.html">Conservative Fine-Tuning Tutorial</a></li>
</ul>
]]></content:encoded></item><item><title>Beyond Atoms: 3D Space Modeling for Molecular Pretraining</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/beyond-atoms/</link><pubDate>Sat, 23 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/beyond-atoms/</guid><description>Lu et al. introduce SpaceFormer, a Transformer that models entire 3D molecular space including atoms for superior representations.</description><content:encoded><![CDATA[<h2 id="paper-typology-and-contribution">Paper Typology and Contribution</h2>
<p>This is a <strong>Method</strong> paper. It challenges the atom-centric paradigm of molecular representation learning by proposing a novel framework that models the continuous 3D space surrounding atoms. The core contribution is <strong>SpaceFormer</strong>, a Transformer-based architecture that discretizes molecular space into grids to capture physical phenomena (electron density, electromagnetic fields) often missed by traditional point-cloud models.</p>
<h2 id="the-physical-intuition-modeling-empty-space">The Physical Intuition: Modeling &ldquo;Empty&rdquo; Space</h2>
<p><strong>The Gap</strong>: Prior 3D molecular representation models, such as Uni-Mol, treat molecules as discrete sets of atoms, essentially point clouds in 3D space. However, from a quantum physics perspective, the &ldquo;empty&rdquo; space between atoms is far from empty. It is permeated by electron density distributions and electromagnetic fields that determine molecular properties.</p>
<p><strong>The Hypothesis</strong>: Explicitly modeling this continuous 3D space alongside discrete atom positions yields superior representations for downstream tasks, particularly for computational properties that depend on electronic structure, such as HOMO/LUMO energies and energy gaps.</p>
<h2 id="a-surprising-observation-virtual-points-improve-representations">A Surprising Observation: Virtual Points Improve Representations</h2>
<p>Before proposing SpaceFormer, the authors present a simple yet revealing experiment. They augment Uni-Mol by adding randomly sampled virtual points (VPs) from the 3D space within the circumscribed cuboid of each molecule. These VPs carry no chemical information whatsoever: they are purely random noise points.</p>
<p>The result is surprising: adding just 10 random VPs already yields a noticeable improvement in validation loss. The improvement remains consistent and gradually increases as the number of VPs grows, eventually reaching a plateau. This observation holds across downstream tasks as well, with Uni-Mol + VPs improving on several quantum property predictions (LUMO, E1-CC2, E2-CC2) compared to vanilla Uni-Mol.</p>
<p>The implication is that even uninformative spatial context helps the model learn better representations, motivating a principled framework for modeling the full 3D molecular space.</p>
<h2 id="spaceformer-voxelization-and-3d-positional-encodings">SpaceFormer: Voxelization and 3D Positional Encodings</h2>
<p>The key innovation is treating the molecular representation problem as <strong>3D space modeling</strong>. SpaceFormer follows these core steps:</p>
<ol>
<li><strong>Voxelizes the entire 3D space</strong> into a grid with cells of $0.49\text{\AA}$ (based on O-H bond length to ensure at most one atom per cell).</li>
<li><strong>Uses adaptive multi-resolution grids</strong> to efficiently handle empty space, keeping it fine-grained near atoms and coarse-grained far away.</li>
<li><strong>Applies Transformers to 3D spatial tokens</strong> with custom positional encodings that achieve linear complexity.</li>
</ol>
<p>Specifically, the model utilizes two forms of 3D Positional Encoding:</p>
<p><strong>3D Directional PE (RoPE Extension)</strong>
They extend Rotary Positional Encoding (RoPE) to 3D continuous space by splitting the Query and Key vectors into three blocks (one for each spatial axis). The directional attention mechanism takes the form:</p>
<p>$$
\begin{aligned}
\mathbf{q}_{i}^{\top} \mathbf{k}_{j} = \sum_{s=1}^{3} \mathbf{q}_{i,s}^{\top} \mathbf{R}(c_{j,s} - c_{i,s}) \mathbf{k}_{j,s}
\end{aligned}
$$</p>
<p><strong>3D Distance PE (RFF Approximation)</strong>
To compute invariant geometric distance without incurring quadratic memory overhead, they use Random Fourier Features (RFF) to approximate a Gaussian kernel of pairwise distances:</p>
<p>$$
\begin{aligned}
\exp \left( - \frac{| \mathbf{c}_i - \mathbf{c}_j |_2^2}{2\sigma^2} \right) &amp;\approx z(\mathbf{c}_i)^\top z(\mathbf{c}_j) \\
z(\mathbf{c}_i) &amp;= \sqrt{\frac{2}{d}} \cos(\sigma^{-1} \mathbf{c}_i^\top \boldsymbol{\omega} + \mathbf{b})
\end{aligned}
$$</p>
<p>This approach enables the model to natively encode complex field-like phenomena without computing exhaustive $O(N^2)$ distance matrices.</p>
<h2 id="experimental-setup-and-downstream-tasks">Experimental Setup and Downstream Tasks</h2>
<p><strong>Pretraining Data</strong>: 19 million unlabeled molecules from the same dataset used by Uni-Mol.</p>
<p><strong>Downstream Benchmarks</strong>: The authors propose a new benchmark of 15 tasks, motivated by known limitations of MoleculeNet: invalid structures, inconsistent chemical representations, data curation errors, and an inability to adequately distinguish model performance. The tasks split into two categories:</p>
<ol>
<li>
<p><strong>Computational Properties (Quantum Mechanics)</strong></p>
<ul>
<li>Subsets of <a href="/notes/chemistry/datasets/gdb-17/">GDB-17</a> (HOMO, LUMO, GAP energy prediction, 20K samples; E1-CC2, E2-CC2, f1-CC2, f2-CC2, 21.7K samples)</li>
<li>Cata-condensed polybenzenoid hydrocarbons (Dipole moment, adiabatic ionization potential, D3 dispersion correction, 8,678 samples)</li>
<li>Metric: Mean Absolute Error (MAE)</li>
</ul>
</li>
<li>
<p><strong>Experimental Properties (Pharma/Bio)</strong></p>
<ul>
<li>MoleculeNet tasks (BBBP, BACE for drug discovery)</li>
<li>Biogen ADME tasks (HLM, MME, Solubility)</li>
<li>Metrics: AUC for classification, MAE for regression</li>
</ul>
</li>
</ol>
<p><strong>Splitting Strategy</strong>: All datasets use 8:1:1 train/validation/test ratio with <strong>scaffold splitting</strong> to test out-of-distribution generalization.</p>
<p><strong>Training Setup</strong>:</p>
<ul>
<li><strong>Objective</strong>: Masked Auto-Encoder (MAE) with 30% random masking. Model predicts whether a cell contains an atom, and if so, regresses both atom type and precise offset position.</li>
<li><strong>Hardware</strong>: ~50 hours on 8 NVIDIA A100 GPUs</li>
<li><strong>Optimizer</strong>: Adam ($\beta_1=0.9, \beta_2=0.99$)</li>
<li><strong>Learning Rate</strong>: Peak 1e-4 with linear decay and 0.01 warmup ratio</li>
<li><strong>Batch Size</strong>: 128</li>
<li><strong>Total Updates</strong>: 1 million</li>
</ul>
<p><strong>Baseline Comparisons</strong>: GROVER (2D graph-based MPR), GEM (2D graph enhanced with 3D information), 3D Infomax (GNN with 3D information), Uni-Mol (3D MPR, primary baseline using the same pretraining dataset), and Mol-AE (extends Uni-Mol with atom-based MAE pretraining).</p>
<h2 id="results-and-analysis">Results and Analysis</h2>
<p><strong>Strong Contextual Performance</strong>: SpaceFormer ranked 1st in 10 of 15 tasks and in the top 2 for 14 of 15 tasks. It surpassed the runner-up models by approximately 20% on quantum property tasks (HOMO, LUMO, GAP, E1-CC2, Dipmom), validating that modeling non-atom space captures electronic structure better than atom-only regimes.</p>
<h3 id="key-results-on-quantum-properties">Key Results on Quantum Properties</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>GROVER</th>
          <th>GEM</th>
          <th>3D Infomax</th>
          <th>Uni-Mol</th>
          <th>Mol-AE</th>
          <th><strong>SpaceFormer</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>HOMO (Ha)</td>
          <td>0.0075</td>
          <td>0.0068</td>
          <td>0.0065</td>
          <td>0.0052</td>
          <td>0.0050</td>
          <td><strong>0.0042</strong></td>
      </tr>
      <tr>
          <td>LUMO (Ha)</td>
          <td>0.0086</td>
          <td>0.0080</td>
          <td>0.0070</td>
          <td>0.0060</td>
          <td>0.0057</td>
          <td><strong>0.0040</strong></td>
      </tr>
      <tr>
          <td>GAP (Ha)</td>
          <td>0.0109</td>
          <td>0.0107</td>
          <td>0.0095</td>
          <td>0.0081</td>
          <td>0.0080</td>
          <td><strong>0.0064</strong></td>
      </tr>
      <tr>
          <td>E1-CC2 (eV)</td>
          <td>0.0101</td>
          <td>0.0090</td>
          <td>0.0089</td>
          <td>0.0067</td>
          <td>0.0070</td>
          <td><strong>0.0058</strong></td>
      </tr>
      <tr>
          <td>Dipmom (Debye)</td>
          <td>0.0752</td>
          <td>0.0289</td>
          <td>0.0291</td>
          <td>0.0106</td>
          <td>0.0113</td>
          <td><strong>0.0083</strong></td>
      </tr>
  </tbody>
</table>
<p>SpaceFormer&rsquo;s advantage is most pronounced on computational properties that depend on electronic structure. On experimental biological tasks (e.g., BBBP), where measurements are noisy, the advantage narrows or reverses: Uni-Mol achieves 0.9066 AUC on BBBP compared to SpaceFormer&rsquo;s 0.8605.</p>
<h3 id="ablation-studies">Ablation Studies</h3>
<p>The authors present several ablations that isolate the source of SpaceFormer&rsquo;s improvements:</p>
<p><strong>MAE vs. Denoising</strong>: SpaceFormer with MAE pretraining outperforms SpaceFormer with denoising on all four ablation tasks. The MAE objective requires predicting <em>whether</em> an atom exists in a masked voxel, which forces the model to learn global structural dependencies. In the denoising variant, only atom cells are masked so the model never needs to predict atom existence, reducing the task to coordinate regression.</p>
<p><strong>FLOPs Control</strong>: A SpaceFormer-Large model (4x width, atom-only) trained with comparable FLOPs still falls short of SpaceFormer with 1000 non-atom cells on most downstream tasks. This confirms the improvement comes from modeling 3D space, not from additional compute.</p>
<p><strong>Virtual Points vs. SpaceFormer</strong>: Adding up to 200 random virtual points to Uni-Mol improves some tasks but leaves a significant gap compared to SpaceFormer, demonstrating that principled space discretization outperforms naive point augmentation.</p>
<p><strong>Efficiency Validation</strong>: The Adaptive Grid Merging method reduces the number of cells by roughly 10x with virtually no performance degradation. The 3D positional encodings scale linearly with the number of cells, while Uni-Mol&rsquo;s pretraining cost scales quadratically.</p>
<h3 id="scope-and-future-directions">Scope and Future Directions</h3>
<p>SpaceFormer does not incorporate built-in SE(3) equivariance, relying instead on data augmentation (random rotations and random boundary padding) during training. The authors identify extending SpaceFormer to force field tasks and larger systems such as proteins and complexes as promising future directions.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="code-and-data-availability">Code and Data Availability</h3>
<ul>
<li><strong>Source Code</strong>: As of the current date, the authors have not released the official source code or pre-trained weights.</li>
<li><strong>Datasets</strong>: Pretraining utilized the same 19M unlabeled molecule dataset as Uni-Mol. Downstream tasks use a newly curated internal benchmark built from subsets of GDB-17, MoleculeNet, and Biogen ADME. The exact customized scaffold splits for these evaluations are pending the official code release.</li>
<li><strong>Compute</strong>: Pretraining the base SpaceFormer encoder (~67.8M parameters, configured to merge level 3) required approximately 50 hours on 8 NVIDIA A100 GPUs.</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Source code</td>
          <td>Code</td>
          <td>N/A</td>
          <td>Not publicly released as of March 2026</td>
      </tr>
      <tr>
          <td>Pre-trained weights</td>
          <td>Model</td>
          <td>N/A</td>
          <td>Not publicly released</td>
      </tr>
      <tr>
          <td>Pretraining data (19M molecules)</td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>Same dataset as Uni-Mol; not independently released</td>
      </tr>
      <tr>
          <td>Downstream benchmark splits</td>
          <td>Dataset</td>
          <td>N/A</td>
          <td>Custom scaffold splits pending code release</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<p>The model treats a molecule as a 3D &ldquo;image&rdquo; via voxelization, processed by a Transformer.</p>
<p><strong>Input Representation</strong>:</p>
<ul>
<li><strong>Discretization</strong>: 3D space divided into grid cells with length <strong>$0.49\text{\AA}$</strong> (based on O-H bond length to ensure at most one atom per cell)</li>
<li><strong>Tokenization</strong>: Tokens are pairs $(t_i, c_i)$ where $t_i$ is atom type (or NULL) and $c_i$ is the coordinate</li>
<li><strong>Embeddings</strong>: Continuous embeddings with dimension 512. Inner-cell positions discretized with $0.01\text{\AA}$ precision</li>
</ul>
<p><strong>Transformer Specifications</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Component</th>
          <th>Layers</th>
          <th>Attention Heads</th>
          <th>Embedding Dim</th>
          <th>FFN Dim</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Encoder</strong></td>
          <td>16</td>
          <td>8</td>
          <td>512</td>
          <td>2048</td>
      </tr>
      <tr>
          <td><strong>Decoder</strong> (MAE)</td>
          <td>4</td>
          <td>4</td>
          <td>256</td>
          <td>1024</td>
      </tr>
  </tbody>
</table>
<p><strong>Attention Mechanism</strong>: FlashAttention for efficient handling of large sequence lengths.</p>
<p><strong>Positional Encodings</strong>:</p>
<ol>
<li><strong>3D Directional PE</strong>: Extension of Rotary Positional Embedding (RoPE) to 3D continuous space, capturing relative directionality</li>
<li><strong>3D Distance PE</strong>: Random Fourier Features (RFF) to approximate Gaussian kernel of pairwise distances with linear complexity</li>
</ol>
<h4 id="visualizing-rff-and-rope">Visualizing RFF and RoPE</h4>















<figure class="post-figure center ">
    <img src="/img/notes/spaceformer-rff-rope-visualization.webp"
         alt="Four-panel visualization showing RFF distance encoding and RoPE directional encoding mechanisms"
         title="Four-panel visualization showing RFF distance encoding and RoPE directional encoding mechanisms"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Visual intuition for SpaceFormer&rsquo;s positional encodings: Top row shows RFF distance encoding (Gaussian-like attention decay and high-frequency feature fingerprints). Bottom row shows RoPE directional encoding (vector rotation fields and resulting attention patterns).</figcaption>
    
</figure>

<p><strong>Top Row (Distance / RFF):</strong> Shows how the model learns &ldquo;closeness.&rdquo; Distance is represented by a complex &ldquo;fingerprint&rdquo; of waves that creates a Gaussian-like force field.</p>
<ul>
<li><strong>Top Left (The Force Field):</strong> The attention score (dot product) naturally forms a Gaussian curve. It is high when atoms are close and decays to zero as they move apart. This mimics physical forces without the model needing to learn that math from scratch.</li>
<li><strong>Top Right (The Fingerprint):</strong> Each dimension oscillates at a different frequency. A specific distance (e.g., $d=2$) has a unique combination of high and low values across these dimensions, creating a unique &ldquo;fingerprint&rdquo; for that exact distance.</li>
</ul>
<p><strong>Bottom Row (Direction / RoPE):</strong> Shows how the model learns &ldquo;relative position.&rdquo; It visualizes the vector rotation and how that creates a grid-like attention pattern.</p>
<ul>
<li><strong>Bottom Left (The Rotation):</strong> This visualizes the &ldquo;X-axis chunk&rdquo; of the vector. As you move from left ($x=-3$) to right ($x=3$), the arrows rotate. The model compares angles between atoms to determine relative positions.</li>
<li><strong>Bottom Right (The Grid):</strong> The resulting attention pattern when combining X-rotations and Y-rotations. The red/blue regions show where the model pays attention relative to the center, forming a grid-like interference pattern that distinguishes relative positions (e.g., &ldquo;top-right&rdquo; vs &ldquo;bottom-left&rdquo;).</li>
</ul>
<h4 id="adaptive-grid-merging">Adaptive Grid Merging</h4>
<p>To make the 3D grid approach computationally tractable, two key strategies are employed:</p>
<ol>
<li><strong>Grid Sampling</strong>: Randomly selecting 10-20% of empty cells during training</li>
<li><strong>Adaptive Grid Merging</strong>: Recursively merging $2 \times 2 \times 2$ blocks of empty cells into larger &ldquo;coarse&rdquo; cells, creating a multi-resolution view that is fine-grained near atoms and coarse-grained in empty space (merging set to Level 3)</li>
</ol>
<p><strong>Visualizing Adaptive Grid Merging</strong>:</p>















<figure class="post-figure center ">
    <img src="/img/notes/spaceformer-adaptive-grid-merging.webp"
         alt="2D simulation of adaptive grid merging for an H2O molecule showing multi-resolution cells"
         title="2D simulation of adaptive grid merging for an H2O molecule showing multi-resolution cells"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Adaptive grid merging demonstrated on H₂O. Red cells (Level 0) contain atoms and remain at full resolution. Progressively darker blue cells represent merged empty regions at higher levels, covering the same volume with fewer tokens.</figcaption>
    
</figure>

<p>The adaptive grid process compresses empty space around molecules while maintaining high resolution near atoms:</p>
<ul>
<li><strong>Red Cells (Level 0):</strong> The smallest squares ($0.49$Å) containing atoms. These are kept at highest resolution because electron density changes rapidly here.</li>
<li><strong>Light Blue Cells (Level 0/1):</strong> Small empty regions close to atoms.</li>
<li><strong>Darker Blue Cells (Level 2/3):</strong> Large blocks of empty space further away.</li>
</ul>
<p>If we used a naive uniform grid, we would have to process thousands of empty &ldquo;Level 0&rdquo; cells containing almost zero information. By merging them into larger blocks (the dark blue squares), the model covers the same volume with significantly fewer input tokens, reducing the number of tokens by roughly <strong>10x</strong> compared to a dense grid.</p>















<figure class="post-figure center ">
    <img src="/img/notes/spaceformer-adaptive-grid-benzene.webp"
         alt="Adaptive grid merging visualization for benzene molecule showing hexagonal ring with multi-resolution grid cells"
         title="Adaptive grid merging visualization for benzene molecule showing hexagonal ring with multi-resolution grid cells"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Adaptive grid merging for benzene (C₆H₆). The model maintains maximum resolution (red Level 0 cells) only where atoms exist, while merging vast empty regions into large blocks (dark blue L3/L4 cells). This allows the model to focus computational power on chemically active zones.</figcaption>
    
</figure>

<p>The benzene example above demonstrates how this scales to larger molecules. The characteristic hexagonal ring of 6 carbon atoms (black) and 6 hydrogen atoms (white) occupies a small fraction of the total grid. The dark blue corners (L3, L4) represent massive merged blocks of empty space, allowing the model to focus 90% of its computational power on the red &ldquo;active&rdquo; zones where chemistry actually happens.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Lu, S., Ji, X., Zhang, B., Yao, L., Liu, S., Gao, Z., Zhang, L., &amp; Ke, G. (2025). Beyond Atoms: Enhancing Molecular Pretrained Representations with 3D Space Modeling. <em>Proceedings of the 42nd International Conference on Machine Learning (ICML)</em>, 267, 40491-40504. <a href="https://proceedings.mlr.press/v267/lu25e.html">https://proceedings.mlr.press/v267/lu25e.html</a></p>
<p><strong>Publication</strong>: ICML 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{lu2025beyond,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Beyond Atoms: Enhancing Molecular Pretrained Representations with 3D Space Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Lu, Shuqi and Ji, Xiaohong and Zhang, Bohang and Yao, Lin and Liu, Siyuan and Gao, Zhifeng and Zhang, Linfeng and Ke, Guolin}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 42nd International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{40491--40504}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{267}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">series</span>=<span style="color:#e6db74">{Proceedings of Machine Learning Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{PMLR}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=Wd9KPQCKwq">OpenReview forum</a></li>
<li><a href="https://openreview.net/pdf?id=Wd9KPQCKwq">PDF on OpenReview</a></li>
<li><a href="https://icml.cc/virtual/2025/poster/45004">ICML 2025 poster page</a></li>
</ul>
]]></content:encoded></item><item><title>Contrastive Learning for Variational Autoencoder Priors</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/contrastive-learning-for-vae-priors/</link><pubDate>Sun, 17 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/contrastive-learning-for-vae-priors/</guid><description>Aneja et al.'s NeurIPS 2021 paper introducing Noise Contrastive Priors (NCPs) to address VAE's 'prior hole' problem with energy-based priors.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>method paper</strong> that introduces a training approach for Variational Autoencoders (VAEs) to address fundamental limitations in their generative quality through improved prior learning.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The work is motivated by a critical limitation in Variational Autoencoders known as the <strong>&ldquo;prior hole&rdquo; problem</strong>, where the prior distribution p(z) fails to match the aggregate approximate posterior q(z). This mismatch leads to areas in the latent space with high density under the prior that don&rsquo;t map to realistic data samples, resulting in poor generative quality compared to GANs and other generative models.</p>















<figure class="post-figure center ">
    <img src="/img/notes/vae-prior-hole-problem-illustrated.webp"
         alt="Visualization of the VAE prior hole problem showing a ring-shaped aggregate posterior q(z) with an empty center, while the standard Gaussian prior p(z) has highest density at the center where no data exists"
         title="Visualization of the VAE prior hole problem showing a ring-shaped aggregate posterior q(z) with an empty center, while the standard Gaussian prior p(z) has highest density at the center where no data exists"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The &lsquo;prior hole&rsquo; problem: the standard Gaussian prior (red dashed contours) assigns highest probability to the center, but the aggregate posterior (blue dots) forms a ring with no data in that region.</figcaption>
    
</figure>

<p>The figure above illustrates this mismatch. The blue dots represent where a trained encoder actually places data in the latent space (the aggregate posterior $q(z)$), which often forms complex, non-Gaussian shapes. The red dashed contours show the standard Gaussian prior $p(z) = \mathcal{N}(0, I)$, which assumes data is centered at the origin. When generating new samples, we draw from this prior, making it likely to sample from the empty &ldquo;hole&rdquo; where the decoder has never seen training data, producing unrealistic outputs.</p>
<p>A natural question arises: the prior $p(z)$ is used for <em>sampling</em> at inference time, so why does learning a better prior also improve <em>likelihood</em> (NLL)? The answer lies in the VAE objective. VAEs maximize the Evidence Lower Bound (ELBO):</p>
<p>$$ \log p(x) \geq \mathcal{L}_{\text{ELBO}}(x) = \underbrace{\mathbb{E}_{q(z|x)}[\log p(x|z)]}_{\text{Reconstruction}} - \underbrace{\text{KL}(q(z|x) \parallel p(z))}_{\text{Regularization}} $$</p>
<p>The KL divergence term penalizes the mismatch between each data point&rsquo;s approximate posterior $q(z|x)$ and the prior $p(z)$. When the prior is a simple Gaussian but the aggregate posterior forms a complex shape (as in the figure above), this KL term remains unnecessarily high for every data point.</p>
<p>By replacing the simple prior with a learned $p_{\text{NCP}}(z)$ that matches the aggregate posterior, the KL penalty decreases, tightening the ELBO and improving NLL. The learned prior thus provides a <strong>unified solution</strong>: better likelihood during training (tighter bound) and better sampling at inference (no &ldquo;holes&rdquo;).</p>
<p>The OpenReview discussion contains a significant theoretical debate regarding the paper&rsquo;s core premise. Reviewers argued that the &ldquo;prior hole&rdquo; problem is actually a failure of the posterior to match the prior, or a failure of the encoder. The authors defended their approach by noting that even with a perfect posterior, a simple Normal prior might fail because the decoder lacks capacity to map a simple distribution to complex data without dropping modes. This justifies fixing the prior by making it learned and complex.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The authors propose an <strong>energy-based model (EBM) prior</strong> that is trained using <strong>Noise Contrastive Estimation (NCE)</strong>, which they term a <strong>Noise Contrastive Prior (NCP)</strong>. The key innovations are:</p>
<ul>
<li><strong>Two-Stage Training Process</strong>: First, a standard VAE is trained with a simple base prior. Then, the VAE weights are frozen and a binary classifier learns to distinguish between samples from the aggregate posterior q(z) and the base prior p(z).</li>
<li><strong>Reweighting Strategy</strong>: The core idea is to reweight a base prior distribution p(z) with a learned reweighting factor r(z) to make the resulting prior $p_{\text{NCP}}(z)$ better match the aggregate posterior q(z).</li>
<li><strong>NCE for EBM Training</strong>: The method frames EBM training as a binary classification task to avoid computationally expensive MCMC sampling.</li>
<li><strong>Scalability to Hierarchical Models</strong>: For hierarchical VAEs with multiple latent groups, the NCP approach can be applied independently and in parallel to each group&rsquo;s conditional prior.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The method was evaluated on several standard image generation benchmarks:</p>
<ul>
<li><strong>MNIST</strong> (dynamically binarized): Likelihood evaluation on a controlled, small-latent-space task</li>
<li><strong>CIFAR-10</strong>: Standard computer vision benchmark for generative modeling</li>
<li><strong>CelebA 64x64</strong>: Applied to both standard VAE architectures and more advanced VAEs with GMM priors (RAE model)</li>
<li><strong>CelebA HQ 256x256</strong>: High-resolution face generation task</li>
</ul>
<p>The hierarchical NVAE models used 30 latent groups for CIFAR-10 and CelebA-64, 20 groups for CelebA-HQ-256, and 10 groups of $4 \times 4$ latent variables for MNIST (deliberately small to enable reliable partition function estimation). The experiments compared FID scores, likelihood metrics, and qualitative sample quality between baseline VAEs and NCP-enhanced versions, with particular focus on hierarchical VAEs (NVAE).</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<p>The proposed NCP method demonstrated improvements in generative quality across evaluated datasets, with modest gains on standard VAEs and particularly large gains on hierarchical models like NVAE:</p>
<ul>
<li><strong>CelebA-64</strong>: NCP improved FID scores from 48.12 to 41.28 for standard VAEs, and from 40.95 to 39.00 for RAE models with GMM priors.</li>
<li><strong>Hierarchical Models (NVAE)</strong>: The impact was particularly pronounced on hierarchical VAEs:
<ul>
<li><strong>CIFAR-10</strong>: FID improved from 51.71 to 24.08</li>
<li><strong>CelebA-64</strong>: FID improved from 13.48 to 5.25, making it competitive with GANs</li>
<li><strong>CelebA HQ 256x256</strong>: FID reduced from 40.26 to 24.79</li>
</ul>
</li>
<li><strong>Likelihood Performance</strong>: On MNIST, NCP-VAE achieved 78.10 nats NLL vs. baseline NVAE&rsquo;s 78.67 nats</li>
</ul>
<p>On CIFAR-10 and CelebA-HQ-256, the concurrent VAEBM method (which forms an EBM on the data space rather than the latent space) outperforms NCP-VAE. However, the authors argue the two approaches are complementary: NCP-VAE targets the latent space while VAEBM operates in data space, and combining them could yield further gains. NCP-VAE also has the advantage of applicability to discrete data (e.g., binarized MNIST) and simpler setup since it only requires training binary classifiers rather than MCMC-based training and sampling.</p>
<p>The key conclusions are that <strong>two-stage training with noise contrastive estimation</strong> provides an effective framework for learning expressive energy-based priors that addresses the prior hole problem while scaling efficiently to hierarchical models.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://drive.google.com/drive/folders/15tCGruQcSdm2G4yLkUpKvGASluSZPIBD">Code (Google Drive)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Official implementation; hosted on Google Drive (may become inaccessible)</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://openreview.net/forum?id=LcSfRundgwI">OpenReview</a></td>
          <td style="text-align: left">Other</td>
          <td style="text-align: left">N/A</td>
          <td style="text-align: left">Reviews, author responses, and supplementary material</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<h4 id="the-reweighting-mechanism">The Reweighting Mechanism</h4>
<p>The core innovation is defining the NCP prior as $p_{\text{NCP}}(z) \propto p(z)r(z)$. The reweighting factor $r(z)$ is derived from the binary classifier $D(z)$ using the <strong>likelihood ratio trick</strong>:</p>
<p>$$ r(z) \approx \frac{D(z)}{1 - D(z)} $$</p>
<p>Here, $D(z)$ is the sigmoid output of the trained discriminator, representing the probability that sample $z$ came from the aggregate posterior $q(z)$ (&ldquo;real&rdquo;). For an optimal discriminator $D^*(z)$, this ratio exactly equals $\frac{q(z)}{p(z)}$, allowing the model to approximate the density ratio without explicit density estimation.</p>















<figure class="post-figure center ">
    <img src="/img/notes/ncp-vae-reweighting-the-prior-posterior.webp"
         alt="Visualization of the NCP reweighting mechanism showing three 1D distributions: q(z) the complex bimodal aggregate posterior, p(z) the simple Gaussian prior, and r(z) the learned reweighting factor that transforms p(z) to match q(z)"
         title="Visualization of the NCP reweighting mechanism showing three 1D distributions: q(z) the complex bimodal aggregate posterior, p(z) the simple Gaussian prior, and r(z) the learned reweighting factor that transforms p(z) to match q(z)"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The reweighting mechanism: the learned factor $r(z)$ (bottom) reweights the simple Gaussian prior $p(z)$ (middle) to approximate the complex aggregate posterior $q(z)$ (top). Where $q(z)$ has high density but $p(z)$ is low, $r(z)$ compensates with high values.</figcaption>
    
</figure>

<h4 id="hierarchical-architecture-strategy">Hierarchical Architecture Strategy</h4>
<p>For hierarchical models (like NVAE), the method trains $K$ binary classifiers in parallel (one for each latent group). Crucially, to ensure efficiency, the classifiers reuse the <strong>context feature</strong> $c(z_{&lt;k})$ extracted by the frozen VAE&rsquo;s prior network. This architectural choice provides significant computational savings.</p>
<h4 id="test-time-sampling-inference">Test-Time Sampling (Inference)</h4>
<p>Since $p_{\text{NCP}}(z)$ is an energy-based model, direct sampling is impossible. The paper employs two methods to generate samples:</p>
<ul>
<li><strong>Sampling-Importance-Resampling (SIR):</strong> Used for most results. It draws $M$ samples (e.g., $M=5000$) from the base prior $p(z)$ and resamples them based on weights $w^{(m)} = r(z^{(m)})$.</li>
<li><strong>Langevin Dynamics (LD):</strong> An iterative refinement method using the gradient of the energy function $E(z) = -\log r(z) - \log p(z)$.</li>
</ul>
<h3 id="models">Models</h3>
<h4 id="decoder-architecture">Decoder Architecture</h4>
<p>For RGB datasets (CIFAR-10, CelebA), the output likelihood must be changed from <strong>Discretized Logistic</strong> (standard NVAE) to a <strong>Normal distribution</strong>. The authors note this change alone led to &ldquo;significant improvements in the base model performance.&rdquo; Using the standard NVAE decoder will result in a weaker baseline than reported.</p>
<h4 id="discriminator-architecture">Discriminator Architecture</h4>
<p>The binary classifier uses a ResNet-style architecture with <strong>Squeeze-and-Excitation (SE)</strong> blocks:</p>
<ul>
<li><strong>Activation:</strong> Swish</li>
<li><strong>Normalization:</strong> Batch Normalization</li>
<li><strong>Optimization:</strong> Adam with Cosine Annealing (learning rate: $10^{-3} \to 10^{-7}$)</li>
</ul>
<p>The SE blocks help the model focus on channel-wise feature recalibration, which is important for distinguishing subtle differences between prior and aggregate posterior in high-dimensional latent spaces.</p>
<h3 id="hardware">Hardware</h3>
<p>The main paper is vague on training time, but the OpenReview rebuttal explicitly lists hardware costs:</p>
<ul>
<li><strong>Hardware:</strong> NVIDIA Tesla V100 (32GB) GPUs</li>
<li><strong>Per-Discriminator Training:</strong> ~13 hours for 100 epochs</li>
<li><strong>Parallelization:</strong> Because latent groups are independent, all discriminators can train in parallel</li>
<li><strong>Total Cost (CelebA-64):</strong> ~8.1 GPU-days</li>
<li><strong>Comparison:</strong> The authors argue this is efficient compared to VDVAE, which requires ~560 GPU-days</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<h4 id="inference-speed-vs-quality-trade-off">Inference Speed vs. Quality Trade-off</h4>
<p>Reviewers flagged that SIR sampling can be prohibitively slow. The authors clarified the exact trade-off:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Proposal Samples ($M$)</th>
          <th style="text-align: left">Time per Image</th>
          <th style="text-align: left">FID (CelebA-64)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">5,000 (paper default)</td>
          <td style="text-align: left">~10.11 seconds</td>
          <td style="text-align: left">5.25</td>
      </tr>
      <tr>
          <td style="text-align: left">500 (practical)</td>
          <td style="text-align: left">~1.25 seconds</td>
          <td style="text-align: left">6.76</td>
      </tr>
  </tbody>
</table>
<p>The quality gain from 500 to 5,000 samples is modest (FID difference of 1.51) while inference time increases roughly 8x, suggesting $M=500$ may be a practical default.</p>
<h4 id="hyperparameters">Hyperparameters</h4>
<ul>
<li><strong>FID Calculation:</strong> 50,000 samples</li>
<li><strong>SIR Proposals:</strong> 5,000 samples (paper default) or 500 (practical)</li>
<li><strong>MNIST:</strong> Dynamically binarized version used for likelihood evaluation</li>
<li><strong>Optimizers:</strong> The study largely adopts hyperparameters from baseline papers (e.g., Lawson et al. for MNIST, Ghosh et al. for RAE)</li>
</ul>
<h4 id="debugging-benchmark-25-gaussians">Debugging Benchmark: 25-Gaussians</h4>
<p>The supplement provides a toy experiment ideal for verifying a new implementation before running on expensive image datasets:</p>
<ul>
<li><strong>Setup:</strong> Synthetic dataset of 25 2D-Gaussians arranged on a grid</li>
<li><strong>Target NLL:</strong> ~-0.954 nats (NCP) vs. ~-2.753 nats (Vanilla VAE)</li>
<li><strong>Success Criterion:</strong> Samples should avoid low-density regions between grid points. A standard VAE will generate samples in these &ldquo;prior holes,&rdquo; while a working NCP implementation should cleanly remove these artifacts.</li>
</ul>
<h4 id="implementation-warnings">Implementation Warnings</h4>
<ul>
<li><strong>SIR Failure Mode:</strong> If the learned prior $p_{\text{NCP}}$ deviates too far from the base prior, SIR sampling collapses (low effective sample size). The paper shows a strong correlation between the NCE classification loss and the effective sample size (Fig. 5(b)), indicating that SIR reliability depends on how well the base prior matches the aggregate posterior.</li>
<li><strong>Temperature Scaling:</strong> The qualitative images in the paper use reduced temperature for improved visual sharpness (Section 5.3). The FID tables do not specify a temperature, so results may or may not use $T=1.0$.</li>
</ul>
<h3 id="data">Data</h3>
<p>The method was evaluated on several standard image generation benchmarks:</p>
<ul>
<li><strong>MNIST</strong> (dynamically binarized): Likelihood evaluation on a controlled, small-latent-space task</li>
<li><strong>CIFAR-10</strong>: Standard computer vision benchmark for generative modeling (32x32 RGB images)</li>
<li><strong>CelebA 64x64</strong>: Face generation task with moderate resolution</li>
<li><strong>CelebA HQ 256x256</strong>: High-resolution face generation task</li>
</ul>
<p>All datasets use standard train/test splits from the computer vision literature.</p>
<h4 id="additional-metrics">Additional Metrics</h4>
<p>Beyond FID and NLL, the paper uses:</p>
<ul>
<li><strong>Effective Sample Size (ESS):</strong> Validates reliability of the SIR sampling procedure</li>
<li><strong>Maximum Mean Discrepancy (MMD):</strong> Measures distance between aggregate posterior and NCP prior distributions</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Aneja, J., Schwing, A. G., Kautz, J., &amp; Vahdat, A. (2021). A contrastive learning approach for training variational autoencoder priors. <em>Advances in Neural Information Processing Systems</em>, 34, 29604-29616. <a href="https://proceedings.neurips.cc/paper/2021/hash/0496604c1d80f66fbeb963c12e570a26-Abstract.html">https://proceedings.neurips.cc/paper/2021/hash/0496604c1d80f66fbeb963c12e570a26-Abstract.html</a></p>
<p><strong>Publication</strong>: NeurIPS 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{aneja2021contrastive,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{A Contrastive Learning Approach for Training Variational Autoencoder Priors}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Aneja, Jyoti and Schwing, Alexander G and Kautz, Jan and Vahdat, Arash}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{34}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{29604--29616}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=LcSfRundgwI">OpenReview Discussion</a></li>
<li><a href="https://drive.google.com/drive/folders/15tCGruQcSdm2G4yLkUpKvGASluSZPIBD">Code Repository</a> (Google Drive; link may become inaccessible over time)</li>
</ul>
]]></content:encoded></item><item><title>GEOM Dataset: 3D Molecular Conformer Generation</title><link>https://hunterheidenreich.com/posts/geom-conformer-generation-dataset/</link><pubDate>Fri, 15 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/geom-conformer-generation-dataset/</guid><description>Learn how GEOM transforms 2D molecular graphs into dynamic 3D conformer ensembles for molecular machine learning applications.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>In molecular machine learning, we often start with a 2D graph, a blueprint of atoms and bonds. A molecule&rsquo;s function is deeply tied to its dynamic 3D shape. Molecules are flexible entities that exist as an <strong>ensemble of low-energy conformations</strong>. Capturing 3D molecular shapes is crucial for predicting molecular behavior.</p>
<p>The <a href="/notes/chemistry/datasets/geom/">GEOM</a> (Geometric Ensemble Of Molecules) dataset was created to bridge this gap. It provides a massive collection of high-quality 3D conformer ensembles, transforming static 2D graphs into something much closer to physical reality. This makes it an invaluable resource for anyone working in geometric deep learning for chemistry and drug discovery.</p>















<figure class="post-figure center ">
    <img src="https://media.springernature.com/full/springer-static/image/art%3A10.1038%2Fs41597-022-01288-4/MediaObjects/41597_2022_1288_Fig1_HTML.png?as=webp"
         alt="Overlay of conformers for a complex molecule"
         title="Overlay of conformers for a complex molecule"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">3D conformer ensembles expand upon 2D blueprints by revealing the diverse shapes the latanoprost molecule adopts.</figcaption>
    
</figure>

<h2 id="the-challenge-of-conformer-generation">The Challenge of Conformer Generation</h2>
<p>Generating 3D structures for every molecule is computationally hard for two main reasons:</p>
<ol>
<li><strong>Combinatorial Explosion</strong>: Think of a molecule with several rotatable bonds. Each bond is like a joint that can be twisted. The number of possible 3D shapes grows exponentially with each new joint. Trying every combination is impractical for most molecules.</li>
<li><strong>Speed vs. Accuracy</strong>: We need to calculate the energy of each shape to know if it&rsquo;s realistic (low energy). Classical <strong>force fields</strong> are fast. <strong>Density Functional Theory (DFT)</strong> provides quantum mechanical accuracy.</li>
</ol>
<p>GEOM uses a semi-empirical method to capture the underlying quantum mechanics efficiently, enabling the generation of millions of conformations for a large dataset.</p>
<h2 id="a-deeper-look-inside-the-geom-dataset">A Deeper Look Inside the GEOM Dataset</h2>
<p>The scale of GEOM is impressive: over <strong>37 million conformations</strong> for more than <strong>450,000 unique molecules</strong>. But the numbers in the paper&rsquo;s tables tell a more interesting story about the dataset&rsquo;s composition.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">AICures drug dataset (N=304,466)</th>
          <th style="text-align: left">Mean</th>
          <th style="text-align: left">Max</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">Number of heavy atoms</td>
          <td style="text-align: left">24.9</td>
          <td style="text-align: left">91</td>
      </tr>
      <tr>
          <td style="text-align: left">Number of rotatable bonds</td>
          <td style="text-align: left">6.5</td>
          <td style="text-align: left">53</td>
      </tr>
      <tr>
          <td style="text-align: left">Conformers</td>
          <td style="text-align: left">102.6</td>
          <td style="text-align: left">7,451</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>QM9 dataset (N=133,258)</strong></td>
          <td style="text-align: left"><strong>Mean</strong></td>
          <td style="text-align: left"><strong>Max</strong></td>
      </tr>
      <tr>
          <td style="text-align: left">Number of heavy atoms</td>
          <td style="text-align: left">8.8</td>
          <td style="text-align: left">9</td>
      </tr>
      <tr>
          <td style="text-align: left">Number of rotatable bonds</td>
          <td style="text-align: left">2.2</td>
          <td style="text-align: left">8</td>
      </tr>
      <tr>
          <td style="text-align: left">Conformers</td>
          <td style="text-align: left">13.5</td>
          <td style="text-align: left">1,101</td>
      </tr>
  </tbody>
</table>
<p><em>A simplified view of Tables 1 &amp; 4 from the paper, highlighting the key differences.</em></p>
<p>What does this tell us?</p>
<ul>
<li><strong>Two Worlds of Molecules</strong>: The dataset is clearly split. The <strong>QM9</strong> subset contains small, relatively rigid molecules (mean of 2.2 rotatable bonds). In contrast, the <strong>AICures</strong> subset contains larger, more flexible drug-like molecules (mean of 6.5 rotatable bonds, with one molecule having 53!). This diversity is ideal for training machine learning models that need to generalize from simple cases to complex, real-world examples.</li>
<li><strong>Conformational Complexity</strong>: The number of conformers found per molecule reflects this flexibility. A typical QM9 molecule has about 13 conformers, while a drug-like molecule has over 100 on average. This highlights the necessity of 3D ensembles for flexible molecules.</li>
</ul>
<p>Beyond the structures themselves, GEOM is rich with experimental data, connecting the 3D shapes to real-world properties. The molecules are labeled with data for everything from <strong>water solubility</strong> and <strong>blood-brain barrier penetration</strong> to <strong>toxicity</strong> and inhibition of key viral targets like the <strong>SARS-CoV-2 3CL protease</strong>. This makes it a powerful tool for developing property prediction models.</p>
<p>In fact, this creates a benchmark for:</p>
<ul>
<li>Property prediction models that can leverage conformer ensembles (or members of the ensemble) as input.</li>
<li>Conformer generation models that must transform 2D graphs into realistic, 3D distributions.</li>
<li>End-to-end property-based evaluation of the conformer ensembles generated by a model.</li>
</ul>
<h2 id="the-toolbox-behind-geom-key-techniques-explained">The Toolbox Behind GEOM: Key Techniques Explained</h2>
<p>The GEOM paper mentions several advanced computational chemistry methods. Let&rsquo;s briefly break down the most important ones:</p>
<ul>
<li><strong>GFN2-xTB</strong>: This is the semi-empirical quantum mechanical method used to calculate energies and forces in GEOM. Think of it as a &ldquo;middle ground&rdquo; method. It provides greater speed than full DFT while capturing electronic effects absent in classical force fields, making it a pragmatic choice for generating a large dataset.</li>
<li><strong>CREST</strong>: This is the program that actually performs the conformer search. It uses a clever technique based on <strong>metadynamics</strong>, where it simulates the molecule&rsquo;s movement and adds a &ldquo;penalty&rdquo; potential to discourage it from revisiting shapes it has already seen. This pushes the molecule to explore its conformational space efficiently, finding many diverse, low-energy structures.</li>
<li><strong>CENSO</strong>: For a small subset of molecules, the authors went a step further with CENSO. This program takes the conformers found by CREST and refines them with more accurate (and expensive) DFT calculations. It&rsquo;s a way of getting very high-quality &ldquo;gold standard&rdquo; data for benchmarking.</li>
<li><strong>Implicit Solvent Models</strong>: Molecules in the body exist in aqueous environments. Methods like <strong>C-PCM</strong> and <strong>ALPB</strong> model water as a continuous medium, which affects the molecule&rsquo;s preferred shape and energy. This is crucial for biological applications.</li>
</ul>
<h2 id="the-math-behind-the-molecules-explained-simply">The Math Behind the Molecules (Explained Simply)</h2>
<p>The paper includes a couple of equations based on the Boltzmann distribution, which is a fundamental concept from statistical mechanics that tells us the probability of finding a system in a certain state.</p>
<p>The key equation used by CREST to assign a probability (or &ldquo;statistical weight&rdquo;) to the <em>i</em>-th conformer is:</p>
<p>$$ P_{i}^{\text{CREST}} = \frac{d_{i}\exp(-E_{i}/k_{B}T)}{\sum_{j}d_{j}\exp(-E_{j}/k_{B}T)} $$</p>
<p>Let&rsquo;s demystify this:</p>
<ul>
<li>$E_i$ is the energy of the conformer. The negative sign and the exponential mean that <strong>lower energy leads to a much higher probability</strong>.</li>
<li>$k_B T$ is the thermal energy at a given temperature $T$. It sets the energy scale. If the energy difference between two conformers is much larger than $k_B T$, the higher-energy one will be virtually nonexistent.</li>
<li>$d_i$ represents the degeneracy of the conformer, which accounts for the number of equivalent states or configurations that share the same energy $E_i$.
<ul>
<li>Degeneracy refers to the number of equivalent, indistinguishable atomic arrangements (rotamers) that correspond to a single overall molecular shape (conformer). For example, the rotation of a methyl group ($-\text{CH}_3$) produces multiple identical-looking orientations of its hydrogen atoms.</li>
</ul>
</li>
<li>The denominator, $\sum_{j}d_{j}\exp(-E_{j}/k_{B}T)$, is the <strong>partition function</strong>. Its job is to sum up the terms from all possible conformers to ensure that all the probabilities add up to 100%.</li>
</ul>
<p>For the high-quality CENSO calculations, the equation uses the <strong>Gibbs Free Energy ($G_i$)</strong>. Free energy provides a complete measure by including the molecule&rsquo;s internal energy, its interaction with a solvent, and entropic effects (like how much it can &ldquo;wiggle&rdquo;). This gives a more accurate ranking of the conformer probabilities.</p>
<h2 id="a-closer-look-at-the-figures-what-the-data-really-shows">A Closer Look at the Figures: What the Data Really Shows</h2>
<p>The paper&rsquo;s figures offer some honest insights into the dataset&rsquo;s quality and the trade-offs involved.</p>















<figure class="post-figure center ">
    <img src="https://media.springernature.com/full/springer-static/image/art%3A10.1038%2Fs41597-022-01288-4/MediaObjects/41597_2022_1288_Fig4_HTML.png?as=webp"
         alt="Scatter plot comparing energy calculation methods."
         title="Scatter plot comparing energy calculation methods."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Comparing the &lsquo;fast&rsquo; GFN2-xTB energies with &lsquo;accurate&rsquo; DFT energies. (a) There&rsquo;s a clear correlation, but also a lot of spread. (b) The ranking accuracy (Spearman ρ) is decent on average (0.39) but highly variable.</figcaption>
    
</figure>

<p>Figure 4 is particularly important. It compares the fast GFN2-xTB (CREST) energies with much more accurate single-point r2scan-3c DFT energies.</p>
<ul>
<li>The <strong>Mean Absolute Error (MAE) of 1.96 kcal/mol</strong> shows that, on average, the fast method gets the energy wrong by about 2 kcal/mol. At room temperature, the thermal energy ($k_B T$) is only about 0.6 kcal/mol. Because the Boltzmann probability depends on the energy _exponentially_, a 2 kcal/mol error can dramatically change the predicted importance of a conformer.</li>
<li>The <strong>Spearman correlation plot</strong> (right side) shows how well GFN2-xTB <em>ranks</em> the conformers from lowest to highest energy compared to DFT. An average correlation of 0.39 provides a strong baseline, though the wide distribution indicates variable performance across different molecules. The ranking accuracy fluctuates, achieving near perfection for certain molecules and showing significant deviation for others.</li>
</ul>
<p>This is a key takeaway: the GFN2-xTB/CREST method excels at <strong>discovering</strong> low-energy shapes. For accurate probability <strong>ranking</strong>, the higher-level DFT energies provided in GEOM are required.</p>
<h2 id="conclusion-what-this-means-for-machine-learning">Conclusion: What This Means for Machine Learning</h2>
<p>For researchers at the intersection of machine learning and chemistry, GEOM provides a realistic foundation to build upon. By shifting the focus from static 2D graphs to dynamic 3D ensembles, GEOM enables a new generation of models.</p>
<p>This dataset is an ideal training ground for models designed to understand 3D geometry, such as <strong>SE(3)-equivariant neural networks</strong>, <strong>diffusion models</strong>, <strong>transformers</strong>, and <strong>VAEs</strong>, which can learn to generate conformer ensembles directly from a 2D graph. By training on GEOM, these models can learn the complex relationship between a molecule&rsquo;s chemical blueprint and its real-world, flexible nature.</p>
<p>For a comprehensive technical reference including detailed specifications, quality metrics, and performance leaderboards, see my <a href="/notes/chemistry/datasets/geom/">GEOM Dataset Card</a>.</p>
<p>Explore the GEOM dataset further by visiting its <a href="https://github.com/learningmatter-mit/geom">GitHub repository</a>.</p>
<h2 id="references">References</h2>
<ul>
<li>Axelrod, S. &amp; Gómez-Bombarelli, R. &ldquo;GEOM, energy-annotated molecular conformations for property prediction and molecular generation.&rdquo; <em>Scientific Data</em> 9, 185 (2022). <a href="https://doi.org/10.1038/s41597-022-01288-4">https://doi.org/10.1038/s41597-022-01288-4</a></li>
<li>GitHub repositories:
<ul>
<li><a href="https://github.com/learningmatter-mit/geom">learningmatter-mit/geom</a></li>
<li><a href="https://github.com/learningmatter-mit/NeuralForceField">learningmatter-mit/NeuralForceField</a></li>
</ul>
</li>
</ul>
]]></content:encoded></item><item><title>SubGrapher: Visual Fingerprinting of Chemical Structures</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/subgrapher/</link><pubDate>Mon, 28 Apr 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/subgrapher/</guid><description>SubGrapher creates molecular fingerprints directly from chemical structure images through functional group segmentation for database retrieval.</description><content:encoded><![CDATA[<h2 id="paper-classification-and-taxonomy">Paper Classification and Taxonomy</h2>
<p>This is primarily a <strong>Methodological Paper ($\Psi_{\text{Method}}$)</strong> with a secondary <strong>Resource ($\Psi_{\text{Resource}}$)</strong> contribution. Using the <a href="/notes/interdisciplinary/research-methods/ai-physical-sciences-paper-taxonomy/">AI and Physical Sciences paper taxonomy</a> framework:</p>
<p><strong>Primary Classification: Method</strong></p>
<p>The dominant basis vector is Methodological because SubGrapher introduces an architecture that replaces the two-step OCSR workflow (image, then structure, then fingerprint) with single-step fingerprinting (image to visual fingerprint). The paper validates this approach through systematic comparison against state-of-the-art methods (MolGrapher, OSRA, DECIMER, MolScribe), demonstrating superior performance on specific tasks like retrieval and robustness to image quality degradation.</p>
<p><strong>Secondary Classification: Resource</strong></p>
<p>The paper makes non-negligible resource contributions by releasing:</p>
<ul>
<li>Code and model weights on <a href="https://github.com/DS4SD/SubGrapher">GitHub</a> and <a href="https://huggingface.co/docling-project/SubGrapher">HuggingFace</a></li>
<li>Five new visual fingerprinting benchmark datasets for molecule retrieval tasks</li>
<li>Comprehensive functional group knowledge base (1,534 substructures)</li>
</ul>
<h2 id="motivation-extracting-complex-structures-from-noisy-images">Motivation: Extracting Complex Structures from Noisy Images</h2>
<p>The motivation tackles a fundamental challenge in chemical informatics: extracting molecular information from the vast amounts of unstructured scientific literature, particularly patents. Millions of molecular structures exist only as images in these documents, making them inaccessible for computational analysis, database searches, or machine learning applications.</p>
<p>Traditional Optical Chemical Structure Recognition (OCSR) tools attempt to fully reconstruct molecular graphs from images, converting them into machine-readable formats like SMILES. However, these approaches face two critical limitations:</p>
<ol>
<li><strong>Brittleness to image quality</strong>: Poor resolution, noise, or unconventional drawing styles frequently degrade recognition accuracy</li>
<li><strong>Limited handling of complex structures</strong>: Markush structures, generic molecular templates with variable R-groups commonly used in patents, are poorly supported by most conventional OCSR methods</li>
</ol>
<p>The key insight driving SubGrapher is that full molecular reconstruction may be unnecessary for many applications. For tasks like database searching, similarity analysis, or document retrieval, a molecular fingerprint - a vectorized representation capturing structural features - is often sufficient. This realization opens up a new approach: bypass the fragile reconstruction step and create fingerprints directly from visual information.</p>
<h2 id="key-innovation-direct-visual-fingerprinting">Key Innovation: Direct Visual Fingerprinting</h2>
<p>SubGrapher takes a different approach to extracting chemical information from images. It creates &ldquo;visual fingerprints&rdquo; through functional group recognition. The key innovations are:</p>
<ol>
<li>
<p><strong>Direct Image-to-Fingerprint Pipeline</strong>: SubGrapher eliminates the traditional two-step process (image → structure → fingerprint) by generating fingerprints directly from pixels. This single-stage approach avoids error accumulation from failed structure reconstructions and can handle images where conventional OCSR tools produce invalid outputs.</p>
</li>
<li>
<p><strong>Dual Instance Segmentation Architecture</strong>: The system employs two specialized Mask-RCNN networks working in parallel:</p>
<ul>
<li><strong>Functional group detector</strong>: Trained to identify 1,534 expert-defined functional groups using pixel-level segmentation masks</li>
<li><strong>Carbon backbone detector</strong>: Recognizes 27 common carbon chain patterns to capture the molecular scaffold</li>
</ul>
<p>Using instance segmentation provides detailed spatial information and higher accuracy through richer supervision during training.</p>
</li>
<li>
<p><strong>Extensive Functional Group Knowledge Base</strong>: The method uses one of the most comprehensive open-source collections of functional groups, encompassing 1,534 substructures. These were systematically defined by:</p>
<ul>
<li>Starting with chemically logical atom combinations (C, O, S, N, B, P)</li>
<li>Expanding to include relevant subgroups and variations</li>
<li>Filtering based on frequency (appearing ~1,000+ times in PubChem)</li>
<li>Additional halogen substituents and organometallic groups relevant to EUV photoresists</li>
<li>Manual curation with SMILES, SMARTS, and descriptive names</li>
</ul>
</li>
<li>
<p><strong>Substructure-Graph Construction</strong>: After detecting functional groups and carbon backbones, SubGrapher builds a connectivity graph where:</p>
<ul>
<li>Each node represents an identified substructure</li>
<li>Edges connect substructures whose bounding boxes overlap (with 10% margin expansion)</li>
<li>This graph captures both the chemical components and their spatial relationships</li>
</ul>
</li>
<li>
<p><strong>Substructure-based Visual Molecular Fingerprint (SVMF)</strong>: The final output is a continuous, count-based fingerprint formally defined as a matrix $SVMF(m) \in \mathbb{R}^{n \times n}$ where $n=1561$ (1,534 functional groups + 27 carbon backbones). The matrix is stored as a compressed upper triangular form:</p>
<p><strong>Diagonal elements</strong> ($i = j$): Weighted count of substructure $i$ plus self-intersection
$$SVMF_{ii}(m) = h_1 \cdot n_i + g_{ii}$$
where $h_1 = 10$ is the diagonal weight hyperparameter, $n_i$ is the instance count, and $g_{ii}$ is the self-intersection coefficient.</p>
<p><strong>Off-diagonal elements</strong> ($i \neq j$): Intersection coefficient based on shortest path distance $d$ in the substructure graph
$$SVMF_{ij}(m) = h_2(d) \cdot \text{intersection}(s_i, s_j)$$
where the distance decay function $h_2(d)$ is:</p>
<ul>
<li>$d \leq 1$: weight = 2</li>
<li>$d = 2$: weight = 2/4 = 0.5</li>
<li>$d = 3$: weight = 2/16 = 0.125</li>
<li>$d = 4$: weight = $2/256 \approx 0.0078$</li>
<li>$d &gt; 4$: weight = 0</li>
</ul>
<p><strong>Key properties</strong>:</p>
<ul>
<li>Carbon chain intersection coefficients are divided by 2, giving functional groups higher effective weight</li>
<li>Similarity between fingerprints calculated using a normalized Euclidean distance (ratio of L2 norm of difference to L2 norm of sum)</li>
<li>Resulting fingerprints are highly sparse (average 0.001% non-zero elements)</li>
<li>Compressed storage enables efficient database searches</li>
</ul>
</li>
<li>
<p><strong>Markush Structure Compatibility</strong>: SubGrapher processes Markush structures by recognizing their constituent functional groups and creating meaningful fingerprints for similarity searches, achieving higher accuracy than existing OCSR methods on the USPTO-Markush benchmark (S-F1: 88).</p>
</li>
</ol>
<h2 id="experimental-validation-and-benchmarks">Experimental Validation and Benchmarks</h2>
<p>The evaluation focused on demonstrating SubGrapher&rsquo;s effectiveness across two critical tasks: accurate substructure detection and robust molecule retrieval from diverse image collections.</p>
<h4 id="substructure-detection-performance">Substructure Detection Performance</h4>
<p>SubGrapher&rsquo;s ability to identify functional groups was tested on three challenging benchmarks that expose different failure modes of OCSR systems:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Size</th>
          <th>Description</th>
          <th>Key Challenge</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>JPO</strong></td>
          <td>341 images</td>
          <td>Japanese Patent Office images (molecules with abbreviations removed)</td>
          <td>Low quality, noise, artifacts, non-standard drawing styles</td>
      </tr>
      <tr>
          <td><strong>USPTO-10K-L</strong></td>
          <td>1,000 images</td>
          <td>Large molecules (&gt;70 atoms)</td>
          <td>Scale variation, structural complexity, many functional groups</td>
      </tr>
      <tr>
          <td><strong>USPTO-Markush</strong></td>
          <td>74 images</td>
          <td>Generic Markush structures</td>
          <td>Variable R-groups, abstract patterns, template representation</td>
      </tr>
  </tbody>
</table>
<p><strong>Key findings:</strong></p>
<ol>
<li>
<p><strong>JPO Dataset (Low-Quality Patent Images)</strong>: SubGrapher achieved the highest Molecule Exact Match rate (83%), demonstrating robustness to image quality degradation where rule-based methods like OSRA scored lower (67% M-EM).</p>
</li>
<li>
<p><strong>USPTO-10K-L (Large Molecules)</strong>: SubGrapher achieved an S-F1 of 97, matching the rule-based OSRA and outperforming all other learning-based methods (MolScribe: 90, DECIMER: 86, MolGrapher: 56). The object detection approach handled scale variation better than other deep-learning OCSR tools on these challenging targets.</p>
</li>
<li>
<p><strong>USPTO-Markush (Generic Structures)</strong>: SubGrapher achieved the highest Substructure F1-score (88) on this benchmark, outperforming MolScribe (86), OSRA (74), and DECIMER (10). While other OCSR tools can attempt these images, they have limited support for Markush features. SubGrapher&rsquo;s instance segmentation approach handles complex Markush structures more effectively by focusing on relevant image regions.</p>
</li>
</ol>
<p>Qualitative analysis revealed that SubGrapher correctly identified functional groups in scenarios where other methods failed completely: images with captions, unconventional drawing styles, or significant quality degradation.</p>
<h4 id="visual-fingerprinting-for-molecule-retrieval">Visual Fingerprinting for Molecule Retrieval</h4>
<p>The core application was evaluated using a retrieval task designed to simulate real-world database searching:</p>
<ol>
<li>
<p><strong>Benchmark Creation</strong>: Five benchmark datasets were constructed around structurally similar molecules (adenosine, camphor, cholesterol, limonene, and pyridine), each containing 500 molecules sampled from PubChem with at least 90% Tanimoto similarity to the reference molecule, rendered as augmented images.</p>
</li>
<li>
<p><strong>Retrieval Task</strong>: Given a SMILES string as a query, the goal was to find the corresponding molecular image within the dataset of 500 visually similar structures. This tests whether the visual fingerprint can distinguish between closely related molecules.</p>
</li>
<li>
<p><strong>Performance Comparison</strong>: SubGrapher significantly outperformed baseline methods, retrieving the correct molecule at an average rank of 95 out of 500. The key advantage was robustness: SubGrapher generates a unique fingerprint for every image, even with partial or uncertain predictions. In contrast, OCSR-based methods frequently fail to produce valid SMILES, making them unable to generate fingerprints for comparison.</p>
</li>
<li>
<p><strong>Real-World Case Study</strong>: A practical demonstration involved searching a 54-page patent document containing 356 chemical images for a specific Markush structure. SubGrapher successfully located the target structure, highlighting its utility for large-scale document mining.</p>
</li>
</ol>
<h4 id="training-data-generation">Training Data Generation</h4>
<p>Since no public datasets existed with the required pixel-level mask annotations for functional groups, the researchers developed a comprehensive synthetic data generation pipeline:</p>
<ol>
<li>
<p><strong>Extended MolDepictor</strong>: They enhanced existing molecular rendering tools to create images from SMILES strings and generate corresponding segmentation masks for all substructures present in each molecule.</p>
</li>
<li>
<p><strong>Markush Structure Rendering</strong>: The pipeline was extended to handle complex generic structures using CXSMILES representations and the CDK library for rendering, creating training data for molecular templates with structural, positional, and frequency variation indicators.</p>
</li>
<li>
<p><strong>Diverse Molecular Sources</strong>: Training molecules were sourced from PubChem to ensure broad chemical diversity and coverage of different structural families.</p>
</li>
</ol>
<h2 id="results-impact-and-limitations">Results, Impact, and Limitations</h2>
<ul>
<li><strong>Superior Robustness to Image Quality</strong>: SubGrapher consistently outperformed traditional OCSR methods on degraded images, particularly the JPO patent dataset. SubGrapher&rsquo;s learned representations proved more resilient to noise, artifacts, and unconventional drawing styles than rule-based alternatives like OSRA (M-EM: 83 vs. 67 on JPO).</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>SubGrapher</th>
          <th>MolScribe</th>
          <th>OSRA</th>
          <th>DECIMER</th>
          <th>MolGrapher</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>S-F1</strong> (JPO)</td>
          <td>92</td>
          <td><strong>94</strong></td>
          <td>81</td>
          <td>86</td>
          <td>89</td>
      </tr>
      <tr>
          <td><strong>M-EM</strong> (JPO)</td>
          <td><strong>83</strong></td>
          <td>82</td>
          <td>67</td>
          <td>79</td>
          <td>80</td>
      </tr>
      <tr>
          <td><strong>S-F1</strong> (USPTO-10K-L)</td>
          <td><strong>97</strong></td>
          <td>90</td>
          <td><strong>97</strong></td>
          <td>86</td>
          <td>56</td>
      </tr>
      <tr>
          <td><strong>M-EM</strong> (USPTO-10K-L)</td>
          <td>55</td>
          <td>55</td>
          <td><strong>75</strong></td>
          <td>66</td>
          <td>31</td>
      </tr>
      <tr>
          <td><strong>S-F1</strong> (USPTO-Markush)</td>
          <td><strong>88</strong></td>
          <td>86</td>
          <td>74</td>
          <td>10</td>
          <td>35</td>
      </tr>
      <tr>
          <td><strong>M-EM</strong> (USPTO-Markush)</td>
          <td>82</td>
          <td><strong>86</strong></td>
          <td>70</td>
          <td>11</td>
          <td>30</td>
      </tr>
      <tr>
          <td><strong>Avg Retrieval Rank</strong></td>
          <td><strong>95/500</strong></td>
          <td>181-241/500</td>
          <td>138-185/500</td>
          <td>N/A</td>
          <td>N/A</td>
      </tr>
  </tbody>
</table>
<p>Note: Retrieval rank ranges reflect the best and worst fingerprint method pairing for each OCSR model (RDKit Daylight or MHFP).</p>
<ul>
<li>
<p><strong>Effective Handling of Scale and Complexity</strong>: The instance segmentation approach successfully managed large molecules and complex structures where traditional graph-reconstruction methods struggled. The Substructure F1-scores on USPTO-10K-L and USPTO-Markush benchmarks demonstrated clear advantages for challenging molecular targets.</p>
</li>
<li>
<p><strong>Markush Structure Processing</strong>: SubGrapher achieves the highest Substructure F1-score on Markush structures (88 vs. MolScribe&rsquo;s 86 and OSRA&rsquo;s 74). While other OCSR methods can attempt Markush images, they support only limited features such as abbreviation-based variable groups. SubGrapher handles complex Markush features more effectively, expanding the scope of automatically extractable chemical information from patent literature.</p>
</li>
<li>
<p><strong>Robust Molecule Retrieval Performance</strong>: The visual fingerprinting approach achieved reliable retrieval performance (average rank 95/500) across diverse molecular families. The key advantage was consistency: SubGrapher generates meaningful fingerprints even from partial or uncertain predictions, while OCSR-based methods often fail to produce any usable output.</p>
</li>
<li>
<p><strong>Practical Document Mining Capability</strong>: The successful identification of specific Markush structures within large patent documents (54 pages, 356 images) demonstrates real-world applicability for large-scale literature mining and intellectual property analysis.</p>
</li>
<li>
<p><strong>Single-Stage Architecture Benefits</strong>: By eliminating the traditional image → structure → fingerprint pipeline, SubGrapher avoids error accumulation from failed molecular reconstructions. Every input image produces a fingerprint, making the system more reliable for batch processing of diverse document collections.</p>
</li>
<li>
<p><strong>Limitations and Scope</strong>: The method remains focused on common organic functional groups and may struggle with inorganic chemistry, organometallic complexes, or highly specialized molecular classes not well-represented in the training data. The 1,534 functional groups, while extensive, represent a curated subset of chemical space. SubGrapher also cannot distinguish enantiomers, as the detected substructures lack stereochemistry information. Additionally, the method currently cannot recognize substructures in abbreviations or single-atom fragments.</p>
</li>
</ul>
<p>The work demonstrates that direct fingerprint generation can be more robust and practical than traditional structure reconstruction approaches. SubGrapher&rsquo;s ability to handle Markush structures and degraded images makes it particularly valuable for patent analysis and large-scale document mining, where traditional OCSR methods frequently fail. The approach suggests that task-specific learning (fingerprints for retrieval) can outperform general-purpose reconstruction methods in many practical applications.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Training Data Generation</strong>: The paper developed a custom synthetic data pipeline since no public datasets existed with pixel-level mask annotations for functional groups:</p>
<ul>
<li><strong>Extended MolDepictor</strong>: Enhanced molecular rendering tool to generate both images and corresponding segmentation masks for all substructures</li>
<li><strong>Markush Structure Rendering</strong>: Pipeline extended to handle complex generic structures</li>
<li><strong>Source Molecules</strong>: PubChem for broad chemical diversity</li>
</ul>
<p><strong>Evaluation Benchmarks</strong>:</p>
<ul>
<li><strong>JPO Dataset</strong>: Real patent images with poor resolution, noise, and artifacts</li>
<li><strong>USPTO-10K-L</strong>: Large complex molecular structures</li>
<li><strong>USPTO-Markush</strong>: Generic structures with variable R-groups</li>
<li><strong>Retrieval Benchmarks</strong>: Five datasets (adenosine, camphor, cholesterol, limonene, pyridine), each with 500 similar molecular images</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Architecture</strong>: Dual instance segmentation system using Mask-RCNN</p>
<ul>
<li><strong>Functional Group Detector</strong>: Mask-RCNN trained to identify 1,534 expert-defined functional groups</li>
<li><strong>Carbon Backbone Detector</strong>: Mask-RCNN trained to recognize 27 common carbon chain patterns</li>
<li><strong>Backbone Network</strong>: Not specified in the paper</li>
</ul>
<p><strong>Functional Group Knowledge Base</strong>: 1,534 substructures systematically defined by:</p>
<ul>
<li>Starting with chemically logical atom combinations (C, O, S, N, B, P)</li>
<li>Expanding to include relevant subgroups and variations</li>
<li>Filtering based on frequency (appearing ~1,000+ times in PubChem)</li>
<li>Manual curation with SMILES, SMARTS, and descriptive names</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Functional Group Definition</strong>:</p>
<ul>
<li><strong>1,534 Functional Groups</strong>: Defined by manually curated SMARTS patterns
<ul>
<li>Must contain heteroatoms (O, N, S, P, B)</li>
<li>Frequency threshold: ~1,000+ occurrences in PubChem</li>
<li>Systematically constructed from chemically logical atom combinations</li>
<li>Manual curation with SMILES, SMARTS, and descriptive names</li>
</ul>
</li>
<li><strong>27 Carbon Backbones</strong>: Patterns of 3-6 carbon atoms (rings and chains) to capture molecular scaffolds</li>
</ul>
<p><strong>Substructure-Graph Construction</strong>:</p>
<ol>
<li>Detect functional groups and carbon backbones using Mask-RCNN models</li>
<li>Build connectivity graph:
<ul>
<li>Each node represents an identified substructure instance</li>
<li>Edges connect substructures whose bounding boxes overlap</li>
<li>Bounding boxes expanded by 10% of smallest box&rsquo;s diagonal to ensure connectivity between adjacent groups</li>
<li>Carbon chain intersection coefficients divided by 2, giving functional groups higher effective weight</li>
</ul>
</li>
</ol>
<p><strong>SVMF Fingerprint Generation</strong>:</p>
<ul>
<li>Matrix form: $SVMF(m) \in \mathbb{R}^{n \times n}$ where $n=1561$</li>
<li>Stored as compressed sparse upper triangular matrix</li>
<li><strong>Diagonal elements</strong>: $SVMF_{ii} = h_1 \cdot n_i + g_{ii}$ where $h_1 = 10$</li>
<li><strong>Off-diagonal elements</strong>: $SVMF_{ij} = h_2(d) \cdot \text{intersection}(s_i, s_j)$ where:
<ul>
<li>$h_2(d) = 2$ for $d = 0, 1$</li>
<li>$h_2(2) = 2/4$, $h_2(3) = 2/16$, $h_2(4) = 2/256$</li>
<li>$h_2(d) = 0$ for $d &gt; 4$</li>
</ul>
</li>
<li>Average sparsity: 0.001% non-zero elements</li>
<li>Similarity metric: Normalized Euclidean distance (L2 norm of difference divided by L2 norm of sum)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>Substructure F1-score (S-F1)</strong>: Harmonic mean of precision and recall for individual substructure detection across all molecules in the dataset</li>
<li><strong>Molecule Exact Match (M-EM)</strong>: Percentage of molecules where S-F1 = 1.0 (all substructures correctly identified)</li>
<li><strong>Retrieval Rank</strong>: Average rank of ground truth molecule in candidate list of 500 similar structures when querying with SMILES fingerprint, averaged across 50 queries per benchmark</li>
</ul>
<p><strong>Baselines</strong>: Compared against SOTA OCSR methods:</p>
<ul>
<li>Deep learning: MolScribe, MolGrapher, DECIMER</li>
<li>Rule-based: OSRA</li>
<li>Fingerprint methods: RDKit Daylight, MHFP (applied to OCSR outputs)</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper. Training and inference hardware details are not provided in the main text or would be found in supplementary materials.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/DS4SD/SubGrapher">SubGrapher (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official inference code</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/docling-project/SubGrapher">SubGrapher (HuggingFace)</a></td>
          <td>Model</td>
          <td>MIT</td>
          <td>Pre-trained model weights</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/datasets/docling-project/SubGrapher-Datasets">SubGrapher-Datasets (HuggingFace)</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>Visual fingerprinting benchmark datasets</td>
      </tr>
  </tbody>
</table>
<h3 id="implementation-gaps">Implementation Gaps</h3>
<p>The following details are not available in the paper and would require access to the code repository or supplementary information:</p>
<ul>
<li>Specific backbone architecture for Mask-RCNN (ResNet variant, Swin Transformer, etc.)</li>
<li>Optimizer type (AdamW, SGD, etc.)</li>
<li>Learning rate and scheduler</li>
<li>Batch size and number of training epochs</li>
<li>Loss function weights (box loss vs. mask loss)</li>
<li>GPU/TPU specifications used for training</li>
<li>Training time and computational requirements</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Morin, L., Meijer, G. I., Weber, V., Van Gool, L., &amp; Staar, P. W. J. (2025). SubGrapher: Visual fingerprinting of chemical structures. Journal of Cheminformatics, 17(1), 149. <a href="https://doi.org/10.1186/s13321-025-01091-4">https://doi.org/10.1186/s13321-025-01091-4</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics (2025)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{morinSubGrapherVisualFingerprinting2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{SubGrapher: Visual Fingerprinting of Chemical Structures}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{SubGrapher}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Morin, Lucas and Meijer, Gerhard Ingmar and Weber, Valéry and Van Gool, Luc and Staar, Peter W. J.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{17}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{149}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-025-01091-4}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Synthetic Isomer Data Generation Pipeline</title><link>https://hunterheidenreich.com/projects/isomer-dataset-generation/</link><pubDate>Sat, 09 Mar 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/projects/isomer-dataset-generation/</guid><description>An end-to-end cheminformatics pipeline transforming 1D chemical formulas into 3D conformer datasets using graph enumeration and physics-based featurization.</description><content:encoded><![CDATA[<h2 id="overview">Overview</h2>
<p>In computational drug discovery, data scarcity is often the bottleneck. This project builds a synthetic data generator that creates labeled 3D molecular datasets starting from nothing but a raw chemical formula (e.g., $C_6H_{14}$).</p>
<p>The pipeline bridges the gap between <strong>1D Chemical Information</strong> (stoichiometry) and <strong>3D Geometric Data</strong> (conformers), effectively serving as a &ldquo;data factory&rdquo; for training molecular machine learning models.</p>
<h2 id="features">Features</h2>
<h3 id="1-graph-enumeration--3d-embedding">1. Graph Enumeration &amp; 3D Embedding</h3>
<p>The core of the project is <code>pysomer/data/gen.py</code>, which orchestrates a multi-step generation process:</p>
<ul>
<li><strong>Structural Isomerism:</strong> Uses <strong>MAYGEN</strong> (via a Java bridge) to mathematically enumerate all valid graph connectivities for a given formula</li>
<li><strong>Conformer Sampling:</strong> Uses <strong>RDKit</strong> to embed these graphs into 3D space, generating multiple conformers (rotamers) per isomer to capture flexibility</li>
<li><strong>IUPAC Labeling:</strong> Automatically queries PubChem APIs to assign human-readable labels (e.g., &ldquo;2-methylpentane&rdquo;) to the generated structures</li>
</ul>
<h3 id="2-physics-aware-featurization">2. Physics-Aware Featurization</h3>
<p>The pipeline computes <strong>Coulomb Matrices</strong>, ensuring the input respects physical invariants:</p>
<p>$$C_{ij} = \begin{cases} 0.5 Z_i^{2.4} &amp; i = j \ \frac{Z_i Z_j}{|R_i - R_j|} &amp; i \neq j \end{cases}$$</p>
<p>This representation encodes the electrostatic potential of the molecule, providing a more informative signal for the neural network than raw Cartesian coordinates.</p>
<h3 id="3-hdf5-data-storage">3. HDF5 Data Storage</h3>
<p>To handle the large volume of generated conformers, the system writes to hierarchical <strong>HDF5</strong> files. This allows for efficient, chunked I/O during training, a critical pattern for scaling to larger chemical spaces.</p>
<h2 id="usage">Usage</h2>
<p>The pipeline is executed via a CLI, taking a chemical formula as input and outputting an HDF5 dataset of 3D conformers.</p>
<h2 id="results">Results</h2>
<p>This project serves as a &ldquo;vertical slice&rdquo; of a cheminformatics workflow.</p>
<ul>
<li><strong>The Good:</strong> The separation of concerns is clean: <code>dataclasses</code> for configuration and HDF5 for storage keep the data-engineering layer tidy and extensible.</li>
<li><strong>The &ldquo;Old School&rdquo;:</strong> The model used is a simple Multi-Layer Perceptron (MLP) on flattened Coulomb Matrices. In a modern production setting (post-2020), I would replace this with an <strong>E(3)-Equivariant GNN</strong> (like SchNet or E3NN) to handle rotational symmetry natively, eliminating manual feature engineering.</li>
<li><strong>Dependency Management:</strong> The reliance on an external Java JAR (<code>MAYGEN</code>) for graph enumeration makes the environment brittle. Today, I would likely swap this for a pure Python enumerator or a containerized microservice to improve portability.</li>
</ul>
<h2 id="related-work">Related Work</h2>
<p>This data pipeline powers the analysis in my comprehensive guide on molecular representation:</p>
<ul>
<li><a href="/posts/alkane-constitutional-isomer-classification/">Coulomb Matrix Eigenvalues: Can You Hear the Shape of a Molecule?</a>: A deep dive into data generation, unsupervised clustering, and supervised classification of alkane isomers.</li>
</ul>
<p>See also:</p>
<ul>
<li><a href="/posts/molecular-descriptor-coulomb-matrix/">The Coulomb Matrix</a>: Deep dive into the physics-based featurization used here</li>
<li><a href="/notes/chemistry/molecular-representations/notations/number-of-isomeric-hydrocarbons/">The Number of Isomeric Hydrocarbons</a>: The foundational 1931 paper on alkane enumeration</li>
</ul>
]]></content:encoded></item><item><title>Modern PyTorch VAEs: A Detailed Implementation Guide</title><link>https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/</link><pubDate>Sun, 03 Mar 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/</guid><description>Complete PyTorch VAE tutorial: Copy-paste code, ELBO derivation, KL annealing, and stable softplus parameterization.</description><content:encoded><![CDATA[<h2 id="what-is-a-variational-autoencoder">What is a Variational Autoencoder?</h2>
<p>A Variational Autoencoder (VAE) is a type of <strong>generative model</strong>, meaning its primary purpose is to learn the underlying structure of a dataset so it can generate new, similar data.</p>
<p>Whether the data is images, raw audio clips, or 2D graphs of drug-like molecules, a VAE aims to capture the essential features that define the data distribution. Once trained, it should be able to create entirely new samples that resemble the training data without simply copying specific examples.</p>
<p>Introduced by Kingma and Welling in 2013 (<a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">Auto-Encoding Variational Bayes</a>, <a href="https://arxiv.org/abs/1312.6114">Paper</a>), VAEs are used for:</p>
<ul>
<li><strong>Generation</strong>: Creating new data (images, music, text).</li>
<li><strong>Dimensionality Reduction</strong>: Compressing data into a much smaller, meaningful representation (a &ldquo;latent space&rdquo;).</li>
<li><strong>Imputation</strong>: Intelligently filling in missing data (e.g., denoising images).</li>
</ul>
<p>Importantly, they aim to provide a structured and continuous latent space, which allows for smooth interpolation between data points and meaningful manipulations of generated samples (think: optimization).</p>
<h2 id="tldr-the-complete-pytorch-implementation">TL;DR: The Complete PyTorch Implementation</h2>
<p>For those who just want the code, here is a complete, modern VAE implementation in PyTorch. It features <strong>softplus standard deviation parameterization</strong> for numerical stability and a <strong>custom training step</strong> that handles the ELBO loss correctly.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn.functional <span style="color:#66d9ef">as</span> F
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> dataclasses <span style="color:#f92672">import</span> dataclass
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEOutput</span>:
</span></span><span style="display:flex;"><span>    z: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    mu: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    x_recon: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss_recon: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss_kl: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAE</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, input_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">784</span>, hidden_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">512</span>, latent_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">16</span>):
</span></span><span style="display:flex;"><span>        super()<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>encoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(input_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh()
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>fc_mu <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Linear(hidden_dim, latent_dim)
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>fc_std <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Linear(hidden_dim, latent_dim)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>decoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(latent_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, input_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">encode</span>(self, x):
</span></span><span style="display:flex;"><span>        h <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encoder(x)
</span></span><span style="display:flex;"><span>        mu <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>fc_mu(h)
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Softplus + epsilon for stable std deviation</span>
</span></span><span style="display:flex;"><span>        std <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>softplus(self<span style="color:#f92672">.</span>fc_std(h)) <span style="color:#f92672">+</span> <span style="color:#ae81ff">1e-6</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu, std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">reparameterize</span>(self, mu, std):
</span></span><span style="display:flex;"><span>        eps <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>randn_like(std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu <span style="color:#f92672">+</span> eps <span style="color:#f92672">*</span> std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">decode</span>(self, z):
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>decoder(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, x, kl_weight<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>        mu, std <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encode(x)
</span></span><span style="display:flex;"><span>        z <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>reparameterize(mu, std)
</span></span><span style="display:flex;"><span>        x_recon <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>decode(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 1. Reconstruction Loss (Binary Cross Entropy for MNIST)</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Sum over features, mean over batch</span>
</span></span><span style="display:flex;"><span>        recon_loss <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>binary_cross_entropy_with_logits(x_recon, x, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>mean()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 2. KL Divergence</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Analytic KL for Normal distributions</span>
</span></span><span style="display:flex;"><span>        kl_loss <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(<span style="color:#ae81ff">1</span> <span style="color:#f92672">+</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> mu<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">-</span> std<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>mean()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 3. Total Loss (ELBO)</span>
</span></span><span style="display:flex;"><span>        loss <span style="color:#f92672">=</span> recon_loss <span style="color:#f92672">+</span> (kl_weight <span style="color:#f92672">*</span> kl_loss)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> VAEOutput(z, mu, std, x_recon, loss, recon_loss, kl_loss)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Training Loop Example ---</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">train_step</span>(model, batch, optimizer, kl_weight<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>    model<span style="color:#f92672">.</span>train()
</span></span><span style="display:flex;"><span>    optimizer<span style="color:#f92672">.</span>zero_grad()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Forward pass</span>
</span></span><span style="display:flex;"><span>    output <span style="color:#f92672">=</span> model(batch, kl_weight)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Backward pass</span>
</span></span><span style="display:flex;"><span>    output<span style="color:#f92672">.</span>loss<span style="color:#f92672">.</span>backward()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Gradient clipping (recommended)</span>
</span></span><span style="display:flex;"><span>    torch<span style="color:#f92672">.</span>nn<span style="color:#f92672">.</span>utils<span style="color:#f92672">.</span>clip_grad_norm_(model<span style="color:#f92672">.</span>parameters(), max_norm<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    optimizer<span style="color:#f92672">.</span>step()
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> output<span style="color:#f92672">.</span>loss<span style="color:#f92672">.</span>item()
</span></span></code></pre></div><h3 id="the-core-idea-learning-to-generate">The Core Idea: Learning to Generate</h3>
<p>The VAE is built on a key assumption: our complex, high-dimensional data (like a $28 \times 28$ pixel image, $\mathbf{x}$) is actually <em>generated</em> by some simpler, low-dimensional, unobserved variable (a &ldquo;latent&rdquo; variable, $\mathbf{z}$).</p>
<blockquote>
<p><strong>A Physical Metaphor: Water Molecules and Phase Diagrams</strong></p>
<p>Consider a glass of water. At the microscopic level, you have more than $10^{24}$ $\text{H}_2\text{O}$ molecules bouncing around in an incredibly high-dimensional space. Each molecule has position, velocity, and interactions with its neighbors, computationally intractable to track directly. Yet we can describe the <em>macroscopic behavior</em> of all these molecules using just two simple variables: <strong>temperature</strong> and <strong>pressure</strong>. These two dimensions create a &ldquo;phase diagram&rdquo; that tells us whether our water will be ice, liquid, or vapor. The temperature and pressure are &ldquo;latent variables&rdquo; that capture the essential physics governing this complex molecular dance.</p></blockquote>















<figure class="post-figure center ">
    <img src="/img/vae-tut/phase-diagram.webp"
         alt="Water phase diagram showing solid, liquid, and gas phases as functions of temperature and pressure"
         title="Water phase diagram showing solid, liquid, and gas phases as functions of temperature and pressure"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A water phase diagram: Complex molecular behavior reduced to two simple variables (temperature and pressure). This illustrates how high-dimensional systems can often be understood through low-dimensional latent representations.</figcaption>
    
</figure>

<p>A VAE makes the same assumption: complex data (like images) emerges from simpler underlying factors. A handwritten digit might be generated by latent factors like &ldquo;pen thickness,&rdquo; &ldquo;writing angle,&rdquo; &ldquo;digit style,&rdquo; and &ldquo;size,&rdquo; a much simpler description than tracking all 784 pixel values independently.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/hypothetical-mnist-factors.webp"
         alt="Hypothetical illustration of MNIST digits generated from latent factors like pen thickness, angle, style, and size"
         title="Hypothetical illustration of MNIST digits generated from latent factors like pen thickness, angle, style, and size"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A hypothetical illustration showing how MNIST digits could be generated from a few latent factors like pen thickness, writing angle, digit style, and size.</figcaption>
    
</figure>

<p>The VAE learns two functions: one that maps from complex data ($\mathbf{x}$) to these descriptive factors ($\mathbf{z}$), and another that maps from these factors back to the data. It accomplishes this with two main components, typically implemented as neural networks:</p>
<p><strong>1. The Encoder (Recognition Model)</strong></p>
<p>This network takes a complex data point $\mathbf{x}$ (an image) and determines the &ldquo;knob settings&rdquo; $\mathbf{z}$ that could explain or generate it. This allows us to <em>compress</em> or <em>understand</em> the data.</p>
<p>$$q_{\phi}(\mathbf{z} | \mathbf{x})$$</p>
<p>It&rsquo;s like examining a container of molecules and summarizing their complex arrangement into key parameters like temperature and pressure.</p>
<p>Crucially, the encoder outputs the <em>parameters</em> of a probability distribution (a simple Gaussian) that describes $\mathbf{z}$.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/encoding-diagram.webp"
         alt="Diagram mapping MNIST five to a Gaussian distribution in latent space with mean and standard deviation"
         title="Diagram mapping MNIST five to a Gaussian distribution in latent space with mean and standard deviation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Encoder maps an input image (e.g., an MNIST digit &lsquo;5&rsquo;) to a Gaussian distribution in latent space, characterized by a mean vector and a standard deviation vector.</figcaption>
    
</figure>

<p>For each input $\mathbf{x}$, the encoder network outputs:</p>
<ul>
<li>A vector of means, $\mathbf{\mu}$</li>
<li>A vector of standard deviations, $\mathbf{\sigma}$</li>
</ul>
<p>These parameters define our approximation $q_{\phi}(\mathbf{z} | \mathbf{x}) = \mathcal{N}(\mathbf{z} \mid \mathbf{\mu}, \mathbf{\sigma}^2\mathbf{I})$. We then <em>sample</em> from this distribution to get the $\mathbf{z}$ that we feed to the decoder. This probabilistic step is what forces the latent space to be continuous and structured. It forces similar inputs to map to nearby regions in latent space, enabling smooth interpolation and generation.</p>
<p><strong>2. The Decoder (Generative Model)</strong></p>
<p>This network learns the &ldquo;generative process.&rdquo; It takes a simple latent vector $\mathbf{z}$ and reconstructs the complex data $\mathbf{x}$. This allows us to <em>generate</em> new data by feeding it a random $\mathbf{z}$ and observing what image $\mathbf{x}$ it produces.</p>
<p>$$p_{\theta}(\mathbf{x} | \mathbf{z})$$</p>
<p>The decoder reverses the encoder: it takes the simple latent representation and &ldquo;paints&rdquo; the full, complex image from it. It&rsquo;s like taking temperature and pressure values and producing a detailed arrangement of water molecules consistent with those conditions. The goal is to reproduce the exact input as closely as possible.</p>
<p>After training, we have two networks that can be used for a variety of purposes:</p>
<ul>
<li><strong>Generation</strong>: If the latent space is well-structured, we can sample random $\mathbf{z}$ vectors from a simple distribution (like a standard normal) and feed them into the Decoder to generate new images. This is particularly useful for searching for data points with desired properties, like in drug discovery, where we might want to generate molecules with specific characteristics.</li>
<li><strong>Compression</strong>: The Encoder can compress complex data into a low-dimensional latent space, which can be useful for visualization or as a feature extractor for other tasks.</li>
</ul>
<h3 id="the-variational-problem">The &ldquo;Variational&rdquo; Problem</h3>
<p>Calculating the <em>true</em> distribution of latent variables $p_{\theta}(\mathbf{z}|\mathbf{x})$ (the posterior) is mathematically intractable.</p>
<p>This intractability arises from Bayes&rsquo; theorem:</p>
<p>$$p_{\theta}(\mathbf{z} | \mathbf{x}) = \frac{p_{\theta}(\mathbf{x} | \mathbf{z}) p_{\theta}(\mathbf{z})}{p_{\theta}(\mathbf{x})}$$</p>
<p>Breaking down each component:</p>
<ul>
<li>$p_{\theta}(\mathbf{x} | \mathbf{z})$ is our decoder, which is straightforward to compute given our likelihood model.</li>
<li>$p_{\theta}(\mathbf{z})$ is our prior over latent variables, typically a simple distribution like a standard normal, making it easy to compute.</li>
<li>$p_{\theta}(\mathbf{x})$ is the marginal likelihood of the data. And here lies the problem. It requires integrating over all possible latent variables that could have generated $\mathbf{x}$:
$$p_{\theta}(\mathbf{x}) = \int p_{\theta}(\mathbf{x} | \mathbf{z}) p_{\theta}(\mathbf{z}) d\mathbf{z}$$
It is the normalization factor that ensures the posterior is a valid probability distribution (i.e., sums to 1 over all $\mathbf{z}$).</li>
</ul>
<p>This integral is intractable because it involves integrating over a high-dimensional latent space with a complex likelihood function. No closed-form solution exists, and numerical integration is computationally prohibitive.</p>
<p>This is where the &ldquo;variational&rdquo; approach provides the solution. We approximate the true posterior by learning an encoder, $q_{\phi}(\mathbf{z} | \mathbf{x})$, that serves as a variational approximation to this intractable true distribution. The VAE&rsquo;s training process optimizes this approximation to be as accurate as possible, pushing this learned distribution closer to the true posterior.</p>
<h3 id="the-vae-objective-a-balancing-act">The VAE Objective: A Balancing Act</h3>
<p>To get these two networks (parameterized by $\theta$ and $\phi$) to work together, we train them jointly with a special loss function. This objective has two parts that balance two different goals:</p>
<h4 id="1-reconstruction-loss">1. Reconstruction Loss</h4>
<p>$$E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})]$$</p>
<p>This term asks: &ldquo;How well can we reconstruct our original image?&rdquo; It forces the VAE to be good at its job. The process goes:</p>
<ol>
<li>Take an input point $\mathbf{x}$.</li>
<li>Use the <strong>Encoder</strong> to get its latent representation $\mathbf{z} \sim q_{\phi}(\mathbf{z} | \mathbf{x})$.</li>
<li>Use the <strong>Decoder</strong> to generate a new image $\mathbf{x}&rsquo;$ from $\mathbf{z}$, $\mathbf{x}&rsquo; \sim p_{\theta}(\mathbf{x} | \mathbf{z})$.</li>
<li>Compare $\mathbf{x}$ and $\mathbf{x}&rsquo;$.</li>
</ol>
<p>The reconstruction loss measures the difference between the original and the reconstructed image.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/reconstruction-loss-graphic.webp"
         alt="Graphic illustrating the reconstruction loss between original and reconstructed images"
         title="Graphic illustrating the reconstruction loss between original and reconstructed images"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Reconstruction Loss measures how closely the Decoder&rsquo;s output matches the original input image.</figcaption>
    
</figure>

<ul>
<li><strong>For continuous inputs</strong> (like general images), this is often Mean Squared Error (MSE).</li>
<li><strong>For inputs in $[0, 1]$</strong> (like MNIST pixel intensities, which after <code>ToTensor()</code> are continuous values in $[0, 1]$), we use Binary Cross-Entropy (BCE). We treat each pixel as an independent Bernoulli variable whose target is its intensity. The decoder outputs the <em>logits</em> for each pixel, and the BCE-with-logits loss (e.g., <code>F.binary_cross_entropy_with_logits</code>) is the numerically stable way to compute the negative log-likelihood.</li>
<li><strong>More generally</strong>, you can output parameters of a desired output distribution. What if you wanted a mixture of Gaussians? The decoder could output the means, variances, and mixture weights, and you could compute the negative log-likelihood accordingly.</li>
</ul>
<p>This loss pushes the encoder to produce useful $\mathbf{z}$ vectors and pushes the decoder to learn how to interpret them accurately.</p>
<h4 id="2-the-kl-divergence-the-regularizer">2. The KL Divergence (The Regularizer)</h4>
<p>$$D_{KL}(q_{\phi}(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>On its own, the reconstruction loss might &ldquo;cheat.&rdquo; The encoder could learn to map every image to a different, specific point in the latent space, essentially &ldquo;memorizing&rdquo; the data. While this minimizes reconstruction error, it creates a meaningless latent space that fails at generation.</p>
<p>The KL divergence term fixes this. It&rsquo;s a regularizer that forces the latent space to be organized and smooth.</p>
<p>We force the encoder&rsquo;s output, $q_{\phi}(\mathbf{z} | \mathbf{x})$, to be close to a simple, predefined <em>prior distribution</em>, $p_{\theta}(\mathbf{z})$. This prior is almost always a standard normal distribution because it is mathematically convenient, easy to sample from, and encourages a well-behaved latent space.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/kl-loss-graphic.webp"
         alt="Graphic illustrating the KL divergence between the encoder&#39;s output distribution and the prior distribution"
         title="Graphic illustrating the KL divergence between the encoder&#39;s output distribution and the prior distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The KL Divergence measures how much the Encoder&rsquo;s output distribution diverges from the simple prior distribution.</figcaption>
    
</figure>

<p>This regularization term acts as a penalty, measuring how much the encoder&rsquo;s output distribution diverges from the simple prior. By minimizing this KL divergence, we encourage the model to:</p>
<ul>
<li><strong>Avoid overfitting</strong> by preventing the encoder from memorizing specific locations for each input</li>
<li><strong>Create meaningful clusters</strong> where similar inputs map to nearby regions in the latent space</li>
<li><strong>Maintain continuity</strong> so that points close together in latent space (like different variations of the digit &ldquo;7&rdquo;) decode into visually similar outputs</li>
</ul>
<p>This smooth, structured latent space is what enables generation: we can sample random points from our prior distribution and decode them into realistic new data.</p>
<p>Ultimately, the optimizer finds a balance between these two objectives: reconstructing the data well while keeping the latent space organized and regularized.</p>
<h3 id="the-reparameterization-trick-making-it-all-trainable">The Reparameterization Trick: Making it All Trainable</h3>
<p>We have a problem. The training process requires sampling:</p>
<ol>
<li>Encoder produces $\mathbf{\mu}$ and $\mathbf{\sigma}$.</li>
<li>We <strong>sample</strong> $\mathbf{z} \sim \mathcal{N}(\mathbf{\mu}, \mathbf{\sigma}^2\mathbf{I})$.</li>
<li>Decoder uses $\mathbf{z}$ to reconstruct $\mathbf{x}&rsquo;$.</li>
<li>We calculate the loss.</li>
</ol>
<p>The &ldquo;sampling&rdquo; step is a random, non-differentiable operation. We can&rsquo;t backpropagate the reconstruction loss from the decoder <em>through</em> this random node to update the encoder&rsquo;s weights.</p>
<p>The <strong>reparameterization trick</strong> makes the sampling process differentiable. We generate $\mathbf{z}$ deterministically by sampling a random noise vector and transforming it:</p>
<ol>
<li>Sample a random noise vector $\mathbf{\epsilon}$ from a simple, fixed distribution (e.g., the standard normal $\mathcal{N}(\mathbf{0}, \mathbf{I})$).</li>
<li>Compute $\mathbf{z}$ as: $\mathbf{z} = \mathbf{\mu} + \mathbf{\sigma} \odot \mathbf{\epsilon}$</li>
</ol>
<p>This simple change moves the randomness &ldquo;outside&rdquo; the network. The gradient can now flow deterministically from $\mathbf{z}$ back through the $\mathbf{\mu}$ and $\mathbf{\sigma}$ nodes to the encoder network. This is the key engineering insight that allows us to train the entire model end-to-end with standard backpropagation.</p>
<h3 id="where-does-this-objective-come-from-the-math">Where Does This Objective Come From? (The Math)</h3>
<p>This two-part loss function is derived directly from the goal of maximizing the marginal likelihood of the data, $\log p_{\theta}(\mathbf{x})$.</p>
<p>For a single data point $\mathbf{x}^{(i)}$, we can write:
$$\log p_{\theta}(\mathbf{x}^{(i)}) = D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z} | \mathbf{x}^{(i)})) + \mathcal{L}(\theta, \phi; \mathbf{x}^{(i)})$$</p>
<ul>
<li>The first term is the KL divergence between our encoder&rsquo;s approximation and the (intractable) true posterior. This is non-negative, and unfortunately we cannot compute it.</li>
<li>The second term, $\mathcal{L}$, is the Variational Lower Bound (also known as the Evidence Lower Bound, or ELBO). Since the KL term is $\ge 0$, we know that $\log p_{\theta}(\mathbf{x}^{(i)}) \ge \mathcal{L}$.</li>
</ul>
<p>By maximizing this lower bound $\mathcal{L}$, we push up the &ldquo;floor&rdquo; on the true likelihood of our data. This is a problem we can solve.</p>
<p>When we expand this $\mathcal{L}$ term, we get our famous two-part objective:</p>
<p>$$\mathcal{L}(\theta, \phi; \mathbf{x}^{(i)}) = E_{q_{\phi}(\mathbf{z} | \mathbf{x}^{(i)})}[\log p_{\theta}(\mathbf{x}^{(i)} | \mathbf{z})] - D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z}))$$</p>
<ul>
<li><strong>Term 1:</strong> The expected log-likelihood of reconstructing $\mathbf{x}^{(i)}$ from $\mathbf{z}$. Maximizing this is the same as minimizing the Reconstruction Loss.</li>
<li><strong>Term 2:</strong> The negative KL divergence between our encoder and the simple prior. Maximizing this is the same as minimizing the KL Divergence Loss.</li>
</ul>
<p>Thus, the VAE&rsquo;s objective balances these two critical goals: faithfully reconstructing the data while maintaining a simple, regularized latent structure that is useful for generation.</p>
<h3 id="from-elbo-to-practical-loss">From ELBO to Practical Loss</h3>
<p>Remember, our goal is to <strong>maximize</strong> the ELBO:</p>
<p>$$\mathcal{L}(\theta, \phi; \mathbf{x}) = E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})] - D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>Since deep learning libraries are built to <strong>minimize</strong> a loss function, we simply flip the sign and <strong>minimize the negative ELBO ($-\mathcal{L}$)</strong>.</p>
<p>This gives us our final, practical loss function:</p>
<p>$$\text{Loss} = -\mathcal{L} = -E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})] + D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>This is the function you actually implement. Minimizing this loss achieves both of our goals:</p>
<ol>
<li>It <strong>minimizes the Reconstruction Loss</strong> (which is the same as maximizing the log-likelihood).</li>
<li>It <strong>minimizes the KL Divergence</strong>, forcing the encoder to match the prior.</li>
</ol>
<h2 id="modern-pytorch-vae-implementation">Modern PyTorch VAE Implementation</h2>
<p>Now that we understand the VAE architecture and objective, let&rsquo;s implement a modern VAE in PyTorch. I&rsquo;ll focus primarily on the model and loss function here, though the full code is available <a href="https://github.com/hunter-heidenreich/vae">on GitHub</a>.</p>
<p>My VAE implementation uses an output <code>dataclass</code> and a VAE class extending <code>nn.Module</code>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#e6db74">&#34;&#34;&#34;Variational Autoencoder (VAE) model implementation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> dataclasses <span style="color:#f92672">import</span> dataclass
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn.functional <span style="color:#66d9ef">as</span> F
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_activation</span>(activation: str) <span style="color:#f92672">-&gt;</span> nn<span style="color:#f92672">.</span>Module:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Get activation function by name.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    activation_lower <span style="color:#f92672">=</span> activation<span style="color:#f92672">.</span>lower()
</span></span><span style="display:flex;"><span>    ACTIVATION_MAP <span style="color:#f92672">=</span> {
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;relu&#34;</span>: nn<span style="color:#f92672">.</span>ReLU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;tanh&#34;</span>: nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;sigmoid&#34;</span>: nn<span style="color:#f92672">.</span>Sigmoid(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;leaky_relu&#34;</span>: nn<span style="color:#f92672">.</span>LeakyReLU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;elu&#34;</span>: nn<span style="color:#f92672">.</span>ELU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;gelu&#34;</span>: nn<span style="color:#f92672">.</span>GELU(),
</span></span><span style="display:flex;"><span>    }
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> activation_lower <span style="color:#f92672">not</span> <span style="color:#f92672">in</span> ACTIVATION_MAP:
</span></span><span style="display:flex;"><span>        supported <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;, &#34;</span><span style="color:#f92672">.</span>join(ACTIVATION_MAP<span style="color:#f92672">.</span>keys())
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">raise</span> <span style="color:#a6e22e">ValueError</span>(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Unsupported activation &#39;</span><span style="color:#e6db74">{</span>activation<span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;. Supported: </span><span style="color:#e6db74">{</span>supported<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> ACTIVATION_MAP[activation_lower]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEConfig</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;VAE model configuration specifying architecture and behavior.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    hidden_dim: int
</span></span><span style="display:flex;"><span>    latent_dim: int
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    input_shape: tuple[int, int, int] <span style="color:#f92672">=</span> (<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">28</span>, <span style="color:#ae81ff">28</span>)  <span style="color:#75715e"># Default: MNIST</span>
</span></span><span style="display:flex;"><span>    activation: str <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;tanh&#34;</span>  <span style="color:#75715e"># Default: tanh, what was used in the original VAE paper</span>
</span></span><span style="display:flex;"><span>    use_softplus_std: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>  <span style="color:#75715e"># Whether to use softplus for std parameterization</span>
</span></span><span style="display:flex;"><span>    n_samples: int <span style="color:#f92672">=</span> <span style="color:#ae81ff">1</span>  <span style="color:#75715e"># Number of latent samples per input during training</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEOutput</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;VAE forward pass output containing all relevant tensors and optional losses.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    x_logits: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    z: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    mu: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    x_recon: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss_recon: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss_kl: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAE</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Variational Autoencoder with support for deterministic and probabilistic reconstruction.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    DEFAULT_EPS <span style="color:#f92672">=</span> <span style="color:#ae81ff">1e-8</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, config: VAEConfig) <span style="color:#f92672">-&gt;</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Initialize VAE with given configuration.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            config: VAE configuration specifying architecture and behavior
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        super()<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>config <span style="color:#f92672">=</span> config
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Build encoder: input -&gt; hidden -&gt; latent parameters (mu, sigma)</span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>encoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Flatten(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(
</span></span><span style="display:flex;"><span>                int(torch<span style="color:#f92672">.</span>prod(torch<span style="color:#f92672">.</span>tensor(config<span style="color:#f92672">.</span>input_shape))), config<span style="color:#f92672">.</span>hidden_dim
</span></span><span style="display:flex;"><span>            ),
</span></span><span style="display:flex;"><span>            get_activation(config<span style="color:#f92672">.</span>activation),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(config<span style="color:#f92672">.</span>hidden_dim, config<span style="color:#f92672">.</span>latent_dim <span style="color:#f92672">*</span> <span style="color:#ae81ff">2</span>),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Build decoder: latent -&gt; hidden -&gt; reconstructed input</span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>decoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(config<span style="color:#f92672">.</span>latent_dim, config<span style="color:#f92672">.</span>hidden_dim),
</span></span><span style="display:flex;"><span>            get_activation(config<span style="color:#f92672">.</span>activation),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(
</span></span><span style="display:flex;"><span>                config<span style="color:#f92672">.</span>hidden_dim, int(torch<span style="color:#f92672">.</span>prod(torch<span style="color:#f92672">.</span>tensor(config<span style="color:#f92672">.</span>input_shape)))
</span></span><span style="display:flex;"><span>            ),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Unflatten(<span style="color:#ae81ff">1</span>, config<span style="color:#f92672">.</span>input_shape),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">encode</span>(self, x: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Encode input to latent distribution parameters.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        encoder_output <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encoder(x)
</span></span><span style="display:flex;"><span>        mu, sigma <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>chunk(encoder_output, <span style="color:#ae81ff">2</span>, dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu, sigma
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">decode</span>(self, z: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Decode latent representation to reconstruction logits&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>decoder(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">reparameterize</span>(self, mu: torch<span style="color:#f92672">.</span>Tensor, std: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Apply reparameterization trick for differentiable sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        epsilon <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>randn_like(std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu <span style="color:#f92672">+</span> std <span style="color:#f92672">*</span> epsilon
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        x: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        compute_loss: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">True</span>,
</span></span><span style="display:flex;"><span>        reconstruct: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>,
</span></span><span style="display:flex;"><span>        eps: float <span style="color:#f92672">=</span> DEFAULT_EPS,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> VAEOutput:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Forward pass through the VAE.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            x: Input tensor of shape (batch_size, *input_shape)
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            compute_loss: Whether to compute VAE loss components
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            reconstruct: Whether to return reconstructions or distributions
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            eps: Small epsilon value for numerical stability
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Returns:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            VAEOutput containing all relevant tensors and optionally computed losses
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Prepare input for multiple sampling if needed</span>
</span></span><span style="display:flex;"><span>        x_expanded <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_expand_for_sampling(x) <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples <span style="color:#f92672">&gt;</span> <span style="color:#ae81ff">1</span> <span style="color:#66d9ef">else</span> x
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Encode and sample from latent space</span>
</span></span><span style="display:flex;"><span>        mu, sigma <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encode(x)
</span></span><span style="display:flex;"><span>        std <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_sigma_to_std(sigma, eps<span style="color:#f92672">=</span>eps)
</span></span><span style="display:flex;"><span>        mu_expanded, std_expanded <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_expand_latent_params(mu, std)
</span></span><span style="display:flex;"><span>        z <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>reparameterize(mu_expanded, std_expanded)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Decode latent samples</span>
</span></span><span style="display:flex;"><span>        x_logits <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>decode(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Create output object</span>
</span></span><span style="display:flex;"><span>        output <span style="color:#f92672">=</span> VAEOutput(
</span></span><span style="display:flex;"><span>            x_logits<span style="color:#f92672">=</span>x_logits,
</span></span><span style="display:flex;"><span>            z<span style="color:#f92672">=</span>z,
</span></span><span style="display:flex;"><span>            mu<span style="color:#f92672">=</span>mu,
</span></span><span style="display:flex;"><span>            std<span style="color:#f92672">=</span>std,
</span></span><span style="display:flex;"><span>            x_recon<span style="color:#f92672">=</span>torch<span style="color:#f92672">.</span>sigmoid(x_logits) <span style="color:#66d9ef">if</span> reconstruct <span style="color:#66d9ef">else</span> <span style="color:#66d9ef">None</span>,
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Compute losses if requested</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> compute_loss:
</span></span><span style="display:flex;"><span>            loss, loss_recon, loss_kl <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_loss(
</span></span><span style="display:flex;"><span>                x_expanded, x_logits, mu, sigma, std
</span></span><span style="display:flex;"><span>            )
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss <span style="color:#f92672">=</span> loss
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss_recon <span style="color:#f92672">=</span> loss_recon
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss_kl <span style="color:#f92672">=</span> loss_kl
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> output
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># ==================== Helper Methods ====================</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_sigma_to_std</span>(
</span></span><span style="display:flex;"><span>        self, sigma: torch<span style="color:#f92672">.</span>Tensor, eps: float <span style="color:#f92672">=</span> DEFAULT_EPS
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Convert sigma parameter to standard deviation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>softplus(sigma) <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>exp(<span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sigma)  <span style="color:#75715e"># sigma represents log-variance</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_expand_for_sampling</span>(self, x: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Expand input tensor for multiple sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        shape_dims <span style="color:#f92672">=</span> [<span style="color:#ae81ff">1</span>] <span style="color:#f92672">*</span> len(self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>input_shape)
</span></span><span style="display:flex;"><span>        x_expanded <span style="color:#f92672">=</span> x<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#f92672">*</span>shape_dims)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> x_expanded<span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, <span style="color:#f92672">*</span>self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>input_shape)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_expand_latent_params</span>(
</span></span><span style="display:flex;"><span>        self, mu: torch<span style="color:#f92672">.</span>Tensor, std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Expand latent parameters for multiple sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples <span style="color:#f92672">==</span> <span style="color:#ae81ff">1</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> mu, std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        mu_expanded <span style="color:#f92672">=</span> (
</span></span><span style="display:flex;"><span>            mu<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>latent_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        std_expanded <span style="color:#f92672">=</span> (
</span></span><span style="display:flex;"><span>            std<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>latent_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu_expanded, std_expanded
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># ==================== Loss Computation ====================</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_loss</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        x: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        x_logits: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        mu: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        sigma: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        std: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute VAE loss components for deterministic reconstruction.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        loss_recon <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_reconstruction_loss(x, x_logits)
</span></span><span style="display:flex;"><span>        loss_kl <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_kl_loss(mu, sigma, std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> loss_recon <span style="color:#f92672">+</span> loss_kl, loss_recon, loss_kl
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_reconstruction_loss</span>(
</span></span><span style="display:flex;"><span>        self, x: torch<span style="color:#f92672">.</span>Tensor, x_logits: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute reconstruction loss using binary cross-entropy.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>binary_cross_entropy_with_logits(
</span></span><span style="display:flex;"><span>            x_logits, x, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;sum&#34;</span>
</span></span><span style="display:flex;"><span>        ) <span style="color:#f92672">/</span> x<span style="color:#f92672">.</span>size(<span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_kl_loss</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        mu: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        sigma: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        std: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        eps: float <span style="color:#f92672">=</span> DEFAULT_EPS,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute KL divergence between latent distribution and standard normal prior.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Analytical KL: KL(N(μ,σ²) || N(0,1)) = 0.5 * Σ(μ² + σ² - 1 - log(σ²))</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># sigma is just the raw output, need to use std directly: σ</span>
</span></span><span style="display:flex;"><span>            kl_per_sample <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(
</span></span><span style="display:flex;"><span>                mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> eps), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>            )
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># sigma represents log-variance parameterization: log(σ²)</span>
</span></span><span style="display:flex;"><span>            kl_per_sample <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> sigma<span style="color:#f92672">.</span>exp() <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> sigma, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> kl_per_sample<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="loss-scaling">Loss Scaling</h3>
<p>Both components of the VAE loss should be summed over data dimensions and averaged over the batch size. A common mistake is using the <code>reduction=&quot;mean&quot;</code> option in PyTorch loss functions, which averages over all elements in the tensor.</p>
<ul>
<li>The <strong>KL Divergence</strong> (<code>loss_kl</code>) is a penalization term. Each dimension of the latent space has the potential to add complexity and deviate from the prior. As you increase the latent dimensionality, you typically see the KL loss increase in magnitude. That&rsquo;s the cost of having a more expressive latent space.</li>
<li>The <strong>Reconstruction Loss</strong> (<code>loss_recon</code>) measures how well the model reconstructs the input data, and it should scale with input dimensionality (this can bias the model toward better reconstruction for higher-dimensional data).</li>
</ul>
<p>In the case of MNIST, if we used <code>reduction=&quot;mean&quot;</code> for BCE, it would be averaged over all $784 \times \text{batch size}$ pixels, making it tiny compared to the KL loss. The KL term would dominate, and the model would learn to ignore the input, potentially leading to posterior collapse.</p>
<p>While modern optimizers can handle a variety of scenarios and you can still learn effective models with imperfect scaling, the original VAE paper used the scaling described above, and I recommend following that convention.</p>
<h3 id="mitigating-posterior-collapse-kl-annealingwarmup">Mitigating Posterior Collapse: KL Annealing/Warmup</h3>
<p>One common issue in training VAEs, especially with powerful decoders (like RNNs or deep CNNs), is <strong>posterior collapse</strong>. This happens when the KL term dominates the loss early in training. The model quickly learns to just output the prior distribution ($q(z|x) \approx p(z)$) to drive the KL loss to zero, effectively ignoring the latent code $z$. The decoder then becomes a powerful autoregressive model that ignores the latent input.</p>
<p>To prevent this, we often use <strong>KL Annealing</strong> (or Warmup). We introduce a weight $\beta$ for the KL term that starts at 0 and slowly increases to 1 over the first $N$ steps or epochs.</p>
<p>$$ \mathcal{L} = \mathcal{L}_{recon} + \beta \cdot D_{KL} $$</p>
<p>This allows the model to focus purely on reconstruction first (using the full latent capacity), and then slowly adds the regularization pressure.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Simple Linear Annealing Scheduler</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_kl_weight</span>(step, total_steps, max_val<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>    val <span style="color:#f92672">=</span> (step <span style="color:#f92672">/</span> total_steps) <span style="color:#f92672">*</span> max_val
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> min(max(val, <span style="color:#ae81ff">0.0</span>), max_val)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># In your training loop:</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> epoch <span style="color:#f92672">in</span> range(epochs):
</span></span><span style="display:flex;"><span>    beta <span style="color:#f92672">=</span> get_kl_weight(epoch, warmup_epochs)
</span></span><span style="display:flex;"><span>    loss <span style="color:#f92672">=</span> recon_loss <span style="color:#f92672">+</span> beta <span style="color:#f92672">*</span> kl_loss
</span></span></code></pre></div><h4 id="parameterizing-standard-deviation">Parameterizing Standard Deviation</h4>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_sigma_to_std</span>(self, sigma: torch<span style="color:#f92672">.</span>Tensor, eps: float <span style="color:#f92672">=</span> DEFAULT_EPS) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Convert sigma parameter to standard deviation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>bound_std <span style="color:#f92672">is</span> <span style="color:#f92672">not</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>sigmoid(sigma) <span style="color:#f92672">*</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>bound_std <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">elif</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>softplus(sigma) <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>exp(<span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sigma)  <span style="color:#75715e"># sigma represents log-variance</span>
</span></span></code></pre></div><p>Parameterizing the mean of the latent distribution is straightforward since $\mu \in \mathbb{R}$. However, the standard deviation $\sigma$ must be strictly positive (as must the variance $\sigma^2$). This type of constrained optimization is challenging for neural networks.</p>
<p><strong>Log-Variance</strong>
One common approach is to have the network output the <strong>log-variance</strong> ($\log \sigma^2$). This is what the original VAE paper did. The idea is to allow the network to output any real number and treat that value as the log-variance, $s = \log \sigma^2$. We can then compute the standard deviation as $\sigma = \exp(0.5 s)$, which is always positive.</p>
<p>The KL divergence formula simplifies nicely with this parameterization:</p>
<p>$$
\text{KL}( \mathcal{N}(\mu, \sigma^2) || \mathcal{N}(0, 1) ) = \frac{1}{2} \sum_{i=1}^d (\mu_i^2 + \sigma_i^2 - 1 - \log \sigma_i^2)
$$</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> s<span style="color:#f92672">.</span>exp() <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> s, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span></code></pre></div><p><strong>Softplus Standard Deviation</strong>
An alternative is to have the network output $\sigma$ directly. This must be handled with care to ensure positivity. Strictly positive activations like <code>softplus</code> are required. Activations like <code>ReLU</code> can output zero, leading to numerical instability (during training) and deterministic behavior (during sampling). Additionally, adding a small epsilon value ensures numerical stability by preventing $\sigma$ from being exactly zero.</p>
<p>The KL divergence formula becomes slightly more complex:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(
</span></span><span style="display:flex;"><span>    mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> eps), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><p><strong>Bounded Standard Deviation</strong>
Another option is to bound the standard deviation to a maximum value using a <code>sigmoid</code> transformation (or similar). This replaces mapping to $(0, \infty)$ with mapping to $(0, \text{bound})$. This helps prevent extremely high variance values that might destabilize training, while limiting the expressiveness of the latent distribution. Like with <code>softplus</code>, adding a small epsilon ensures numerical stability by preventing $\sigma$ from being exactly zero or approaching it too closely.</p>
<p><strong>Gradient Behavior</strong>
All parameterizations can work well in practice and have different gradient behaviors. Think of $g(s)$ as a transformation function from the network output to the proper domain of $\sigma$ (or $\sigma^2$); in the log-variance case, $g(s) = \exp(s)$, while in the softplus case, $g(s) = \text{softplus}(s) + \epsilon$.</p>
<p>The gradient of the loss with respect to these outputs can be written using the chain rule:</p>
<p>$$
\frac{\partial \mathcal{L}}{\partial s} = \frac{\partial \mathcal{L}}{\partial \sigma} \cdot \frac{\partial g(s)}{\partial s}
$$</p>
<p>where $\frac{\partial g(s)}{\partial s}$ is the derivative of the transformation function.</p>
<p>We need to guard against two pathological cases:</p>
<ul>
<li>$\frac{\partial g(s)}{\partial s} \rightarrow 0$: This leads to vanishing gradients, making it hard for the network to learn.</li>
<li>$\frac{\partial g(s)}{\partial s} \rightarrow \infty$: This leads to exploding gradients, causing instability during training and potentially divergence.</li>
</ul>
<p>The log-variance parameterization, with its exponential transformation that is its own derivative, exhibits both issues at extreme values. If $s \rightarrow -\infty$, then $\sigma \rightarrow 0$ and the gradient vanishes. If $s \rightarrow \infty$, then $\sigma \rightarrow \infty$ and the gradient explodes. Since the interval $(0, 1)$ is mapped to $(-\infty, 0)$ in log-space, it&rsquo;s much more difficult for the network to drive $\sigma$ to small values. In practice, exploding gradients at high values have been more problematic in my experience. Gradient clipping, learning rate scheduling, and clamping the log-variance output to a maximum value can help mitigate this.</p>
<p>What about softplus? The derivative of <code>softplus</code> is the <code>sigmoid</code> function, which smoothly maps $(-\infty, \infty)$ to $(0, 1)$. Gradients are always bounded by unity, preventing explosion (barring explosion from other parts of the network). However, as $s \rightarrow -\infty$, the gradient approaches zero, leading to vanishing gradients. Adding a small epsilon helps mitigate this, ensuring that $\sigma$ never gets too close to zero. Nonetheless, learning can still slow down.</p>
<p>For bounded standard deviation, the derivative of the <code>sigmoid</code> function is also bounded, preventing exploding gradients. (The gradient of <code>sigmoid</code> is defined in terms of itself: $\text{sig}&rsquo;(x) = \text{sig}(x)(1 - \text{sig}(x))$; its maximum value is $0.25$ at $x=0$.)</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/gradient_behaviors.webp"
         alt="Graph comparing gradient behaviors of log-variance, softplus, and bounded standard deviation parameterizations"
         title="Graph comparing gradient behaviors of log-variance, softplus, and bounded standard deviation parameterizations"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Gradient behaviors of different standard deviation parameterizations: Log-Variance (exponential), Softplus, and Bounded Standard Deviation (sigmoid). Each has unique characteristics affecting training stability.</figcaption>
    
</figure>

<h2 id="experiments">Experiments</h2>
<h3 id="2d-mnist-vae-with-different-std-dev-parameterizations">2D MNIST VAE with Different Std. Dev. Parameterizations</h3>
<p>First, let&rsquo;s run an experiment that is close to what was done in the original VAE paper. We&rsquo;ll use MNIST as our dataset, a simple feedforward architecture with <code>tanh</code> activations, and the log-variance parameterization for the latent distribution.</p>
<p>Some of the differences from the original paper include:</p>
<ul>
<li>Using a hidden size of 512 (the original used 500)</li>
<li>Using the AdamW optimizer (the original used vanilla Adagrad)</li>
<li>Applying similar weight decay, doing so quite differently due to the optimizer change</li>
<li>Focusing primarily on 2D latent spaces (for now)</li>
</ul>
<p>This results in a network with 807,700 parameters.
I train each model for 150 epochs at most and highlight the best based on the reconstruction loss on the test set.
Just for fun, I sweep across different standard deviation parameterizations and learning rate warmup strategies.</p>
<table>
  <thead>
      <tr>
          <th>Std. Dev. Param</th>
          <th>Warmup Steps</th>
          <th>Test Recon. Loss</th>
          <th>Test KL Loss</th>
          <th>Test Total Loss</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Log-Variance</td>
          <td>0</td>
          <td>140.88</td>
          <td>6.96</td>
          <td>147.84</td>
      </tr>
      <tr>
          <td>Log-Variance</td>
          <td>600</td>
          <td>141.41</td>
          <td>6.63</td>
          <td>148.04</td>
      </tr>
      <tr>
          <td>Softplus</td>
          <td>0</td>
          <td>141.51</td>
          <td>6.56</td>
          <td>148.07</td>
      </tr>
      <tr>
          <td>Softplus</td>
          <td>600</td>
          <td>140.37</td>
          <td>6.67</td>
          <td>147.04</td>
      </tr>
      <tr>
          <td>Bounded Std. Dev. (10)</td>
          <td>0</td>
          <td>140.96</td>
          <td>6.82</td>
          <td>147.79</td>
      </tr>
      <tr>
          <td>Bounded Std. Dev. (10)</td>
          <td>600</td>
          <td>141.78</td>
          <td>6.68</td>
          <td>148.45</td>
      </tr>
  </tbody>
</table>
<p>From this summary table, all three parameterizations work well. The differences in final loss values are quite small. This could be due to the simplicity of the dataset and model architecture, further amplified by forcing the network to compress images into a very low-dimensional latent space (2D).</p>
<p>Since the softplus parameterization with learning rate warmup achieved the best reconstruction loss, let&rsquo;s visualize some of its training dynamics and results more closely.</p>
<h4 id="loss-dynamics">Loss Dynamics</h4>
<p>To understand the VAE&rsquo;s behavior, we must look at the ELBO and its two components: the Reconstruction Loss and the KL Divergence.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-elbo_epochs.webp"
         alt="Plot showing training and testing ELBO across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing ELBO across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Total ELBO: Training and testing ELBO across 150 epochs.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-reconstruction_loss_epochs.webp"
         alt="Plot showing training and testing reconstruction loss across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing reconstruction loss across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Reconstruction Loss: Training and testing reconstruction loss across 150 epochs.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-kl_loss_epochs.webp"
         alt="Plot showing training and testing KL divergence loss across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing KL divergence loss across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">KL Divergence: Training and testing KL divergence loss across 150 epochs.</figcaption>
    
</figure>

<p>These plots reveal a clear narrative:</p>
<ol>
<li><strong>Rapid Initial Learning:</strong> Performance skyrockets in the first ~15 epochs.</li>
<li><strong>Overfitting:</strong> The <strong>Reconstruction Loss</strong> (middle) flatlines for the test set while continuing to improve for training, a classic sign of memorization.</li>
<li><strong>The Balancing Act:</strong> The <strong>KL Divergence</strong> (bottom) initially rises (&ldquo;The Cost of Learning&rdquo;) as the model stretches the latent space to encode digits, then saturates.</li>
<li><strong>Equilibrium:</strong> The total <strong>ELBO</strong> (top) improves slowly, driven by the model finding the optimal trade-off between reconstruction and regularization. Notice that Test and Train KL tracks closely: a sign of good regularization!</li>
</ol>
<p><strong>Visualizing the VAE Trade-Off: BCE vs. KL</strong></p>
<p>While the line plots visualize progress over time, they miss the evolving <em>relationship</em> between our two competing objectives.</p>
<p>A VAE is fundamentally a multi-objective optimization problem. We want to:</p>
<ol>
<li>Minimize Reconstruction Loss (BCE)</li>
<li>Minimize KL Divergence</li>
</ol>
<p>Combining them as the ELBO is common and effective, though it can mask some of the underlying dynamics.</p>
<p>These two goals are in direct conflict. To get perfect reconstruction (BCE = 0), the encoder would need to &ldquo;memorize&rdquo; each input, mapping it to a unique, precise point in latent space. This would cause the KL divergence to skyrocket, as these specific, &ldquo;pointy&rdquo; distributions are nothing like our smooth <code>N(0, 1)</code> prior.</p>
<p>Conversely, to get perfect KL divergence (KL = 0), the encoder must <em>always</em> output <code>N(0, 1)</code>, regardless of the input. This perfectly matches the prior. Since the latent code $\mathbf{z}$ now contains zero information about the input $\mathbf{x}$, the decoder can only learn to output the &ldquo;average&rdquo; image, resulting in terrible reconstruction.</p>
<p>The training process is a search for the best compromise.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-loss_scatter_epochs.webp"
         alt="Scatter plot of Test BCE vs KL Divergence, showing the training path from epoch 0 to 150"
         title="Scatter plot of Test BCE vs KL Divergence, showing the training path from epoch 0 to 150"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The training path on the Test set, plotting Reconstruction Loss (BCE) vs. KL Divergence. The model&rsquo;s journey clearly shows the trade-off between these two objectives.</figcaption>
    
</figure>

<p>This plot shows the test set&rsquo;s BCE (y-axis) vs. KL Divergence (x-axis) at every evaluation step. The color gradient from cool (blue) to warm (red) represents the training progress from Epoch 0 to 150.</p>
<p>Here&rsquo;s how to interpret this training path:</p>
<ol>
<li>
<p><strong>The Start (Green Diamond, ~Epoch 0):</strong> The model starts at the top-left.</p>
<ul>
<li><strong>High BCE (Reconstruction):</strong> The decoder is random and hasn&rsquo;t learned to reconstruct anything. Reconstruction is terrible.</li>
<li><strong>Low KL Divergence:</strong> The <em>encoder</em> is also random. Its output distributions $q_{\phi}(\mathbf{z} | \mathbf{x})$ are a random mess. On average, this &ldquo;mess&rdquo; is coincidentally close to the &ldquo;mess&rdquo; of the prior $p_{\theta}(\mathbf{z})$, so the KL penalty is low. The model isn&rsquo;t encoding any useful information yet, so it&rsquo;s not paying a high price for it.</li>
</ul>
</li>
<li>
<p><strong>Phase 1: The Initial Plunge (Blue Path):</strong> The path moves almost <em>straight down</em>.</p>
<ul>
<li><strong>BCE Plummets:</strong> The model&rsquo;s first and easiest task is to learn to reconstruct <em>something</em>. The optimizer finds massive, easy gains by making the decoder output &ldquo;blurry digits&rdquo; to replace the initial noise.</li>
<li><strong>KL Stays Low:</strong> The model achieves this huge reconstruction win without needing to learn a very complex latent space. It&rsquo;s the &ldquo;low-hanging fruit&rdquo; of training.</li>
</ul>
</li>
<li>
<p><strong>Phase 2: The Trade-Off (The &ldquo;Elbow&rdquo;):</strong> The path stops dropping vertically and starts moving to the <em>right and down</em>.</p>
<ul>
<li><strong>&ldquo;Spending&rdquo; KL to &ldquo;Buy&rdquo; Reconstruction:</strong> This is the true VAE trade-off in action. The easy wins are gone. To make the reconstructions sharper and more accurate (lowering BCE further), the model must now learn a more complex, informative latent representation.</li>
<li>It &ldquo;stretches&rdquo; the latent distributions $q_{\phi}(\mathbf{z} | \mathbf{x})$ to encode more details about each specific digit. This &ldquo;stretching&rdquo; moves it further from the simple <code>N(0, 1)</code> prior, and the KL divergence (the &ldquo;cost&rdquo;) goes up.</li>
</ul>
</li>
<li>
<p><strong>The End Game (Red Path &amp; Star):</strong> The path settles in the bottom-right corner.</p>
<ul>
<li><strong>Finding the &ldquo;Elbow&rdquo;:</strong> The model finds an equilibrium. It has pushed the KL divergence as high as it&rsquo;s &ldquo;worth&rdquo; for the reconstruction gains it gets. Trying to get even better reconstruction (moving further down) would cost an enormous, disproportionate amount in KL divergence (moving far to the right), and the total loss would increase.</li>
<li><strong>Best Recon (Orange Star):</strong> The best reconstruction model (Epoch 118) is found right at this &ldquo;elbow,&rdquo; representing the best-found balance point on the trade-off frontier.</li>
</ul>
</li>
</ol>
<p>This single plot visualizes the entire training dynamic as a journey along the <strong>Pareto frontier</strong>: the set of optimal solutions where you can&rsquo;t improve one objective (BCE) without worsening the other (KL).</p>
<h4 id="generative-performance">Generative Performance</h4>
<p>Let&rsquo;s take a look at how well this model can decode samples.</p>
<p><strong>Reconstruction Performance</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-reconstructions.webp"
         alt="Grid of original and reconstructed MNIST images from the test set using the trained VAE model"
         title="Grid of original and reconstructed MNIST images from the test set using the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Original (top row) vs. Reconstructed (bottom row) MNIST images from the test set using the trained VAE model.</figcaption>
    
</figure>

<p>Immediately, we see a couple of key points:</p>
<ul>
<li>Reconstructions are quite blurry compared to the originals. This is expected given the low capacity of the model and the extreme compression into a 2D latent space. General structure is typically preserved, while fine details are lost.</li>
<li>The network struggles with 4s and 9s, often mixing them up or producing ambiguous shapes. This is a common failure mode in MNIST models due to the similarity of these digits.</li>
</ul>
<p><strong>Sampling from the Prior</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-samples.webp"
         alt="Grid of MNIST-like images generated by sampling from the prior distribution using the trained VAE model"
         title="Grid of MNIST-like images generated by sampling from the prior distribution using the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MNIST-like images generated by sampling from the prior distribution using the trained VAE model.</figcaption>
    
</figure>

<p>If we sample from the prior <code>N(0, 1)</code> and decode those samples, we get a variety of digit-like images. From this, we get a pretty rich representation of digits. Almost all digits appear to be featured in this random sampling. Again, we see the standard blurriness.</p>
<p><strong>Sweeping the Latent Space</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-generation_interpolation.webp"
         alt="Grid of images generated by sweeping across the 2D latent space of the trained VAE model"
         title="Grid of images generated by sweeping across the 2D latent space of the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Images generated by sweeping across the 2D latent space of the trained VAE model.</figcaption>
    
</figure>

<p>We can select two points at random (here, two zeros), embed them into our latent space and then walk across that latent space to interpolate between two data points.
Here, we see a walk that takes us from a zero that is askew to one that is more upright.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-generation_latent_sweep.webp"
         alt="2D latent sweep, varying one dimension at a time while holding the other constant"
         title="2D latent sweep, varying one dimension at a time while holding the other constant"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent sweep, varying one dimension at a time while holding the other constant.</figcaption>
    
</figure>

<p>Finally, we can also sweep each latent dimension independently to see how they affect the generated images.</p>
<ol>
<li>Sweeping <code>z_1</code> (top row), we see a 5 become an 8 and then a 9. The slant shifts from left to right as we sweep the dimension.</li>
<li>Sweeping <code>z_2</code> (bottom row), we see a 4 become a 9 and then an 8. Then it becomes a 3, a 2, some nonsense, and a 6.</li>
</ol>
<p>So clearly each latent dimension is encoding some high-level features of the digits, and we can manipulate those features by moving in latent space.</p>
<h4 id="inspecting-the-latent-space">Inspecting the Latent Space</h4>
<p>What does the actual latent space look like?</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-latent_combined.webp"
         alt="2D latent space visualization with points colored by their true digit labels"
         title="2D latent space visualization with points colored by their true digit labels"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent space visualization with points colored by their true digit labels (left) and 2D heatmap of latent space density (right).</figcaption>
    
</figure>

<p>Even without class information, the network organizes the latent space to encode digit structure effectively. It also becomes immediately apparent why 4s and 9s are so confused by the model. That region is a dense mixture of the two.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-latent_marginals.webp"
         alt="1D histograms of each latent dimension compared to the standard normal distribution"
         title="1D histograms of each latent dimension compared to the standard normal distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">1D histograms of each latent dimension compared to the standard normal distribution.</figcaption>
    
</figure>

<p>We can also look at the marginal distributions of each latent dimension to see how well they match the prior <code>N(0, 1)</code>. Here, <code>z_1</code> is closer to the prior than <code>z_2</code>. <code>z_2</code> exhibits a bimodal marginal distribution, indicating that the encoder is using this dimension to separate two distinct clusters of data.</p>
<p>We also might want to understand how the log-variance of the latent distributions behaves.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-logvar_combined.webp"
         alt="2D latent space visualization with log-variance values and 1D histograms of log-variance for each latent dimension"
         title="2D latent space visualization with log-variance values and 1D histograms of log-variance for each latent dimension"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent space visualization with log-variance values with respect to digit class (left) and 2D heatmap of log-variance magnitude (right).</figcaption>
    
</figure>

<p>For the most part, we see similar concentration. Some digits are more concentrated than others, though in general the difference is slight.</p>
<h3 id="beyond-2d-higher-dimensional-latent-spaces">Beyond 2D: Higher-Dimensional Latent Spaces</h3>
<p>What happens as we increase the latent dimensionality? We must do dimensionality reduction to visualize latent spaces, giving us an approximate sense of how the latent space is organized.</p>
<table>
  <thead>
      <tr>
          <th>Latent Dimensionality</th>
          <th>Test Recon. Loss</th>
          <th>Test KL Loss</th>
          <th>Test Total Loss</th>
          <th>KL per Dim</th>
          <th>Active Dims (KL &gt; 0.1)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>2</td>
          <td>140.37</td>
          <td>6.67</td>
          <td>147.04</td>
          <td>3.34</td>
          <td>2</td>
      </tr>
      <tr>
          <td>4</td>
          <td>114.94</td>
          <td>10.61</td>
          <td>125.56</td>
          <td>2.65</td>
          <td>4</td>
      </tr>
      <tr>
          <td>8</td>
          <td>89.46</td>
          <td>16.84</td>
          <td>106.31</td>
          <td>2.11</td>
          <td>8</td>
      </tr>
      <tr>
          <td>16</td>
          <td>76.57</td>
          <td>23.65</td>
          <td>100.21</td>
          <td>1.48</td>
          <td>16</td>
      </tr>
      <tr>
          <td>32</td>
          <td>74.65</td>
          <td>25.59</td>
          <td>100.25</td>
          <td>0.80</td>
          <td>24</td>
      </tr>
  </tbody>
</table>
<p>As we double the dimensionality, we see a dominant trend at first:</p>
<ul>
<li>The reconstruction loss goes down</li>
<li>The KL loss goes up</li>
<li>The KL loss per dimension goes down</li>
</ul>
<p>Something odd happens when we jump from 16 to 32 latent dimensions: some of our latent dimensions become degenerate and stop encoding useful information.
This could be an indication we need to choose our hyperparameters a little more cautiously. Perhaps we need a different architecture. Or maybe there is an intrinsic limit to the dimensionality needed for this dataset past which it&rsquo;s not really helpful to keep scaling the latent dimension.</p>
<h4 id="training-dynamics">Training Dynamics</h4>















<figure class="post-figure center ">
    <img src="/img/vae-tut/loss_scatter_epochs.webp"
         alt="Scatter plot of Test BCE vs KL Divergence for different latent dimensionalities, showing training paths from epoch 0 to 150"
         title="Scatter plot of Test BCE vs KL Divergence for different latent dimensionalities, showing training paths from epoch 0 to 150"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The training paths on the Test set for different latent dimensionalities, plotting Reconstruction Loss (BCE) vs. KL Divergence. Each path shows the model&rsquo;s journey, clearly illustrating the trade-off between these two objectives.</figcaption>
    
</figure>

<p>The training dynamics show the battle between reconstruction and KL divergence for different latent dimensionalities. As we increase the latent dimensionality, the oscillation in the KL divergence becomes more pronounced. Particularly chaotic is the $D=16$ case, which struggles to find a stable equilibrium. By the time we expand to $D=32$, the KL penalty seems to overpower the ability to encode information in the latent space, leading to many inactive dimensions. The drop in KL complexity has staircase-like steps without clearly gaining reconstruction ability.</p>
<h4 id="reconstruction-and-generation">Reconstruction and Generation</h4>
<p>As we increase the latent dimensionality, the reconstruction quality improves significantly.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/reconstructions.webp"
         alt="Grid of original and reconstructed MNIST images from the test set using trained VAE models with different latent dimensionalities"
         title="Grid of original and reconstructed MNIST images from the test set using trained VAE models with different latent dimensionalities"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Original (top row) vs. Reconstructed (bottom row) MNIST images from the test set using trained VAE models with different latent dimensionalities.</figcaption>
    
</figure>

<p>As we increase the dimensionality, we see the increase in quality we&rsquo;d expect given the reduction in BCE reconstruction loss. In the jump to 4D, we&rsquo;re able to better resolve the differences between 4s and 9s. Images become much sharper by the time we hit 16 dimensions. The differences between 16 and 32 dimensions, however, are marginal.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/samples.webp"
         alt="Grid of MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities"
         title="Grid of MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities.</figcaption>
    
</figure>

<p>Sampling quality also improves with latent dimensionality. Images are sharper as we increase the dimensionality. However, the space seems to get sparser as we increase to the largest dimensionalities, which makes sense given the size and nature of our dataset.</p>
<h4 id="latent-space-visualizations">Latent Space Visualizations</h4>















<figure class="post-figure center ">
    <img src="/img/vae-tut/latent_combined.webp"
         alt="2D PCA projections of higher-dimensional latent spaces colored by their true digit labels"
         title="2D PCA projections of higher-dimensional latent spaces colored by their true digit labels"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D PCA projections of higher-dimensional latent spaces colored by their true digit labels.</figcaption>
    
</figure>

<p>The challenge with visualizing higher-dimensional latent spaces is that we must reduce their dimensionality to 2D. PCA struggles to capture the variance of higher dimensionalities. The 4D and 8D plots suggest increasingly better separation of the numeric classes. However, the 16D and 32D plots only show 10-20% of the variance and give a misleading image of overlap.</p>
<h2 id="conclusion">Conclusion</h2>
<p>In this tutorial, we&rsquo;ve journeyed from the core theory of Variational Autoencoders to a practical, modern PyTorch implementation and a series of experiments on the MNIST dataset. Our findings highlight several key takeaways for practitioners:</p>
<ol>
<li>
<p><strong>The VAE is a Balancing Act:</strong> The fundamental tension between reconstruction fidelity and latent space regularization is the core of the VAE. Our visualization of the BCE vs. KL loss trade-off clearly showed training as a search for an optimal point on this Pareto frontier, where improving one objective necessarily means sacrificing the other.</p>
</li>
<li>
<p><strong>Latent Dimensionality is a Critical Hyperparameter:</strong> Increasing the latent dimension consistently improved reconstruction quality with diminishing returns. As we saw in the jump from 16 to 32 dimensions, too much capacity can lead to &ldquo;inactive&rdquo; dimensions, where the KL penalty overpowers the model&rsquo;s ability to encode useful information. This demonstrates that choosing the right latent size is crucial for both performance and efficiency.</p>
</li>
<li>
<p><strong>VAEs Learn Meaningful Unsupervised Representations:</strong> Without any labels, our VAE successfully organized the latent space, clustering similar digits and enabling smooth interpolations. This underscores the power of VAEs for unsupervised learning, dimensionality reduction, and discovering the underlying structure in complex data.</p>
</li>
<li>
<p><strong>Implementation Details Matter:</strong> While different standard deviation parameterizations yielded similar results on this simple problem, understanding their gradient behaviors is key for tackling more complex datasets where training stability can be a major challenge. Proper loss scaling is similarly crucial to prevent one term from dominating the other and leading to issues like posterior collapse.</p>
</li>
</ol>
<p>While the classic VAE produces characteristically blurry reconstructions, it remains a foundational generative model. The principles we&rsquo;ve explored here (the ELBO, the reparameterization trick, and the trade-off between reconstruction and regularization) are central to many more advanced generative models used today.</p>
<p><strong>Questions or feedback?</strong> Feel free to reach out. I&rsquo;d love to hear about your experiences with VAE experiments!</p>
]]></content:encoded></item><item><title>Sarcasm Detection with Transformers: A Cautionary Tale</title><link>https://hunterheidenreich.com/posts/sarcasm-detection-with-transformers/</link><pubDate>Sun, 25 Feb 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/sarcasm-detection-with-transformers/</guid><description>Learn how dataset bias can lead to misleading results in NLP: a sarcasm detection model that learned to classify news sources.</description><content:encoded><![CDATA[<h2 id="why-sarcasm-detection-is-hard">Why Sarcasm Detection Is Hard</h2>
<p>Sarcasm detection represents one of the most challenging problems in NLP. The difficulties include:</p>
<p><strong>Context dependence</strong>: Sarcasm relies on situational knowledge and shared understanding that extends beyond the text itself.</p>
<p><strong>Subtlety</strong>: Even humans struggle with sarcastic interpretation, especially in written text without vocal cues.</p>
<p><strong>Cultural variability</strong>: Sarcastic expressions vary significantly across cultures and regions.</p>
<p><strong>Annotation disagreement</strong>: Human annotators often disagree on what constitutes sarcasm.</p>
<p>These challenges raise a fundamental question: can sarcasm detection be well-defined as a computational problem? This case study explores what happens when we try (and reveals a common pitfall in dataset construction).</p>
<h2 id="the-dataset-a-hidden-flaw">The Dataset: A Hidden Flaw</h2>
<p>I used the <a href="https://huggingface.co/datasets/raquiba/Sarcasm_News_Headline">Sarcasm News Headlines dataset</a>, which combines headlines from <a href="https://theonion.com/">The Onion</a> (satirical) and <a href="https://www.huffpost.com/">The Huffington Post</a> (traditional news). The dataset contains ~50,000 examples.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> datasets <span style="color:#f92672">import</span> load_dataset
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>dataset <span style="color:#f92672">=</span> load_dataset(<span style="color:#e6db74">&#34;raquiba/Sarcasm_News_Headline&#34;</span>)
</span></span><span style="display:flex;"><span>print(dataset[<span style="color:#e6db74">&#34;train&#34;</span>][<span style="color:#ae81ff">0</span>])
</span></span><span style="display:flex;"><span>print(dataset[<span style="color:#e6db74">&#34;train&#34;</span>][<span style="color:#ae81ff">1</span>])
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-plaintext" data-lang="plaintext"><span style="display:flex;"><span>{&#39;headline&#39;: &#39;thirtysomething scientists unveil doomsday clock of hair loss&#39;,
</span></span><span style="display:flex;"><span> &#39;is_sarcastic&#39;: 1}
</span></span><span style="display:flex;"><span>{&#39;headline&#39;: &#39;dem rep. totally nails why congress is falling short on gender, racial equality&#39;,
</span></span><span style="display:flex;"><span> &#39;is_sarcastic&#39;: 0}
</span></span></code></pre></div><p><strong>The critical flaw</strong>: This dataset uses binary classification based on source domain. The Onion headlines are labeled sarcastic, HuffPost headlines are not. This creates a dangerous shortcut where models learn to detect the publication source.</p>
<p>After preprocessing to standardize column names:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>dataset <span style="color:#f92672">=</span> dataset<span style="color:#f92672">.</span>map(
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">lambda</span> example: {<span style="color:#e6db74">&#34;text&#34;</span>: example[<span style="color:#e6db74">&#34;headline&#34;</span>], <span style="color:#e6db74">&#34;label&#34;</span>: example[<span style="color:#e6db74">&#34;is_sarcastic&#34;</span>]},
</span></span><span style="display:flex;"><span>    remove_columns<span style="color:#f92672">=</span>[<span style="color:#e6db74">&#34;headline&#34;</span>, <span style="color:#e6db74">&#34;article_link&#34;</span>, <span style="color:#e6db74">&#34;is_sarcastic&#34;</span>]
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><h2 id="fine-tuning-roberta">Fine-Tuning RoBERTa</h2>
<p>I fine-tuned a pre-trained RoBERTa model using standard practices:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> transformers <span style="color:#f92672">import</span> AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>model_name <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;FacebookAI/roberta-base&#34;</span>
</span></span><span style="display:flex;"><span>tokenizer <span style="color:#f92672">=</span> AutoTokenizer<span style="color:#f92672">.</span>from_pretrained(model_name)
</span></span><span style="display:flex;"><span>model <span style="color:#f92672">=</span> AutoModelForSequenceClassification<span style="color:#f92672">.</span>from_pretrained(model_name, num_labels<span style="color:#f92672">=</span><span style="color:#ae81ff">2</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Tokenize the data</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">tokenize_function</span>(examples):
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> tokenizer(examples[<span style="color:#e6db74">&#34;text&#34;</span>], truncation<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>, max_length<span style="color:#f92672">=</span><span style="color:#ae81ff">512</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>tokenized_datasets <span style="color:#f92672">=</span> dataset<span style="color:#f92672">.</span>map(tokenize_function, batched<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Training configuration</span>
</span></span><span style="display:flex;"><span>training_args <span style="color:#f92672">=</span> TrainingArguments(
</span></span><span style="display:flex;"><span>    output_dir<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;./results&#34;</span>,
</span></span><span style="display:flex;"><span>    num_train_epochs<span style="color:#f92672">=</span><span style="color:#ae81ff">5</span>,
</span></span><span style="display:flex;"><span>    per_device_train_batch_size<span style="color:#f92672">=</span><span style="color:#ae81ff">32</span>,
</span></span><span style="display:flex;"><span>    evaluation_strategy<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;epoch&#34;</span>,
</span></span><span style="display:flex;"><span>    save_strategy<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;epoch&#34;</span>,
</span></span><span style="display:flex;"><span>    load_best_model_at_end<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>,
</span></span><span style="display:flex;"><span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>trainer <span style="color:#f92672">=</span> Trainer(
</span></span><span style="display:flex;"><span>    model<span style="color:#f92672">=</span>model,
</span></span><span style="display:flex;"><span>    args<span style="color:#f92672">=</span>training_args,
</span></span><span style="display:flex;"><span>    train_dataset<span style="color:#f92672">=</span>tokenized_datasets[<span style="color:#e6db74">&#34;train&#34;</span>],
</span></span><span style="display:flex;"><span>    eval_dataset<span style="color:#f92672">=</span>tokenized_datasets[<span style="color:#e6db74">&#34;test&#34;</span>],
</span></span><span style="display:flex;"><span>    tokenizer<span style="color:#f92672">=</span>tokenizer,
</span></span><span style="display:flex;"><span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>trainer<span style="color:#f92672">.</span>train()
</span></span></code></pre></div><h2 id="results-too-good-to-be-true">Results: Too Good to Be True</h2>
<p>The model achieved high accuracy:</p>
<table>
  <thead>
      <tr>
          <th>Epoch</th>
          <th>Test Accuracy</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1</td>
          <td>96.3%</td>
      </tr>
      <tr>
          <td>2</td>
          <td>97.8%</td>
      </tr>
      <tr>
          <td>3</td>
          <td>99.4%</td>
      </tr>
      <tr>
          <td>4</td>
          <td>99.8%</td>
      </tr>
      <tr>
          <td>5</td>
          <td>99.8%</td>
      </tr>
  </tbody>
</table>
<p>This should immediately raise red flags. Sarcasm detection is notoriously difficult, even for humans. Such high accuracy suggests the model learned a proxy task.</p>
<p>My hypothesis: <strong>The model bypassed sarcasm detection entirely, learning only to distinguish between The Onion and HuffPost writing styles.</strong></p>
<h2 id="interacting-with-the-model">Interacting with the Model</h2>
<p>Let&rsquo;s test our hypothesis by interacting with the model.</p>
<p>First, let&rsquo;s load the model and tokenizer:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> transformers <span style="color:#f92672">import</span> pipeline
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>model <span style="color:#f92672">=</span> AutoModelForSequenceClassification<span style="color:#f92672">.</span>from_pretrained(<span style="color:#e6db74">&#39;results/2024-02-25_20-24-51/checkpoint-4475&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>clf <span style="color:#f92672">=</span> pipeline(<span style="color:#e6db74">&#39;text-classification&#39;</span>, model<span style="color:#f92672">=</span>model, tokenizer<span style="color:#f92672">=</span>tokenizer)
</span></span></code></pre></div><p>Now, let&rsquo;s test the model with some examples.</p>
<p>First, let&rsquo;s try an Onion article from this week, something I know to be sarcastic and not in the training data.
Let&rsquo;s use <a href="https://theonion.com/alabama-supreme-court-justice-invokes-veggietales-in-1851282252/">&ldquo;Alabama Supreme Court Justice Invokes &lsquo;VeggieTales&rsquo; In Ruling&rdquo;</a>:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>clf(<span style="color:#e6db74">&#34;Alabama Supreme Court Justice Invokes ‘VeggieTales&#39; In Ruling&#34;</span>)
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-plaintext" data-lang="plaintext"><span style="display:flex;"><span>[{&#39;label&#39;: &#39;LABEL_0&#39;, &#39;score&#39;: 0.99916672706604}]
</span></span></code></pre></div><p>The model is extremely confident that this is not sarcastic.</p>
<p>Let&rsquo;s try a different Onion article, possibly even more difficult: <a href="https://theonion.com/trump-booed-frozen-burritos-and-more-this-week-in-br-1851282066/">Breaking News Trump Booed, Frozen Burritos, And More: This Week In Breaking News February 24, 2024</a>:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>clf(<span style="color:#e6db74">&#34;Breaking News Trump Booed, Frozen Burritos, And More: This Week In Breaking News February 24, 2024&#34;</span>)
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-plaintext" data-lang="plaintext"><span style="display:flex;"><span>[{&#39;label&#39;: &#39;LABEL_0&#39;, &#39;score&#39;: 0.9993497729301453}]
</span></span></code></pre></div><p>Again, very confident that this is not sarcastic. Hmm. It could be the temporal accuracy of our model just cannot capture the sarcasm of the Onion in 2024.</p>
<p>Let&rsquo;s try one more Onion article, this one that is still recent but a bit more of a low-hanging fruit: <a href="https://theonion.com/mom-only-likes-the-other-outback-steakhouse-1851265335/">Mom Only Likes The Other Outback Steakhouse</a>:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>clf(<span style="color:#e6db74">&#34;Mom Only Likes The Other Outback Steakhouse&#34;</span>)
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-plaintext" data-lang="plaintext"><span style="display:flex;"><span>[{&#39;label&#39;: &#39;LABEL_1&#39;, &#39;score&#39;: 0.9997231364250183}]
</span></span></code></pre></div><p>Finally, a correct prediction! The model is confident that this is sarcastic.
Our model detects only very specific types of sarcasm. It fails to generalize to new, unseen data within the same domain.</p>
<p>Let&rsquo;s also try some headlines from the Huffington Post, which the model should predict as not sarcastic.
Let&rsquo;s try the five most recent headlines from the Huffington Post:</p>
<ul>
<li><a href="https://www.huffpost.com/entry/donald-trump-south-carolina-nikki-haley_n_65db61f5e4b0e4346d52bed8">Donald Trump Won South Carolina - But There&rsquo;s 1 Big Caveat</a></li>
<li><a href="https://www.huffpost.com/entry/israeli-embassy-washington-man-set-fire_n_65db9364e4b0e4346d52ce3d">Man Sets Himself On Fire In Front Of Israeli Embassy In Washington</a></li>
<li><a href="https://www.huffpost.com/entry/bc-ml-israel-palestinians-temporary-truce-cease-fire_n_65db2e9ae4b0189a6a7e32ea">Israeli Media Report Progress On Reaching A Temporary Truce In Gaza And A Hostage-Prisoner Exchange</a></li>
<li><a href="https://www.huffpost.com/entry/george-latimer-race-comments-democratic-primary_n_65d8fac3e4b0cc1f2f7bafd8">A White Liberal Is Trying To Oust A Progressive Black Congressman. His Comments Could Make That Job Harder.</a></li>
<li><a href="https://www.huffpost.com/entry/mongolia-climate-change-extreme-weather_n_65d90294e4b0cc1f2f7bb527">Climate Change-Fueled Winter Extremes Put 90% Of This Country At &lsquo;High Risk&rsquo;</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>clf([
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;Donald Trump Won South Carolina - But There&#39;s 1 Big Caveat&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;Man Sets Himself On Fire In Front Of Israeli Embassy In Washington&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;Israeli Media Report Progress On Reaching A Temporary Truce In Gaza And A Hostage-Prisoner Exchange&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;A White Liberal Is Trying To Oust A Progressive Black Congressman. His Comments Could Make That Job Harder.&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;Climate Change-Fueled Winter Extremes Put 90% Of This Country At &#39;High Risk&#39;&#34;</span>
</span></span><span style="display:flex;"><span>])
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-plaintext" data-lang="plaintext"><span style="display:flex;"><span>[{&#39;label&#39;: &#39;LABEL_0&#39;, &#39;score&#39;: 0.9993808269500732},
</span></span><span style="display:flex;"><span> {&#39;label&#39;: &#39;LABEL_0&#39;, &#39;score&#39;: 0.9993786811828613},
</span></span><span style="display:flex;"><span> {&#39;label&#39;: &#39;LABEL_0&#39;, &#39;score&#39;: 0.9985186457633972},
</span></span><span style="display:flex;"><span> {&#39;label&#39;: &#39;LABEL_0&#39;, &#39;score&#39;: 0.9993883371353149},
</span></span><span style="display:flex;"><span> {&#39;label&#39;: &#39;LABEL_0&#39;, &#39;score&#39;: 0.9993487000465393}]
</span></span></code></pre></div><p>The model is extremely confident that these are not sarcastic.</p>
<p>The model detects sarcasm in limited cases. It fails to generalize to new, unseen data within the same domain. This is a common problem in machine learning. Training a model that performs well on a specific dataset is straightforward. Training a model that generalizes to new, unseen data remains a significant challenge.
Furthermore, our sarcasm detection project resulted in a domain classifier. For fuzzier concepts like sarcasm, it&rsquo;s important to be clear about what we&rsquo;re actually detecting, and to collect the necessary scale of data to capture the full range of the concept.</p>
<h2 id="key-takeaways">Key Takeaways</h2>
<p>This case study reveals a fundamental problem in ML: <strong>high accuracy guarantees only performance on the training distribution</strong>. Here&rsquo;s what actually happened:</p>
<ol>
<li><strong>Dataset bias</strong>: Using publication source as a proxy for sarcasm created a shortcut for the model</li>
<li><strong>Domain classification</strong>: The model exclusively learned to distinguish writing styles</li>
<li><strong>Poor generalization</strong>: New examples from the same sources often failed</li>
</ol>
<p>This is a common pitfall when building datasets for subjective concepts. The lesson: high accuracy must be accompanied by validation of the model&rsquo;s actual learned behavior.</p>
<p>For better sarcasm detection, we&rsquo;d need:</p>
<ul>
<li>Diverse sources beyond two publications</li>
<li>Human annotation across multiple contexts</li>
<li>Careful evaluation on out-of-domain examples</li>
</ul>
<p>Instructive failures in ML projects provide valuable lessons about our assumptions and the limitations of our approaches.</p>
]]></content:encoded></item><item><title>Hearing Molecular Shape via Coulomb Matrix Eigenvalues</title><link>https://hunterheidenreich.com/posts/alkane-constitutional-isomer-classification/</link><pubDate>Sat, 24 Feb 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/alkane-constitutional-isomer-classification/</guid><description>Explore molecular shape recognition using Coulomb matrix eigenvalues. An analysis of alkane isomers, clustering limits, and supervised classification.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>Can you determine a molecule&rsquo;s shape from mathematical fingerprints alone? This question drives some of the most fundamental challenges in computational chemistry and machine learning. In the broader ML context, this is the classic search for the right <em>inductive bias</em> or <em>invariant representation</em>. Whether we are processing messy documents, natural language, or molecular dynamics, finding a representation that captures essential structure while ignoring irrelevant variations is critical.</p>
<p>I recently encountered a paper with an intriguing title: <a href="https://doi.org/10.1021/acs.jcim.0c00631">&ldquo;Can One Hear the Shape of a Molecule (from its Coulomb Matrix Eigenvalues)?&rdquo;</a> The title references Mark Kac&rsquo;s famous mathematical question <a href="https://www.math.ucdavis.edu/~hunter/m207b/kac.pdf">&ldquo;Can One Hear the Shape of a Drum?&rdquo;</a> exploring whether a drum&rsquo;s shape dictates its sound frequencies.</p>
<p>The molecular version asks: can we determine a molecule&rsquo;s structure from the eigenvalues of its <a href="/posts/molecular-descriptor-coulomb-matrix/">Coulomb matrix</a>?</p>
<p>Molecular representations are the foundation of machine learning in chemistry. If eigenvalues can capture structural information, they become powerful features for property prediction. Successfully separating simple structural differences is a prerequisite for handling more complex molecules.</p>
<p>The original authors tested this hypothesis using alkane constitutional isomers (molecules with identical formulas but different structural arrangements). I decided to replicate and extend their work to better understand both the methods and their limitations.</p>
<p>In this post, we will explore molecular representation through eigenvalue analysis, covering data generation, unsupervised clustering approaches, and supervised classification methods. I&rsquo;ll also explore log-transformed Coulomb matrices, which can reveal structural details that standard matrices miss.</p>
<h2 id="why-alkanes-make-ideal-test-cases">Why Alkanes Make Ideal Test Cases</h2>
<p><a href="https://en.wikipedia.org/wiki/Alkane">Alkanes</a> are the simplest organic molecules: carbon and hydrogen connected by single bonds with the general formula $C_{n}H_{2n+2}$.</p>
<p>What makes them perfect for testing molecular representations is their constitutional isomers: molecules with identical formulas but different structural arrangements. For small alkanes ($n \leq 3$), atoms can connect in only one way. Starting with butane ($n = 4$), multiple arrangements become possible:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/4-Butane-3D-balls.webp"
         alt="Butane as a ball-and-stick model."
         title="Butane as a ball-and-stick model."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Butane: a linear chain</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/4-Isobutane-3D-balls.webp"
         alt="Isobutane as a ball-and-stick model."
         title="Isobutane as a ball-and-stick model."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Isobutane: a branched structure</figcaption>
    
</figure>

<p>The number of isomers grows rapidly with molecular size. By undecane ($n = 11$), there are 159 different structural arrangements:</p>
<table>
  <thead>
      <tr>
          <th>Alkane</th>
          <th>n</th>
          <th>Isomers</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Butane</td>
          <td>4</td>
          <td>2</td>
      </tr>
      <tr>
          <td>Pentane</td>
          <td>5</td>
          <td>3</td>
      </tr>
      <tr>
          <td>Hexane</td>
          <td>6</td>
          <td>5</td>
      </tr>
      <tr>
          <td>Heptane</td>
          <td>7</td>
          <td>9</td>
      </tr>
      <tr>
          <td>Octane</td>
          <td>8</td>
          <td>18</td>
      </tr>
      <tr>
          <td>Nonane</td>
          <td>9</td>
          <td>35</td>
      </tr>
      <tr>
          <td>Decane</td>
          <td>10</td>
          <td>75</td>
      </tr>
      <tr>
          <td>Undecane</td>
          <td>11</td>
          <td>159</td>
      </tr>
  </tbody>
</table>
<p>This creates a natural classification challenge: can Coulomb matrix eigenvalues distinguish these structural differences? Successfully separating simple alkane isomers is a prerequisite for handling more complex molecules.</p>
<h2 id="computational-pipeline">Computational Pipeline</h2>
<p>The analysis requires three computational steps:</p>
<ol>
<li><strong>Generate constitutional isomers</strong> for each alkane formula</li>
<li><strong>Create multiple 3D conformations</strong> for each isomer</li>
<li><strong>Calculate Coulomb matrix eigenvalues</strong> for each conformation</li>
</ol>
<h3 id="generating-constitutional-isomers">Generating Constitutional Isomers</h3>
<p>Enumerating all possible carbon skeletons is a combinatorial problem. I used <a href="https://github.com/MehmetAzizYirik/MAYGEN">MAYGEN</a>, an open-source Java tool for generating molecular structures from chemical formulas.</p>
<p>For butane ($C_{4}H_{10}$):</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bash" data-lang="bash"><span style="display:flex;"><span>java -jar MAYGEN-1.8.jar -v -m -f C4H10 -smi -o butane_conformers.smi
</span></span></code></pre></div><p>This generates:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-plaintext" data-lang="plaintext"><span style="display:flex;"><span>CCCC
</span></span><span style="display:flex;"><span>CC(C)C
</span></span></code></pre></div><p>The first is n-butane (linear), the second is isobutane (branched). We can automate this across all alkanes:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>os<span style="color:#f92672">.</span>makedirs(<span style="color:#e6db74">&#39;isomers&#39;</span>, exist_ok<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    cmd <span style="color:#f92672">=</span> <span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;java -jar MAYGEN-1.8.jar -f C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74"> -smi -o isomers/C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.smi&#34;</span>
</span></span><span style="display:flex;"><span>    os<span style="color:#f92672">.</span>system(cmd)
</span></span></code></pre></div><h3 id="generating-3d-conformations">Generating 3D Conformations</h3>
<p>For machine learning applications, we need multiple 3D structures of each isomer to capture conformational flexibility. I used <a href="https://github.com/rdkit/rdkit">RDKit</a>&rsquo;s ETKDG method, which <a href="https://doi.org/10.1021/acs.jcim.7b00505">remains competitive</a> with commercial alternatives:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> rdkit.Chem <span style="color:#f92672">import</span> AllChem
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">smiles_str_to_rdkit_mol</span>(smiles_str: str) <span style="color:#f92672">-&gt;</span> rdkit<span style="color:#f92672">.</span>Chem<span style="color:#f92672">.</span>Mol:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Convert a SMILES string to an RDKit mol object.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    - smiles_str (str): A SMILES string representing a molecule.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Returns:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    - mol (rdkit.Chem.Mol): An RDKit mol object representing the molecule.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Convert SMILES string to RDKit mol object</span>
</span></span><span style="display:flex;"><span>    mol <span style="color:#f92672">=</span> rdkit<span style="color:#f92672">.</span>Chem<span style="color:#f92672">.</span>MolFromSmiles(smiles_str)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Add hydrogens to the molecule</span>
</span></span><span style="display:flex;"><span>    mol <span style="color:#f92672">=</span> rdkit<span style="color:#f92672">.</span>Chem<span style="color:#f92672">.</span>AddHs(mol)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Assign 3D coordinates to the molecule</span>
</span></span><span style="display:flex;"><span>    AllChem<span style="color:#f92672">.</span>EmbedMolecule(mol)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> mol
</span></span></code></pre></div><h3 id="computing-coulomb-matrix-eigenvalues">Computing Coulomb Matrix Eigenvalues</h3>
<p>The Coulomb matrix encodes 3D structure in a rotation and translation-invariant way. Its eigenvalues should capture structural information while remaining invariant to molecular orientation.</p>
<p>First, I wrote a helper function to convert RDKit molecules into <a href="https://ase-lib.org/">ASE</a> <code>Atoms</code> objects that <a href="https://singroup.github.io/dscribe/">DScribe</a> can process:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">rdkit_mol_to_ase_atoms</span>(rdkit_mol: rdkit<span style="color:#f92672">.</span>Chem<span style="color:#f92672">.</span>Mol) <span style="color:#f92672">-&gt;</span> ase<span style="color:#f92672">.</span>Atoms:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Convert an RDKit molecule to an ASE Atoms object.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        rdkit_mol: RDKit molecule object.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Returns:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        ASE Atoms object.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    ase_atoms <span style="color:#f92672">=</span> ase<span style="color:#f92672">.</span>Atoms(
</span></span><span style="display:flex;"><span>        numbers<span style="color:#f92672">=</span>[
</span></span><span style="display:flex;"><span>            atom<span style="color:#f92672">.</span>GetAtomicNum() <span style="color:#66d9ef">for</span> atom <span style="color:#f92672">in</span> rdkit_mol<span style="color:#f92672">.</span>GetAtoms()
</span></span><span style="display:flex;"><span>        ],
</span></span><span style="display:flex;"><span>        positions<span style="color:#f92672">=</span>rdkit_mol<span style="color:#f92672">.</span>GetConformer()<span style="color:#f92672">.</span>GetPositions()
</span></span><span style="display:flex;"><span>    )
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> ase_atoms
</span></span></code></pre></div><p>Then I computed Coulomb matrix eigenvalues using DScribe, with optional log transformation:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">ase_atoms_to_coloumb_matrix_eigenvalues</span>(
</span></span><span style="display:flex;"><span>    ase_atoms: ase<span style="color:#f92672">.</span>Atoms,
</span></span><span style="display:flex;"><span>    log: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>
</span></span><span style="display:flex;"><span>) <span style="color:#f92672">-&gt;</span> np<span style="color:#f92672">.</span>ndarray:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Convert an ASE Atoms object to a Coulomb matrix and calculate its eigenvalues.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        ase_atoms: ASE Atoms object.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        log: Whether to log transform the Coulomb matrix prior to calculating the eigenvalues.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Returns:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Eigenvalues of the Coulomb matrix.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Create a Coulomb matrix</span>
</span></span><span style="display:flex;"><span>    coulomb_matrix <span style="color:#f92672">=</span> dscribe<span style="color:#f92672">.</span>descriptors<span style="color:#f92672">.</span>CoulombMatrix(
</span></span><span style="display:flex;"><span>        n_atoms_max<span style="color:#f92672">=</span>ase_atoms<span style="color:#f92672">.</span>get_global_number_of_atoms(),
</span></span><span style="display:flex;"><span>    )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Calculate the Coulomb matrix</span>
</span></span><span style="display:flex;"><span>    coulomb_matrix <span style="color:#f92672">=</span> coulomb_matrix<span style="color:#f92672">.</span>create(ase_atoms)
</span></span><span style="display:flex;"><span>    coulomb_matrix <span style="color:#f92672">=</span> coulomb_matrix<span style="color:#f92672">.</span>reshape(
</span></span><span style="display:flex;"><span>        ase_atoms<span style="color:#f92672">.</span>get_global_number_of_atoms(),
</span></span><span style="display:flex;"><span>        ase_atoms<span style="color:#f92672">.</span>get_global_number_of_atoms())
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> log:
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Log transform the Coulomb matrix</span>
</span></span><span style="display:flex;"><span>        coulomb_matrix <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>log(coulomb_matrix)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Calculate the eigenvalues of the Coulomb matrix</span>
</span></span><span style="display:flex;"><span>    eigenvalues <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>linalg<span style="color:#f92672">.</span>eigvals(coulomb_matrix)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> eigenvalues
</span></span></code></pre></div><p>Combining these functions enables efficient data generation:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Generate 1000 conformations per isomer for each alkane</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># gen_n_spectra() combines the above functions to generate multiple conformations</span>
</span></span><span style="display:flex;"><span>os<span style="color:#f92672">.</span>makedirs(<span style="color:#e6db74">&#39;spectra&#39;</span>, exist_ok<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>n_confs <span style="color:#f92672">=</span> <span style="color:#ae81ff">1000</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;Generating spectra for C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">with</span> open(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;isomers/C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.smi&#39;</span>) <span style="color:#66d9ef">as</span> f:
</span></span><span style="display:flex;"><span>        smiles_list <span style="color:#f92672">=</span> [line<span style="color:#f92672">.</span>strip() <span style="color:#66d9ef">for</span> line <span style="color:#f92672">in</span> f]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i, smiles <span style="color:#f92672">in</span> enumerate(smiles_list):
</span></span><span style="display:flex;"><span>        spectra <span style="color:#f92672">=</span> gen_n_spectra(smiles, n_confs, log<span style="color:#f92672">=</span><span style="color:#66d9ef">False</span>)
</span></span><span style="display:flex;"><span>        np<span style="color:#f92672">.</span>save(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;spectra/C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">_</span><span style="color:#e6db74">{</span>i<span style="color:#e6db74">}</span><span style="color:#e6db74">.npy&#39;</span>, spectra)
</span></span></code></pre></div><h2 id="reproducing-the-original-results">Reproducing the Original Results</h2>
<p>To validate our computational pipeline, I replicated key figures from the original paper. This ensures our implementation correctly captures the phenomena they observed.</p>
<p>We generate data using 1000 conformations per isomer:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>os<span style="color:#f92672">.</span>makedirs(<span style="color:#e6db74">&#39;spectra&#39;</span>, exist_ok<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>n_confs <span style="color:#f92672">=</span> <span style="color:#ae81ff">1000</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;Generating spectra for C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">with</span> open(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;isomers/C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.smi&#39;</span>) <span style="color:#66d9ef">as</span> f:
</span></span><span style="display:flex;"><span>        lines <span style="color:#f92672">=</span> f<span style="color:#f92672">.</span>readlines()
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">for</span> i, line <span style="color:#f92672">in</span> enumerate(lines):
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">if</span> <span style="color:#f92672">not</span> line<span style="color:#f92672">.</span>strip():
</span></span><span style="display:flex;"><span>                <span style="color:#66d9ef">continue</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>            smiles <span style="color:#f92672">=</span> line<span style="color:#f92672">.</span>strip()
</span></span><span style="display:flex;"><span>            spectra <span style="color:#f92672">=</span> gen_n_spectra(n_confs, smiles, log<span style="color:#f92672">=</span><span style="color:#66d9ef">False</span>)
</span></span><span style="display:flex;"><span>            np<span style="color:#f92672">.</span>save(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;spectra/C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">_</span><span style="color:#e6db74">{</span>i<span style="color:#e6db74">:</span><span style="color:#e6db74">03d</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.npy&#39;</span>, spectra)
</span></span></code></pre></div><p>After generation, we can load the data into a structured format:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> re
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> glob <span style="color:#f92672">import</span> glob
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>spectra <span style="color:#f92672">=</span> {}
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    spectra[n] <span style="color:#f92672">=</span> {}
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> f <span style="color:#f92672">in</span> glob(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;spectra/C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">_*.npy&#39;</span>):
</span></span><span style="display:flex;"><span>        j <span style="color:#f92672">=</span> int(re<span style="color:#f92672">.</span>search(<span style="color:#e6db74">rf</span><span style="color:#e6db74">&#39;C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">_(\d+).npy&#39;</span>, f)<span style="color:#f92672">.</span>group(<span style="color:#ae81ff">1</span>))
</span></span><span style="display:flex;"><span>        spectra[n][j] <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>load(f)
</span></span></code></pre></div><h3 id="largest-eigenvalues-across-alkane-series">Largest Eigenvalues Across Alkane Series</h3>
<p>The first analysis examines how the largest Coulomb matrix eigenvalues vary across constitutional isomers for each alkane formula. This plot reveals whether single eigenvalues can distinguish between different molecular formulas.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>fig, ax <span style="color:#f92672">=</span> plt<span style="color:#f92672">.</span>subplots(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>, figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">10</span>, <span style="color:#ae81ff">5</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    eigenvalues <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>array([spectra[n][i][:, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>mean() <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> spectra[n]])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># add black dots to boxplot, but not dirextly to the center line</span>
</span></span><span style="display:flex;"><span>    jitter <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>random<span style="color:#f92672">.</span>normal(<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">0.1</span>, size<span style="color:#f92672">=</span>len(eigenvalues))
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>scatter(np<span style="color:#f92672">.</span>full(len(eigenvalues), n) <span style="color:#f92672">+</span> jitter, eigenvalues, color<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;black&#39;</span>, s<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.3</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Plot median</span>
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>scatter([n], [np<span style="color:#f92672">.</span>median(eigenvalues)], color<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;red&#39;</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.5</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Plot range</span>
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>plot([n <span style="color:#f92672">-</span> <span style="color:#ae81ff">0.5</span>, n <span style="color:#f92672">+</span> <span style="color:#ae81ff">0.5</span>], [np<span style="color:#f92672">.</span>min(eigenvalues), np<span style="color:#f92672">.</span>min(eigenvalues)], <span style="color:#e6db74">&#39;k-&#39;</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.5</span>)
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>plot([n <span style="color:#f92672">-</span> <span style="color:#ae81ff">0.5</span>, n <span style="color:#f92672">+</span> <span style="color:#ae81ff">0.5</span>], [np<span style="color:#f92672">.</span>max(eigenvalues), np<span style="color:#f92672">.</span>max(eigenvalues)], <span style="color:#e6db74">&#39;k-&#39;</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.5</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Molecular formula&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Largest eigenvalue&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_xticks(range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>))
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_xticklabels([<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span> <span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>)])
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;Largest eigenvalues of the Coulomb matrix for alkane constitutional isomers&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>savefig(<span style="color:#e6db74">&#39;alkane_coulomb_matrix_largest_eigenvalues.webp&#39;</span>, bbox_inches<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;tight&#39;</span>)
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/alkane_coulomb_matrix_largest_eigenvalues.webp"
         alt="Largest eigenvalues of the Coulomb matrix for alkane constitutional isomers."
         title="Largest eigenvalues of the Coulomb matrix for alkane constitutional isomers."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Largest eigenvalues show sub-linear growth with molecular size and increasing overlap between isomers.</figcaption>
    
</figure>

<p>Our results match the original paper. We observe sub-linear growth in the largest eigenvalue with carbon number, and critically, increasing overlap between isomers as molecules grow larger. The largest eigenvalue alone cannot reliably distinguish constitutional isomers for larger alkanes.</p>
<h3 id="eigenvalue-distributions-for-heptane-isomers">Eigenvalue Distributions for Heptane Isomers</h3>
<p>Looking deeper at a specific case, I analyzed the probability density functions for heptane ($C_7H_{16}$) isomers. This molecule has nine constitutional isomers, providing a good test of discrimination power.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n_sel <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">8</span>):
</span></span><span style="display:flex;"><span>    fig, ax <span style="color:#f92672">=</span> plt<span style="color:#f92672">.</span>subplots(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>, figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">10</span>, <span style="color:#ae81ff">5</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    smiles <span style="color:#f92672">=</span> []
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">with</span> open(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;isomers/C</span><span style="color:#e6db74">{</span>n_sel<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n_sel <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.smi&#39;</span>) <span style="color:#66d9ef">as</span> f:
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">for</span> line <span style="color:#f92672">in</span> f:
</span></span><span style="display:flex;"><span>            smiles<span style="color:#f92672">.</span>append(line<span style="color:#f92672">.</span>strip())
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> range(len(spectra[n_sel])):
</span></span><span style="display:flex;"><span>        eigenvalues <span style="color:#f92672">=</span> spectra[n_sel][i][:, <span style="color:#ae81ff">0</span>]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># kde plot with the following params</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># - Gaussian kernel</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># - bandwidth with Silverman&#39;s rule of thumb</span>
</span></span><span style="display:flex;"><span>        ax <span style="color:#f92672">=</span> sns<span style="color:#f92672">.</span>kdeplot(eigenvalues, bw_method<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;silverman&#39;</span>, label<span style="color:#f92672">=</span>get_iupac_name(smiles[i]))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Largest eigenvalue&#39;</span>)
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Density&#39;</span>)
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;PDF of the largest eigenvalue for $C_&#39;</span> <span style="color:#f92672">+</span> str(n_sel) <span style="color:#f92672">+</span> <span style="color:#e6db74">&#39;H_{&#39;</span> <span style="color:#f92672">+</span> str(<span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n_sel <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> <span style="color:#e6db74">&#39;}$&#39;</span>)
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>legend()
</span></span><span style="display:flex;"><span>    plt<span style="color:#f92672">.</span>savefig(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;pdf_largest_eigenvalue_C</span><span style="color:#e6db74">{</span>n_sel<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n_sel <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.webp&#39;</span>, bbox_inches<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;tight&#39;</span>)
</span></span><span style="display:flex;"><span>    plt<span style="color:#f92672">.</span>close()
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/pdf_largest_eigenvalue_C7H16.webp"
         alt="PDFs of the largest eigenvalue for heptane isomers."
         title="PDFs of the largest eigenvalue for heptane isomers."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Heptane isomers show distinct eigenvalue ranges: n-heptane (smallest), 2,2,3-trimethylbutane (largest), with others overlapping.</figcaption>
    
</figure>

<p>The pattern is clear:</p>
<ul>
<li><strong>n-heptane</strong> (linear chain) has the smallest eigenvalues</li>
<li><strong>2,2,3-trimethylbutane</strong> (highly branched) has the largest</li>
<li><strong>Seven other isomers</strong> fall in between with substantial overlap</li>
</ul>
<p>This demonstrates the fundamental limitation: while extreme structural differences (linear vs. highly branched) create separable eigenvalue distributions, intermediate structures become indistinguishable.</p>
<p>For smaller alkanes, the separation is more promising:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/pdf_largest_eigenvalue_C4H10.webp"
         alt="PDFs for butane isomers."
         title="PDFs for butane isomers."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Butane (n=4): Clean separation between linear and branched structures.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/pdf_largest_eigenvalue_C5H12.webp"
         alt="PDFs for pentane isomers."
         title="PDFs for pentane isomers."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Pentane (n=5): Good separation between most isomers.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/pdf_largest_eigenvalue_C6H14.webp"
         alt="PDFs for hexane isomers."
         title="PDFs for hexane isomers."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Hexane (n=6): Some isomers (2-methylpentane, 3-methylpentane) become difficult to distinguish.</figcaption>
    
</figure>

<p>The progression is clear: eigenvalue-based discrimination works well for small alkanes and degrades as molecular complexity increases.</p>
<h3 id="two-dimensional-eigenvalue-space">Two-Dimensional Eigenvalue Space</h3>
<p>Can we improve discrimination by using multiple eigenvalues? For butane, plotting the first two eigenvalues reveals interesting structure:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>fig, ax <span style="color:#f92672">=</span> plt<span style="color:#f92672">.</span>subplots(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>, figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">10</span>, <span style="color:#ae81ff">5</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>n_sel <span style="color:#f92672">=</span> <span style="color:#ae81ff">4</span>
</span></span><span style="display:flex;"><span>smiles <span style="color:#f92672">=</span> []
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">with</span> open(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;isomers/C</span><span style="color:#e6db74">{</span>n_sel<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n_sel <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.smi&#39;</span>) <span style="color:#66d9ef">as</span> f:
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> line <span style="color:#f92672">in</span> f:
</span></span><span style="display:flex;"><span>        smiles<span style="color:#f92672">.</span>append(line<span style="color:#f92672">.</span>strip())
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> range(len(spectra[n_sel])):
</span></span><span style="display:flex;"><span>    eigenvalues <span style="color:#f92672">=</span> spectra[n_sel][i][:, :<span style="color:#ae81ff">2</span>]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>scatter(eigenvalues[:, <span style="color:#ae81ff">0</span>], eigenvalues[:, <span style="color:#ae81ff">1</span>], label<span style="color:#f92672">=</span>get_iupac_name(smiles[i]))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Largest eigenvalue&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Second largest eigenvalue&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;2D plot of the first two eigenvalues for $C_4H_</span><span style="color:#e6db74">{10}</span><span style="color:#e6db74">$ conformers&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>legend()
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>savefig(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;2d_largest_eigenvalue_C</span><span style="color:#e6db74">{</span>n_sel<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n_sel <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.webp&#39;</span>, bbox_inches<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;tight&#39;</span>)
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/2d_largest_eigenvalue_C4H10.webp"
         alt="2D eigenvalue space for butane isomers."
         title="2D eigenvalue space for butane isomers."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Perfect linear separation of butane isomers using the first two eigenvalues.</figcaption>
    
</figure>

<p>The two isomers cluster distinctly, demonstrating that multi-dimensional eigenvalue features can achieve perfect separation for simple cases. The outlier point in the lower right likely results from conformational sampling noise.</p>
<h3 id="dimensionality-and-information-content">Dimensionality and Information Content</h3>
<p>How many eigenvalues do we actually need? Principal component analysis reveals the effective dimensionality of the eigenvalue representations:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.decomposition <span style="color:#f92672">import</span> PCA
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>fig, ax <span style="color:#f92672">=</span> plt<span style="color:#f92672">.</span>subplots(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>, figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">10</span>, <span style="color:#ae81ff">5</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>n_components <span style="color:#f92672">=</span> {}
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    n_components[n] <span style="color:#f92672">=</span> []
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> range(len(spectra[n])):
</span></span><span style="display:flex;"><span>        eigenvalues <span style="color:#f92672">=</span> spectra[n][i]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># PCA</span>
</span></span><span style="display:flex;"><span>        pca <span style="color:#f92672">=</span> PCA(n_components<span style="color:#f92672">=</span><span style="color:#ae81ff">0.99</span>, svd_solver<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;full&#39;</span>, whiten<span style="color:#f92672">=</span><span style="color:#66d9ef">False</span>, random_state<span style="color:#f92672">=</span><span style="color:#ae81ff">42</span>)
</span></span><span style="display:flex;"><span>        pca<span style="color:#f92672">.</span>fit(eigenvalues)
</span></span><span style="display:flex;"><span>        n_components[n]<span style="color:#f92672">.</span>append(pca<span style="color:#f92672">.</span>n_components_)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>scatter([n] <span style="color:#f92672">*</span> len(n_components[n]), n_components[n], alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.3</span>)
</span></span><span style="display:flex;"><span>    ax<span style="color:#f92672">.</span>plot([n <span style="color:#f92672">-</span> <span style="color:#ae81ff">0.25</span>, n <span style="color:#f92672">+</span> <span style="color:#ae81ff">0.25</span>], [np<span style="color:#f92672">.</span>mean(n_components[n]), np<span style="color:#f92672">.</span>mean(n_components[n])], <span style="color:#e6db74">&#39;k-&#39;</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.5</span>)  <span style="color:#75715e"># Draw a line for the mean</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>plot([<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">11</span>], [<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">11</span>], <span style="color:#e6db74">&#39;k--&#39;</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.5</span>, label<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;y = num carbon&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>plot([<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">11</span>], [<span style="color:#ae81ff">5</span>, <span style="color:#ae81ff">35</span>], <span style="color:#e6db74">&#39;r--&#39;</span>, alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.5</span>, label<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;y = num atoms&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Number of principal components&#39;</span>)
</span></span><span style="display:flex;"><span>ax<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;99% variance explained by number of principal components&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>legend()
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>savefig(<span style="color:#e6db74">&#39;99_variance_explained.webp&#39;</span>, bbox_inches<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;tight&#39;</span>)
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/99_variance_explained.webp"
         alt="Principal components needed for 99% variance."
         title="Principal components needed for 99% variance."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The eigenvalue space compresses efficiently. Far fewer components than the $3n+2$ eigenvalue dimensions are needed.</figcaption>
    
</figure>

<p>Key observations:</p>
<ul>
<li><strong>High compressibility</strong>: Far fewer than the $3n+2$ eigenvalue dimensions (one per atom) are needed</li>
<li><strong>Linear scaling</strong>: Principal components grow roughly linearly with carbon number</li>
<li><strong>Efficient representation</strong>: The eigenvalue space has lower effective dimensionality than expected</li>
</ul>
<p>This suggests the representations are highly correlated, enabling significant dimensionality reduction without information loss.</p>
<h2 id="log-transformed-coulomb-matrices">Log-Transformed Coulomb Matrices</h2>
<p>As explored in our <a href="/posts/molecular-descriptor-coulomb-matrix/">previous post on Coulomb matrices</a>, log transformation can reveal different structural information. Standard Coulomb matrices emphasize heavy atom interactions. Log transformation expands the influence of hydrogen atoms by mapping magnitudes in $[0,1]$ to $[-\infty,0]$.</p>
<p>I generated equivalent datasets using log-transformed matrices to test how this affects discriminative power.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>os<span style="color:#f92672">.</span>makedirs(<span style="color:#e6db74">&#39;spectra&#39;</span>, exist_ok<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>n_confs <span style="color:#f92672">=</span> <span style="color:#ae81ff">1000</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;Generating spectra for C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">with</span> open(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;isomers/C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.smi&#39;</span>) <span style="color:#66d9ef">as</span> f:
</span></span><span style="display:flex;"><span>        lines <span style="color:#f92672">=</span> f<span style="color:#f92672">.</span>readlines()
</span></span><span style="display:flex;"><span>        print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;</span><span style="color:#ae81ff">\t</span><span style="color:#e6db74">Number of SMILES strings: </span><span style="color:#e6db74">{</span>len(lines)<span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">for</span> i, line <span style="color:#f92672">in</span> enumerate(tqdm(lines)):
</span></span><span style="display:flex;"><span>            print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;</span><span style="color:#ae81ff">\t\t</span><span style="color:#e6db74">{</span>i <span style="color:#f92672">+</span> <span style="color:#ae81ff">1</span><span style="color:#e6db74">}</span><span style="color:#e6db74">/</span><span style="color:#e6db74">{</span>len(lines)<span style="color:#e6db74">}</span><span style="color:#e6db74"> - </span><span style="color:#e6db74">{</span>line<span style="color:#f92672">.</span>strip()<span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">if</span> <span style="color:#f92672">not</span> line<span style="color:#f92672">.</span>strip():
</span></span><span style="display:flex;"><span>                <span style="color:#66d9ef">continue</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">if</span> os<span style="color:#f92672">.</span>path<span style="color:#f92672">.</span>exists(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;spectra/log-C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">_</span><span style="color:#e6db74">{</span>i<span style="color:#e6db74">}</span><span style="color:#e6db74">.npy&#39;</span>):
</span></span><span style="display:flex;"><span>                <span style="color:#66d9ef">continue</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>            smiles <span style="color:#f92672">=</span> line<span style="color:#f92672">.</span>strip()
</span></span><span style="display:flex;"><span>            spectra <span style="color:#f92672">=</span> gen_n_spectra(n<span style="color:#f92672">=</span>n_confs, smiles_str<span style="color:#f92672">=</span>smiles, log<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)
</span></span><span style="display:flex;"><span>            np<span style="color:#f92672">.</span>save(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;spectra/log-C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">_</span><span style="color:#e6db74">{</span>i<span style="color:#e6db74">:</span><span style="color:#e6db74">03d</span><span style="color:#e6db74">}</span><span style="color:#e6db74">.npy&#39;</span>, spectra)
</span></span></code></pre></div><h3 id="log-transformed-eigenvalue-distributions">Log-Transformed Eigenvalue Distributions</h3>
<p>The log transformation substantially changes the eigenvalue landscape:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/alkane_log_coulomb_matrix_largest_eigenvalues.webp"
         alt="Log-transformed eigenvalues across alkane series."
         title="Log-transformed eigenvalues across alkane series."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Log transformation emphasizes hydrogen interactions, creating larger eigenvalue ranges and more negative values.</figcaption>
    
</figure>

<p>Log-transformed versions exhibit distinct characteristics:</p>
<ul>
<li><strong>Span large negative values</strong> due to hydrogen atom emphasis</li>
<li><strong>Show increasing variance</strong> between isomers as molecular size grows</li>
<li><strong>Demonstrate greater discrimination potential</strong> for some isomers</li>
</ul>
<p>This comes with trade-offs. The distributions can become significantly broader, as seen in the heptane analysis:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/pdf_log_largest_eigenvalue_C7H16.webp"
         alt="Log-transformed eigenvalue PDFs for heptane."
         title="Log-transformed eigenvalue PDFs for heptane."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Log transformation creates wider, more overlapping distributions that may reduce discrimination power.</figcaption>
    
</figure>

<p>The log scale on the y-axis is necessary because unbranched isomers become nearly invisible due to the highly concentrated distributions of branched isomers.</p>
<h3 id="two-dimensional-log-transformed-space">Two-Dimensional Log-Transformed Space</h3>
<p>The 2D eigenvalue plot for log-transformed butane shows similar clustering behavior:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/2d_log_largest_eigenvalue_C4H10.webp"
         alt="2D log-transformed eigenvalue space for butane."
         title="2D log-transformed eigenvalue space for butane."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Log transformation brings isomers closer together while maintaining separability for simple cases.</figcaption>
    
</figure>

<p>The transformation reduces the separation distance, yet linear discriminability remains intact for this simple case.</p>
<h3 id="dimensionality-of-log-transformed-features">Dimensionality of Log-Transformed Features</h3>
<p>Principal component analysis of log-transformed eigenvalues reveals similar compression properties:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/99_variance_explained_log.webp"
         alt="Principal components for log-transformed eigenvalues."
         title="Principal components for log-transformed eigenvalues."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Log transformation requires slightly more principal components while maintaining efficient compression.</figcaption>
    
</figure>

<p>The log-transformed features show comparable dimensionality reduction with marginally higher component requirements.</p>
<h2 id="testing-eigenvalue-separability">Testing Eigenvalue Separability</h2>
<p>Our exploratory analysis revealed concerning patterns that hint at fundamental limitations: high correlation between eigenvalue dimensions, rapid dimensionality compression via PCA, and overlapping distributions for larger molecules ($n \geq 6$).</p>
<p>These findings leave the question open. We now test eigenvalues directly to see if they can actually separate constitutional isomers without supervision.</p>
<p>We&rsquo;ll use two complementary clustering metrics to measure how well eigenvalues separate constitutional isomers. This is a fair test. We only compare isomers with identical molecular formulas, keeping eigenvalue dimensions constant.</p>
<h3 id="dunn-index-global-cluster-quality">Dunn Index: Global Cluster Quality</h3>
<p>The <a href="https://en.wikipedia.org/wiki/Dunn_index">Dunn Index</a> provides a single metric capturing cluster quality. It asks: &ldquo;Are the closest different clusters still farther apart than the most spread-out individual cluster?&rdquo;</p>
<p>$$
\text{Dunn Index} = \frac{\text{smallest distance between different clusters}}{\text{largest diameter within any cluster}}
$$</p>
<p>Higher values indicate better separation. When it approaches zero, clusters become indistinguishable, exactly what we suspected from the overlapping eigenvalue distributions observed earlier.</p>
<p>Computing the Dunn Index for each alkane series:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> time
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>dunn_scores <span style="color:#f92672">=</span> {}
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    tik <span style="color:#f92672">=</span> time<span style="color:#f92672">.</span>time()
</span></span><span style="display:flex;"><span>    dunn_scores[n] <span style="color:#f92672">=</span> dunn_index([spectra[n][i] <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> spectra[n]])
</span></span><span style="display:flex;"><span>    tok <span style="color:#f92672">=</span> time<span style="color:#f92672">.</span>time()
</span></span><span style="display:flex;"><span>    dunn_scores[n][<span style="color:#e6db74">&#39;time&#39;</span>] <span style="color:#f92672">=</span> tok <span style="color:#f92672">-</span> tik
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">:&#39;</span>, dunn_scores[n])
</span></span></code></pre></div><div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-plaintext" data-lang="plaintext"><span style="display:flex;"><span>C4H10: {&#39;diameter&#39;: 21.43072917950398, &#39;distance&#39;: 8.316362440688767, &#39;dunn_index&#39;: 0.3880578383978837, &#39;time&#39;: 0.06010293960571289}
</span></span><span style="display:flex;"><span>C5H12: {&#39;diameter&#39;: 23.449286379564892, &#39;distance&#39;: 2.4693042873545856, &#39;dunn_index&#39;: 0.10530402705587172, &#39;time&#39;: 0.10832405090332031}
</span></span><span style="display:flex;"><span>C6H14: {&#39;diameter&#39;: 19.602363375467938, &#39;distance&#39;: 1.4477574259511048, &#39;dunn_index&#39;: 0.07385626917634591, &#39;time&#39;: 0.28030991554260254}
</span></span><span style="display:flex;"><span>C7H16: {&#39;diameter&#39;: 20.065014927470955, &#39;distance&#39;: 0.4050094394280803, &#39;dunn_index&#39;: 0.02018485612355977, &#39;time&#39;: 1.0307331085205078}
</span></span><span style="display:flex;"><span>C8H18: {&#39;diameter&#39;: 24.794154667613665, &#39;distance&#39;: 0.5013450168168625, &#39;dunn_index&#39;: 0.020220290771668196, &#39;time&#39;: 4.199508905410767}
</span></span><span style="display:flex;"><span>C9H20: {&#39;diameter&#39;: 21.811025941686033, &#39;distance&#39;: 0.34381162248560415, &#39;dunn_index&#39;: 0.01576320267578513, &#39;time&#39;: 17.400264978408813}
</span></span><span style="display:flex;"><span>C10H22: {&#39;diameter&#39;: 27.180773716656066, &#39;distance&#39;: 0.4986608768730121, &#39;dunn_index&#39;: 0.0183460883811206, &#39;time&#39;: 86.00787401199341}
</span></span><span style="display:flex;"><span>C11H24: {&#39;diameter&#39;: 25.58731511020692, &#39;distance&#39;: 0.5490373275460223, &#39;dunn_index&#39;: 0.021457402825629343, &#39;time&#39;: 424.4431610107422}
</span></span></code></pre></div><p>The computation time grows dramatically, over 7 minutes for $C_{11}H_{24}$, due to quadratic scaling with the number of isomers (159 isomers requiring ~12,000 pairwise comparisons).</p>
<p>The results reveal a clear trend:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>fig, axs <span style="color:#f92672">=</span> plt<span style="color:#f92672">.</span>subplots(<span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">2</span>, figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">15</span>, <span style="color:#ae81ff">10</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># in axs[0, 0] - diameter vs number of carbon atoms</span>
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>plot(
</span></span><span style="display:flex;"><span>    list(range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)),
</span></span><span style="display:flex;"><span>    [dunn_scores[n][<span style="color:#e6db74">&#39;diameter&#39;</span>] <span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)],
</span></span><span style="display:flex;"><span>    marker<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;o&#39;</span>
</span></span><span style="display:flex;"><span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Diameter&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;Diameter vs number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># in axs[0, 1] - distance vs number of carbon atoms</span>
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>plot(
</span></span><span style="display:flex;"><span>    list(range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)),
</span></span><span style="display:flex;"><span>    [dunn_scores[n][<span style="color:#e6db74">&#39;distance&#39;</span>] <span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)],
</span></span><span style="display:flex;"><span>    marker<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;o&#39;</span>
</span></span><span style="display:flex;"><span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Distance&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;Distance vs number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># in axs[1, 0] - dunn index vs number of carbon atoms</span>
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>plot(
</span></span><span style="display:flex;"><span>    list(range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)),
</span></span><span style="display:flex;"><span>    [dunn_scores[n][<span style="color:#e6db74">&#39;dunn_index&#39;</span>] <span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)],
</span></span><span style="display:flex;"><span>    marker<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;o&#39;</span>
</span></span><span style="display:flex;"><span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Dunn index&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">0</span>]<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;Dunn index vs number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># in axs[1, 1] - time vs number of carbon atoms</span>
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>plot(
</span></span><span style="display:flex;"><span>    list(range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)),
</span></span><span style="display:flex;"><span>    [dunn_scores[n][<span style="color:#e6db74">&#39;time&#39;</span>] <span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)],
</span></span><span style="display:flex;"><span>    marker<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;o&#39;</span>
</span></span><span style="display:flex;"><span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>set_xlabel(<span style="color:#e6db74">&#39;Number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>set_ylabel(<span style="color:#e6db74">&#39;Time (s)&#39;</span>)
</span></span><span style="display:flex;"><span>axs[<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>]<span style="color:#f92672">.</span>set_title(<span style="color:#e6db74">&#39;Time vs number of carbon atoms&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>tight_layout()
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>savefig(<span style="color:#e6db74">&#39;dunn_index_vs_num_carbon_atoms.webp&#39;</span>, bbox_inches<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;tight&#39;</span>)
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/dunn_index_vs_num_carbon_atoms.webp"
         alt="Dunn Index analysis showing separability metrics, distances, and computation time versus molecular size"
         title="Dunn Index analysis showing separability metrics, distances, and computation time versus molecular size"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Dunn Index analysis reveals deteriorating separability as molecular complexity increases.</figcaption>
    
</figure>

<p>The trend confirms our earlier concerns:</p>
<ul>
<li><strong>$C_{4}H_{10}$</strong>: Excellent separation (Dunn Index = 0.39) between butane and isobutane</li>
<li><strong>$C_{5}H_{12}$ to $C_{6}H_{14}$</strong>: Rapid decline in separability</li>
<li><strong>$C_{7}H_{16}$ and beyond</strong>: Poor separation (Dunn Index $\approx$ 0.02)</li>
</ul>
<p>This validates our computational pipeline and matches the original paper&rsquo;s findings. For larger molecules, eigenvalue clusters become nearly indistinguishable, confirming the overlapping distributions we observed earlier.</p>
<h3 id="silhouette-analysis-individual-conformation-assessment">Silhouette Analysis: Individual Conformation Assessment</h3>
<p>The Dunn Index provides the global view. We must also consider individual molecules. The <a href="https://en.wikipedia.org/wiki/Silhouette_(clustering)">silhouette score</a> evaluates each conformation separately, asking: &ldquo;Is this molecule closer to its own isomer family or to a different one?&rdquo;</p>
<p>For each molecular conformation $i$:</p>
<p>$$
s(i) = \frac{b(i) - a(i)}{\max(a(i), b(i))}
$$</p>
<p>where:</p>
<ul>
<li>$a(i)$ = average distance to other conformations of the <strong>same</strong> isomer</li>
<li>$b(i)$ = average distance to conformations of the <strong>nearest different</strong> isomer</li>
</ul>
<p><strong>Interpretation:</strong></p>
<ul>
<li><strong>Score near +1</strong>: Conformation clusters correctly (good clustering)</li>
<li><strong>Score near -1</strong>: Conformation closer to different isomer (misclassification)</li>
</ul>
<p>This enables two critical measurements:</p>
<ol>
<li>How many isomers have <strong>any</strong> misclassified conformations?</li>
<li>What fraction of <strong>individual conformations</strong> get misclassified?</li>
</ol>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.metrics <span style="color:#f92672">import</span> silhouette_samples
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> tqdm <span style="color:#f92672">import</span> tqdm
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>s_scores <span style="color:#f92672">=</span> {}
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> tqdm(range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>)):
</span></span><span style="display:flex;"><span>    X <span style="color:#f92672">=</span> []
</span></span><span style="display:flex;"><span>    y <span style="color:#f92672">=</span> []
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> spectra[n]:
</span></span><span style="display:flex;"><span>        X<span style="color:#f92672">.</span>append(spectra[n][i])
</span></span><span style="display:flex;"><span>        y<span style="color:#f92672">.</span>extend(np<span style="color:#f92672">.</span>full(spectra[n][i]<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>], i))
</span></span><span style="display:flex;"><span>    X <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>concatenate(X)
</span></span><span style="display:flex;"><span>    y <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>array(y)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    s_scores[n] <span style="color:#f92672">=</span> silhouette_samples(X, y)
</span></span></code></pre></div><p>Computing both clustering quality metrics:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Metric 1: Fraction of isomers with ANY negative scores</span>
</span></span><span style="display:flex;"><span>neg_iso <span style="color:#f92672">=</span> {}
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    n_iso <span style="color:#f92672">=</span> s_scores[n]<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>] <span style="color:#f92672">//</span> <span style="color:#ae81ff">1000</span>
</span></span><span style="display:flex;"><span>    n_has_neg <span style="color:#f92672">=</span> <span style="color:#ae81ff">0</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> range(n_iso):
</span></span><span style="display:flex;"><span>        chunk <span style="color:#f92672">=</span> s_scores[n][i <span style="color:#f92672">*</span> <span style="color:#ae81ff">1000</span>:(i <span style="color:#f92672">+</span> <span style="color:#ae81ff">1</span>) <span style="color:#f92672">*</span> <span style="color:#ae81ff">1000</span>]
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> np<span style="color:#f92672">.</span>any(chunk <span style="color:#f92672">&lt;</span> <span style="color:#ae81ff">0</span>):
</span></span><span style="display:flex;"><span>            n_has_neg <span style="color:#f92672">+=</span> <span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>    neg_iso[n] <span style="color:#f92672">=</span> n_has_neg <span style="color:#f92672">/</span> n_iso
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Metric 2: Individual conformation misclassification rates</span>
</span></span><span style="display:flex;"><span>neg_confs <span style="color:#f92672">=</span> {}
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    n_iso <span style="color:#f92672">=</span> s_scores[n]<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>] <span style="color:#f92672">//</span> <span style="color:#ae81ff">1000</span>
</span></span><span style="display:flex;"><span>    neg_confs[n] <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>zeros(n_iso)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> range(n_iso):
</span></span><span style="display:flex;"><span>        isomer_scores <span style="color:#f92672">=</span> s_scores[n][i <span style="color:#f92672">*</span> <span style="color:#ae81ff">1000</span>:(i <span style="color:#f92672">+</span> <span style="color:#ae81ff">1</span>) <span style="color:#f92672">*</span> <span style="color:#ae81ff">1000</span>]
</span></span><span style="display:flex;"><span>        neg_confs[n][i] <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>sum(isomer_scores <span style="color:#f92672">&lt;</span> <span style="color:#ae81ff">0</span>) <span style="color:#f92672">/</span> isomer_scores<span style="color:#f92672">.</span>shape[<span style="color:#ae81ff">0</span>]
</span></span></code></pre></div><h4 id="isomer-level-analysis">Isomer-Level Analysis</h4>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/fraction_of_negative_silhouette_scores_vs_num_carbon_atoms.webp"
         alt="Chart showing fraction of isomers with at least one misclassified conformation"
         title="Chart showing fraction of isomers with at least one misclassified conformation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Fraction of isomers with at least one misclassified conformation (a stringent test of cluster purity).</figcaption>
    
</figure>

<p>The trend is concerning: by $C_{11}H_{24}$, 97% of isomers have at least one conformation that would be misclassified. This metric is deliberately strict. Even a single misplaced conformation marks the entire isomer as problematic.</p>
<h4 id="conformation-level-analysis">Conformation-Level Analysis</h4>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/fraction_of_negative_silhouette_scores_vs_num_carbon_atoms_individual.webp"
         alt="Chart showing individual misclassification rates per isomer with horizontal lines showing range for each molecular size"
         title="Chart showing individual misclassification rates per isomer with horizontal lines showing range for each molecular size"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Individual misclassification rates per isomer. Each point represents one isomer; horizontal lines show the range for each molecular size.</figcaption>
    
</figure>

<p>The individual analysis reveals dramatic variation:</p>
<ul>
<li><strong>$C_{4}H_{10}$</strong>: Perfect clustering (0% misclassification), confirming our earlier 2D separation plots</li>
<li><strong>$C_{5}H_{12}$ to $C_{6}H_{14}$</strong>: Modest problems (1-8% misclassification rates)</li>
<li><strong>$C_{11}H_{24}$</strong>: Average 35% conformations misclassified per isomer</li>
</ul>
<p>Some isomers experience up to 99.5% conformation misclassification (they become essentially unrecognizable in eigenvalue space). This directly connects to our earlier observation: mathematical representations that appear elegant may lack the structural nuances needed for practical discrimination.</p>
<h2 id="supervised-learning-finding-hidden-structure">Supervised Learning: Finding Hidden Structure</h2>
<p>Both clustering metrics deliver the same conclusion: Coulomb matrix eigenvalues alone struggle to reliably distinguish constitutional isomers for larger alkanes. The mathematical elegance of eigenvalues encounters practical limitations as molecular complexity increases.</p>
<p>Supervised learning offers an alternative approach. Providing labels allows models to extract hidden patterns that elude clustering algorithms. The mathematical structure often requires explicit guidance for discovery.</p>
<p>I&rsquo;ll focus on two baseline approaches: k-nearest neighbors and logistic regression. These represent fundamentally different learning paradigms (one memorizes patterns, the other learns linear boundaries) giving us insight into what types of structure might exist in eigenvalue space.</p>
<h2 id="k-nearest-neighbors-pattern-recognition-through-memory">k-Nearest Neighbors: Pattern Recognition Through Memory</h2>
<p>k-NN represents the simplest supervised learning approach: it stores all training examples and classifies new samples based on their closest neighbors. If eigenvalue patterns truly distinguish isomers, nearby points in eigenvalue space should belong to the same class.</p>
<p>This directly tests the local structure. Local neighborhoods often preserve meaningful distinctions even when global structure appears diffuse.</p>
<h3 id="testing-different-feature-representations">Testing Different Feature Representations</h3>
<p>We compare three approaches: full eigenvalue vectors, top 10 eigenvalues only, and PCA-reduced representations.</p>
<p>Testing 1-nearest neighbor with full dimensionality:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.model_selection <span style="color:#f92672">import</span> cross_val_score, StratifiedKFold
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.neighbors <span style="color:#f92672">import</span> KNeighborsClassifier
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>df_1nn <span style="color:#f92672">=</span> []
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Prepare the data for CnH2n+2</span>
</span></span><span style="display:flex;"><span>    X, y <span style="color:#f92672">=</span> prep_data(n<span style="color:#f92672">=</span>n)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Create knn classifier</span>
</span></span><span style="display:flex;"><span>    knn <span style="color:#f92672">=</span> KNeighborsClassifier(n_neighbors<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Set up stratified 5-fold cross-validation</span>
</span></span><span style="display:flex;"><span>    cv <span style="color:#f92672">=</span> StratifiedKFold(n_splits<span style="color:#f92672">=</span><span style="color:#ae81ff">5</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Perform cross-validation. Since &#39;cross_val_score&#39; computes accuracy, we compute misclassification rate by subtracting accuracy from 1.</span>
</span></span><span style="display:flex;"><span>    acc_scores <span style="color:#f92672">=</span> cross_val_score(knn, X, y, cv<span style="color:#f92672">=</span>cv, scoring<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;accuracy&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Convert accuracy scores to misclassification error rates</span>
</span></span><span style="display:flex;"><span>    misclassification_error_rates <span style="color:#f92672">=</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> acc_scores
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Calculate the average and standard deviation of the misclassification error rates</span>
</span></span><span style="display:flex;"><span>    avg_misclassification_error <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>mean(misclassification_error_rates)
</span></span><span style="display:flex;"><span>    std_misclassification_error <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>std(misclassification_error_rates)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">: </span><span style="color:#e6db74">{</span>avg_misclassification_error<span style="color:#e6db74">:</span><span style="color:#e6db74">.2%</span><span style="color:#e6db74">}</span><span style="color:#e6db74"> ± </span><span style="color:#e6db74">{</span>std_misclassification_error<span style="color:#e6db74">:</span><span style="color:#e6db74">.2%</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    df_1nn<span style="color:#f92672">.</span>append({
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;molecule&#39;</span>: <span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;avg_misclassification_error&#39;</span>: avg_misclassification_error,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;std_misclassification_error&#39;</span>: std_misclassification_error,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;n&#39;</span>: n,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;representation&#39;</span>: <span style="color:#e6db74">&#39;full&#39;</span>,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;model&#39;</span>: <span style="color:#e6db74">&#39;1nn&#39;</span>,
</span></span><span style="display:flex;"><span>    })
</span></span></code></pre></div><p>The results are remarkable compared to unsupervised clustering:</p>
<pre><code>C4H10: 0.00% ± 0.00%
C5H12: 0.00% ± 0.00%
C6H14: 0.00% ± 0.00%
C7H16: 0.07% ± 0.05%
C8H18: 0.11% ± 0.05%
C9H20: 0.51% ± 0.09%
C10H22: 1.31% ± 0.09%
C11H24: 3.24% ± 0.09%
</code></pre>
<p><strong>Perfect classification</strong> for molecules up to $C_{6}H_{14}$, with low error rates even for $C_{11}H_{24}$ (3.24%). This is a large improvement over clustering, where 97% of $C_{11}H_{24}$ isomers had misclassified conformations.</p>
<p><strong>Note on feature scaling:</strong> Standardizing features significantly degraded performance; eigenvalue magnitudes carry crucial structural information.</p>
<p>Comparing performance across different feature representations:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/alkane-classification-1nn.webp"
         alt="1-NN performance across different representations"
         title="1-NN performance across different representations"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">1-NN classification performance across different eigenvalue representations shows similar results, with slight advantages for full representations on larger molecules.</figcaption>
    
</figure>

<p><strong>Key insights:</strong></p>
<ul>
<li><strong>Representation choice matters little</strong> for 1-NN. Full, top-10, and PCA representations perform nearly identically</li>
<li><strong>PCA slightly outperforms</strong> top-10 eigenvalues for larger molecules, capturing more structural variance</li>
<li><strong>Perfect classification</strong> persists through $C_{6}H_{14}$ regardless of representation</li>
</ul>
<p>This confirms that discriminative information concentrates in the largest eigenvalues, validating our earlier PCA findings.</p>
<h3 id="the-neighbor-count-effect">The Neighbor Count Effect</h3>
<p>Testing k-NN with different neighbor counts (k=1, 3, 5) reveals a counterintuitive pattern:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/alkane-classification-knn.webp"
         alt="k-NN performance for different k values"
         title="k-NN performance for different k values"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">k-NN classification performance decreases as k increases. More neighbors actually hurt accuracy.</figcaption>
    
</figure>

<p><strong>Why does performance degrade with more neighbors?</strong> This connects directly to our earlier clustering analysis. The eigenvalue space lacks meaningful local structure. When k-NN examines beyond the immediate nearest neighbor, it increasingly finds examples from different classes.</p>
<p>This validates our unsupervised findings: in the absence of clear cluster boundaries, examining more neighbors introduces noise.</p>
<h2 id="logistic-regression-learning-linear-decision-boundaries">Logistic Regression: Learning Linear Decision Boundaries</h2>
<p>Logistic regression represents a fundamentally different approach. Logistic regression learns linear decision boundaries in eigenvalue space. If eigenvalues encode structural information linearly, this should work well.</p>
<p>We&rsquo;ll focus on PCA-reduced representations to keep computation manageable, using insights from the k-NN analysis.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.linear_model <span style="color:#f92672">import</span> LogisticRegression
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.pipeline <span style="color:#f92672">import</span> Pipeline
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.decomposition <span style="color:#f92672">import</span> PCA
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>df_lr <span style="color:#f92672">=</span> []
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> n <span style="color:#f92672">in</span> range(<span style="color:#ae81ff">4</span>, <span style="color:#ae81ff">12</span>):
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Prepare the data for CnH2n+2</span>
</span></span><span style="display:flex;"><span>    X, y <span style="color:#f92672">=</span> prep_data(n<span style="color:#f92672">=</span>n)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Create logistic regression classifier with PCA</span>
</span></span><span style="display:flex;"><span>    lr <span style="color:#f92672">=</span> Pipeline([
</span></span><span style="display:flex;"><span>        (<span style="color:#e6db74">&#39;pca&#39;</span>, PCA(n_components<span style="color:#f92672">=</span><span style="color:#ae81ff">10</span>)),
</span></span><span style="display:flex;"><span>        (<span style="color:#e6db74">&#39;lr&#39;</span>, LogisticRegression(
</span></span><span style="display:flex;"><span>            max_iter<span style="color:#f92672">=</span><span style="color:#ae81ff">10_000</span>,
</span></span><span style="display:flex;"><span>            penalty<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;l2&#39;</span>,
</span></span><span style="display:flex;"><span>            solver<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;lbfgs&#39;</span>,
</span></span><span style="display:flex;"><span>            C<span style="color:#f92672">=</span><span style="color:#ae81ff">10.0</span>,  <span style="color:#75715e"># Reduced regularization</span>
</span></span><span style="display:flex;"><span>            random_state<span style="color:#f92672">=</span><span style="color:#ae81ff">42</span>,
</span></span><span style="display:flex;"><span>            n_jobs<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>,
</span></span><span style="display:flex;"><span>        ))
</span></span><span style="display:flex;"><span>    ])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># 5-fold stratified cross-validation</span>
</span></span><span style="display:flex;"><span>    cv <span style="color:#f92672">=</span> StratifiedKFold(n_splits<span style="color:#f92672">=</span><span style="color:#ae81ff">5</span>)
</span></span><span style="display:flex;"><span>    acc_scores <span style="color:#f92672">=</span> cross_val_score(lr, X, y, cv<span style="color:#f92672">=</span>cv, scoring<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;accuracy&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Convert to misclassification rates</span>
</span></span><span style="display:flex;"><span>    avg_error <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>mean(<span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> acc_scores)
</span></span><span style="display:flex;"><span>    std_error <span style="color:#f92672">=</span> np<span style="color:#f92672">.</span>std(<span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> acc_scores)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;C</span><span style="color:#e6db74">{</span>n<span style="color:#e6db74">}</span><span style="color:#e6db74">H</span><span style="color:#e6db74">{</span><span style="color:#ae81ff">2</span><span style="color:#f92672">*</span>n <span style="color:#f92672">+</span> <span style="color:#ae81ff">2</span><span style="color:#e6db74">}</span><span style="color:#e6db74">: </span><span style="color:#e6db74">{</span>avg_error<span style="color:#e6db74">:</span><span style="color:#e6db74">.2%</span><span style="color:#e6db74">}</span><span style="color:#e6db74"> ± </span><span style="color:#e6db74">{</span>std_error<span style="color:#e6db74">:</span><span style="color:#e6db74">.2%</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span></code></pre></div><p>Comparing k-NN versus logistic regression performance:</p>















<figure class="post-figure center ">
    <img src="/img/alkane-constitutional-isomers/alkane-classification-1nn-lr.webp"
         alt="Comparison of 1-NN and Logistic Regression performance"
         title="Comparison of 1-NN and Logistic Regression performance"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">k-NN significantly outperforms logistic regression, especially for larger molecules. The performance gap widens as molecular complexity increases.</figcaption>
    
</figure>

<p><strong>Key observations:</strong></p>
<ul>
<li><strong>k-NN dominates</strong> across all molecular sizes</li>
<li><strong>Linear boundaries fail</strong> for larger molecules. This suggests nonlinear eigenvalue relationships.</li>
<li><strong>Performance gap grows</strong> with molecular complexity, indicating increasingly nonlinear structural patterns</li>
</ul>
<p>Logistic regression&rsquo;s performance indicates that discriminative patterns in eigenvalue space are fundamentally nonlinear. Capturing these complex relationships requires memory-based or non-linear approaches.</p>
<h2 id="implications-for-molecular-representation">Implications for Molecular Representation</h2>
<p>Our supervised learning experiments reveal a nuanced picture of Coulomb matrix eigenvalues as molecular descriptors. Eigenvalues preserve sufficient local structure for nearest-neighbor classification to work remarkably well, despite lacking clean global clusters.</p>
<p>This analysis reveals important lessons about molecular representations:</p>
<ol>
<li><strong>Empirical performance and mathematical elegance are separate axes</strong>: an elegant descriptor can still fail in practice as the space gets more complex.</li>
<li><strong>Context matters</strong>: Representations exhibit distinct performance characteristics under supervised versus unsupervised conditions.</li>
<li><strong>Molecular complexity is challenging</strong>: Even simple alkanes test our best descriptors.</li>
<li><strong>Local vs. global structure</strong>: Local neighborhood structures often contain highly discriminative information.</li>
</ol>
<p>For practitioners working with molecular representations, it is crucial to test multiple learning paradigms. Supervised and unsupervised approaches often yield different insights. Furthermore, logistic regression&rsquo;s poor performance indicates that discriminative patterns in eigenvalue space are fundamentally nonlinear. Capturing these complex relationships requires memory-based or non-linear approaches.</p>
<p><strong>Why this matters beyond the alkane case:</strong> molecular representations are the input layer for property prediction, and understanding where a simple descriptor like Coulomb-matrix eigenvalues fails (overlapping clusters for larger molecules, nonlinear class structure that defeats logistic regression) is what motivates moving to graph- or coordinate-aware models. The failure modes here are the argument for richer representations.</p>
<p>The data pipeline that generated the datasets used in this analysis is available at the <a href="/projects/isomer-dataset-generation/">Synthetic Isomer Data Generation Pipeline project page</a>.</p>
]]></content:encoded></item><item><title>Classifying Congressional Bills with Machine Learning</title><link>https://hunterheidenreich.com/posts/congressional-bill-policy-area-classification/</link><pubDate>Wed, 21 Feb 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/congressional-bill-policy-area-classification/</guid><description>Testing ML classification of congressional bills by policy area. Comparing Naive Bayes, Logistic Regression, and XGBoost on legislative text.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>This post explores machine learning approaches for classifying congressional bills by policy area, using data from the 115th to 117th Congresses (2017-2023). We&rsquo;ll examine:</p>
<ul>
<li>The fundamentals of bill classification</li>
<li>Traditional machine learning models as baselines</li>
<li>Performance analysis across different time periods and policy domains</li>
</ul>
<p>This work establishes baselines for future deep learning approaches to legislative text classification.</p>
<p><em>This post builds on the data foundation established in <a href="/posts/us-117th-congress-data-exploration/">Exploring the 117th U.S. Congress</a>.</em></p>
<h3 id="motivation">Motivation</h3>
<p>Automatically classifying congressional bills by policy area has practical value for researchers, journalists, and citizens who need to navigate thousands of bills each Congress. Machine learning can help identify patterns in legislative priorities and track policy trends over time.</p>
<h2 id="data">Data</h2>
<p>The data comes from scraping <a href="https://www.congress.gov/">Congress.gov</a> for all bills from the 115th through 117th Congresses. Each bill includes:</p>
<ul>
<li>Bill ID and title</li>
<li>Summary (when available): the earliest summary provided</li>
<li>Full text (when available): the earliest text version</li>
<li>Policy area classification</li>
</ul>
<p>Our task is to predict policy area from text features:</p>
<p>$$
f(X) = \hat{y}, \quad \text{where} \quad X = { \text{title}, \text{summary}, \text{text} }, \quad \hat{y} \in { \text{policy areas} }
$$</p>
<p>The complete dataset is available at <a href="https://huggingface.co/datasets/hheiden/us-congress-bill-policy-115_117">Hugging Face: hheiden/us-congress-bill-policy-115_117</a>.</p>
<h3 id="bills-by-congress">Bills by Congress</h3>
<p>Our dataset contains the following distribution:</p>
<table>
  <thead>
      <tr>
          <th>Congress</th>
          <th>Bills</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>115th</td>
          <td>13,556</td>
      </tr>
      <tr>
          <td>116th</td>
          <td>16,601</td>
      </tr>
      <tr>
          <td>117th</td>
          <td>17,817</td>
      </tr>
      <tr>
          <td><strong>Total</strong></td>
          <td><strong>47,974</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="policy-areas">Policy Areas</h3>
<p>Each bill receives a policy area label from <a href="https://www.congress.gov/">Congress.gov</a> (see <a href="https://www.congress.gov/help/field-values/policy-area">glossary</a>). The dataset includes 33 policy areas, though these classes are highly imbalanced.</p>
<p>The following table shows the number of bills in each policy area across the three Congresses:</p>
<table>
  <thead>
      <tr>
          <th>Policy Area</th>
          <th>115th</th>
          <th>116th</th>
          <th>117th</th>
          <th>Total</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Agriculture and Food</td>
          <td>312</td>
          <td>328</td>
          <td>398</td>
          <td>1,038</td>
      </tr>
      <tr>
          <td>Animals</td>
          <td>96</td>
          <td>83</td>
          <td>71</td>
          <td>250</td>
      </tr>
      <tr>
          <td>Armed Forces and National Security</td>
          <td>1,108</td>
          <td>1,337</td>
          <td>1,399</td>
          <td>3,844</td>
      </tr>
      <tr>
          <td>Arts, Culture, Religion</td>
          <td>81</td>
          <td>79</td>
          <td>103</td>
          <td>263</td>
      </tr>
      <tr>
          <td>Civil Rights and Liberties, Minority Issues</td>
          <td>175</td>
          <td>205</td>
          <td>220</td>
          <td>600</td>
      </tr>
      <tr>
          <td>Commerce</td>
          <td>312</td>
          <td>593</td>
          <td>633</td>
          <td>1,538</td>
      </tr>
      <tr>
          <td>Congress</td>
          <td>594</td>
          <td>541</td>
          <td>640</td>
          <td>1,775</td>
      </tr>
      <tr>
          <td>Crime and Law Enforcement</td>
          <td>827</td>
          <td>904</td>
          <td>1,022</td>
          <td>2,753</td>
      </tr>
      <tr>
          <td>Economics and Public Finance</td>
          <td>176</td>
          <td>210</td>
          <td>197</td>
          <td>583</td>
      </tr>
      <tr>
          <td>Education</td>
          <td>607</td>
          <td>798</td>
          <td>801</td>
          <td>2,206</td>
      </tr>
      <tr>
          <td>Emergency Management</td>
          <td>207</td>
          <td>198</td>
          <td>202</td>
          <td>607</td>
      </tr>
      <tr>
          <td>Energy</td>
          <td>316</td>
          <td>370</td>
          <td>530</td>
          <td>1,216</td>
      </tr>
      <tr>
          <td>Environmental Protection</td>
          <td>352</td>
          <td>423</td>
          <td>464</td>
          <td>1,239</td>
      </tr>
      <tr>
          <td>Families</td>
          <td>79</td>
          <td>127</td>
          <td>139</td>
          <td>345</td>
      </tr>
      <tr>
          <td>Finance and Financial Sector</td>
          <td>556</td>
          <td>611</td>
          <td>601</td>
          <td>1,768</td>
      </tr>
      <tr>
          <td>Foreign Trade and International Finance</td>
          <td>120</td>
          <td>148</td>
          <td>212</td>
          <td>480</td>
      </tr>
      <tr>
          <td>Government Operations and Politics</td>
          <td>1,008</td>
          <td>1,258</td>
          <td>1,272</td>
          <td>3,538</td>
      </tr>
      <tr>
          <td>Health</td>
          <td>1,526</td>
          <td>2,109</td>
          <td>2,276</td>
          <td>5,911</td>
      </tr>
      <tr>
          <td>Housing and Community Development</td>
          <td>142</td>
          <td>250</td>
          <td>231</td>
          <td>623</td>
      </tr>
      <tr>
          <td>Immigration</td>
          <td>398</td>
          <td>466</td>
          <td>591</td>
          <td>1,455</td>
      </tr>
      <tr>
          <td>International Affairs</td>
          <td>918</td>
          <td>1,178</td>
          <td>1,390</td>
          <td>3,486</td>
      </tr>
      <tr>
          <td>Labor and Employment</td>
          <td>348</td>
          <td>452</td>
          <td>552</td>
          <td>1,352</td>
      </tr>
      <tr>
          <td>Law</td>
          <td>109</td>
          <td>162</td>
          <td>175</td>
          <td>446</td>
      </tr>
      <tr>
          <td>Native Americans</td>
          <td>175</td>
          <td>234</td>
          <td>245</td>
          <td>654</td>
      </tr>
      <tr>
          <td>Public Lands and Natural Resources</td>
          <td>718</td>
          <td>648</td>
          <td>642</td>
          <td>2,008</td>
      </tr>
      <tr>
          <td>Science, Technology, Communications</td>
          <td>389</td>
          <td>551</td>
          <td>505</td>
          <td>1,445</td>
      </tr>
      <tr>
          <td>Social Sciences and History</td>
          <td>5</td>
          <td>6</td>
          <td>4</td>
          <td>15</td>
      </tr>
      <tr>
          <td>Social Welfare</td>
          <td>177</td>
          <td>229</td>
          <td>199</td>
          <td>605</td>
      </tr>
      <tr>
          <td>Sports and Recreation</td>
          <td>92</td>
          <td>93</td>
          <td>125</td>
          <td>310</td>
      </tr>
      <tr>
          <td>Taxation</td>
          <td>983</td>
          <td>1,156</td>
          <td>1,078</td>
          <td>3,217</td>
      </tr>
      <tr>
          <td>Transportation and Public Works</td>
          <td>492</td>
          <td>672</td>
          <td>742</td>
          <td>1,906</td>
      </tr>
      <tr>
          <td>Water Resources Development</td>
          <td>89</td>
          <td>111</td>
          <td>110</td>
          <td>310</td>
      </tr>
      <tr>
          <td>Private Legislation</td>
          <td>69</td>
          <td>71</td>
          <td>48</td>
          <td>188</td>
      </tr>
  </tbody>
</table>
<p>The class imbalance is severe: <code>Social Sciences and History</code> has only 15 bills across all three Congresses, while <code>Health</code> has 5,911 bills. This imbalance presents modeling challenges, as minority classes may lack sufficient representative samples.</p>
<h3 id="text-statistics">Text Statistics</h3>
<p>We analyzed token counts using spaCy to understand the computational requirements for each text field.</p>
<p>Title Token Statistics:</p>
<table>
  <thead>
      <tr>
          <th>Congress</th>
          <th>Average Tokens</th>
          <th>Min Tokens</th>
          <th>Max Tokens</th>
          <th>Total Tokens</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>115th</td>
          <td>12.3</td>
          <td>1</td>
          <td>167</td>
          <td>166,763</td>
      </tr>
      <tr>
          <td>116th</td>
          <td>11.3</td>
          <td>1</td>
          <td>226</td>
          <td>188,158</td>
      </tr>
      <tr>
          <td>117th</td>
          <td>11.5</td>
          <td>1</td>
          <td>272</td>
          <td>204,978</td>
      </tr>
      <tr>
          <td>All</td>
          <td>11.7</td>
          <td>1</td>
          <td>272</td>
          <td>559,419</td>
      </tr>
  </tbody>
</table>
<p>Summary Token Statistics:</p>
<table>
  <thead>
      <tr>
          <th>Congress</th>
          <th>Average Tokens</th>
          <th>Min Tokens</th>
          <th>Max Tokens</th>
          <th>Total Tokens</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>115th</td>
          <td>109.1</td>
          <td>2</td>
          <td>6,839</td>
          <td>1,479,212</td>
      </tr>
      <tr>
          <td>116th</td>
          <td>94.9</td>
          <td>2</td>
          <td>5,886</td>
          <td>1,574,732</td>
      </tr>
      <tr>
          <td>117th</td>
          <td>95.1</td>
          <td>2</td>
          <td>502</td>
          <td>1,695,276</td>
      </tr>
      <tr>
          <td>All</td>
          <td>99.0</td>
          <td>2</td>
          <td>6,839</td>
          <td>4,749,220</td>
      </tr>
  </tbody>
</table>
<p>Full Text Token Statistics:</p>
<table>
  <thead>
      <tr>
          <th>Congress</th>
          <th>Average Tokens</th>
          <th>Min Tokens</th>
          <th>Max Tokens</th>
          <th>Total Tokens</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>115th</td>
          <td>2,588.7</td>
          <td>91</td>
          <td>304,478</td>
          <td>35,092,075</td>
      </tr>
      <tr>
          <td>116th</td>
          <td>2,760.3</td>
          <td>70</td>
          <td>973,173</td>
          <td>45,824,498</td>
      </tr>
      <tr>
          <td>117th</td>
          <td>2,706.7</td>
          <td>71</td>
          <td>1,013,608</td>
          <td>48,224,757</td>
      </tr>
      <tr>
          <td>All</td>
          <td>-</td>
          <td>70</td>
          <td>1,013,608</td>
          <td>129,141,330</td>
      </tr>
  </tbody>
</table>
<p>These statistics reveal computational trade-offs:</p>
<ul>
<li><strong>Titles</strong> average ~12 tokens: computationally efficient but limited information.</li>
<li><strong>Summaries</strong> average ~100 tokens: good balance of information and efficiency.</li>
<li><strong>Full text</strong> averages ~2,700 tokens with 129M total tokens: detailed but computationally expensive. Processing this volume of text introduces real-world engineering challenges, such as memory constraints and a higher noise-to-signal ratio typical of long legal documents.</li>
</ul>
<p>We&rsquo;ll prototype with titles and summaries before considering full text, given the computational costs involved.</p>
<h2 id="evaluation-framework">Evaluation Framework</h2>
<h3 id="experimental-design">Experimental Design</h3>
<p>We train models on one Congress and test on others, creating a 3x3 evaluation grid. This setup evaluates both within-Congress performance (same session) and cross-Congress generalization (different sessions). We expect temporal drift between Congress sessions to impact performance.</p>
<h3 id="metrics-and-hyperparameter-tuning">Metrics and Hyperparameter Tuning</h3>
<p>We use weighted average F1 score to handle class imbalance, ensuring fair evaluation across all policy areas regardless of frequency.</p>
<p>For within-Congress evaluation, we report cross-validated scores. For cross-Congress evaluation, we test on the entire target Congress dataset.</p>
<p>Hyperparameter tuning uses Cross-Validation Grid Search with folds set to <code>min(3, n_samples)</code> to ensure all classes are represented. We apply the best parameters from training to test generalization across different Congresses.</p>
<h2 id="baseline-models">Baseline Models</h2>
<p>We evaluate three traditional machine learning approaches using TF-IDF vectorization:</p>
<h3 id="text-preprocessing">Text Preprocessing</h3>
<p>We convert text to numerical features using TF-IDF (term frequency-inverse document frequency), which weighs word importance by frequency within documents relative to the entire corpus. This creates normalized feature vectors suitable for machine learning classification.</p>
<h3 id="multinomial-naive-bayes">Multinomial Naive Bayes</h3>
<p>We start with Multinomial Naive Bayes as our simplest baseline. Despite its &ldquo;naive&rdquo; independence assumption between features, this model often performs surprisingly well for text classification tasks and serves as an important benchmark. If more complex models can&rsquo;t beat Naive Bayes, it signals potential issues with the approach or data.</p>
<p>The model&rsquo;s <code>feature_log_prob_</code> attribute reveals the most influential words for each policy area, providing interpretable insights into classification patterns.</p>
<p>You can see the code for training the Naive Bayes model below:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.feature_extraction.text <span style="color:#f92672">import</span> TfidfVectorizer
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.model_selection <span style="color:#f92672">import</span> GridSearchCV
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.pipeline <span style="color:#f92672">import</span> Pipeline
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.naive_bayes <span style="color:#f92672">import</span> MultinomialNB
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Create a pipeline with TF-IDF vectorizer and Multinomial Naive Bayes classifier</span>
</span></span><span style="display:flex;"><span>pipeline <span style="color:#f92672">=</span> Pipeline([
</span></span><span style="display:flex;"><span>    (<span style="color:#e6db74">&#39;tfidf&#39;</span>, TfidfVectorizer(lowercase<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>, dtype<span style="color:#f92672">=</span>np<span style="color:#f92672">.</span>float32)),
</span></span><span style="display:flex;"><span>    (<span style="color:#e6db74">&#39;clf&#39;</span>, MultinomialNB()),
</span></span><span style="display:flex;"><span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Define the parameters for grid search</span>
</span></span><span style="display:flex;"><span>parameters <span style="color:#f92672">=</span> {  
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#39;tfidf__ngram_range&#39;</span>: [(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>), (<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>), (<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">3</span>)],
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#39;tfidf__max_df&#39;</span>: (<span style="color:#ae81ff">0.05</span>, <span style="color:#ae81ff">0.1</span>, <span style="color:#ae81ff">0.25</span>, <span style="color:#ae81ff">0.5</span>),
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#39;tfidf__min_df&#39;</span>: (<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">5</span>, <span style="color:#ae81ff">10</span>),
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#39;clf__alpha&#39;</span>: (<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">0.1</span>, <span style="color:#ae81ff">0.01</span>, <span style="color:#ae81ff">0.001</span>),
</span></span><span style="display:flex;"><span>}
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Perform grid search with cross-validation</span>
</span></span><span style="display:flex;"><span>grid_search <span style="color:#f92672">=</span> GridSearchCV(
</span></span><span style="display:flex;"><span>    pipeline,
</span></span><span style="display:flex;"><span>    parameters,
</span></span><span style="display:flex;"><span>    scoring<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;f1_weighted&#39;</span>,
</span></span><span style="display:flex;"><span>    n_jobs<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>,
</span></span><span style="display:flex;"><span>    refit<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>,
</span></span><span style="display:flex;"><span>    cv<span style="color:#f92672">=</span><span style="color:#ae81ff">3</span>,
</span></span><span style="display:flex;"><span>)
</span></span><span style="display:flex;"><span>grid_search<span style="color:#f92672">.</span>fit(X_train, y_train)
</span></span></code></pre></div><h3 id="logistic-regression">Logistic Regression</h3>
<p>Logistic regression provides a natural step up in complexity from Naive Bayes. It uses the logistic function to convert linear combinations of features into probabilities, making it an excellent baseline for comparison with more sophisticated models while remaining interpretable.</p>
<p>You can see the code for training the Logistic Regression model below:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.feature_extraction.text <span style="color:#f92672">import</span> TfidfVectorizer
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.model_selection <span style="color:#f92672">import</span> GridSearchCV
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.pipeline <span style="color:#f92672">import</span> Pipeline
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.linear_model <span style="color:#f92672">import</span> LogisticRegression
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Create a pipeline with TF-IDF vectorizer and Logistic Regression classifier</span>
</span></span><span style="display:flex;"><span>pipeline <span style="color:#f92672">=</span> Pipeline([
</span></span><span style="display:flex;"><span>    (<span style="color:#e6db74">&#39;tfidf&#39;</span>, TfidfVectorizer(lowercase<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>, dtype<span style="color:#f92672">=</span>np<span style="color:#f92672">.</span>float32)),
</span></span><span style="display:flex;"><span>    (<span style="color:#e6db74">&#39;clf&#39;</span>, LogisticRegression(max_iter<span style="color:#f92672">=</span><span style="color:#ae81ff">1000</span>, random_state<span style="color:#f92672">=</span><span style="color:#ae81ff">42</span>, class_weight<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;balanced&#39;</span>)),
</span></span><span style="display:flex;"><span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Define the parameters for grid search</span>
</span></span><span style="display:flex;"><span>parameters <span style="color:#f92672">=</span> {  
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#39;tfidf__ngram_range&#39;</span>: [(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>), (<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)],
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#39;tfidf__max_df&#39;</span>: (<span style="color:#ae81ff">0.05</span>, <span style="color:#ae81ff">0.1</span>, <span style="color:#ae81ff">0.25</span>),
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#39;clf__C&#39;</span>: [<span style="color:#ae81ff">0.1</span>, <span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">10</span>],
</span></span><span style="display:flex;"><span>}
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Perform grid search with cross-validation</span>
</span></span><span style="display:flex;"><span>grid_search <span style="color:#f92672">=</span> GridSearchCV(
</span></span><span style="display:flex;"><span>    pipeline,
</span></span><span style="display:flex;"><span>    parameters,
</span></span><span style="display:flex;"><span>    scoring<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;f1_weighted&#39;</span>,
</span></span><span style="display:flex;"><span>    n_jobs<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>,
</span></span><span style="display:flex;"><span>    refit<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>,
</span></span><span style="display:flex;"><span>    cv<span style="color:#f92672">=</span><span style="color:#ae81ff">3</span>,
</span></span><span style="display:flex;"><span>)
</span></span><span style="display:flex;"><span>grid_search<span style="color:#f92672">.</span>fit(X_train, y_train)
</span></span></code></pre></div><h3 id="xgboost">XGBoost</h3>
<p>We include XGBoost as our tree-based ensemble method. While XGBoost typically excels on structured tabular data, we test whether its gradient boosting approach can effectively handle TF-IDF features for text classification.</p>
<p>You can see the code for training the XGBoost model below:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.feature_extraction.text <span style="color:#f92672">import</span> TfidfVectorizer
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.model_selection <span style="color:#f92672">import</span> GridSearchCV
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.pipeline <span style="color:#f92672">import</span> Pipeline
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> xgboost <span style="color:#f92672">import</span> XGBClassifier
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Create a pipeline with TF-IDF vectorizer and XGBoost classifier</span>
</span></span><span style="display:flex;"><span>pipeline <span style="color:#f92672">=</span> Pipeline([
</span></span><span style="display:flex;"><span>    (<span style="color:#e6db74">&#39;tfidf&#39;</span>, TfidfVectorizer(lowercase<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>, dtype<span style="color:#f92672">=</span>np<span style="color:#f92672">.</span>float32)),
</span></span><span style="display:flex;"><span>    (<span style="color:#e6db74">&#39;clf&#39;</span>, XGBClassifier(use_label_encoder<span style="color:#f92672">=</span><span style="color:#66d9ef">False</span>, eval_metric<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;mlogloss&#39;</span>, objective<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;multi:softmax&#39;</span>, seed<span style="color:#f92672">=</span><span style="color:#ae81ff">42</span>, n_jobs<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)),
</span></span><span style="display:flex;"><span>])
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Define the parameters for grid search</span>
</span></span><span style="display:flex;"><span>parameters <span style="color:#f92672">=</span> {  
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#39;tfidf__max_df&#39;</span>: (<span style="color:#ae81ff">0.05</span>, <span style="color:#ae81ff">0.1</span>, <span style="color:#ae81ff">0.25</span>),
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#39;clf__max_depth&#39;</span>: (<span style="color:#ae81ff">3</span>, <span style="color:#ae81ff">6</span>, <span style="color:#ae81ff">9</span>),
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#39;clf__n_estimators&#39;</span>: (<span style="color:#ae81ff">100</span>, <span style="color:#ae81ff">200</span>, <span style="color:#ae81ff">300</span>),
</span></span><span style="display:flex;"><span>}
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Perform grid search with cross-validation</span>
</span></span><span style="display:flex;"><span>grid_search <span style="color:#f92672">=</span> GridSearchCV(
</span></span><span style="display:flex;"><span>    pipeline,
</span></span><span style="display:flex;"><span>    parameters,
</span></span><span style="display:flex;"><span>    scoring<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;f1_weighted&#39;</span>,
</span></span><span style="display:flex;"><span>    refit<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>,
</span></span><span style="display:flex;"><span>    cv<span style="color:#f92672">=</span><span style="color:#ae81ff">3</span>,
</span></span><span style="display:flex;"><span>    verbose<span style="color:#f92672">=</span><span style="color:#ae81ff">3</span>,
</span></span><span style="display:flex;"><span>)
</span></span><span style="display:flex;"><span>grid_search<span style="color:#f92672">.</span>fit(X_train, y_train, clf__sample_weight<span style="color:#f92672">=</span>sample_weight)
</span></span></code></pre></div><h2 id="results">Results</h2>
<p>We evaluate models on three input types:</p>
<ul>
<li><strong>Title-only</strong>: Quick prototyping with limited context</li>
<li><strong>Summary-only</strong>: Balanced information content and computational efficiency</li>
<li><strong>Full text</strong>: Maximum context with computational constraints (limited hyperparameter tuning)</li>
</ul>
<h3 id="title-only-inputs">Title-Only Inputs</h3>
<h4 id="naive-bayes">Naive Bayes</h4>
<p>Title-only Naive Bayes experiments are run with the following settings:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>sweep_nb(
</span></span><span style="display:flex;"><span>    data,
</span></span><span style="display:flex;"><span>    X_key<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;title&#39;</span>,
</span></span><span style="display:flex;"><span>    y_key<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;policy_area&#39;</span>,
</span></span><span style="display:flex;"><span>    tfidf_params<span style="color:#f92672">=</span>{
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;lowercase&#39;</span>: <span style="color:#66d9ef">True</span>,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;dtype&#39;</span>: np<span style="color:#f92672">.</span>float32,
</span></span><span style="display:flex;"><span>    },
</span></span><span style="display:flex;"><span>    tfidf_grid<span style="color:#f92672">=</span>{
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;ngram_range&#39;</span>: [(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>), (<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)],
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;max_df&#39;</span>: (<span style="color:#ae81ff">0.05</span>, <span style="color:#ae81ff">0.1</span>, <span style="color:#ae81ff">0.25</span>, <span style="color:#ae81ff">0.5</span>),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;min_df&#39;</span>: (<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>, <span style="color:#ae81ff">5</span>),
</span></span><span style="display:flex;"><span>    },
</span></span><span style="display:flex;"><span>    nb_params<span style="color:#f92672">=</span>{},
</span></span><span style="display:flex;"><span>    nb_grid<span style="color:#f92672">=</span>{
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;alpha&#39;</span>: (<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">0.1</span>, <span style="color:#ae81ff">0.01</span>, <span style="color:#ae81ff">0.001</span>),
</span></span><span style="display:flex;"><span>    },
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><p>and the results:</p>
<pre><code>Training on Congress 115
Best score: 0.661
Refit Time: 0.570
Best parameters set:
	clf__alpha: 0.01
	tfidf__max_df: 0.05
	tfidf__min_df: 1
	tfidf__ngram_range: (1, 2)
Testing on Congress 116 F1: 0.6369760774921475
Testing on Congress 117 F1: 0.5488274400521962

Training on Congress 116
Best score: 0.677
Refit Time: 0.499
Best parameters set:
	clf__alpha: 0.01
	tfidf__max_df: 0.05
	tfidf__min_df: 1
	tfidf__ngram_range: (1, 2)
Testing on Congress 115 F1: 0.691175262953872
Testing on Congress 117 F1: 0.6798043069585031

Training on Congress 117
Best score: 0.670
Refit Time: 0.565
Best parameters set:
	clf__alpha: 0.01
	tfidf__max_df: 0.25
	tfidf__min_df: 1
	tfidf__ngram_range: (1, 2)
Testing on Congress 115 F1: 0.6168474701996426
Testing on Congress 116 F1: 0.6981574942116808

Mean fit time: 0.54 ± 0.03s
</code></pre>
<h4 id="results-summary">Results Summary</h4>
<p>The results demonstrate several key findings:</p>
<ul>
<li><strong>Fast training</strong>: Sub-second training times make this highly practical</li>
<li><strong>Solid baseline performance</strong>: F1 scores around 0.65-0.70 provide a reasonable starting point</li>
<li><strong>Consistent hyperparameters</strong>: Similar optimal settings across Congresses suggest stable patterns</li>
<li><strong>Temporal effects</strong>: Performance generally decreases when training and testing on Congresses further apart in time</li>
</ul>
<p>Training on the 116th Congress yields the best cross-Congress performance, likely due to its temporal proximity to both adjacent sessions.</p>















<figure class="post-figure center ">
    <img src="/img/nb_title_policy_area/f1s.webp"
         alt="Naive Bayes Policy Area Classification F1 Score"
         title="Naive Bayes Policy Area Classification F1 Score"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Naive Bayes F1 scores show temporal effects, with better performance between adjacent Congresses</figcaption>
    
</figure>

<p>The model learns interpretable features for each policy area. For example, Agriculture bills are strongly associated with terms like &ldquo;farm,&rdquo; &ldquo;crop,&rdquo; and &ldquo;livestock,&rdquo; while Armed Forces bills correlate with &ldquo;military,&rdquo; &ldquo;defense,&rdquo; and &ldquo;veterans.&rdquo;</p>















<figure class="post-figure center ">
    <img src="/img/nb_title_policy_area/top-Agriculture_and_Food.webp"
         alt="Naive Bayes Top Features for Agriculture and Food"
         title="Naive Bayes Top Features for Agriculture and Food"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Naive Bayes Top Features for Agriculture and Food</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/nb_title_policy_area/top-Armed_Forces_and_National_Security.webp"
         alt="Naive Bayes Top Features for Armed Forces and National Security"
         title="Naive Bayes Top Features for Armed Forces and National Security"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Naive Bayes Top Features for Armed Forces and National Security</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/nb_title_policy_area/top-Health.webp"
         alt="Naive Bayes Top Features for Health"
         title="Naive Bayes Top Features for Health"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Naive Bayes Top Features for Health</figcaption>
    
</figure>

<h4 id="logistic-regression-1">Logistic Regression</h4>
<p>Title-only Logistic Regression experiments are run with the following settings:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>sweep_logreg(
</span></span><span style="display:flex;"><span>    data,
</span></span><span style="display:flex;"><span>    X_key<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;title&#39;</span>,
</span></span><span style="display:flex;"><span>    y_key<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;policy_area&#39;</span>,
</span></span><span style="display:flex;"><span>    tfidf_params<span style="color:#f92672">=</span>{
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;lowercase&#39;</span>: <span style="color:#66d9ef">True</span>,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;dtype&#39;</span>: np<span style="color:#f92672">.</span>float32,
</span></span><span style="display:flex;"><span>    },
</span></span><span style="display:flex;"><span>    tfidf_grid<span style="color:#f92672">=</span>{
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;ngram_range&#39;</span>: [(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>), (<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">2</span>)],
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;max_df&#39;</span>: (<span style="color:#ae81ff">0.05</span>, <span style="color:#ae81ff">0.1</span>, <span style="color:#ae81ff">0.25</span>),
</span></span><span style="display:flex;"><span>    },
</span></span><span style="display:flex;"><span>    logreg_params<span style="color:#f92672">=</span>{
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;max_iter&#39;</span>: <span style="color:#ae81ff">1000</span>,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;random_state&#39;</span>: <span style="color:#ae81ff">42</span>,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;class_weight&#39;</span>: <span style="color:#e6db74">&#39;balanced&#39;</span>,
</span></span><span style="display:flex;"><span>    },
</span></span><span style="display:flex;"><span>    logreg_grid<span style="color:#f92672">=</span>{
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;C&#39;</span>: [<span style="color:#ae81ff">0.1</span>, <span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">10</span>],
</span></span><span style="display:flex;"><span>    },
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><p>and the results:</p>
<pre><code>Training on Congress 115
Best score: 0.704
Refit Time: 32.063
Best parameters set:
	clf__C: 10
	tfidf__max_df: 0.05
	tfidf__ngram_range: (1, 2)
Testing on Congress 116 F1: 0.6809188275881766
Testing on Congress 117 F1: 0.601917336933838

Training on Congress 116
Best score: 0.714
Refit Time: 31.227
Best parameters set:
	clf__C: 10
	tfidf__max_df: 0.05
	tfidf__ngram_range: (1, 2)
Testing on Congress 115 F1: 0.7408989977276476
Testing on Congress 117 F1: 0.7200639105208106

Training on Congress 117
Best score: 0.711
Refit Time: 34.083
Best parameters set:
	clf__C: 10
	tfidf__max_df: 0.05
	tfidf__ngram_range: (1, 2)
Testing on Congress 115 F1: 0.674418393892329
Testing on Congress 116 F1: 0.7405934743144291

Mean fit time: 32.46 ± 1.20s
</code></pre>
<h4 id="results-summary-1">Results Summary</h4>
<p>Logistic regression improves upon Naive Bayes performance:</p>
<ul>
<li><strong>Higher F1 scores</strong>: Generally 5-7 percentage points better than Naive Bayes</li>
<li><strong>Consistent hyperparameters</strong>: Optimal settings remain stable across Congresses</li>
<li><strong>Reasonable training time</strong>: 30-35 seconds per model remains manageable</li>
<li><strong>Strong cross-Congress generalization</strong>: F1 scores consistently above 0.70</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/logreg_title_policy_area/f1s.webp"
         alt="Logistic Regression Policy Area Classification F1 Score"
         title="Logistic Regression Policy Area Classification F1 Score"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Logistic Regression Policy Area Classification F1 Score</figcaption>
    
</figure>

<h4 id="xgboost-1">XGBoost</h4>
<p>Title-only XGBoost experiments are run with the following settings:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>sweep_xgb(
</span></span><span style="display:flex;"><span>    data,
</span></span><span style="display:flex;"><span>    X_key<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;title&#39;</span>,
</span></span><span style="display:flex;"><span>    y_key<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;policy_area&#39;</span>,
</span></span><span style="display:flex;"><span>    tfidf_grid<span style="color:#f92672">=</span>{
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;max_df&#39;</span>: (<span style="color:#ae81ff">0.05</span>,),
</span></span><span style="display:flex;"><span>    },
</span></span><span style="display:flex;"><span>    xgb_grid<span style="color:#f92672">=</span>{
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;max_depth&#39;</span>: (<span style="color:#ae81ff">6</span>,),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;eta&#39;</span>: (<span style="color:#ae81ff">0.3</span>,),
</span></span><span style="display:flex;"><span>    },
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><p>and the results:</p>
<pre><code>Training on Congress 115
Best score: 0.591
Refit Time: 198.063
Best parameters set:
	clf__eta: 0.3
	clf__max_depth: 6
	clf__num_class: 33
	tfidf__max_df: 0.05
Testing on Congress 116 F1: 0.5649530686141018
Testing on Congress 117 F1: 0.5215939580735101

Training on Congress 116
Best score: 0.600
Refit Time: 264.824
Best parameters set:
	clf__eta: 0.3
	clf__max_depth: 6
	clf__num_class: 33
	tfidf__max_df: 0.05
Testing on Congress 115 F1: 0.6037922738570368
Testing on Congress 117 F1: 0.5965027418245722

Training on Congress 117
Best score: 0.595
Refit Time: 249.799
Best parameters set:
	clf__eta: 0.3
	clf__max_depth: 6
	clf__num_class: 33
	tfidf__max_df: 0.05
Testing on Congress 115 F1: 0.5600491477899472
Testing on Congress 116 F1: 0.60815381664894

Mean fit time: 237.56 ± 28.60s
</code></pre>
<h4 id="results-summary-2">Results Summary</h4>
<p>XGBoost underperforms relative to expectations:</p>
<ul>
<li><strong>Poor performance</strong>: F1 scores significantly below linear models (0.55-0.60 range)</li>
<li><strong>Long training times</strong>: 4+ minutes per model with limited hyperparameter exploration</li>
<li><strong>Questionable value</strong>: The computational cost doesn&rsquo;t justify the poor performance</li>
</ul>
<p>Given these results, we focus on the more promising linear models for subsequent experiments with longer text inputs.</p>















<figure class="post-figure center ">
    <img src="/img/xgb_title_policy_area/f1s.webp"
         alt="XGBoost Policy Area Classification F1 Score"
         title="XGBoost Policy Area Classification F1 Score"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">XGBoost Policy Area Classification F1 Score</figcaption>
    
</figure>

<h4 id="training-efficiency">Training Efficiency</h4>
<p>The computational costs vary dramatically:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Training Time</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Naive Bayes</td>
          <td>0.54 $\pm$ 0.03s</td>
      </tr>
      <tr>
          <td>Logistic Regression</td>
          <td>32.46 $\pm$ 1.20s</td>
      </tr>
      <tr>
          <td>XGBoost</td>
          <td>237.56 $\pm$ 28.60s</td>
      </tr>
  </tbody>
</table>
<p>XGBoost&rsquo;s poor performance despite high computational cost suggests that tree-based methods may not be well-suited for sparse TF-IDF features. This is a classic example of the &ldquo;curse of dimensionality&rdquo;: tree-based models struggle to make effective splits in highly sparse, high-dimensional bag-of-words spaces compared to linear models that simply assign weights to all features simultaneously. We&rsquo;ll focus on linear models for the remaining experiments.</p>
<h3 id="summary-only-results">Summary-Only Results</h3>
<p>Using bill summaries provides substantially more context than titles alone, leading to significant performance improvements.</p>
<h4 id="naive-bayes-performance">Naive Bayes Performance</h4>
<p>The summary-based models show dramatic improvement over title-only versions:</p>
<ul>
<li><strong>F1 scores</strong>: 0.85+ within-Congress, 0.77-0.86 cross-Congress</li>
<li><strong>Training time</strong>: Still fast at ~3.4 seconds</li>
<li><strong>Strong generalization</strong>: Consistent performance across time periods</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/nb_summary_policy_area/f1s.webp"
         alt="Naive Bayes Summary Performance"
         title="Naive Bayes Summary Performance"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Summary-based models achieve 80%+ F1 scores across most Congress combinations</figcaption>
    
</figure>

<h4 id="logistic-regression-performance">Logistic Regression Performance</h4>
<p>Logistic regression slightly outperforms Naive Bayes on summaries:</p>
<ul>
<li><strong>F1 scores</strong>: 0.86+ within-Congress, 0.79-0.87 cross-Congress</li>
<li><strong>Training time</strong>: Reasonable at ~12 seconds</li>
<li><strong>Stable hyperparameters</strong>: Consistent optimal settings across Congresses</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/logreg_summary_policy_area/f1s.webp"
         alt="Logistic Regression Summary Performance"
         title="Logistic Regression Summary Performance"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Logistic regression maintains slight performance advantage over Naive Bayes</figcaption>
    
</figure>

<p>The performance difference between models suggests they rely on similar feature patterns, with logistic regression better capturing feature interactions.</p>
<h4 id="logistic-regression-2">Logistic Regression</h4>
<p>Summary-only Logistic Regression experiments are run with the following settings:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>sweep_logreg(
</span></span><span style="display:flex;"><span>    data,
</span></span><span style="display:flex;"><span>    X_key<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;summary&#39;</span>,
</span></span><span style="display:flex;"><span>    y_key<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;policy_area&#39;</span>,
</span></span><span style="display:flex;"><span>    tfidf_params<span style="color:#f92672">=</span>{
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;lowercase&#39;</span>: <span style="color:#66d9ef">True</span>,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;dtype&#39;</span>: np<span style="color:#f92672">.</span>float32,
</span></span><span style="display:flex;"><span>    },
</span></span><span style="display:flex;"><span>    tfidf_grid<span style="color:#f92672">=</span>{
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># &#39;ngram_range&#39;: [(1, 1), (1, 2)],</span>
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;max_df&#39;</span>: (<span style="color:#ae81ff">0.05</span>, <span style="color:#ae81ff">0.1</span>, <span style="color:#ae81ff">0.25</span>),
</span></span><span style="display:flex;"><span>    },
</span></span><span style="display:flex;"><span>    logreg_params<span style="color:#f92672">=</span>{
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;max_iter&#39;</span>: <span style="color:#ae81ff">1000</span>,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;random_state&#39;</span>: <span style="color:#ae81ff">42</span>,
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;class_weight&#39;</span>: <span style="color:#e6db74">&#39;balanced&#39;</span>,
</span></span><span style="display:flex;"><span>    },
</span></span><span style="display:flex;"><span>    logreg_grid<span style="color:#f92672">=</span>{
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#39;C&#39;</span>: [<span style="color:#ae81ff">0.1</span>, <span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">10</span>],
</span></span><span style="display:flex;"><span>    },
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><p>And the results:</p>
<pre><code>Training on Congress 115
Best score: 0.862
Refit Time: 9.007
Best parameters set:
	clf__C: 10
	tfidf__max_df: 0.25
Testing on Congress 116 F1: 0.8284864693401133
Testing on Congress 117 F1: 0.7934161507811646

Training on Congress 116
Best score: 0.865
Refit Time: 13.897
Best parameters set:
	clf__C: 10
	tfidf__max_df: 0.25
Testing on Congress 115 F1: 0.8637852557418315
Testing on Congress 117 F1: 0.8594775615031977

Training on Congress 117
Best score: 0.862
Refit Time: 12.167
Best parameters set:
	clf__C: 10
	tfidf__max_df: 0.25
Testing on Congress 115 F1: 0.8355736563084967
Testing on Congress 116 F1: 0.8696403838390832

Mean fit time: 11.69 ± 2.02s
</code></pre>















<figure class="post-figure center ">
    <img src="/img/logreg_summary_policy_area/f1s.webp"
         alt="Logistic Regression Policy Area Classification F1 Score"
         title="Logistic Regression Policy Area Classification F1 Score"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Logistic Regression Policy Area Classification F1 Score</figcaption>
    
</figure>

<h3 id="full-text-results">Full Text Results</h3>
<p>We test whether complete bill text improves performance over summaries, using optimal hyperparameters from summary experiments.</p>
<h4 id="naive-bayes-on-full-text">Naive Bayes on Full Text</h4>
<p>Surprisingly, full text yields slightly lower performance than summaries:</p>
<ul>
<li><strong>F1 scores</strong>: 0.84-0.85 within-Congress, 0.77-0.86 cross-Congress</li>
<li><strong>Training time</strong>: ~50 seconds (10x slower than summaries)</li>
<li><strong>Performance drop</strong>: Likely due to increased noise in lengthy documents</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/nb_text_policy_area/f1s.webp"
         alt="Naive Bayes Full Text Performance"
         title="Naive Bayes Full Text Performance"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Full text performance is slightly worse than summaries, suggesting diminishing returns</figcaption>
    
</figure>

<h4 id="logistic-regression-on-full-text">Logistic Regression on Full Text</h4>
<p>Logistic regression shows the strongest performance on full text:</p>
<ul>
<li><strong>F1 scores</strong>: 0.87-0.88 within-Congress, 0.83-0.89 cross-Congress</li>
<li><strong>Training time</strong>: ~70 seconds</li>
<li><strong>Best overall performance</strong>: up to 0.89 F1 on the strongest single cross-Congress pair (best within-Congress score 0.877)</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/logreg_text_policy_area/f1s.webp"
         alt="Logistic Regression Full Text Performance"
         title="Logistic Regression Full Text Performance"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Logistic regression achieves the best performance using full bill text</figcaption>
    
</figure>

<p>The logistic regression model benefits from having access to complete legislative language while effectively regularizing against noise.</p>
<h2 id="key-findings">Key Findings</h2>
<p>This baseline study establishes several important results:</p>
<p><strong>Best performing model</strong>: Logistic regression trained on full bill text reaches up to 0.89 F1 on the strongest single cross-Congress pair (best within-Congress score 0.877), providing a strong benchmark for future deep learning approaches.</p>
<p><strong>Text input comparison</strong>:</p>
<ul>
<li>Titles: Limited but fast (F1 ~0.65-0.70)</li>
<li>Summaries: Good balance of performance and efficiency (F1 ~0.85)</li>
<li>Full text: Best performance but computationally expensive (certified weighted-F1 0.871-0.877; up to ~0.89 on the strongest single cross-Congress pair)</li>
</ul>
<p><strong>Cross-Congress generalization</strong>: Models trained on one Congress generalize reasonably well to others, though performance decreases with temporal distance between sessions.</p>
<p><strong>Model performance ranking</strong>: Logistic Regression &gt; Naive Bayes &raquo; XGBoost for this text classification task.</p>
<h2 id="next-steps">Next Steps</h2>
<p>The strong baseline performance sets the stage for several research directions:</p>
<ol>
<li><strong>Deep learning models</strong>: Transformer-based approaches using pre-trained language models</li>
<li><strong>Dataset expansion</strong>: Including additional Congresses and more detailed bill metadata</li>
<li><strong>Error analysis</strong>: Understanding failure cases and class-specific performance patterns</li>
<li><strong>Feature engineering</strong>: Exploring domain-specific text preprocessing and feature extraction</li>
</ol>
<p>The complete dataset and experimental code are available for researchers interested in building upon these baselines.</p>
<p><strong>Resources</strong>:</p>
<ul>
<li>Dataset: <a href="https://huggingface.co/datasets/hheiden/us-congress-bill-policy-115_117">Hugging Face: hheiden/us-congress-bill-policy-115_117</a></li>
<li>Leaderboard: <a href="/leaderboards/policy_area_classification_leaderboard/">Policy Area Classification Leaderboard</a></li>
<li>Project: <a href="/projects/congressional-data-analysis/">Congressional Knowledge Graph &amp; Policy Classification</a></li>
</ul>
]]></content:encoded></item><item><title>Coulomb Matrices for Molecular Machine Learning</title><link>https://hunterheidenreich.com/posts/molecular-descriptor-coulomb-matrix/</link><pubDate>Sat, 10 Feb 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/molecular-descriptor-coulomb-matrix/</guid><description>Learn how Coulomb matrices encode 3D molecular structure for machine learning from basic theory to Python implementation and practical limitations.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>When working with machine learning in chemistry, one of the first challenges you encounter is how to represent molecules in a way that algorithms can understand. You can&rsquo;t just feed raw atomic coordinates into a model. The representation needs to be invariant to rotation, translation, and atom ordering, since these operations don&rsquo;t change the molecule&rsquo;s fundamental properties.</p>
<p>The Coulomb matrix, introduced by Rupp et al. in 2012 <a href="#ref-1">[1]</a>, provides a straightforward solution to this problem. While newer methods have largely superseded it for practical applications, the Coulomb matrix remains an excellent starting point for understanding how molecular descriptors work.</p>
<p>The key insight is simple: we encode pairwise relationships between atoms in a way that captures the essential physics while maintaining the required invariances.</p>
<h2 id="the-coulomb-matrix-theory-and-intuition">The Coulomb Matrix: Theory and Intuition</h2>
<p>The Coulomb matrix encodes molecular structure in a symmetric $N \times N$ matrix, where $N$ is the number of atoms. Each element $C_{ij}$ is defined as:</p>
<p>$$
C_{ij} = \begin{cases} 0.5 Z_i^{2.4} &amp; \text{if } i = j, \\ \frac{Z_i Z_j}{|\mathbf{R}_i - \mathbf{R}_j|} &amp; \text{if } i \neq j, \end{cases}
$$</p>
<p>Here, $Z_i$ is the atomic number of atom $i$, and $\mathbf{R}_i$ is its position in 3D space. The diagonal elements ($0.5 Z_i^{2.4}$) represent atomic self-energies, derived from fitting atomic numbers to experimental data. The off-diagonal elements mimic Coulombic interactions between atoms. They&rsquo;re inversely proportional to distance, just like electrostatic potential energy <a href="#ref-3">[3]</a>.</p>
<p>This construction gives us several useful properties:</p>
<ul>
<li><strong>Rotation and translation invariant</strong>: Only relative distances matter</li>
<li><strong>Symmetric</strong>: $C_{ij} = C_{ji}$, which is physically sensible</li>
<li><strong>Size-extensive</strong>: Larger molecules have larger matrix elements</li>
<li><strong>Captures 3D structure</strong>: Nearby atoms have larger interaction terms</li>
</ul>
<p>While more sophisticated methods exist today <a href="#ref-2">[2]</a>, the Coulomb matrix&rsquo;s simplicity makes it ideal for understanding the fundamentals of molecular representation.</p>
<h3 id="hands-on-example-bicyclobutane">Hands-on Example: Bicyclobutane</h3>
<p>Let&rsquo;s calculate the Coulomb matrix for <a href="https://en.wikipedia.org/wiki/Bicyclobutane">bicyclobutane</a>, a strained but stable bicyclic system (bicyclo[1.1.0]butane, C4H6, two cis-fused cyclopropane rings). This example will show you exactly how the theory translates to practice.</p>















<figure class="post-figure center ">
    <img src="https://upload.wikimedia.org/wikipedia/commons/b/b4/Bicyclobutane-2.svg"
         alt="Bicyclobutane"
         title="Bicyclobutane"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Bicyclobutane structure (Smokefoot, Public domain, via Wikimedia Commons)</figcaption>
    
</figure>

<p>I&rsquo;ll use Python with the Atomic Simulation Environment (<code>ase</code>) for molecular structure <a href="#ref-4">[4]</a> and <code>dscribe</code> for the Coulomb matrix calculation <a href="#ref-2">[2]</a>:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> ase.build <span style="color:#f92672">import</span> molecule
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> ase.visualize <span style="color:#f92672">import</span> view
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Load the bicyclobutane structure</span>
</span></span><span style="display:flex;"><span>bicyclobutane <span style="color:#f92672">=</span> molecule(<span style="color:#e6db74">&#39;bicyclobutane&#39;</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Optional: visualize the structure</span>
</span></span><span style="display:flex;"><span>view(bicyclobutane, viewer<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;x3d&#39;</span>)
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/bicyclobutane_ase_1.webp"
         alt="Bicyclobutane 3D structure"
         title="Bicyclobutane 3D structure"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">3D structure of bicyclobutane</figcaption>
    
</figure>

<p>Now we calculate the Coulomb matrix using DScribe:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> dscribe.descriptors <span style="color:#f92672">import</span> CoulombMatrix
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Set up the descriptor</span>
</span></span><span style="display:flex;"><span>cm <span style="color:#f92672">=</span> CoulombMatrix(n_atoms_max<span style="color:#f92672">=</span>len(bicyclobutane))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Calculate and reshape into matrix form</span>
</span></span><span style="display:flex;"><span>cm_bicyclobutane <span style="color:#f92672">=</span> cm<span style="color:#f92672">.</span>create(bicyclobutane)
</span></span><span style="display:flex;"><span>cm_bicyclobutane <span style="color:#f92672">=</span> cm_bicyclobutane<span style="color:#f92672">.</span>reshape(len(bicyclobutane), len(bicyclobutane))
</span></span></code></pre></div><h3 id="visualizing-the-results">Visualizing the Results</h3>
<p>The Coulomb matrix can be visualized as a heatmap. Let&rsquo;s look at both the raw matrix and its logarithm:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> matplotlib.pyplot <span style="color:#66d9ef">as</span> plt
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> numpy <span style="color:#66d9ef">as</span> np
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Raw Coulomb matrix</span>
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>figure(figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">8</span>, <span style="color:#ae81ff">8</span>), dpi<span style="color:#f92672">=</span><span style="color:#ae81ff">150</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>imshow(cm_bicyclobutane, cmap<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;coolwarm&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>colorbar(label<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;Magnitude&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>title(<span style="color:#e6db74">&#39;Coulomb Matrix for Bicyclobutane&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>show()
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/cm_bicyclobutane.webp"
         alt="Coulomb matrix of bicyclobutane"
         title="Coulomb matrix of bicyclobutane"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Coulomb matrix for bicyclobutane</figcaption>
    
</figure>

<p>The raw matrix shows clear patterns:</p>
<ul>
<li><strong>Large diagonal elements</strong>: Carbon atoms (Z=6) dominate due to their higher atomic numbers</li>
<li><strong>Smaller off-diagonal elements</strong>: Represent pairwise interactions</li>
<li><strong>Minimal hydrogen contribution</strong>: Hydrogen atoms (Z=1) have much smaller values</li>
</ul>
<p>For better visualization of the structure, the logarithm reveals more detail:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>figure(figsize<span style="color:#f92672">=</span>(<span style="color:#ae81ff">8</span>, <span style="color:#ae81ff">8</span>), dpi<span style="color:#f92672">=</span><span style="color:#ae81ff">150</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>imshow(np<span style="color:#f92672">.</span>log(cm_bicyclobutane), cmap<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;coolwarm&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>colorbar(label<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;log(Magnitude)&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>title(<span style="color:#e6db74">&#39;Log Coulomb Matrix for Bicyclobutane&#39;</span>)
</span></span><span style="display:flex;"><span>plt<span style="color:#f92672">.</span>show()
</span></span></code></pre></div>














<figure class="post-figure center ">
    <img src="/img/cm_bicyclobutane_log.webp"
         alt="Log Coulomb matrix of bicyclobutane"
         title="Log Coulomb matrix of bicyclobutane"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Log-scale reveals more structural detail</figcaption>
    
</figure>

<h3 id="eigenvalue-analysis">Eigenvalue Analysis</h3>
<p>The eigenvalues of the Coulomb matrix provide another perspective on molecular structure:</p>















<figure class="post-figure center ">
    <img src="/img/cm_bicyclobutane_eigenvalues.webp"
         alt="Eigenvalues of Coulomb matrix"
         title="Eigenvalues of Coulomb matrix"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Eigenvalues of the Coulomb matrix</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/cm_bicyclobutane_log_eigenvalues.webp"
         alt="Eigenvalues of log Coulomb matrix"
         title="Eigenvalues of log Coulomb matrix"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Eigenvalues on logarithmic scale</figcaption>
    
</figure>

<p>These eigenvalues are often used as features themselves, providing a more compact representation than the full matrix.</p>
<h2 id="practical-limitations">Practical Limitations</h2>
<p>The Coulomb matrix has significant limitations that explain why it&rsquo;s been largely superseded by modern methods. Understanding these constraints is crucial for knowing when and how to use this descriptor.</p>
<h3 id="the-size-problem">The Size Problem</h3>
<p>Every molecule must be represented by the same size matrix, which creates several issues:</p>
<ul>
<li><strong>Padding overhead</strong>: Small molecules get padded with zeros up to the maximum size</li>
<li><strong>Quadratic scaling</strong>: An $N$-atom molecule requires $N^2$ features</li>
<li><strong>Fixed maximum size</strong>: You can&rsquo;t represent molecules larger than your preset limit</li>
<li><strong>Inefficient storage</strong>: Most elements are zero for small molecules in large matrices</li>
</ul>
<p>For a dataset ranging from 5-atom to 50-atom molecules, every molecule needs a 50x50 matrix. That&rsquo;s 2,500 features, most of which are zero for smaller molecules.</p>
<h3 id="permutation-sensitivity">Permutation Sensitivity</h3>
<p>Despite being called &ldquo;invariant,&rdquo; the Coulomb matrix can actually change if you reorder the atoms in your input file. The standard solution is to sort atoms by the L2 norm of their matrix rows, but this introduces its own problems:</p>
<ul>
<li><strong>Symmetry breaking</strong>: Equivalent atoms might be ordered differently</li>
<li><strong>Numerical instability</strong>: Small coordinate changes can flip the ordering</li>
<li><strong>Loss of chemical intuition</strong>: The sorted order doesn&rsquo;t reflect meaningful chemistry</li>
</ul>
<p>Interestingly, some studies suggest that adding controlled noise to create multiple permutations can actually improve machine learning performance <a href="#ref-5">[5]</a>.</p>
<h3 id="limited-scope">Limited Scope</h3>
<p>The Coulomb matrix works well only for specific types of systems:</p>
<ul>
<li><strong>Small molecules</strong>: Performance degrades for large systems due to size scaling</li>
<li><strong>Gas-phase</strong>: Not suitable for periodic systems like crystals or surfaces</li>
<li><strong>Single conformations</strong>: Each 3D structure gets its own matrix</li>
<li><strong>Non-reactive</strong>: Doesn&rsquo;t capture bond-breaking or formation</li>
</ul>
<p>For periodic systems, you&rsquo;d need specialized variants like the Ewald sum matrix <a href="#ref-6">[6]</a>.</p>
<h2 id="why-learn-it-anyway">Why Learn It Anyway?</h2>
<p>Given these limitations, why spend time understanding the Coulomb matrix? Several reasons:</p>
<p><strong>Educational value</strong>: It&rsquo;s conceptually straightforward and provides excellent intuition for how molecular descriptors work. The mathematical formulation is simple enough to implement from scratch.</p>
<p><strong>Historical importance</strong>: Many subsequent methods build on ideas first explored with Coulomb matrices. Understanding this foundation helps you appreciate why newer methods were developed.</p>
<p><strong>Benchmarking</strong>: It remains useful as a baseline method for comparing new descriptors on small molecular datasets.</p>
<p><strong>Proof of concept</strong>: For exploratory work on small, well-defined datasets, the Coulomb matrix can still provide quick insights.</p>
<p>If you&rsquo;re working on practical problems with larger datasets or diverse molecular sizes, consider modern alternatives like graph neural networks, descriptors from DScribe&rsquo;s extended library, or learned representations from transformer models.</p>
<h2 id="putting-it-in-context">Putting It in Context</h2>
<p>To see the Coulomb matrix applied to real problems, I&rsquo;ve written a detailed guide using it for molecular classification:</p>
<ul>
<li><a href="/posts/alkane-constitutional-isomer-classification/">Coulomb Matrix Eigenvalues: Can You Hear the Shape of a Molecule?</a>: A comprehensive analysis of alkane isomers, from unsupervised clustering limits to supervised classification successes.</li>
</ul>
<p>For comparison with modern approaches, check out my post on <a href="/posts/geom-conformer-generation-dataset/">3D conformer generation with the GEOM dataset</a>, which showcases more sophisticated molecular representations. For technical specifications and benchmarks, see the <a href="/notes/chemistry/datasets/geom/">GEOM dataset card</a>.</p>
<p>The Coulomb matrix may be dated, but it remains an excellent entry point into the world of molecular machine learning. Once you understand its strengths and limitations, you&rsquo;ll be better equipped to appreciate why the field has moved toward more sophisticated approaches.</p>
<hr>
<p><em>Have questions about molecular descriptors or want to discuss other approaches to molecular machine learning? I&rsquo;d be happy to explore these topics further.</em></p>
<h2 id="references">References</h2>
<ul>
<li><a id="ref-1"></a>[1]: M. Rupp, A. Tkatchenko, K.-R. Müller, and O. A. von Lilienfeld, &ldquo;Fast and Accurate Modeling of Molecular Atomization Energies with Machine Learning,&rdquo; Physical Review Letters, 108(5), 058301 (2012). <a href="https://doi.org/10.1103/PhysRevLett.108.058301">https://doi.org/10.1103/PhysRevLett.108.058301</a> <a href="https://arxiv.org/abs/1109.2618">arXiv:1109.2618</a></li>
<li><a id="ref-2"></a>[2] L. Himanen, M. O. J. Jäger, E. V. Morooka, F. F. Canova, Y. S. Ranawat, D. Z. Gao, P. Rinke, and A. S. Foster, &ldquo;DScribe: Library of descriptors for machine learning in materials science,&rdquo; Computer Physics Communications, 247, 106949 (2020). <a href="https://doi.org/10.1016/j.cpc.2019.106949">https://doi.org/10.1016/j.cpc.2019.106949</a> <a href="https://arxiv.org/abs/1904.08875">arXiv:1904.08875</a></li>
<li><a id="ref-3"></a>[3] J. Schrier, &ldquo;Can one hear the shape of a molecule (from its Coulomb matrix eigenvalues)?,&rdquo; Journal of Chemical Information and Modeling, 60(8), 3804-3811 (2020). <a href="https://doi.org/10.1021/acs.jcim.0c00631">https://doi.org/10.1021/acs.jcim.0c00631</a></li>
<li><a id="ref-4"></a>[4] A. H. Larsen, J. J. Mortensen, J. Blomqvist, I. E. Castelli, R. Christensen, M. Dułak, J. Friis, M. N. Groves, B. Hammer, C. Hargus, E. D. Hermes, P. C. Jennings, P. B. Jensen, J. Kermode, J. R. Kitchin, E. L. Kolsbjerg, J. Kubal, K. Kaasbjerg, S. Lysgaard, J. B. Maronsson, T. Maxson, T. Olsen, L. Pastewka, A. Peterson, C. Rostgaard, J. Schiøtz, O. Schütt, M. Strange, K. S. Thygesen, T. Vegge, L. Vilhelmsen, M. Walter, Z. Zeng, and K. W. Jacobsen, &ldquo;The Atomic Simulation Environment - A Python library for working with atoms,&rdquo; J. Phys.: Condens. Matter, 29, 273002 (2017). <a href="https://doi.org/10.1088/1361-648X/aa680e">https://doi.org/10.1088/1361-648X/aa680e</a> <a href="https://ase-lib.org/index.html">documentation</a></li>
<li><a id="ref-5"></a>[5] G. Montavon, K. Hansen, S. Fazli, M. Rupp, F. Biegler, A. Ziehe, A. Tkatchenko, A. Lilienfeld, and K.-R. Müller, &ldquo;Learning invariant representations of molecules for atomization energy prediction,&rdquo; Advances in Neural Information Processing Systems, 25 (2012). Available online: <a href="https://proceedings.neurips.cc/paper_files/paper/2012/file/115f89503138416a242f40fb7d7f338e-Paper.pdf">https://proceedings.neurips.cc/paper_files/paper/2012/file/115f89503138416a242f40fb7d7f338e-Paper.pdf</a></li>
<li><a id="ref-6"></a>[6] F. Faber, A. Lindmaa, O. A. von Lilienfeld, and R. Armiento, &ldquo;Crystal structure representations for machine learning models of formation energies,&rdquo; International Journal of Quantum Chemistry, 115(16), 1094-1101 (2015). <a href="https://doi.org/10.1002/qua.24917">https://doi.org/10.1002/qua.24917</a></li>
</ul>
]]></content:encoded></item><item><title>How Does Congress Actually Work? Data from 15K Bills</title><link>https://hunterheidenreich.com/posts/us-117th-congress-data-exploration/</link><pubDate>Thu, 05 Oct 2023 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/us-117th-congress-data-exploration/</guid><description>What happens to bills in Congress? Analyzing 15K+ bills from the 117th Congress to understand legislative patterns, party dynamics, and success rates.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>Analyzing congressional data reveals the underlying mechanics of the legislative process. Legislative text is a large, structured corpus well suited to text classification and other NLP tasks. I scraped data from Congress.gov to analyze what actually happens to the thousands of bills introduced each session and to build a foundational dataset for downstream machine learning tasks.</p>
<p>This analysis focuses on the 117th Congress (2021-2023), examining 15,000+ bills to understand basic patterns: Which bills get introduced? How many receive votes? What factors influence success?</p>
<p>This post covers the foundational exploratory analysis and data collection process, setting the stage for <a href="/posts/congressional-bill-policy-area-classification/">predictive modeling and policy area classification</a>.</p>
<h2 id="data-collection">Data Collection</h2>
<p>My primary source is <a href="https://www.congress.gov/">Congress.gov</a>, maintained by the Library of Congress. I focused on the 117th Congress (2021-2023), collecting data on bills and joint resolutions, omitting simple resolutions, concurrent resolutions, and amendments.</p>
<p><strong>Data collected:</strong></p>
<table>
  <thead>
      <tr>
          <th>Bill Type</th>
          <th>Introduced</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>House Bill</td>
          <td>9,698</td>
      </tr>
      <tr>
          <td>House Joint Resolution</td>
          <td>106</td>
      </tr>
      <tr>
          <td>Senate Bill</td>
          <td>5,357</td>
      </tr>
      <tr>
          <td>Senate Joint Resolution</td>
          <td>70</td>
      </tr>
      <tr>
          <td><strong>Total</strong></td>
          <td><strong>15,231</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="technical-implementation">Technical Implementation</h3>
<p>Building a usable NLP dataset requires careful handling of the source. Congress.gov loads content dynamically and presents nested DOM structures, so the scraper combines static HTML parsing with a headless browser to render JavaScript before parsing.</p>
<p><strong>Implementation details:</strong></p>
<ul>
<li><a href="https://www.python.org/">Python</a> for core orchestration and data schema management</li>
<li><a href="https://www.selenium.dev/">Selenium</a> for executing JavaScript and loading dynamic page elements</li>
<li><a href="https://www.crummy.com/software/BeautifulSoup/bs4/doc/">BeautifulSoup</a> for structured HTML parsing</li>
<li>Regex for text normalization and extracting clean legislative text for language models</li>
</ul>
<p>The crawler used 5-second delays between requests to respect server limits, a roughly 3-day collection run. It handles edge cases in congressional text formatting and writes one JSON record per bill on a fixed schema. The crawler and processed data are available on <a href="https://github.com/hunter-heidenreich/congress-scraper">GitHub</a>.</p>
<p>For each bill, I queried two pages:</p>
<ul>
<li>All info page: <code>https://www.congress.gov/bill/117th-congress/{bill_type}/{bill_id}/all-info</code></li>
<li>Text page: <code>https://www.congress.gov/bill/117th-congress/{bill_type}/{bill_id}/text?format=txt</code></li>
</ul>
<p>The parsing process involved targeting specific HTML elements and implementing basic caching to avoid redundant requests.</p>
<h2 id="key-findings">Key Findings</h2>
<p>The analysis reveals clear patterns in congressional activity. Most bills never receive votes, and success rates vary significantly by party and policy area.</p>
<h3 id="legislative-outcomes">Legislative Outcomes</h3>
<p>The fundamental question: what happens to bills after introduction?</p>
<p>Each bill has a tracker status indicating its position in the legislative process. The eight possible statuses can be grouped into three meaningful categories:</p>
<ul>
<li><strong>Introduced</strong>: Bills introduced but never voted on</li>
<li><strong>Stalled</strong>: Bills that saw votes but didn&rsquo;t become law (since the 117th Congress ended, these effectively died)</li>
<li><strong>Law</strong>: Bills signed by the President</li>
</ul>
<table>
  <thead>
      <tr>
          <th></th>
          <th>Introduced</th>
          <th>Stalled</th>
          <th>Law</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>House Bill</td>
          <td>8,977</td>
          <td>523</td>
          <td>198</td>
      </tr>
      <tr>
          <td>House Joint Resolution</td>
          <td>102</td>
          <td>1</td>
          <td>3</td>
      </tr>
      <tr>
          <td>Senate Bill</td>
          <td>5,083</td>
          <td>114</td>
          <td>160</td>
      </tr>
      <tr>
          <td>Senate Joint Resolution</td>
          <td>57</td>
          <td>9</td>
          <td>4</td>
      </tr>
      <tr>
          <td><strong>Total</strong></td>
          <td><strong>14,219</strong></td>
          <td><strong>647</strong></td>
          <td><strong>365</strong></td>
      </tr>
  </tbody>
</table>
<p><strong>Key insights:</strong></p>
<ul>
<li>Only 7% of introduced bills ever receive a vote</li>
<li>Of bills that receive votes, 36% become law</li>
<li>Overall, just 2% of introduced bills become law</li>
</ul>
<h3 id="sponsor-analysis">Sponsor Analysis</h3>
<p>The bill sponsor (the primary member who introduces legislation) provides insights into party and geographic patterns.</p>
<h4 id="party-breakdown">Party Breakdown</h4>
<table>
  <thead>
      <tr>
          <th></th>
          <th>Introduced</th>
          <th>Stalled</th>
          <th>Law</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Democrat</td>
          <td>8,271</td>
          <td>437</td>
          <td>235</td>
      </tr>
      <tr>
          <td>Republican</td>
          <td>5,883</td>
          <td>210</td>
          <td>130</td>
      </tr>
      <tr>
          <td>Independent</td>
          <td>65</td>
          <td>0</td>
          <td>0</td>
      </tr>
  </tbody>
</table>
<p><strong>Party comparison:</strong></p>
<ul>
<li><strong>Democrats</strong>: 7.5% of bills moved beyond introduction; 2.6% became law</li>
<li><strong>Republicans</strong>: 5.5% of bills moved beyond introduction; 2.1% became law</li>
<li>When bills do advance, Republicans have a slightly higher success rate (38% vs 35%)</li>
</ul>
<h4 id="geographic-distribution">Geographic Distribution</h4>
<p><strong>Top 10 states by bills introduced:</strong></p>
<table>
  <thead>
      <tr>
          <th>Ranking</th>
          <th>State: Introduced</th>
          <th>State: Stalled</th>
          <th>State: Law</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1</td>
          <td>CA: 1,350</td>
          <td>CA: 93</td>
          <td>CA: 34</td>
      </tr>
      <tr>
          <td>2</td>
          <td>TX: 879</td>
          <td>NY: 44</td>
          <td>MI: 30</td>
      </tr>
      <tr>
          <td>3</td>
          <td>NY: 784</td>
          <td>TX: 43</td>
          <td>TX: 25</td>
      </tr>
      <tr>
          <td>4</td>
          <td>FL: 766</td>
          <td>MI: 28</td>
          <td>NY: 24</td>
      </tr>
      <tr>
          <td>5</td>
          <td>IL: 660</td>
          <td>NJ: 28</td>
          <td>MN: 17</td>
      </tr>
      <tr>
          <td>6</td>
          <td>PA: 521</td>
          <td>IL: 27</td>
          <td>IL: 16</td>
      </tr>
      <tr>
          <td>7</td>
          <td>NJ: 478</td>
          <td>VA: 26</td>
          <td>OH: 11</td>
      </tr>
      <tr>
          <td>8</td>
          <td>MI: 380</td>
          <td>FL: 24</td>
          <td>VA: 11</td>
      </tr>
      <tr>
          <td>9</td>
          <td>OH: 377</td>
          <td>PA: 22</td>
          <td>FL: 11</td>
      </tr>
      <tr>
          <td>10</td>
          <td>MA: 361</td>
          <td>OH: 19</td>
          <td>GA: 9</td>
      </tr>
  </tbody>
</table>
<p><strong>Per-representative normalization reveals different patterns:</strong></p>
<table>
  <thead>
      <tr>
          <th>Ranking</th>
          <th>State: Introduced</th>
          <th>State: Stalled</th>
          <th>State: Law</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1</td>
          <td>DC: 101.0</td>
          <td>DC: 7.0</td>
          <td>AK: 2.2</td>
      </tr>
      <tr>
          <td>2</td>
          <td>NH: 47.5</td>
          <td>AK: 2.8</td>
          <td>NH: 2.0</td>
      </tr>
      <tr>
          <td>3</td>
          <td>MT: 44.0</td>
          <td>IA: 2.3</td>
          <td>MT: 2.0</td>
      </tr>
      <tr>
          <td>4</td>
          <td>OR: 41.0</td>
          <td>SD: 2.3</td>
          <td>MI: 1.9</td>
      </tr>
      <tr>
          <td>5</td>
          <td>NV: 40.0</td>
          <td>NH: 2.2</td>
          <td>MN: 1.5</td>
      </tr>
      <tr>
          <td>6</td>
          <td>DE: 38.7</td>
          <td>VA: 2.0</td>
          <td>HI: 1.5</td>
      </tr>
      <tr>
          <td>7</td>
          <td>SD: 38.3</td>
          <td>NJ: 2.0</td>
          <td>CT: 1.3</td>
      </tr>
      <tr>
          <td>8</td>
          <td>IA: 37.7</td>
          <td>PR: 2.0</td>
          <td>IA: 1.2</td>
      </tr>
      <tr>
          <td>9</td>
          <td>RI: 36.5</td>
          <td>NV: 1.8</td>
          <td>OR: 1.1</td>
      </tr>
      <tr>
          <td>10</td>
          <td>UT: 36.0</td>
          <td>MO: 1.8</td>
          <td>SD: 1.0</td>
      </tr>
  </tbody>
</table>
<h4 id="top-individual-sponsors">Top Individual Sponsors</h4>
<p><strong>Most prolific legislators by bills introduced:</strong></p>
<table>
  <thead>
      <tr>
          <th>Ranking</th>
          <th>Individual: Introduced</th>
          <th>Individual: Stalled</th>
          <th>Individual: Law</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1</td>
          <td>Sen. Rubio (R-FL): 186</td>
          <td>Sen. Peters (D-MI): 11</td>
          <td>Sen. Peters (D-MI): 19</td>
      </tr>
      <tr>
          <td>2</td>
          <td>Sen. Klobuchar (D-MN): 143</td>
          <td>Sen. Cornyn (R-TX): 8</td>
          <td>Sen. Cornyn (R-TX): 15</td>
      </tr>
      <tr>
          <td>3</td>
          <td>Sen. Lee (R-UT): 125</td>
          <td>Rep. Connolly (D-VA-11): 8</td>
          <td>Sen. Klobuchar (D-MN): 7</td>
      </tr>
      <tr>
          <td>4</td>
          <td>Sen. Markey (D-MA): 118</td>
          <td>Rep. Takano (D-CA-41): 8</td>
          <td>Sen. Tester (D-MT): 6</td>
      </tr>
      <tr>
          <td>5</td>
          <td>Sen. Casey (D-PA): 116</td>
          <td>Sen. Grassley (R-IA): 7</td>
          <td>Sen. Rubio (R-FL): 6</td>
      </tr>
      <tr>
          <td>6</td>
          <td>Sen. Cortez Masto (D-NV): 109</td>
          <td>Del. Norton (D-DC): 7</td>
          <td>Rep. DeLauro (D-CT-3): 6</td>
      </tr>
      <tr>
          <td>7</td>
          <td>Sen. Booker (D-NJ): 106</td>
          <td>Rep. Johnson (D-TX-30): 7</td>
          <td>Sen. Grassley (R-IA): 5</td>
      </tr>
      <tr>
          <td>8</td>
          <td>Sen. Durbin (D-IL): 102</td>
          <td>Rep. Katko (R-NY-24): 7</td>
          <td>Sen. Ossoff (D-GA): 4</td>
      </tr>
      <tr>
          <td>9</td>
          <td>Del. Norton (D-DC): 101</td>
          <td>Rep. Dean (D-PA-4): 6</td>
          <td>Sen. Murkowski (R-AK): 4</td>
      </tr>
      <tr>
          <td>10</td>
          <td>Sen. Menendez (D-NJ): 99</td>
          <td>Rep. Wagner (R-MO-2): 6</td>
          <td>Sen. Padilla (D-CA): 4</td>
      </tr>
  </tbody>
</table>
<p><strong>Effectiveness score (laws enacted / total bills):</strong></p>
<p>$$
\text{effectiveness} = \frac{\text{bills that became law}}{\text{total bills introduced}}
$$</p>
<table>
  <thead>
      <tr>
          <th>Ranking</th>
          <th>Individual: Effectiveness Score</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1</td>
          <td>Rep. Pelosi (D-CA-12): 0.500</td>
      </tr>
      <tr>
          <td>2</td>
          <td>Rep. Mrvan (D-IN-1): 0.444</td>
      </tr>
      <tr>
          <td>3</td>
          <td>Rep. Yarmuth (D-KY-3): 0.333</td>
      </tr>
      <tr>
          <td>4</td>
          <td>Rep. Stivers (R-OH-15): 0.250</td>
      </tr>
      <tr>
          <td>5</td>
          <td>Rep. Graves (R-MO-6): 0.222</td>
      </tr>
      <tr>
          <td>6</td>
          <td>Rep. Jeffries (D-NY-8): 0.200</td>
      </tr>
      <tr>
          <td>7</td>
          <td>Rep. Neal (D-MA-1): 0.200</td>
      </tr>
      <tr>
          <td>8</td>
          <td>Rep. Palazzo (R-MS-4): 0.200</td>
      </tr>
      <tr>
          <td>9</td>
          <td>Sen. Peters (D-MI): 0.186</td>
      </tr>
      <tr>
          <td>10</td>
          <td>Rep. Fischbach (R-MN-7): 0.176</td>
      </tr>
  </tbody>
</table>
<h3 id="policy-focus-areas">Policy Focus Areas</h3>
<p>Each bill is assigned a primary policy area. Here are the most active areas by legislative outcome:</p>
<table>
  <thead>
      <tr>
          <th>Ranking</th>
          <th>Policy Area: Introduced</th>
          <th>Policy Area: Stalled</th>
          <th>Policy Area: Law</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1</td>
          <td>Health: 1,885</td>
          <td>Government Operations: 79</td>
          <td>Government Operations: 94</td>
      </tr>
      <tr>
          <td>2</td>
          <td>Armed Forces: 1,114</td>
          <td>Armed Forces: 60</td>
          <td>Armed Forces: 69</td>
      </tr>
      <tr>
          <td>3</td>
          <td>Taxation: 1,066</td>
          <td>International Affairs: 60</td>
          <td>Crime &amp; Law Enforcement: 31</td>
      </tr>
      <tr>
          <td>4</td>
          <td>Government Operations: 982</td>
          <td>Health: 56</td>
          <td>Health: 19</td>
      </tr>
      <tr>
          <td>5</td>
          <td>International Affairs: 866</td>
          <td>Crime &amp; Law Enforcement: 44</td>
          <td>Native Americans: 17</td>
      </tr>
      <tr>
          <td>6</td>
          <td>Crime &amp; Law Enforcement: 842</td>
          <td>Public Lands: 44</td>
          <td>International Affairs: 14</td>
      </tr>
      <tr>
          <td>7</td>
          <td>Education: 663</td>
          <td>Science &amp; Technology: 44</td>
          <td>Economics &amp; Finance: 13</td>
      </tr>
      <tr>
          <td>8</td>
          <td>Transportation: 663</td>
          <td>Commerce: 43</td>
          <td>Public Lands: 13</td>
      </tr>
      <tr>
          <td>9</td>
          <td>Public Lands: 548</td>
          <td>Finance: 34</td>
          <td>Commerce: 13</td>
      </tr>
      <tr>
          <td>10</td>
          <td>Finance: 547</td>
          <td>Emergency Management: 27</td>
          <td>Emergency Management: 11</td>
      </tr>
  </tbody>
</table>
<p>Notable patterns: Health dominates introductions but has lower success rates, while government operations and armed forces bills are more likely to become law.</p>
<h2 id="next-steps">Next Steps</h2>
<p>This analysis establishes baseline patterns: most bills fail, party affiliation affects success rates, and certain policy areas perform better than others.</p>
<p>Future work could explore:</p>
<ul>
<li>Committee dynamics and voting patterns</li>
<li>Geographic analysis of state-level interests</li>
<li>Bill text analysis using NLP techniques</li>
<li>Predictive modeling for bill outcomes</li>
</ul>
<blockquote>
<p><strong>Update</strong>: I&rsquo;ve since applied machine learning to this type of data in <a href="/posts/congressional-bill-policy-area-classification/">Congressional Bill Policy Area Classification</a>, using 48K+ bills from three Congresses to automatically categorize bills by policy area.</p></blockquote>
<p>The complete dataset and code are publicly available to support further research into legislative transparency.</p>
]]></content:encoded></item><item><title>Congressional Knowledge Graph &amp; Policy Classification</title><link>https://hunterheidenreich.com/projects/congressional-data-analysis/</link><pubDate>Wed, 01 Mar 2023 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/projects/congressional-data-analysis/</guid><description>A 47,000+ bill knowledge graph from Congress.gov with co-sponsorship networks and TF-IDF baselines for 33-class policy-area classification.</description><content:encoded><![CDATA[<h2 id="overview">Overview</h2>
<p>A computational social science project that constructed a dataset of 47,000+ US congressional bills by extracting legislative text and metadata from the 115th-117th Congresses. The project creates a &ldquo;legislative graph&rdquo;
(linking sponsors, committees, and bill text) and establishes TF-IDF baseline models for policy area classification across 33 (highly imbalanced) policy classes, now hosted on Hugging Face to support reproducible political science research.</p>
<h2 id="features">Features</h2>
<h3 id="intelligent-data-acquisition">Intelligent Data Acquisition</h3>
<p>Standard APIs impose strict rate limits. I built a Selenium-based extraction engine to handle Congress.gov&rsquo;s complex DOM structures.</p>
<ul>
<li><strong>Optimization</strong>: Targeted aggregate endpoints (e.g., <code>/all-info</code>) to pull each bill&rsquo;s text and metadata in fewer requests.</li>
<li><strong>Resilience</strong>: Implemented a local caching layer to store raw HTML, separating the fetch step from the parse step. This made the parse step re-runnable without re-fetching, and minimized server load during iterative development.</li>
<li><strong>Graph construction</strong>: Beyond simple text, the script extracts relational data including co-sponsorship networks, committee assignments, and related bill lineage.</li>
</ul>
<h3 id="natural-language-processing">Natural Language Processing</h3>
<ul>
<li><strong>Corpus construction</strong>: Cleaned and normalized legislative text, removing procedural artifacts (e.g., &ldquo;A BILL TO&hellip;&rdquo;) to isolate semantic policy content.</li>
<li><strong>Feature engineering</strong>: Utilized TF-IDF vectorization with N-gram analysis to capture legislative jargon.</li>
<li><strong>Modeling</strong>: Benchmarked Naive Bayes, Logistic Regression, and gradient-boosted trees (XGBoost), reaching ~0.86 weighted F1 on bill summaries and up to ~0.89 on full text (cross-validated). Weighted F1, not raw accuracy, is the honest metric here: the 33 policy classes are severely imbalanced (Health has 5,911 bills; Social Sciences and History has 15).</li>
</ul>
<h2 id="usage">Usage</h2>
<p>The dataset is available on Hugging Face and can be loaded directly via the <code>datasets</code> library. The scraper can be run locally to fetch new bills.</p>
<h2 id="results">Results</h2>
<ul>
<li><strong>The &ldquo;partisan vocabulary&rdquo;</strong>: Feature importance analysis revealed distinct linguistic markers separating Democratic and Republican legislation, identifiable even without metadata.</li>
<li><strong>Temporal drift</strong>: Policy priorities and terminology showed measurable shifts across congressional sessions (115th vs 117th).</li>
<li><strong>Classification success</strong>: Simple linear models (Logistic Regression and Naive Bayes) proved effective at distinguishing policy domains, outperforming gradient-boosted trees on these sparse TF-IDF features and suggesting legislative language is highly structured.</li>
</ul>
<h2 id="impact--deliverables">Impact &amp; Deliverables</h2>
<ul>
<li><strong>Hugging Face dataset</strong>: Released a machine-readable, ML-ready dataset of modern bills (115th-117th Congresses) on Hugging Face for reproducible research.</li>
<li><strong>Open source tooling</strong>: Published the scraper and parsing logic to allow others to extend the dataset to future congresses.</li>
<li><strong>Academic benchmark</strong>: Establishing a clear baseline for &ldquo;Government NLP&rdquo; tasks, aiding in the automated transparency and monitoring of new legislation.</li>
</ul>
<h2 id="related-work">Related Work</h2>
<ul>
<li><a href="/posts/us-117th-congress-data-exploration/">117th Congress Data Exploration</a></li>
<li><a href="/posts/congressional-bill-policy-area-classification/">Congressional Bill Policy Area Classification</a></li>
</ul>
]]></content:encoded></item><item><title>PyConversations: Social Media Conversational Analysis</title><link>https://hunterheidenreich.com/projects/pyconversations-social-media-analysis/</link><pubDate>Tue, 01 Jun 2021 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/projects/pyconversations-social-media-analysis/</guid><description>Undergraduate thesis exploring representation learning for social media text and developing tools for cross-platform conversational analysis.</description><content:encoded><![CDATA[<h2 id="overview">Overview</h2>
<p>Undergraduate thesis exploring representation learning for social media text and developing tools for cross-platform conversational analysis. Built PyConversations, a Python module for analyzing social media conversations, and found that domain-specific approaches often outperform large pre-trained models.</p>
<h2 id="features">Features</h2>
<h3 id="pyconversations-module">PyConversations Module</h3>
<ul>
<li><strong>Graph-based modeling</strong>: Models conversations as Directed Acyclic Graphs (DAGs) to quantify topological structure (depth, width, density)</li>
<li><strong>Unified interface</strong>: Polymorphic design normalizing heterogeneous data from Twitter, Reddit, 4chan, and Facebook into a single analysis schema</li>
<li><strong>Linguistic dynamics</strong>: Implements information-theoretic feature extraction, including harmonic mixing laws and entropy measures</li>
<li><strong>Stream processing</strong>: Memory-efficient generators to ingest and traverse multi-gigabyte JSON dumps (e.g., 135M+ Reddit posts) without loading the full corpus into RAM</li>
</ul>
<h3 id="research-contributions">Research Contributions</h3>
<ul>
<li><strong>Representation learning</strong>: Investigated domain-specific vs. general-purpose Transformers (BERT vs. specialized variants) on social media text</li>
<li><strong>Topological analysis</strong>: Demonstrated that conversational structure (context) is as critical as content for classification tasks</li>
<li><strong>Cross-platform study</strong>: Comparative analysis of communication dynamics across moderated (Reddit/Twitter) and unmoderated (4chan) spaces</li>
</ul>
<h2 id="usage">Usage</h2>
<p>The PyConversations module can be imported into Python scripts to parse and analyze social media datasets.</p>
<h2 id="results">Results</h2>
<ul>
<li><strong>Model performance</strong>: Smaller, domain-specific approaches frequently outperformed standard pre-trained models</li>
<li><strong>Context importance</strong>: Conversational context and dialogue structure proved crucial for understanding social media interactions</li>
<li><strong>Domain adaptation</strong>: Social media text benefits from specialized handling over generic approaches</li>
<li><strong>Cross-platform challenges</strong>: Different platforms require adapted approaches despite seeming similarities</li>
</ul>
<h2 id="team--recognition">Team &amp; Recognition</h2>
<ul>
<li><strong>Hunter Heidenreich</strong> - Lead Researcher and Developer</li>
<li><strong>Jake Williams</strong> - Faculty Advisor</li>
<li><strong>First Place - Research Undergraduate Senior Thesis</strong> at Drexel University</li>
</ul>
<h2 id="impact">Impact</h2>
<p>This library served as the engineering backbone for my thesis, <a href="/research/look-dont-tweet/">Look, Don&rsquo;t Tweet</a>, enabling the processing of 308 million posts to evaluate Transformer performance on toxic data.</p>
<p>The findings about model performance suggested that specialized domains require tailored model architectures, a perspective that has become more relevant as the field continues to evolve.</p>
]]></content:encoded></item><item><title>5 Axes of Multi-Arm Bandit Problems: A Practical Guide</title><link>https://hunterheidenreich.com/posts/a-roadmap-to-multi-arm-bandit-algorithms/</link><pubDate>Tue, 10 Nov 2020 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/a-roadmap-to-multi-arm-bandit-algorithms/</guid><description>Explore 5 key dimensions of multi-arm bandit problems to help practitioners better navigate the exploration-exploitation tradeoff in ML applications.</description><content:encoded><![CDATA[<h2 id="what-is-a-multi-arm-bandit-problem">What is a Multi-Arm Bandit Problem?</h2>
<p>Multi-arm bandit problems are a fundamental class of sequential decision-making problems in machine learning. They&rsquo;re less complex than a full reinforcement learning problem, but they capture a lot of the essential challenges of learning from interaction.</p>
<p>The name comes from the analogy of a gambler facing multiple slot machines (or &ldquo;one-armed bandits&rdquo;) and trying to figure out which one pays out the most over time. Do you keep playing the machine that&rsquo;s paid out well so far, or try others to see if they&rsquo;re better? This is the exploration-exploitation dilemma at the heart of bandit algorithms.</p>
<p>Bandit algorithms solve this by learning the reward distributions for each arm over time, balancing the need to explore new options with exploiting what they&rsquo;ve learned.</p>















<figure class="post-figure center ">
    <img src="/img/multi-arm-bandits/multi-arm-bandit-conceptual-graphic.webp"
         alt="Illustration of a vintage slot machine with multiple arms, representing the multi-arm bandit problem in machine learning where algorithms must balance exploration and exploitation"
         title="Illustration of a vintage slot machine with multiple arms, representing the multi-arm bandit problem in machine learning where algorithms must balance exploration and exploitation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Multi-arm bandit algorithms balance exploration and exploitation</figcaption>
    
</figure>

<p>What I&rsquo;ve found helpful is thinking about any bandit problem along five key dimensions. Asking these five questions helps quickly identify which approaches work best for a given problem:</p>
<div style="display: flex; flex-direction: column; gap: 16px; max-width: 700px; margin: 30px auto; font-family: system-ui, -apple-system, sans-serif;">
  <div style="background-color: #5A9BD5; color: white; padding: 20px 24px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.05);">
    <h3 style="margin: 0 0 8px 0; font-size: 1.25rem; font-weight: 700; color: white;">1. Action Space</h3>
    <p style="margin: 0; font-size: 1.05rem; line-height: 1.5;"><em>What does the problem action space look like?</em> <br>Consider whether your options are finite vs. infinite, or single vs. combinatorial.</p>
  </div>
  <div style="background-color: #55C3BA; color: white; padding: 20px 24px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.05);">
    <h3 style="margin: 0 0 8px 0; font-size: 1.25rem; font-weight: 700; color: white;">2. Problem Structure</h3>
    <p style="margin: 0; font-size: 1.05rem; line-height: 1.5;"><em>Is there any structure to the problem?</em> <br>Determine whether choosing certain actions provides information about the expected rewards of other actions.</p>
  </div>
  <div style="background-color: #5CC382; color: white; padding: 20px 24px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.05);">
    <h3 style="margin: 0 0 8px 0; font-size: 1.25rem; font-weight: 700; color: white;">3. External Information</h3>
    <p style="margin: 0; font-size: 1.05rem; line-height: 1.5;"><em>Is there external information that my learner has access to?</em> <br>Assess the availability of contextual information (like user data or environment state) before an action is chosen.</p>
  </div>
  <div style="background-color: #56B14E; color: white; padding: 20px 24px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.05);">
    <h3 style="margin: 0 0 8px 0; font-size: 1.25rem; font-weight: 700; color: white;">4. Reward Mechanism</h3>
    <p style="margin: 0; font-size: 1.05rem; line-height: 1.5;"><em>How are rewards generated?</em> <br>Identify if the environment's payouts are stochastic (random but consistent), non-stationary (changing over time), or adversarial.</p>
  </div>
  <div style="background-color: #81B653; color: white; padding: 20px 24px; border-radius: 10px; box-shadow: 0 4px 6px rgba(0,0,0,0.05);">
    <h3 style="margin: 0 0 8px 0; font-size: 1.25rem; font-weight: 700; color: white;">5. Learner Feedback</h3>
    <p style="margin: 0; font-size: 1.05rem; line-height: 1.5;"><em>What kind of feedback does my learner receive each round?</em> <br>Clarify if the feedback loop is strict bandit (only seeing the reward for the chosen arm), full, partial, or semi-bandit feedback.</p>
  </div>
</div>
<p>Jump to <strong><a href="#real-examples">Real Examples</a></strong> to see these dimensions in practice.</p>
<hr>
<h2 id="1-what-can-you-do-action-space">1. What Can You Do? (Action Space)</h2>
<p>The first question is simple: what options do you have? Action spaces are generally categorized along two dimensions: size and complexity.</p>
<p><strong>Finite vs. Infinite (Size)</strong></p>
<ul>
<li><strong>Finite (Discretized):</strong> Most people think of this scenario where the agent chooses between a small number of discrete options. Examples include selecting an arm in a standard 3-arm bandit or testing 5 different website layouts.</li>
<li><strong>Infinite (Continuous):</strong> Sometimes your action has continuous parameters. For example, selecting a bid price anywhere in the range of (0, 1), or adjusting a recommendation algorithm&rsquo;s temperature parameter. This requires different mathematical approaches.</li>
</ul>
<p><strong>Single vs. Combinatorial (Complexity)</strong></p>
<ul>
<li><strong>Single Action:</strong> Selection of exactly 1 action per round (e.g., showing a user <em>one</em> specific ad).</li>
<li><strong>Combinatorial Actions:</strong> Selection of a vector (multiple) of actions simultaneously. For example, selecting a subset of edges in a graph to form a path from node $t$ to node $s$. If you are picking 10 movies to populate a Netflix homepage at once, those selections might interact with one another.</li>
</ul>
<h2 id="2-do-your-choices-tell-you-about-other-choices-problem-structure">2. Do Your Choices Tell You About Other Choices? (Problem Structure)</h2>
<p>This is often overlooked but crucial: does trying option A teach you anything about option B?</p>
<ul>
<li><strong>Independent (Unstructured):</strong> Information gained from one action provides <em>zero insight</em> into the expected reward of other actions. For example, knowing the open rate of an email campaign tells you nothing about the click-through rate of a separate social media post.</li>
<li><strong>Correlated (Structured):</strong> Information gained from one action provides <em>valuable hints</em> about similar actions. If you test a 10% discount and see high conversions, it strongly implies a 15% discount will also perform well, helping you map the underlying demand curve.</li>
</ul>
<h2 id="3-what-extra-information-do-you-have-context">3. What Extra Information Do You Have? (Context)</h2>
<p>Real-world problems rarely happen in isolation. The question is: what additional predictive information might help you make better decisions <em>before</em> you pull the lever?</p>
<p>When you incorporate these state variables, you move from a Standard Bandit to a <strong>Contextual Bandit</strong>. This data usually falls into two buckets:</p>
<ul>
<li><strong>User Context:</strong> State information tied directly to the individual, such as age, location, or past purchase history.</li>
<li><strong>Environmental Context:</strong> State information tied to the surrounding conditions at the exact moment of decision, such as time of day, seasonality, or device type (mobile vs. desktop).</li>
</ul>
<h2 id="4-how-do-rewards-work-reward-mechanism">4. How Do Rewards Work? (Reward Mechanism)</h2>
<p>This dimension is about understanding the nature of the feedback you&rsquo;re getting. Are the rewards predictable, changing over time, or actively working against you?</p>
<blockquote>
<p><strong>Stochastic (Stable but Random)</strong>
Each action corresponds to an IID (Independent and Identically Distributed) reward. The underlying mean rewards do not shift significantly over time.
<em>Example: Clinical trials. A drug has a true underlying effectiveness rate, but individual patients respond with random variation around that mean.</em></p></blockquote>
<blockquote>
<p><strong>Non-Stationary (Changing Over Time)</strong>
Reward distributions <em>do</em> shift over time, usually following some underlying rule. This is a realistic relaxation of the stochastic model, but comes at a learning cost.
<em>Example: Stock trading or ad performance. What worked last month might not work today.</em></p></blockquote>
<blockquote>
<p><strong>Adversarial (Actively Working Against You)</strong>
An adversary selects the worst-case rewards for your options, often with full knowledge of your learner&rsquo;s policy. Randomization is key to remaining unpredictable.
<em>Example: Cybersecurity defenses. Attackers actively adapt their strategies to exploit your algorithm.</em></p></blockquote>
<h2 id="5-how-much-do-you-learn-from-each-decision-feedback">5. How Much Do You Learn From Each Decision? (Feedback)</h2>
<p>The last dimension is about information flow. How much do you learn each time you make a choice?</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Feedback Type</th>
          <th style="text-align: left">What You Learn</th>
          <th style="text-align: left">Real-World Example</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Bandit</strong></td>
          <td style="text-align: left">You only observe the reward for the specific action selected. You have no knowledge of what could have been gained from other options.</td>
          <td style="text-align: left">A literal slot machine payout.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Semi-Bandit</strong></td>
          <td style="text-align: left">Common in combinatorial settings. You see the individual rewards associated with each <em>sub-action</em> you took.</td>
          <td style="text-align: left">Learning exactly which specific edges of a graph &ldquo;dropped&rdquo; or succeeded in a path-finding algorithm.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Full</strong></td>
          <td style="text-align: left">You see all reward signals for <em>every</em> action, including the ones you didn&rsquo;t take.</td>
          <td style="text-align: left">Analyzing historical stock market data where you can see all alternative outcomes.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Partial Monitoring</strong></td>
          <td style="text-align: left">Feedback is <em>not</em> received every round. You are occasionally flying blind.</td>
          <td style="text-align: left"><em>(Note: Because the core feedback loop is broken, this is generally not considered a true bandit problem!)</em></td>
      </tr>
  </tbody>
</table>
<h2 id="real-examples">Real Examples</h2>
<p>If you are trying to figure out which bandit algorithm to use for your project, it helps to map out your problem first. Here is how these five dimensions play out in three common real-world scenarios:</p>
<h3 id="1-e-commerce-recommendations">1. E-commerce Recommendations</h3>















<figure class="post-figure center ">
    <img src="/img/multi-arm-bandits/multi-arm-bandit-for-ecommerce.webp"
         alt="Illustration of a multi-arm bandit algorithm applied to e-commerce recommendations, showing a user interacting with a carousel of product recommendations on a website"
         title="Illustration of a multi-arm bandit algorithm applied to e-commerce recommendations, showing a user interacting with a carousel of product recommendations on a website"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The algorithm must populate an entire &lsquo;Recommended for You&rsquo; carousel with multiple items simultaneously.</figcaption>
    
</figure>

<ul>
<li><strong>Action Space:</strong> Combinatorial <em>(Pick multiple products to show at once)</em></li>
<li><strong>Problem Structure:</strong> Structured <em>(Similar products perform similarly)</em></li>
<li><strong>Context:</strong> Available <em>(User history, time of day, device type)</em></li>
<li><strong>Reward Mechanism:</strong> Stochastic <em>(Relatively stable underlying user preferences)</em></li>
<li><strong>Feedback:</strong> Bandit <em>(You only see clicks on the specific items you showed)</em></li>
</ul>
<h3 id="2-online-ad-bidding">2. Online Ad Bidding</h3>















<figure class="post-figure center ">
    <img src="/img/multi-arm-bandits/multi-arm-bandit-for-ad-bidding.webp"
         alt="Illustration of a multi-arm bandit algorithm applied to online ad bidding, showing a marketer adjusting bid prices in real-time auctions for ad placements"
         title="Illustration of a multi-arm bandit algorithm applied to online ad bidding, showing a marketer adjusting bid prices in real-time auctions for ad placements"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The algorithm must learn the optimal price to bid in real-time auctions.</figcaption>
    
</figure>

<ul>
<li><strong>Action Space:</strong> Infinite / Continuous <em>(Bid any amount from 0.01 to 10.00)</em></li>
<li><strong>Problem Structure:</strong> Structured <em>(Similar bid prices yield similar win rates)</em></li>
<li><strong>Context:</strong> Available <em>(User demographics, search terms, ad relevance)</em></li>
<li><strong>Reward Mechanism:</strong> Non-stationary <em>(Market conditions and competitor budgets change constantly)</em></li>
<li><strong>Feedback:</strong> Bandit <em>(You only see the results from your winning bids)</em></li>
</ul>
<h3 id="3-content-personalization">3. Content Personalization</h3>
<p><em>A media site dynamically selects articles or videos based on trending topics and user habits to create a personalized homepage.</em></p>
<ul>
<li><strong>Action Space:</strong> Combinatorial <em>(Select a layout of multiple articles/videos)</em></li>
<li><strong>Problem Structure:</strong> Structured <em>(Content categories and tags have predictable patterns)</em></li>
<li><strong>Context:</strong> Available <em>(User profile, current geographic trends, browsing history)</em></li>
<li><strong>Reward Mechanism:</strong> Non-stationary <em>(User interests and news cycles evolve over time)</em></li>
<li><strong>Feedback:</strong> Partial / Semi-Bandit <em>(You see the performance of the individual content pieces within the selected layout)</em></li>
</ul>
<h2 id="the-algorithm-selection-cheat-sheet">The Algorithm Selection Cheat Sheet</h2>
<p>If you prefer a text breakdown, here is how those steps translate into algorithm choices:</p>
<p><strong>Step 1: What are my options? (Action Space)</strong></p>
<ul>
<li><strong>Few discrete choices?</strong> $\rightarrow$ Start with standard <strong>UCB</strong> (Upper Confidence Bound) or <strong>Thompson Sampling</strong>.</li>
<li><strong>Continuous parameters?</strong> $\rightarrow$ Look into <strong>Gaussian Process</strong> methods.</li>
</ul>
<p><strong>Step 2: Do my choices relate to each other? (Structure)</strong></p>
<ul>
<li><strong>Yes?</strong> $\rightarrow$ Use <strong>Linear Bandits</strong> or kernel methods to share information across arms.</li>
<li><strong>No?</strong> $\rightarrow$ Treat each option completely independently.</li>
</ul>
<p><strong>Step 3: What extra information do I have? (Context)</strong></p>
<ul>
<li><strong>Rich context available?</strong> $\rightarrow$ Use contextual bandits (LinUCB is the standard choice here).</li>
<li><strong>No context?</strong> $\rightarrow$ Stick with standard, context-free approaches.</li>
</ul>
<p><strong>Step 4: How do rewards behave? (Mechanism)</strong></p>
<ul>
<li><strong>Stable?</strong> $\rightarrow$ <strong>UCB</strong> and <strong>Thompson Sampling</strong> are your go-to choices.</li>
<li><strong>Changing over time?</strong> $\rightarrow$ Add forgetting factors (discounting) or use sliding-window variants of standard algorithms.</li>
<li><strong>Actively working against me?</strong> $\rightarrow$ You need adversarial approaches, most notably the <strong>EXP3</strong> algorithm.</li>
</ul>
<p><strong>Step 5: How much do I learn each time? (Feedback)</strong></p>
<ul>
<li><strong>Just my choice?</strong> $\rightarrow$ Standard bandit algorithms.</li>
<li><strong>Everything?</strong> $\rightarrow$ Online gradient descent or multiplicative weights.</li>
<li><strong>Something in between?</strong> $\rightarrow$ Look for specialized algorithms that exploit your specific combinatorial structure.</li>
</ul>
<p>This framework keeps the focus exactly where it needs to be: on what actually matters for the problem at hand.</p>
<h2 id="summary">Summary</h2>
<p>These five questions have helped me navigate bandit problems more systematically:</p>
<ol>
<li><strong>What can you do?</strong> (Action space: few vs many options, single vs multiple choices)</li>
<li><strong>Do choices relate?</strong> (Whether trying one option teaches you about others)</li>
<li><strong>What extra info do you have?</strong> (Context that might improve decisions)</li>
<li><strong>How do rewards work?</strong> (Stable, changing, or adversarial)</li>
<li><strong>How much do you learn?</strong> (Feedback from just your choice vs everything)</li>
</ol>
<p>This framework helps avoid getting overwhelmed by the academic literature and focuses attention on what matters for real problems. Structured problems let you learn faster by sharing information between options, while adversarial settings require completely different approaches.</p>
<h2 id="if-you-want-to-learn-more">If You Want to Learn More</h2>
<ul>
<li><strong><a href="https://tor-lattimore.com/downloads/book/book.pdf">Bandit Algorithms by Lattimore &amp; Szepesvári</a></strong> - The comprehensive textbook (free PDF)</li>
<li><strong><a href="https://arxiv.org/abs/1904.07272">Introduction to Multi-Armed Bandits</a></strong> - Aleksandrs Slivkins&rsquo; survey paper is more accessible</li>
</ul>
<h2 id="get-your-hands-dirty">Get Your Hands Dirty</h2>
<p>Want to see these dimensions in code? I&rsquo;ve implemented the foundational algorithms (Explore-Then-Commit and Follow-The-Leader) in Python. Check out the <strong><a href="https://github.com/hunter-heidenreich/Bandit-Algorithms">Bandit Algorithms Repository</a></strong> to run the simulations yourself.</p>
]]></content:encoded></item><item><title>A Guide to Neuroevolution: NEAT and HyperNEAT</title><link>https://hunterheidenreich.com/posts/neuroevolution-neat-and-hyperneat/</link><pubDate>Wed, 02 Jan 2019 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/neuroevolution-neat-and-hyperneat/</guid><description>Explore the evolution of neural network topologies with NEAT and how HyperNEAT scales this approach using geometric patterns and indirect encoding.</description><content:encoded><![CDATA[<h2 id="automating-neural-architecture-design">Automating Neural Architecture Design</h2>
<p>Designing neural network architectures is typically a manual, iterative process. Researchers experiment with different layer configurations, activation functions, and connection patterns, often guided by intuition and empirical results. Evolution offers an automated alternative to this design process.</p>
<p><a href="https://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf">NEAT (NeuroEvolution of Augmenting Topologies)</a>, introduced in 2002, optimizes network weights and evolves the network structure itself, starting from minimal topologies and growing complexity only when beneficial.</p>
<p>NEAT&rsquo;s core innovations solved fundamental problems that had plagued earlier attempts at topology evolution. Its solutions for genetic encoding, structural crossover, and innovation protection remain influential today, especially as neural architecture search and automated ML gain prominence.</p>
<h2 id="the-core-challenges-of-neat">The Core Challenges of NEAT</h2>
<p>Evolving neural network topologies presents several fundamental challenges that NEAT elegantly addressed. Understanding these problems helps explain why NEAT&rsquo;s solutions were so influential.</p>
<h3 id="genetic-encoding-how-to-represent-networks">Genetic Encoding: How to Represent Networks</h3>
<p>Evolutionary algorithms require a genetic representation, a way to encode individuals that enables meaningful selection, mutation, and crossover. For neural networks, this choice is critical.</p>
<p><strong>Direct encoding</strong> explicitly represents each network component. Genes directly correspond to nodes and connections. This approach is intuitive and readable, and it works well for smaller networks.</p>
<p><strong>Indirect encoding</strong> specifies construction rules or processes. These encodings are more compact and can generate highly complex structures from simple rules.</p>
<p>NEAT chose direct encoding with a simple two-part structure: separate gene lists for nodes and connections. This balances simplicity with the flexibility needed for evolutionary operations.</p>















<figure class="post-figure center ">
    <img src="/img/neat_genomes.webp"
         alt="NEAT genome encoding showing node genes and connection genes with innovation numbers"
         title="NEAT genome encoding showing node genes and connection genes with innovation numbers"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">NEAT&rsquo;s direct encoding: node genes (top) and connection genes (bottom) with historical markings</figcaption>
    
</figure>

<p>Connection genes specify the source and target nodes, weight, enabled status, and an innovation number for historical tracking. Input and output nodes are fixed; only hidden nodes evolve.</p>
<h3 id="structural-mutations-growing-complexity">Structural Mutations: Growing Complexity</h3>
<p>NEAT employs two categories of mutations to evolve both weights and structure:</p>
<p><strong>Weight mutations</strong> adjust existing connection strengths using standard perturbation methods, the familiar approach from traditional neuroevolution.</p>
<p><strong>Structural mutations</strong> add new network components:</p>
<ul>
<li><strong>Add connection</strong>: Creates a new link between existing nodes with a random initial weight</li>
<li><strong>Add node</strong>: Splits an existing connection by inserting a new node. The original connection is disabled, while two new connections replace it. One inherits the original weight, the other starts at 1.0</li>
</ul>
<p>This node-splitting approach minimizes disruption. The new node initially acts as an identity function, giving it time to prove useful before natural selection pressure intensifies.</p>
<h3 id="solving-the-competing-conventions-problem">Solving the Competing Conventions Problem</h3>
<p>Performing crossover between networks with different structures presents a fundamental challenge. Consider two networks that solve the same problem using different internal organizations. Naive crossover between them typically produces broken offspring.</p>















<figure class="post-figure center ">
    <img src="/img/competing_conventions.webp"
         alt="Two neural networks performing the same function but with different internal structures"
         title="Two neural networks performing the same function but with different internal structures"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The competing conventions problem: identical functions, different implementations</figcaption>
    
</figure>

<p>NEAT&rsquo;s solution draws inspiration from biology through <strong>historical markings</strong>. Each structural innovation (adding a node or connection) receives a unique innovation number, a timestamp of when that change first appeared in the population.</p>
<p>During crossover, genes with matching innovation numbers are aligned and combined. This biological concept of homology enables meaningful recombination between networks of different sizes and structures.</p>















<figure class="post-figure center ">
    <img src="/img/neat_crossover.webp"
         alt="Diagram showing how NEAT aligns genes during crossover using innovation numbers"
         title="Diagram showing how NEAT aligns genes during crossover using innovation numbers"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">NEAT crossover using historical markings for gene alignment</figcaption>
    
</figure>

<h3 id="protecting-innovation-through-speciation">Protecting Innovation Through Speciation</h3>
<p>New structural innovations face a harsh reality: they usually perform worse initially. Adding nodes or connections typically decreases performance before optimization can improve the new structure. Without protection, these innovations disappear before realizing their potential.</p>
<p>NEAT addresses this through <strong>speciation</strong>: dividing the population into species based on structural and weight similarity. The historical markings that enable crossover also measure compatibility between individuals.</p>
<p>Crucially, individuals only compete within their species. This gives new structural innovations time to optimize without immediately competing against established, well-tuned networks.</p>
<p><strong>Explicit fitness sharing</strong> enhances this protection: species divide their collective fitness among members, preventing any single species from dominating the population while maintaining diversity for continued exploration.</p>
<h3 id="complexification-starting-minimal">Complexification: Starting Minimal</h3>
<p>NEAT begins with the simplest possible networks (just input and output nodes connected by random weights). No hidden layers exist initially. Complexity emerges only when mutations that add structure prove beneficial.</p>
<p>This complexification approach builds efficient solutions that solve problems with minimal structure. Combined with speciation, it tends to produce highly optimized architectures.</p>
<h2 id="scaling-up-hyperneat">Scaling Up: HyperNEAT</h2>
<p>NEAT evolved networks through direct encoding, where each gene explicitly specifies nodes and connections. Scaling this approach to larger architectures requires a fundamentally different method. Evolving networks with billions of connections like the brain requires indirect encoding.</p>
<p><a href="https://doi.org/10.1162/artl.2009.15.2.15202">HyperNEAT</a> introduces <strong>indirect encoding</strong> through geometric principles. HyperNEAT evolves geometric patterns that generate connections based on spatial relationships. This enables the evolution of large networks with biological regularities: symmetry, repetition, and locality.</p>
<p>The key insight is leveraging Compositional Pattern Producing Networks (CPPNs) to map coordinates to connection weights, exploiting the geometric organization found in natural neural networks.</p>
<h3 id="biological-motivation">Biological Motivation</h3>
<p>The human brain exhibits remarkable organizational principles:</p>
<ul>
<li><strong>Scale</strong>: ~86 billion neurons with ~100 trillion connections</li>
<li><strong>Repetition</strong>: Structural patterns reused across regions</li>
<li><strong>Symmetry</strong>: Mirrored structures like bilateral visual processing</li>
<li><strong>Locality</strong>: Spatial proximity influences connectivity and function</li>
</ul>
<p>HyperNEAT aims to evolve networks that capture these biological regularities, leading to more efficient and interpretable architectures.</p>
<h3 id="compositional-pattern-producing-networks">Compositional Pattern Producing Networks</h3>
<p>CPPNs are the foundation of HyperNEAT&rsquo;s indirect encoding. Think of them as pattern generators that create complex spatial structures from simple coordinate inputs.</p>
<p>DNA exemplifies indirect encoding (roughly 20,000 protein-coding genes specify a brain with trillions of connections). This massive compression ratio suggests that simple rules can generate complex structures through developmental processes.</p>
<p>CPPNs abstract this concept, using compositions of mathematical functions to create patterns in coordinate space. The same genetic program (function composition) can be reused across different locations and scales, just like how developmental genes control pattern formation throughout an organism.</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_cppns.webp"
         alt="Various symmetric and repetitive patterns created by CPPNs"
         title="Various symmetric and repetitive patterns created by CPPNs"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Complex patterns generated by CPPNs through function composition</figcaption>
    
</figure>

<h3 id="pattern-generation-through-function-composition">Pattern Generation Through Function Composition</h3>
<p>CPPNs generate patterns by composing simple mathematical functions. Key function types include:</p>
<ul>
<li><strong>Gaussian functions</strong>: Create symmetric patterns and gradients</li>
<li><strong>Trigonometric functions</strong>: Generate periodic/repetitive structures</li>
<li><strong>Linear functions</strong>: Produce gradients and asymmetric patterns</li>
<li><strong>Sigmoid functions</strong>: Create sharp transitions and boundaries</li>
</ul>
<p>By combining these functions, CPPNs can encode complex regularities that would require many explicit rules in direct encoding.</p>
<h3 id="evolution-of-cppns">Evolution of CPPNs</h3>
<p>HyperNEAT uses NEAT to evolve the CPPN structure. This brings several advantages:</p>
<ul>
<li><strong>Complexification</strong>: CPPNs start simple and grow more complex only when beneficial</li>
<li><strong>Historical markings</strong>: Enable proper crossover between different CPPN topologies</li>
<li><strong>Speciation</strong>: Protects innovative CPPN patterns during evolution</li>
</ul>
<p>Additional activation functions beyond standard neural networks are crucial:</p>
<ul>
<li>Gaussian functions for symmetry</li>
<li>Sine/cosine for repetition</li>
<li>Specialized functions for specific geometric patterns</li>
</ul>
<h2 id="the-hyperneat-process">The HyperNEAT Process</h2>
<h3 id="substrates-geometric-organization">Substrates: Geometric Organization</h3>
<p>A <strong>substrate</strong> defines the spatial arrangement of neurons. Substrates embed neurons in geometric space (2D grids, 3D volumes, etc.), providing an alternative to layer-based connectivity rules.</p>
<p>The CPPN maps from coordinates to connection weights:</p>
<p>$$\text{CPPN}(x_1, y_1, x_2, y_2) = w$$</p>
<p>Where $(x_1, y_1)$ and $(x_2, y_2)$ are the coordinates of two neurons, and $w$ determines their connection weight.</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_cppn_basics.webp"
         alt="Diagram showing CPPN taking four coordinate inputs and outputting connection weight"
         title="Diagram showing CPPN taking four coordinate inputs and outputting connection weight"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Basic CPPN architecture mapping coordinates to connection weights</figcaption>
    
</figure>

<p>This geometric approach enables several key properties:</p>
<ul>
<li><strong>Locality</strong>: Nearby neurons tend to have similar connectivity patterns</li>
<li><strong>Symmetry</strong>: Patterns can be mirrored across spatial axes</li>
<li><strong>Repetition</strong>: Periodic functions create repeating motifs</li>
<li><strong>Scalability</strong>: The same pattern can be applied at different resolutions</li>
</ul>
<h3 id="emergent-regularities">Emergent Regularities</h3>
<p>The geometric encoding naturally produces the desired biological patterns:</p>
<p><strong>Symmetry</strong> emerges from symmetric functions. A Gaussian centered at the origin creates identical patterns when $(x_1, y_1)$ and $(x_2, y_2)$ are equidistant from the center.</p>
<p><strong>Repetition</strong> arises from periodic functions like sine and cosine. These create repeating connectivity motifs across the substrate.</p>
<p><strong>Locality</strong> results from functions that vary smoothly across space. Nearby coordinates produce similar outputs, leading to local connectivity patterns.</p>
<p><strong>Imperfect regularity</strong> occurs when these patterns are modulated by additional coordinate dependencies, creating biological-like variation within the basic structure.</p>
<h3 id="substrate-configurations">Substrate Configurations</h3>
<p>The choice of substrate geometry critically influences network behavior. Several standard configurations exist:</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_substrate_configurations.webp"
         alt="Various substrate layouts including grids, 3D arrangements, and circular patterns"
         title="Various substrate layouts including grids, 3D arrangements, and circular patterns"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Common substrate geometries for different problem types</figcaption>
    
</figure>

<p><strong>2D Grid</strong>: Simple planar arrangement, CPPN takes four coordinates $(x_1, y_1, x_2, y_2)$</p>
<p><strong>3D Volume</strong>: Extends to three dimensions, CPPN becomes six-dimensional $(x_1, y_1, z_1, x_2, y_2, z_2)$</p>
<p><strong>Sandwich</strong>: Input layer connects only to output layer, useful for sensory-motor tasks</p>
<p><strong>Circular</strong>: Radial geometry enables rotation-invariant patterns and cyclic behaviors</p>
<p>The substrate must be chosen before evolution begins, making domain knowledge important for success.</p>
<h3 id="exploiting-input-output-geometry">Exploiting Input-Output Geometry</h3>
<p>HyperNEAT exploits the spatial organization of inputs and outputs. For visual tasks, pixel coordinates provide meaningful geometric information. For control problems, sensor and actuator layouts can guide connectivity patterns.</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_inputs_outputs.webp"
         alt="Visual representation of how HyperNEAT maps spatial input arrangements to output patterns"
         title="Visual representation of how HyperNEAT maps spatial input arrangements to output patterns"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Spatial organization of inputs and outputs enables geometric exploitation</figcaption>
    
</figure>

<p>This spatial awareness allows HyperNEAT to:</p>
<ul>
<li>Develop receptive fields similar to biological vision systems</li>
<li>Create locally connected patterns for spatial processing</li>
<li>Generate symmetric motor control patterns</li>
<li>Scale across different input resolutions</li>
</ul>
<h3 id="resolution-independence">Resolution Independence</h3>
<p>A unique advantage of HyperNEAT is <strong>substrate resolution independence</strong>. Networks evolved on low-resolution substrates can be deployed on higher-resolution versions without retraining. The CPPN&rsquo;s coordinate-based mapping scales naturally across different granularities.</p>
<p>This property suggests that evolved patterns capture fundamental spatial relationships, providing a key insight for scalable neural architecture design.</p>
<h2 id="impact-and-future-directions">Impact and Future Directions</h2>
<p>NEAT and HyperNEAT demonstrated that evolution could design neural network topologies and scale them through indirect encoding. The algorithms&rsquo; key insights, exploiting geometry, generating patterns through function composition, and scaling across resolutions, continue to influence modern research.</p>
<p>Extensions like ES-HyperNEAT add even more sophisticated capabilities by evolving the substrate itself. As neural architecture search becomes increasingly important, these principles find new applications in hybrid approaches that combine evolutionary pattern generation with gradient-based optimization.</p>
<p>The emphasis on spatial organization and regularity also connects to contemporary work on geometric deep learning and equivariant networks, suggesting that evolution and hand-design converge on similar organizing principles for building structured, efficient neural architectures.</p>
]]></content:encoded></item><item><title>Breaking Down Machine Learning for the Average Person</title><link>https://hunterheidenreich.com/posts/breaking-down-ml-for-the-average-person/</link><pubDate>Tue, 04 Dec 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/breaking-down-ml-for-the-average-person/</guid><description>Discover how machine learning actually works through three fundamental approaches, explained with everyday examples you already know and use.</description><content:encoded><![CDATA[<h2 id="machine-learning">Machine Learning</h2>
<p>Machine learning is about teaching computer programs to improve at tasks through experience. We show algorithms examples and let them discover patterns in data.</p>
<p>There are three main approaches to machine learning: supervised learning, unsupervised learning, and reinforcement learning. Each works differently and suits different types of problems.</p>















<figure class="post-figure center ">
    <img src="/img/types-of-machine-learning.webp"
         alt="Diagram showing the three main types of machine learning: supervised, unsupervised, and reinforcement learning"
         title="Diagram showing the three main types of machine learning: supervised, unsupervised, and reinforcement learning"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Three fundamental approaches to machine learning, each suited to different types of problems and data</figcaption>
    
</figure>

<p>Each type addresses different kinds of problems and works with different data requirements.</p>
<h3 id="supervised-learning">Supervised Learning</h3>
<p>Supervised learning works like teaching with examples and answers. You show the algorithm many input-output pairs, and it learns to predict outputs for new inputs it hasn&rsquo;t seen before.</p>
<p>The algorithm learns by comparing its predictions to the correct answers. Over time, it gets better at finding patterns that connect inputs to outputs. Once trained, it can make predictions on new data.</p>
<p>Common examples you encounter:</p>
<ul>
<li><strong>Email Spam Filtering</strong>: Email systems learn to identify spam by training on thousands of emails labeled as spam or legitimate.</li>
<li><strong>Advertisement Targeting</strong>: Algorithms predict which ads you might click based on your browsing history and demographics.</li>
<li><strong>Face Recognition</strong>: Social media platforms use tagged photos to learn who appears in new images.</li>
</ul>
<h3 id="unsupervised-learning">Unsupervised Learning</h3>
<p>Unsupervised learning works without correct answers. Instead, algorithms analyze data to find patterns, group similar items, or discover structure that wasn&rsquo;t obvious before.</p>
<p>This approach is useful because most real-world data doesn&rsquo;t come with labels. Unsupervised algorithms can process large amounts of data to find patterns that might not be obvious to humans.</p>
<p>Examples include:</p>
<ul>
<li><strong>Recommendation Systems</strong>: Netflix and YouTube analyze viewing patterns to suggest content, even without explicit ratings.</li>
<li><strong>Customer Segmentation</strong>: Companies group customers by purchasing behavior for targeted marketing.</li>
<li><strong>Problem Identification</strong>: Tech companies automatically group similar bug reports to identify common issues.</li>
</ul>
<h3 id="reinforcement-learning">Reinforcement Learning</h3>
<p>Reinforcement learning works through trial and error. An algorithm tries different actions in an environment and learns from the consequences, getting rewards for good choices and penalties for poor ones.</p>
<p>This mirrors how many animals learn: through consequences. Good behavior gets rewards, bad behavior gets correction.</p>
<p>Consider an algorithm learning to play Mario:</p>
<ul>
<li><strong>Agent</strong>: The learning algorithm</li>
<li><strong>Environment</strong>: The game world</li>
<li><strong>Actions</strong>: Controller inputs (jump, run, etc.)</li>
<li><strong>State</strong>: Current game screen</li>
<li><strong>Reward</strong>: Points gained or lost</li>
</ul>
<p>The algorithm tries different button combinations, sees what happens, and gradually learns strategies that lead to higher scores.</p>
<p>Real applications include:</p>
<ul>
<li><strong>Game AI</strong>: AlphaGo and similar systems learned to play complex games through self-play.</li>
<li><strong>Robotics</strong>: Factory robots learn assembly processes through trial and error in simulated environments.</li>
<li><strong>Resource Management</strong>: Google uses reinforcement learning to manage data center cooling, reducing energy costs.</li>
</ul>
<h3 id="putting-it-together">Putting It Together</h3>
<p>In practice, these approaches often work together. Many real systems combine different learning methods depending on the problem and available data.</p>
<p>For example:</p>
<ul>
<li>A game AI might use supervised learning to recognize objects and reinforcement learning for strategy</li>
<li>A language model might learn word relationships without supervision, then improve with supervised training</li>
<li>A recommendation system could group users without labels, then use supervised learning to predict preferences</li>
</ul>
<p>These three approaches cover most machine learning applications. Understanding them helps explain how the AI systems we use daily actually work, whether it&rsquo;s email filters, recommendation engines, or game-playing algorithms.</p>
<p>Machine learning is pattern recognition and learning from data, not magic. The more we understand these basics, the better we can work with and build these systems.</p>
]]></content:encoded></item><item><title>Foundations of AI: Knowledge-Based Agents and Logic</title><link>https://hunterheidenreich.com/posts/knowledge-based-agents-and-logic/</link><pubDate>Sat, 01 Dec 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/knowledge-based-agents-and-logic/</guid><description>A look back at classic symbolic AI: how knowledge-based agents use logic, reasoning, and inference to build intelligent behavior.</description><content:encoded><![CDATA[<p><em>Note from 2026: I originally wrote these notes in 2018 while studying for my undergraduate AI final. They cover classic symbolic AI (often called Good Old-Fashioned AI, or GOFAI). Looking back from the current era of Large Language Models (LLMs) and Vision-Language Models (VLMs), it is fascinating to see how these foundational concepts have evolved. The classic &ldquo;knowledge base&rdquo; and &ldquo;Ask/Tell&rdquo; operations conceptually mirror modern Retrieval-Augmented Generation (RAG) systems, and the classic problem of logical &ldquo;grounding&rdquo; is exactly what we tackle today with multimodal visual grounding in models like GutenOCR. I have preserved and merged these notes here as a time capsule of my learning journey.</em></p>
<hr>
<h2 id="the-evolution-of-knowledge-based-systems">The Evolution of Knowledge-Based Systems</h2>
<p>Early AI research focused heavily on knowledge bases and agents that could interact with them, leading to <strong>expert systems</strong>. These systems relied on central knowledge bases to make decisions through &ldquo;if-then&rdquo; reasoning patterns. While some consider expert systems the first major AI breakthrough, others debate whether they truly belong in the AI category.</p>
<p>Knowledge-based agents remain relevant in modern AI. If you&rsquo;ve worked in natural language processing, you&rsquo;ve likely encountered knowledge bases like WordNet. Wikipedia represents another massive knowledge base, encoding semantic relationships between countless entities.</p>
<p>This abundance of knowledge bases raises important questions: How do we build agents that effectively interface with these repositories? How can they update knowledge bases and make decisions based on stored information?</p>
<p>This article explores these questions through a practical introduction to knowledge-based agents, knowledge representation, and logic.</p>
<h2 id="anatomy-of-a-knowledge-based-agent">Anatomy of a Knowledge-Based Agent</h2>
<p>An <strong>agent</strong> is any entity that acts within an environment. In AI, we build <strong>rational agents</strong>. Entities that act sensibly within their environment. In well-understood environments, rational agents choose actions that yield desired outcomes. In uncertain environments, they act to maximize expected positive outcomes.</p>
<p>Consider agents that maintain internal knowledge, reason over that knowledge, and update their understanding through observations and actions. This is the foundation of <strong>knowledge-based agents</strong>.</p>
<h3 id="knowledge-bases-the-foundation">Knowledge Bases: The Foundation</h3>
<p>The core component of any knowledge-based agent is its <strong>knowledge base</strong> (KB): the repository of what the agent knows about the world.</p>
<p>Every KB consists of <strong>sentences</strong>. These are statements written in a <strong>knowledge representation language</strong>. These specialized languages express assertions about the world in formats that enable systematic reasoning.</p>
<p>Knowledge representation languages are designed for systematic representation of world knowledge. This is why we use these specialized languages for agent knowledge bases.</p>
<h4 id="types-of-knowledge-in-a-kb">Types of Knowledge in a KB</h4>
<p><strong>Axioms</strong> are foundational sentences assumed to be true without derivation from other KB content. If you&rsquo;re familiar with mathematics, this concept will feel natural.</p>
<p><strong>Inferred sentences</strong> are derived from existing KB content through logical reasoning. These are systematically derived using reasoning rules that may be built into the knowledge representation language itself.</p>
<h3 id="core-agent-operations">Core Agent Operations</h3>
<p>Knowledge-based agents interact with their KBs through two fundamental operations: <strong>Ask</strong> and <strong>Tell</strong>.</p>
<h4 id="ask-querying-knowledge">Ask: Querying Knowledge</h4>
<p><strong>Ask</strong> is how agents extract information from their KB. When an agent asks a &ldquo;question&rdquo; or queries its KB, it must format the request in the KB&rsquo;s expected format (typically in the knowledge representation language).</p>
<p>The KB responds with sentences that are either:</p>
<ul>
<li>Directly stored in the KB, or</li>
<li>Inferred from existing KB information</li>
</ul>
<p>This guarantees that responses never contradict the KB&rsquo;s knowledge.</p>
<h4 id="tell-adding-knowledge">Tell: Adding Knowledge</h4>
<p><strong>Tell</strong> updates the KB with new information. Agents use this when they:</p>
<ul>
<li>Observe environmental changes</li>
<li>Update the KB with planned or completed actions</li>
<li>Add newly learned facts</li>
</ul>
<p>Like Ask operations, new information must be properly formatted. The KB stores the new sentence and may trigger reasoning and inference processes to derive additional conclusions.</p>
<h3 id="a-generic-knowledge-based-agent-architecture">A Generic Knowledge-Based Agent Architecture</h3>
<p>Here&rsquo;s a high-level view of how a knowledge-based agent operates:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">kb_agent</span>(percept):
</span></span><span style="display:flex;"><span>    Tell(KB, make_percept_sentence(percept, t))
</span></span><span style="display:flex;"><span>    action <span style="color:#f92672">=</span> Ask(KB, make_action_query(t))
</span></span><span style="display:flex;"><span>    Tell(KB, make_action_sentence(action, t))
</span></span><span style="display:flex;"><span>    t <span style="color:#f92672">=</span> t <span style="color:#f92672">+</span> <span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> action
</span></span></code></pre></div><p>This function illustrates the agent&rsquo;s reasoning cycle:</p>
<ol>
<li><strong>Perceive</strong>: Convert environmental observations into KB-compatible sentences via <code>make_percept_sentence()</code></li>
<li><strong>Reason</strong>: Query the KB for the appropriate action using <code>make_action_query()</code></li>
<li><strong>Act</strong>: Record the chosen action in the KB through <code>make_action_sentence()</code></li>
<li><strong>Update</strong>: Increment the time step and return the action</li>
</ol>
<p>The helper functions handle the crucial task of translating between the external world and the knowledge representation language. This architecture focuses purely on the reasoning process, the &ldquo;brains&rdquo; of the operation, while abstracting away perception and action execution details.</p>
<h2 id="design-perspectives-for-knowledge-based-agents">Design Perspectives for Knowledge-Based Agents</h2>
<p>Knowledge-based agents can be analyzed and designed from three distinct levels, each addressing different aspects of the system.</p>
<h3 id="knowledge-level-the-strategic-view">Knowledge Level: The Strategic View</h3>
<p>The <strong>knowledge level</strong> represents the highest level of analysis, focusing on:</p>
<ul>
<li><strong>Goals</strong>: What objectives does the agent pursue?</li>
<li><strong>Knowledge scope</strong>: How much does the agent know about its world initially?</li>
</ul>
<p>This level helps us understand the agent&rsquo;s capabilities and limitations from a strategic perspective, independent of implementation details.</p>
<h3 id="logical-level-the-representation-view">Logical Level: The Representation View</h3>
<p>The <strong>logical level</strong> examines how knowledge is represented and reasoned about:</p>
<ul>
<li><strong>Knowledge representation language</strong>: Which language best suits our domain?</li>
<li><strong>Logical framework</strong>: Are we using propositional logic, first-order logic, or something else?</li>
</ul>
<p>Each choice carries trade-offs. Some languages excel at expressing certain types of knowledge but struggle with others. The logical framework determines what kinds of reasoning the agent can perform.</p>
<h3 id="implementation-level-the-technical-view">Implementation Level: The Technical View</h3>
<p>The <strong>implementation level</strong> addresses the concrete technical decisions:</p>
<ul>
<li><strong>Data structures</strong>: How is knowledge stored? (structs, databases, objects, vectors?)</li>
<li><strong>Algorithms</strong>: Which inference procedures are used?</li>
<li><strong>Performance</strong>: How do design choices affect speed and memory usage?</li>
</ul>
<p>These decisions significantly impact the agent&rsquo;s practical performance and scalability.</p>
<h2 id="learning-and-knowledge-acquisition">Learning and Knowledge Acquisition</h2>
<h3 id="declarative-vs-procedural-approaches">Declarative vs. Procedural Approaches</h3>
<p>Knowledge-based agents can be constructed through two primary approaches:</p>
<p><strong>Declarative approach</strong>: Initialize the agent with an empty KB, then systematically Tell it all the knowledge it needs. This explicit knowledge encoding offers transparency and modularity.</p>
<p><strong>Procedural approach</strong>: Write programs that encode knowledge directly into the agent&rsquo;s behavior. This approach can be more efficient but less transparent.</p>
<p>Real-world systems typically benefit from combining both approaches, leveraging the strengths of each method.</p>
<h3 id="incorporating-learning">Incorporating Learning</h3>
<p>Learning enhances knowledge-based agents in several ways:</p>
<p><strong>Perceptual learning</strong>: Agents can learn to combine observations in novel ways, creating new sentences that improve goal achievement. These learned patterns become part of the KB for future reasoning.</p>
<p><strong>Inference optimization</strong>: Learning algorithms can identify efficient reasoning paths within existing KBs, speeding up the inference process.</p>
<p><strong>Knowledge base expansion</strong>: Research continues into generating new connections in existing KBs like WordNet. As semantic webs proliferate online, efficient methods for KB expansion become increasingly valuable.</p>
<p>The integration of learning with knowledge-based reasoning represents an active area of AI research, with promising applications in knowledge graph completion, automated reasoning, and intelligent system adaptation.</p>
<h2 id="the-components-of-logic">The Components of Logic</h2>
<p>Every logical system has three components that create a framework for representing and reasoning about knowledge.</p>
<h3 id="syntax-the-rules-of-formation">Syntax: The Rules of Formation</h3>
<p>Syntax defines how sentences are constructed, the rules that determine if a sentence is well-formed.</p>
<p>Consider English: &ldquo;name Hunter mine is&rdquo; violates syntax rules, while &ldquo;my name is Hunter&rdquo; follows them. In mathematics, &ldquo;1+=2 3&rdquo; breaks syntax, but &ldquo;1+2=3&rdquo; follows it.</p>
<p>Knowledge bases need proper syntax because they&rsquo;re built from collections of sentences.</p>
<h3 id="semantics-the-meaning-behind-sentences">Semantics: The Meaning Behind Sentences</h3>
<p>While syntax governs structure, semantics determines meaning, whether a sentence is true or false in a given context.</p>
<p>The sentence &ldquo;my name is Hunter&rdquo; is true in our context, but &ldquo;my name is Paul&rdquo; would be false. In math, &ldquo;x=5&rdquo; is true only when x actually equals 5.</p>
<p>We use <strong>model</strong> to describe a specific world state where variables have particular values. Different models let us evaluate sentence truth. One model might have x=4, another x=5.</p>
<p>When a sentence is true under a model, the model <strong>satisfies</strong> the sentence. If model₅ has x=5, then model₅ satisfies <code>x=5</code>.</p>
<p>For any sentence A, $M(A)$ represents all models that satisfy A.</p>
<h3 id="entailment-logical-relationships">Entailment: Logical Relationships</h3>
<p><strong>Entailment</strong> enables reasoning. When sentence A entails sentence B (written $A \models B$), B must be true whenever A is true.</p>
<p>More precisely, A entails B if B is true in every model where A is true:</p>
<ul>
<li>A is the <strong>premise</strong></li>
<li>B is the <strong>consequent</strong></li>
<li>B is a necessary consequence of A</li>
</ul>
<p><strong>Example:</strong> <code>x=1</code> entails <code>xy=y</code>. If x equals 1, then xy will always equal y, regardless of y&rsquo;s value.</p>
<p>This helps knowledge bases reason. When we&rsquo;re uncertain about <code>xy=y</code> but add <code>x=1</code> to our KB, entailment lets us assert that <code>xy=y</code> is now true.</p>
<p>To check if a KB entails sentence A, we verify that $M(KB) \subseteq M(A)$, every model satisfying our KB also satisfies A.</p>
<p>When we add sentences to our KB, we need systematic ways to discover newly entailed sentences. <strong>Inference algorithms</strong> do this work.</p>
<h2 id="inference-algorithms-deriving-new-knowledge">Inference Algorithms: Deriving New Knowledge</h2>
<p>Inference algorithms systematically derive new sentences entailed by our knowledge base. These algorithms need two properties to be useful and trustworthy.</p>
<h3 id="soundness-truth-preservation">Soundness: Truth Preservation</h3>
<p>An inference algorithm is <strong>sound</strong> if it only derives sentences actually entailed by the KB. If an algorithm adds unjustified sentences, it is fabricating information, undermining the KB&rsquo;s reliability.</p>
<h3 id="completeness-finding-everything">Completeness: Finding Everything</h3>
<p>A <strong>complete</strong> inference algorithm finds all sentences entailed by the KB, missing no valid conclusions. This ensures we don&rsquo;t overlook important logical consequences.</p>
<p>Achieving both soundness and completeness grows challenging as environments become more complex. Small, bounded models make this manageable. Unbounded environments require sophisticated algorithms that guarantee both properties while maintaining reasonable performance.</p>
<h3 id="grounding-connecting-logic-to-reality">Grounding: Connecting Logic to Reality</h3>
<p>Beyond formal algorithm properties lies a practical concern: <strong>grounding</strong>. Is our knowledge base actually grounded in reality? Does it accurately represent what&rsquo;s true?</p>
<p>Grounding depends on how knowledge enters the system:</p>
<ul>
<li><strong>Sensors</strong>: Agent perceptions are only as accurate as their sensors</li>
<li><strong>Learning algorithms</strong>: Learned sentences are only as reliable as the learning process</li>
</ul>
<p>With accurate sensors and reliable learning, we can trust our KB&rsquo;s grounding. This remains important when deploying knowledge-based systems in real applications.</p>
<h2 id="conclusion">Conclusion</h2>
<p>Logic provides the foundation for knowledge representation in AI systems. Understanding syntax (forming valid sentences), semantics (sentence meaning), entailment (how conclusions follow from premises), and inference algorithms (deriving new knowledge) enables building knowledge-based agents.</p>
<p>The interplay between these components, governed by soundness, completeness, and grounding, determines how effectively our AI systems represent and reason about the world.</p>
]]></content:encoded></item><item><title>QuAC: Question Answering in Context Dataset</title><link>https://hunterheidenreich.com/posts/quac-question-answering-in-context/</link><pubDate>Wed, 31 Oct 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/quac-question-answering-in-context/</guid><description>Analysis of QuAC's conversational QA through student-teacher interactions, featuring ~100K context-dependent questions and coreference challenges.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>The <a href="https://aclanthology.org/D18-1241/">QuAC dataset</a> (Question Answering in Context) presents a conversational question answering approach that models student-teacher interactions. Published at EMNLP 2018, this work by Choi et al. addresses how systems can understand dialogue context, resolve references across conversation turns, and handle natural conversation ambiguity. Previous datasets treated questions independently.</p>
<p>The dataset addresses limitations in question answering research by incorporating real-world information-seeking dialogue complexities, where questions build upon previous exchanges and context drives understanding.</p>
<p>For comparison with related work, see my analysis of <a href="/posts/coqa-conversation-question-answering/">CoQA</a>.</p>
<h2 id="the-student-teacher-framework">The Student-Teacher Framework</h2>
<p>QuAC models information-seeking dialogue through a student-teacher setup:</p>
<ul>
<li><strong>Teacher</strong>: Has complete access to information (Wikipedia passage)</li>
<li><strong>Student</strong>: Seeks knowledge through questioning with limited initial context</li>
<li><strong>Interaction</strong>: Handles context-dependent questions, abstract inquiries, and unanswerable requests</li>
</ul>
<p>This framework mirrors real-world scenarios where one party has expertise while another seeks to learn through dialogue. AI systems must act as effective teachers, using available information to provide helpful responses despite ambiguous or incomplete questions.</p>
<p>The dataset contains roughly 100K questions across ~14K dialogues (precisely 98,407 questions and 13,594 dialogues), providing substantial scale for training and evaluation.</p>















<figure class="post-figure center ">
    <img src="/img/quac_stats.webp"
         alt="QuAC dataset statistics and scale"
         title="QuAC dataset statistics and scale"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">QuAC dataset statistics and scale</figcaption>
    
</figure>

<h2 id="dataset-construction">Dataset Construction</h2>
<p>QuAC was built using Amazon Mechanical Turk with a two-person dialogue setup:</p>
<p><strong>Teacher role</strong>: Has access to the complete Wikipedia passage and provides answers extracted directly from the text</p>
<p><strong>Student role</strong>: Sees only the article title, introduction paragraph, and section heading, then asks questions to learn about the content</p>
<p>This asymmetric information design ensures student questions naturally differ from the passage content, creating realistic information-seeking scenarios. The extractive answer requirement maintains objective evaluation while simplifying scoring.</p>
<p><strong>Dialogue termination</strong>:</p>
<ul>
<li>12 questions answered</li>
<li>Manual termination by either participant</li>
<li>Two consecutive unanswerable questions</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/quac_convo.webp"
         alt="Example QuAC conversation showing student-teacher interaction"
         title="Example QuAC conversation showing student-teacher interaction"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Example QuAC conversation showing student-teacher interaction</figcaption>
    
</figure>

<h3 id="content-selection">Content Selection</h3>
<p>QuAC focuses on Wikipedia biographical articles for several practical reasons:</p>
<ul>
<li><strong>Reduced complexity</strong>: People-focused content requires less specialized domain knowledge</li>
<li><strong>Natural question flow</strong>: Biographical information lends itself to sequential questioning</li>
<li><strong>Quality control</strong>: Articles filtered to include only subjects with 100+ incoming links, ensuring content depth</li>
</ul>
<p>This focused scope enables consistent evaluation while maintaining broad coverage through diverse biographical subjects across fields and time periods.</p>
<h2 id="key-dataset-characteristics">Key Dataset Characteristics</h2>
<p>QuAC introduces several features that distinguish it from existing question answering benchmarks:</p>















<figure class="post-figure center ">
    <img src="/img/quac_comparison.webp"
         alt="Comparative analysis of QuAC against other QA datasets"
         title="Comparative analysis of QuAC against other QA datasets"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Comparative analysis of QuAC against other QA datasets</figcaption>
    
</figure>

<p><strong>Notable features</strong>:</p>
<ul>
<li><strong>High contextual dependency</strong>: a large majority of questions depend on the conversation context, and a substantial share require coreference resolution</li>
<li><strong>Non-factoid focus</strong>: 54% of questions go beyond simple fact retrieval</li>
<li><strong>Extended answers</strong>: Responses are longer and more detailed</li>
<li><strong>Unanswerable questions</strong>: Realistic scenarios where information isn&rsquo;t available</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/quac_dist.webp"
         alt="Distribution of question types in QuAC"
         title="Distribution of question types in QuAC"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Distribution of question types in QuAC</figcaption>
    
</figure>

<h3 id="the-coreference-resolution-challenge">The Coreference Resolution Challenge</h3>
<p>QuAC&rsquo;s complexity stems from its heavy reliance on coreference resolution across multiple contexts:</p>
<p><strong>Reference types</strong>:</p>
<ul>
<li><strong>Passage references</strong>: Pronouns and references to entities in the source text</li>
<li><strong>Dialogue references</strong>: References to previously discussed topics</li>
<li><strong>Abstract references</strong>: Challenging cases like &ldquo;what else?&rdquo; that require inferring the inquiry scope</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/quac_coref.webp"
         alt="Types and distribution of coreferences in QuAC"
         title="Types and distribution of coreferences in QuAC"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Types and distribution of coreferences in QuAC</figcaption>
    
</figure>

<p>The prevalence of coreference resolution makes QuAC particularly challenging, as this remains an active research problem in NLP. Models must understand passage content, track dialogue history, and resolve complex referential expressions simultaneously.</p>
<h2 id="performance-results">Performance Results</h2>
<p>Models face substantial challenges on QuAC, with significant gaps between human and machine performance:</p>















<figure class="post-figure center ">
    <img src="/img/quac_performance.webp"
         alt="Baseline model performance comparison on QuAC"
         title="Baseline model performance comparison on QuAC"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Baseline model performance comparison on QuAC</figcaption>
    
</figure>

<p><strong>Performance summary</strong>:</p>
<ul>
<li><strong>Human performance</strong>: 81.1% F1 score</li>
<li><strong>Best baseline</strong>: BiDAF++ with context achieves 60.2% F1</li>
<li><strong>Performance gap</strong>: 20+ point difference shows room for improvement</li>
</ul>
<h3 id="human-equivalence-metrics">Human Equivalence Metrics</h3>
<p>QuAC introduces evaluation metrics beyond traditional F1 scores:</p>
<p><strong>HEQ-Q (Human Equivalence Question-level)</strong>: Percentage of questions where the model achieves human-level or better performance</p>
<p><strong>HEQ-D (Human Equivalence Dialogue-level)</strong>: Percentage of complete dialogues where the model matches human performance across all questions</p>
<p><strong>Current results</strong>:</p>
<ul>
<li>Human baseline: 100% HEQ-Q, 100% HEQ-D (by definition)</li>
<li>Best model: 55.1% HEQ-Q, 5.2% HEQ-D</li>
</ul>
<p>These metrics show both average performance and consistency across questions and conversations, important for practical dialogue systems.</p>
<h2 id="research-impact">Research Impact</h2>
<p>QuAC represents an important step in question answering research by introducing realistic conversational dynamics that existing datasets lack. The student-teacher framework captures natural information-seeking behavior while maintaining extractive evaluation for objective assessment.</p>
<p><strong>Key contributions</strong>:</p>
<ul>
<li><strong>Conversational realism</strong>: Context-dependent questions that mirror dialogue patterns</li>
<li><strong>Coreference complexity</strong>: Integration of challenging NLP problems into QA evaluation</li>
<li><strong>Evaluation metrics</strong>: HEQ scores that measure consistency alongside average performance</li>
<li><strong>Large-scale framework</strong>: Substantial dataset enabling robust model training and evaluation</li>
</ul>
<p>The dataset&rsquo;s <a href="https://quac.ai/">leaderboard</a> provides researchers with a challenging benchmark for developing conversational AI systems. As models improve on QuAC, we can expect progress in dialogue agents, virtual assistants, and educational AI systems that engage in more natural, context-aware conversations.</p>
<p>QuAC&rsquo;s focus on dialogue context and reference resolution pushes the field toward AI systems that can engage in genuine conversation and understand complex dialogue flows.</p>
<h2 id="a-builders-perspective-quac-and-modern-instruction-tuning">A Builder&rsquo;s Perspective: QuAC and Modern Instruction Tuning</h2>
<p>Looking at QuAC through the lens of modern production ML, the student-teacher framework maps directly onto how we now train and evaluate assistants. Today, we train foundation models using Reinforcement Learning from Human Feedback (RLHF) and instruction tuning, which rely heavily on multi-turn, context-aware interactions.</p>
<p>When building a system like GutenOCR, users rarely ask perfectly formulated, context-free questions. They ask follow-ups, use pronouns, and expect the system to act as a knowledgeable &ldquo;teacher&rdquo; guiding them through the document. QuAC was an early dataset to formalize this asymmetric information dynamic. It highlighted the necessity of handling unanswerable questions gracefully, a critical feature for preventing hallucinations in today&rsquo;s production LLMs.</p>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{choi-etal-2018-quac,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">&#34;{Q}u{AC}: Question Answering in Context&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">&#34;Choi, Eunsol and He, He and Iyyer, Mohit and Yatskar, Mark and Yih, Wen-tau and Choi, Yejin and Liang, Percy and Zettlemoyer, Luke&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">&#34;Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">month</span> = oct # <span style="color:#e6db74">&#34;-&#34;</span> # nov,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">&#34;2018&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">&#34;Brussels, Belgium&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">&#34;Association for Computational Linguistics&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">&#34;https://aclanthology.org/D18-1241/&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">&#34;10.18653/v1/D18-1241&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">&#34;2174--2184&#34;</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>CoQA Dataset: Advancing Conversational Question Answering</title><link>https://hunterheidenreich.com/posts/coqa-conversation-question-answering/</link><pubDate>Thu, 23 Aug 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/coqa-conversation-question-answering/</guid><description>Analysis of CoQA, a conversational QA dataset with multi-turn dialogue, coreference resolution, and natural answers for QA research.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>The <a href="https://doi.org/10.1162/tacl_a_00266">CoQA dataset</a> (Reddy et al., 2019) introduces conversational dynamics to question answering research. CoQA requires models to maintain context across multi-turn conversations while reading and reasoning about text passages. Previous datasets focused on isolated question-answer pairs.</p>
<p>This dataset addresses a gap in conversational AI research by providing a benchmark for systems that must understand dialogue flow and implicit references. These are key components of natural human conversation.</p>
<p>For related work on conversational question answering, see my analysis of <a href="/posts/quac-question-answering-in-context/">QuAC</a>.</p>
<h2 id="what-makes-conversational-qa-different">What Makes Conversational QA Different</h2>
<p>Conversational question answering introduces challenges beyond traditional reading comprehension:</p>
<ol>
<li><strong>Context dependency</strong>: Questions rely on previous dialogue turns for meaning</li>
<li><strong>Coreference resolution</strong>: Understanding pronouns and implicit references</li>
<li><strong>Abstractive answering</strong>: Rephrasing information to generate natural responses</li>
<li><strong>Multi-turn reasoning</strong>: Maintaining coherent dialogue across multiple exchanges</li>
</ol>
<p>These requirements differentiate CoQA from existing question answering datasets that treat each question independently.</p>
<h2 id="why-coqa-matters">Why CoQA Matters</h2>
<p>Question answering systems typically excel at finding specific information in text. However, they often struggle with natural conversation. Human communication involves building on previous exchanges, using pronouns and implicit references, and expressing ideas in varied ways.</p>
<p>CoQA addresses this by creating a large-scale dataset for conversational question answering with three primary characteristics:</p>
<ol>
<li>
<p><strong>Conversation-dependent questions</strong>: After the first question, every subsequent question depends on dialogue history across 127,000 questions spanning 8,000 conversations</p>
</li>
<li>
<p><strong>Natural, abstractive answers</strong>: CoQA requires rephrased responses that sound natural in conversation. The answerer first highlighted the relevant text span, then rephrased the information.</p>
</li>
<li>
<p><strong>Domain diversity</strong>: Training covers 5 domains with testing on 7 domains, including 2 unseen during training</p>
</li>
</ol>
<p>The performance gap is notable: humans achieve 88.8% F1 score while the best models at the time reached 65.1% F1, indicating substantial room for improvement.</p>
<h2 id="dataset-construction">Dataset Construction</h2>
<p>CoQA was constructed using Amazon Mechanical Turk, pairing workers in a question-answer dialogue setup. One worker asked questions about a given passage while another provided answers. The answerer first highlighted the relevant text span, then rephrased the information using different words to create natural, abstractive responses.</p>
<p>This methodology produces answers that sound conversational. This makes the dataset highly realistic for dialogue applications.</p>
<h3 id="domain-coverage">Domain Coverage</h3>
<p>CoQA spans diverse text types to ensure evaluation across different writing styles and topics:</p>
<p><strong>Training domains (5):</strong></p>
<ul>
<li>Children&rsquo;s stories from <a href="https://web.archive.org/web/20180829214346/https://uclmr.github.io/ai4exams/data.html#mctest">MCTest</a></li>
<li>Literature from <a href="https://www.gutenberg.org/">Project Gutenberg</a></li>
<li>Educational content from <a href="https://www.cs.cmu.edu/~glai1/data/race/">RACE</a> (middle/high school English)</li>
<li>CNN news articles</li>
<li>Wikipedia articles</li>
</ul>
<p><strong>Test-only domains (2):</strong></p>
<ul>
<li>Science articles from <a href="http://data.allenai.org/ai2-science-questions/">AI2 Science Questions</a></li>
<li>Creative writing from <a href="https://www.reddit.com/r/WritingPrompts/">Reddit WritingPrompts</a></li>
</ul>















<figure class="post-figure center ">
    <img src="/img/coqa_domains.webp"
         alt="Domain distribution in the CoQA dataset"
         title="Domain distribution in the CoQA dataset"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Domain distribution in the CoQA dataset</figcaption>
    
</figure>

<p>The inclusion of test-only domains provides a rigorous evaluation of model generalization to unseen text types.</p>
<h2 id="comparison-with-existing-datasets">Comparison with Existing Datasets</h2>
<p>Prior to CoQA, the dominant question answering benchmark was <a href="https://rajpurkar.github.io/SQuAD-explorer/">SQuAD (Stanford Question Answering Dataset)</a>. SQuAD established foundations for reading comprehension and presented specific constraints:</p>
<ul>
<li><strong>SQuAD 1.0</strong>: 100,000+ questions requiring exact text extraction from Wikipedia passages</li>
<li><strong>SQuAD 2.0</strong>: Added 50,000+ unanswerable questions to test when no answer exists</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/squad_coqa_size.webp"
         alt="Scale comparison between SQuAD and CoQA datasets"
         title="Scale comparison between SQuAD and CoQA datasets"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Scale comparison between SQuAD and CoQA datasets</figcaption>
    
</figure>

<p>SQuAD treats each question independently and requires only extractive answers. CoQA addresses these constraints through conversational context and abstractive responses.</p>
<h3 id="question-and-answer-analysis">Question and Answer Analysis</h3>
<p>The differences between SQuAD and CoQA extend beyond conversational context:</p>
<p><strong>Question diversity</strong>: SQuAD heavily favors &ldquo;what&rdquo; questions (~50%). CoQA shows a more balanced distribution across question types, reflecting natural conversation patterns.</p>















<figure class="post-figure center ">
    <img src="/img/squad_v_coqa.webp"
         alt="Question type distribution comparison between SQuAD and CoQA"
         title="Question type distribution comparison between SQuAD and CoQA"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Question type distribution comparison between SQuAD and CoQA</figcaption>
    
</figure>

<p><strong>Context dependence</strong>: CoQA includes challenging single-word questions like &ldquo;who?&rdquo;, &ldquo;where?&rdquo;, or &ldquo;why?&rdquo; that depend entirely on dialogue history.</p>
<p><strong>Answer characteristics</strong>: CoQA answers vary significantly in length and style. SQuAD primarily features extractive spans.</p>















<figure class="post-figure center ">
    <img src="/img/squad_coqa_answers.webp"
         alt="Answer length distribution in SQuAD vs CoQA"
         title="Answer length distribution in SQuAD vs CoQA"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Answer length distribution in SQuAD vs CoQA</figcaption>
    
</figure>

<h2 id="the-coreference-challenge">The Coreference Challenge</h2>
<p>CoQA&rsquo;s difficulty stems largely from its reliance on coreference resolution (determining when different expressions refer to the same entity). This remains a challenging research problem in NLP.</p>
<p><strong>Coreference types in CoQA</strong>:</p>
<ul>
<li><strong>Explicit coreferences</strong> (~50% of questions): Clear indicators like pronouns (&ldquo;him,&rdquo; &ldquo;it,&rdquo; &ldquo;her,&rdquo; &ldquo;that&rdquo;)</li>
<li><strong>Implicit coreferences</strong> (~20% of questions): Context-dependent references requiring inference (e.g., asking &ldquo;where?&rdquo; without specifying what)</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/coqa_coreferences.webp"
         alt="Distribution of coreference types in CoQA questions"
         title="Distribution of coreference types in CoQA questions"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Distribution of coreference types in CoQA questions</figcaption>
    
</figure>

<p>These linguistic phenomena make CoQA more difficult than traditional reading comprehension, as models must resolve references across dialogue turns while maintaining conversational coherence.</p>
<h2 id="performance-benchmarks">Performance Benchmarks</h2>
<p>Models faced significant challenges on CoQA, with substantial room for improvement:</p>















<figure class="post-figure center ">
    <img src="/img/coqa_scores.webp"
         alt="Performance comparison on CoQA across different model types"
         title="Performance comparison on CoQA across different model types"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Performance comparison on CoQA across different model types</figcaption>
    
</figure>

<p>The performance gap between human and machine capabilities highlighted conversational question answering as a challenging frontier in NLP research.</p>
<h2 id="research-impact-and-future-directions">Research Impact and Future Directions</h2>
<p>CoQA represents a step toward more natural conversational AI systems. By requiring models to handle dialogue context, coreference resolution, and abstractive reasoning simultaneously, it challenges current NLP system capabilities.</p>
<p>The dataset&rsquo;s <a href="https://stanfordnlp.github.io/coqa/">leaderboard</a> provides a benchmark for measuring progress on this task. As models improve on CoQA, we can expect advances in conversational AI applications, from chatbots to virtual assistants that engage in more natural, context-aware dialogue.</p>
<p>CoQA&rsquo;s contribution to the field aims to parallel ImageNet&rsquo;s impact on computer vision, providing a challenging, well-constructed benchmark that drives research toward more capable AI systems.</p>
<h2 id="a-builders-perspective-coqa-in-the-era-of-llms">A Builder&rsquo;s Perspective: CoQA in the Era of LLMs</h2>
<p>Looking back at CoQA from the perspective of modern production systems, the dataset anticipated where the field went. The challenges it introduced, such as multi-turn reasoning, coreference resolution, and abstractive answering, are the exact capabilities we now expect from instruction-tuned Large Language Models (LLMs).</p>
<p>Production document-processing pipelines rarely extract isolated facts. Users want to chat with their documents, asking follow-up questions like, &ldquo;What does that mean for the Q3 budget?&rdquo; Resolving &ldquo;that&rdquo; to a previous turn&rsquo;s context is exactly the problem CoQA formalized. Datasets like CoQA shifted the field&rsquo;s focus from simple extraction toward dialogue comprehension, the foundation modern conversational document interfaces are built on.</p>
<h2 id="references">References</h2>
<p>Reddy, S., Chen, D., &amp; Manning, C. D. (2019). CoQA: A conversational question answering challenge. <em>Transactions of the Association for Computational Linguistics</em>, 7, 249-266.</p>
]]></content:encoded></item><item><title>Understanding GANs: From Fundamentals to Objective Functions</title><link>https://hunterheidenreich.com/posts/what-is-a-gan/</link><pubDate>Sat, 18 Aug 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/what-is-a-gan/</guid><description>A complete guide to Generative Adversarial Networks (GANs), covering intuitive explanations, mathematical foundations, and objective functions.</description><content:encoded><![CDATA[<h2 id="understanding-generative-models">Understanding Generative Models</h2>
<p>Modern generative AI is dominated by diffusion models and autoregressive transformers. The adversarial training dynamics and objective functions introduced by <a href="https://arxiv.org/abs/1406.2661">Generative Adversarial Networks</a> (GANs) still inform how the field thinks about loss-function design and training stability today. Before diving into GANs, let&rsquo;s establish what we&rsquo;re trying to accomplish with generative models.</p>
<p><strong>The core goal</strong>: Create a system that can generate new, realistic data that appears to come from the same distribution as our training data.</p>
<p>Think of having a model that can create images, text, or audio that are difficult to distinguish from human-created content. This is what generative modeling aims to achieve.</p>
<h3 id="the-mathematical-foundation">The Mathematical Foundation</h3>
<p>Generative models aim to estimate the probability distribution of real data. If we have parameters $\theta$, we want to find the optimal $\theta^*$ that maximizes the likelihood of observing our real samples:</p>
<p>$$
\theta^* = \arg\max_\theta \prod_{i=1}^{n} p_\theta(x_i)
$$</p>
<p>This is equivalent to minimizing the distance between our estimated distribution and the true data distribution. A common distance measure is the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">Kullback-Leibler Divergence</a>. Maximizing log-likelihood equals minimizing KL divergence.</p>
<h3 id="two-approaches-to-generative-modeling">Two Approaches to Generative Modeling</h3>
<h4 id="explicit-distribution-models">Explicit Distribution Models</h4>
<p>These models define an explicit probability distribution and refine it through training.</p>
<p><strong>Example</strong>: <a href="https://arxiv.org/abs/1606.05908">Variational Auto-Encoders</a> (VAEs) require:</p>
<ul>
<li>An explicitly assumed prior distribution</li>
<li>A likelihood distribution</li>
<li>A &ldquo;variational approximation&rdquo; to evaluate performance</li>
</ul>
<h4 id="implicit-distribution-models">Implicit Distribution Models</h4>
<p>These models learn to generate data by indirectly sampling from a learned distribution. GANs exemplify this implicit approach, learning distributions through adversarial competition.</p>















<figure class="post-figure center ">
    <img src="/img/gen_ai_types.webp"
         alt="Types of deep generative models showing taxonomy"
         title="Types of deep generative models showing taxonomy"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>Taxonomy of Deep Generative Models</strong>: GANs fall into the implicit density category, learning distributions through adversarial training. <em>Source: NeurIPS 2016 tutorial on Generative Adversarial Networks</em></figcaption>
    
</figure>

<h2 id="the-gan-architecture-a-game-of-deception">The GAN Architecture: A Game of Deception</h2>
<p>Generative Adversarial Networks get their name from three key components:</p>
<ul>
<li><strong>Generative</strong>: They create new data</li>
<li><strong>Adversarial</strong>: Two networks compete against each other</li>
<li><strong>Networks</strong>: Built using neural networks</li>
</ul>
<p>The core innovation is the adversarial setup: two neural networks compete against each other, driving mutual improvement.</p>















<figure class="post-figure center ">
    <img src="/img/GAN-70.webp"
         alt="Diagram showing data flow through a GAN architecture"
         title="Diagram showing data flow through a GAN architecture"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>GAN Data Flow</strong>: The generator creates fake samples from random noise, while the discriminator tries to distinguish real from fake data. This adversarial competition drives both networks to improve.</figcaption>
    
</figure>

<h3 id="the-generator-the-forger">The Generator: The Forger</h3>
<p><strong>Role</strong>: Create convincing fake data from random noise</p>
<p>The generator network $G$ learns a mapping function:
$$z \rightarrow G(z) \approx x_{\text{real}}$$</p>
<p>Where:</p>
<ul>
<li>$z$ is a random latent vector (the &ldquo;noise&rdquo;)</li>
<li>$G(z)$ is the generated sample</li>
<li>The goal is making $G(z)$ indistinguishable from real data</li>
</ul>
<p><strong>Key insight</strong>: The latent space $z$ is continuous, meaning small changes in $z$ produce smooth, meaningful changes in the generated output.</p>
<h3 id="the-discriminator-the-detective">The Discriminator: The Detective</h3>
<p><strong>Role</strong>: Distinguish between real and generated samples</p>
<p>The discriminator network $D$ outputs a probability:
$$D(x) = P(\text{x is real})$$</p>
<ul>
<li>$D(x) \approx 1$ for real samples</li>
<li>$D(x) \approx 0$ for fake samples</li>
</ul>
<p>It functions as an &ldquo;authenticity detector&rdquo; that progressively improves.</p>
<h3 id="the-adversarial-competition">The Adversarial Competition</h3>
<p>This adversarial dynamic drives the training process. The generator and discriminator have <strong>directly opposing objectives</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Generator Goal</th>
          <th>Discriminator Goal</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Fool the discriminator</td>
          <td>Correctly classify all samples</td>
      </tr>
      <tr>
          <td>Minimize $D(G(z))$</td>
          <td>Maximize $D(x_{\text{real}})$ and minimize $D(G(z))$</td>
      </tr>
      <tr>
          <td>&ldquo;Create convincing fakes&rdquo;</td>
          <td>&ldquo;Never be fooled&rdquo;</td>
      </tr>
  </tbody>
</table>
<p>This creates a dynamic where both networks continuously improve:</p>
<ul>
<li>Generator creates better fakes to fool the discriminator</li>
<li>Discriminator becomes better at detecting fakes</li>
<li>The cycle continues until equilibrium</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/GAN-SUMMARY-50.webp"
         alt="Illustration of GAN training process showing adversarial competition"
         title="Illustration of GAN training process showing adversarial competition"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>The Adversarial Training Process</strong>: Through competition, both networks improve. The generator learns to create increasingly realistic samples while the discriminator becomes more discerning.</figcaption>
    
</figure>

<h2 id="learning-through-metaphors">Learning Through Metaphors</h2>
<p>Relatable analogies often clarify complex concepts. Here are two metaphors that capture different aspects of how GANs work.</p>
<h3 id="the-art-forger-vs-critic">The Art Forger vs. Critic</h3>
<p><strong>Generator = Art Forger</strong><br>
<strong>Discriminator = Art Critic</strong></p>
<p>A criminal forger tries to create fake masterpieces, while an art critic must identify authentic works. Each interaction teaches both parties:</p>
<ul>
<li>The forger learns what makes art look authentic</li>
<li>The critic develops a keener eye for detecting fakes</li>
<li>Eventually, the forger becomes so skilled that even experts can&rsquo;t tell the difference</li>
</ul>
<p><em>This captures the adversarial nature and continuous improvement aspect of GANs.</em></p>
<h3 id="the-counterfeiter-vs-bank-teller">The Counterfeiter vs. Bank Teller</h3>
<p><strong>Generator = Counterfeiter</strong><br>
<strong>Discriminator = Bank Teller</strong></p>
<p>Day 1: Criminal brings a crayon drawing of a dollar bill. Even a new teller spots this fake.</p>
<p>Day 100: The counterfeiter has learned better techniques. The teller has developed expertise in security features.</p>
<p>Day 1000: The fake money is so convincing that detecting it requires advanced equipment.</p>
<p><em>This illustrates the progressive improvement and escalating sophistication in both networks.</em></p>
<h2 id="the-mathematical-foundation-1">The Mathematical Foundation</h2>
<p>Now let&rsquo;s examine the mathematical framework that makes GANs work. The core of GAN training is solving a <strong>minimax optimization problem</strong>.</p>
<h3 id="the-minimax-objective">The Minimax Objective</h3>
<p>$$
\min_{G} \max_{D} V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]
$$</p>
<p><strong>Breaking this down:</strong></p>
<ul>
<li>$\mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)]$: The expected log-probability for real data.
<ul>
<li><strong>Discriminator&rsquo;s Goal</strong>: Maximize this term to correctly classify real samples.</li>
</ul>
</li>
<li>$\mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]$: The expected log-probability for fake data being correctly identified as fake.
<ul>
<li><strong>Discriminator&rsquo;s Goal</strong>: Maximize this term.</li>
<li><strong>Generator&rsquo;s Goal</strong>: Minimize this term to fool the discriminator.</li>
</ul>
</li>
</ul>
<h3 id="why-minimax">Why &ldquo;Minimax&rdquo;?</h3>
<ul>
<li><strong>Discriminator ($D$)</strong>: Tries to <strong>maximize</strong> the objective → Better at distinguishing real from fake.</li>
<li><strong>Generator ($G$)</strong>: Tries to <strong>minimize</strong> the objective → Better at fooling the discriminator.</li>
</ul>
<h3 id="a-practical-challenge-vanishing-gradients">A Practical Challenge: Vanishing Gradients</h3>
<p>The minimax objective presents a practical problem early in training. When the generator is poor, the discriminator can easily distinguish real from fake samples with high confidence ($D(G(z)) \approx 0$). This causes $\log(1 - D(G(z)))$ to saturate and results in vanishing gradients for the generator, which effectively stalls learning.</p>
<p><strong>The Solution</strong>: Practitioners typically train the generator to <strong>maximize</strong> $\log(D(G(z)))$ to provide stronger gradients early in training. This non-saturating heuristic prevents the learning process from stalling.</p>
<h3 id="the-training-process">The Training Process</h3>
<p>The beauty of GANs lies in their alternating optimization:</p>
<ol>
<li><strong>Fix $G$, train $D$</strong>: Make the discriminator optimal for the current generator</li>
<li><strong>Fix $D$, train $G$</strong>: Improve the generator against the current discriminator</li>
<li><strong>Repeat</strong>: Continue until reaching Nash equilibrium</li>
</ol>
<h3 id="theoretical-goal-nash-equilibrium">Theoretical Goal: Nash Equilibrium</h3>
<p>At convergence, the discriminator outputs $D(x) = 0.5$ for all samples, meaning it can&rsquo;t distinguish between real and fake data. This indicates that $p_{\text{generator}} = p_{\text{data}}$. Our generator has learned the true data distribution.</p>
<h2 id="the-evolution-of-objective-functions">The Evolution of Objective Functions</h2>
<p>The objective function is the mathematical heart of any GAN. It defines how we measure the &ldquo;distance&rdquo; between our generated distribution and the real data distribution. This choice profoundly impacts:</p>
<ul>
<li><strong>Training stability</strong>: Some objectives lead to more stable convergence</li>
<li><strong>Sample quality</strong>: Different losses emphasize different aspects of realism</li>
<li><strong>Mode collapse</strong>: The tendency to generate limited variety</li>
<li><strong>Computational efficiency</strong>: Some objectives are faster to compute</li>
</ul>
<p>The original GAN uses Jensen-Shannon Divergence (JSD), but researchers have discovered many alternatives that address specific limitations. Let&rsquo;s explore this evolution.</p>
<h3 id="the-original-gan-jensen-shannon-divergence">The Original GAN: Jensen-Shannon Divergence</h3>
<p>The foundational GAN minimizes the Jensen-Shannon Divergence:</p>
<p>$$
\text{JSD}(P, Q) = \frac{1}{2} \text{KL}(P | M) + \frac{1}{2} \text{KL}(Q | M)
$$</p>
<p>Where $M = \frac{1}{2}(P + Q)$ is the average distribution, and $\text{KL}$ is the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">Kullback-Leibler Divergence</a>.</p>
<p><strong>Strengths</strong>: Solid theoretical foundation, introduced adversarial training<br>
<strong>Limitations</strong>: Can suffer from vanishing gradients and mode collapse</p>
<h3 id="wasserstein-gan-wgan">Wasserstein GAN (WGAN)</h3>
<p>The <a href="https://arxiv.org/abs/1701.07875">Wasserstein GAN</a> replaced Jensen-Shannon divergence with the Earth-Mover (Wasserstein) distance, which gives meaningful gradients even when the real and generated distributions do not overlap.</p>
<h4 id="understanding-earth-mover-distance">Understanding Earth-Mover Distance</h4>
<p>The Wasserstein distance, also known as Earth-Mover distance, has an intuitive interpretation:</p>
<blockquote>
<p><strong>Imagine two probability distributions as piles of dirt.</strong> The Earth-Mover distance measures the minimum cost to transform one pile into the other, where cost = mass x distance moved.</p></blockquote>
<p>Mathematically:</p>
<p>$$
W_p(\mu, \nu) = \left( \inf_{\gamma \in \Gamma(\mu, \nu)} \int_{M xM} d(x, y)^p , d\gamma(x, y) \right)^{1/p}
$$</p>
<h4 id="why-earth-mover-distance-matters">Why Earth-Mover Distance Matters</h4>
<table>
  <thead>
      <tr>
          <th>Jensen-Shannon Divergence</th>
          <th>Earth-Mover Distance</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Can be discontinuous</td>
          <td><strong>Always continuous</strong></td>
      </tr>
      <tr>
          <td>May have vanishing gradients</td>
          <td><strong>Meaningful gradients everywhere</strong></td>
      </tr>
      <tr>
          <td>Limited convergence guarantees</td>
          <td><strong>Broader convergence properties</strong></td>
      </tr>
  </tbody>
</table>
<h4 id="wgan-implementation">WGAN Implementation</h4>
<p>Since we can&rsquo;t compute Wasserstein distance directly, WGAN uses the <strong>Kantorovich-Rubinstein duality</strong>:</p>
<ol>
<li><strong>Train a critic function</strong> $f$ to approximate the Wasserstein distance</li>
<li><strong>Constrain the critic</strong> to be 1-Lipschitz (using weight clipping)</li>
<li><strong>Optimize the generator</strong> to minimize this distance</li>
</ol>















<figure class="post-figure center ">
    <img src="/img/wasserstein.webp"
         alt="WGAN training results showing stable convergence"
         title="WGAN training results showing stable convergence"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>WGAN Results</strong>: Demonstrating improved training stability and meaningful loss curves. <em>Source: Wasserstein GAN paper</em></figcaption>
    
</figure>

<h4 id="key-wgan-benefits">Key WGAN Benefits</h4>
<p><strong>Meaningful loss function</strong>: Loss correlates with sample quality<br>
<strong>Improved stability</strong>: Less prone to mode collapse<br>
<strong>Theoretical guarantees</strong>: Solid mathematical foundation<br>
<strong>Better convergence</strong>: Works even when distributions don&rsquo;t overlap</p>
<h3 id="improved-wgan-solving-the-weight-clipping-problem">Improved WGAN: Solving the Weight Clipping Problem</h3>
<p><a href="https://arxiv.org/abs/1704.00028">Improved WGAN</a> (WGAN-GP) addresses a critical flaw in the original WGAN: <strong>weight clipping</strong>.</p>
<h4 id="the-problem-with-weight-clipping">The Problem with Weight Clipping</h4>
<p>Original WGAN clips weights to maintain the 1-Lipschitz constraint:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Problematic approach</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> param <span style="color:#f92672">in</span> critic<span style="color:#f92672">.</span>parameters():
</span></span><span style="display:flex;"><span>    param<span style="color:#f92672">.</span>data<span style="color:#f92672">.</span>clamp_(<span style="color:#f92672">-</span><span style="color:#ae81ff">0.01</span>, <span style="color:#ae81ff">0.01</span>)
</span></span></code></pre></div><p><strong>Issues with clipping</strong>:</p>
<ul>
<li>Forces critic to use extremely simple functions</li>
<li>Pushes weights toward extreme values ($\pm c$)</li>
<li>Can lead to poor gradient flow</li>
<li>Capacity limitations hurt performance</li>
</ul>
<h4 id="the-gradient-penalty-solution">The Gradient Penalty Solution</h4>
<p>WGAN-GP introduces a <strong>gradient penalty term</strong> to constrain the critic:</p>
<p>$$
L = E_{\tilde{x} \sim P_g}[D(\tilde{x})] - E_{x \sim P_r}[D(x)] + \lambda E_{\hat{x}}[(||\nabla_{\hat{x}} D(\hat{x})||_2 - 1)^2]
$$</p>
<p>Where $\hat{x}$ are points sampled uniformly along straight lines between real and generated data points.</p>
<p><strong>Advantages</strong>:</p>
<ul>
<li>No capacity limitations</li>
<li>Better gradient flow</li>
<li>More stable training</li>
<li>Works across different architectures</li>
</ul>
<h3 id="lsgan-the-power-of-least-squares">LSGAN: The Power of Least Squares</h3>
<p><a href="https://arxiv.org/abs/1611.04076">Least Squares GAN</a> takes a different approach. It replaces the logarithmic loss with <strong>L2 (least squares) loss</strong>.</p>
<h4 id="motivation-beyond-binary-classification">Motivation: Beyond Binary Classification</h4>
<p>Traditional GANs use log loss, which focuses primarily on correct classification:</p>
<ul>
<li>Real sample correctly classified → minimal penalty</li>
<li>Fake sample correctly classified → minimal penalty</li>
<li>Distance from decision boundary ignored</li>
</ul>
<h4 id="l2-loss-distance-matters">L2 Loss: Distance Matters</h4>
<p>LSGAN uses L2 loss, which <strong>penalizes proportionally to distance</strong>:</p>
<p>$$
\min_D V_{LSGAN}(D) = \frac{1}{2}E_{x \sim p_{data}(x)}[(D(x) - b)^2] + \frac{1}{2}E_{z \sim p_z(z)}[(D(G(z)) - a)^2]
$$</p>
<p>$$
\min_G V_{LSGAN}(G) = \frac{1}{2}E_{z \sim p_z(z)}[(D(G(z)) - c)^2]
$$</p>
<p>Where typically: $a = 0$ (fake label), $b = c = 1$ (real label)</p>
<h4 id="benefits-of-l2-loss">Benefits of L2 Loss</h4>
<table>
  <thead>
      <tr>
          <th>Log Loss</th>
          <th>L2 Loss</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Binary focus</td>
          <td><strong>Distance-aware</strong></td>
      </tr>
      <tr>
          <td>Can saturate</td>
          <td><strong>Informative gradients</strong></td>
      </tr>
      <tr>
          <td>Sharp decision boundary</td>
          <td><strong>Smooth decision regions</strong></td>
      </tr>
  </tbody>
</table>















<figure class="post-figure center ">
    <img src="/img/lsgan-result.webp"
         alt="LSGAN generated samples showing improved quality"
         title="LSGAN generated samples showing improved quality"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>LSGAN Results</strong>: Demonstrating improved sample quality through distance-aware loss functions. <em>Source: LSGAN paper</em></figcaption>
    
</figure>

<p><strong>Key insight</strong>: LSGAN minimizes the Pearson χ² divergence, providing smoother optimization landscape than JSD.</p>
<h3 id="relaxed-wasserstein-gan-rwgan">Relaxed Wasserstein GAN (RWGAN)</h3>
<p><a href="https://arxiv.org/abs/1705.07164">Relaxed WGAN</a> bridges the gap between WGAN and WGAN-GP, proposing a <strong>general framework</strong> for designing GAN objectives.</p>
<h4 id="key-innovations">Key Innovations</h4>
<p><strong>Asymmetric weight clamping</strong>: RWGAN introduces an asymmetric approach that provides better balance.</p>
<p><strong>Relaxed Wasserstein divergences</strong>: A generalized framework that extends the Wasserstein distance, enabling systematic design of new GAN variants while maintaining theoretical guarantees.</p>
<h4 id="benefits">Benefits</h4>
<ul>
<li>Better convergence properties than standard WGAN</li>
<li>Framework for designing new loss functions and GAN architectures</li>
<li>Competitive performance with other Wasserstein-based methods</li>
</ul>
<p><strong>Key insight</strong>: RWGAN parameterized with KL divergence shows excellent performance while maintaining the theoretical foundations that make Wasserstein GANs attractive.</p>
<h3 id="statistical-distance-approaches">Statistical Distance Approaches</h3>
<p>Several GAN variants focus on minimizing specific statistical distances between distributions.</p>
<h4 id="mcgan-mean-and-covariance-matching">McGAN: Mean and Covariance Matching</h4>
<p><a href="https://arxiv.org/abs/1702.08398">McGAN</a> belongs to the Integral Probability Metric (IPM) family, using <strong>statistical moments</strong> as the distance measure.</p>
<p><strong>Approach</strong>: Match first and second-order statistics:</p>
<ul>
<li><strong>Mean matching</strong>: Align distribution centers</li>
<li><strong>Covariance matching</strong>: Align distribution shapes</li>
</ul>
<p>Moment-matching objectives like this are conceptually related to settings where aligning statistical moments matters, such as matching a generated distribution to a target physical distribution (e.g., molecular conformations). McGAN itself, however, was introduced and demonstrated as an IPM method for image generation.</p>
<p><strong>Limitation</strong>: Relies on weight clipping like original WGAN.</p>
<h4 id="gmmn-maximum-mean-discrepancy">GMMN: Maximum Mean Discrepancy</h4>
<p><a href="https://arxiv.org/abs/1502.02761">Generative Moment Matching Networks</a> eliminates the discriminator entirely, directly minimizing <strong>Maximum Mean Discrepancy (MMD)</strong>.</p>
<p><strong>MMD Intuition</strong>: Compare distributions by their means in a high-dimensional feature space:</p>
<p>$$
\text{MMD}^2(X, Y) = ||E[\phi(x)] - E[\phi(y)]||^2
$$</p>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Simple, discriminator-free training</li>
<li>Theoretical guarantees</li>
<li>Can incorporate autoencoders for better MMD estimation</li>
</ul>
<p><strong>Drawbacks</strong>:</p>
<ul>
<li>Computationally expensive</li>
<li>Often weaker empirical results</li>
</ul>
<h4 id="mmd-gan-learning-better-kernels">MMD GAN: Learning Better Kernels</h4>
<p><a href="https://arxiv.org/abs/1705.08584">MMD GAN</a> improves GMMN by <strong>learning optimal kernels</strong> adversarially to improve upon fixed Gaussian kernels.</p>
<p><strong>Innovation</strong>: Combine GAN adversarial training with MMD objective for the best of both worlds.</p>
<h3 id="different-distance-metrics">Different Distance Metrics</h3>
<h4 id="cramer-gan-addressing-sample-bias">Cramer GAN: Addressing Sample Bias</h4>
<p><a href="https://arxiv.org/abs/1705.10743">Cramer GAN</a> identifies a critical issue with WGAN: <strong>biased sample gradients</strong>.</p>
<p><strong>The Problem</strong>: WGAN&rsquo;s Wasserstein distance lacks three important properties:</p>
<ol>
<li><strong>Sum invariance</strong> (satisfied)</li>
<li><strong>Scale sensitivity</strong> (satisfied)</li>
<li><strong>Unbiased sample gradients</strong> (not satisfied)</li>
</ol>
<p><strong>The Solution</strong>: Use the <strong>Cramer distance</strong>, which satisfies all three properties:</p>
<p>$$
d_C^2(\mu, \nu) = \int ||E_{X \sim \mu}[X - x] - E_{Y \sim \nu}[Y - x]||^2 d\pi(x)
$$</p>
<p><strong>Benefit</strong>: More reliable gradients lead to better training dynamics.</p>
<h4 id="fisher-gan-chi-square-distance">Fisher GAN: Chi-Square Distance</h4>
<p><a href="https://arxiv.org/abs/1705.09675">Fisher GAN</a> uses a <strong>data-dependent constraint</strong> on the critic&rsquo;s second-order moments (variance).</p>
<p><strong>Key Innovation</strong>: The constraint naturally bounds the critic without manual techniques:</p>
<ul>
<li>No weight clipping needed</li>
<li>No gradient penalties required</li>
<li>Constraint emerges from the objective itself</li>
</ul>
<p><strong>Distance</strong>: Approximates the <strong>Chi-square distance</strong> as critic capacity increases:</p>
<p>$$
\chi^2(P, Q) = \int \frac{(P(x) - Q(x))^2}{Q(x)} dx
$$</p>
<p>The Fisher GAN essentially measures the Mahalanobis distance, which accounts for correlated variables relative to the distribution&rsquo;s centroid. This ensures the generator and critic remain bounded, and as the critic&rsquo;s capacity increases, it estimates the Chi-square distance.</p>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Efficient computation</li>
<li>Training stability</li>
<li>Unconstrained critic capacity</li>
</ul>
<h3 id="beyond-traditional-gans-alternative-approaches">Beyond Traditional GANs: Alternative Approaches</h3>
<p>The following variants explore fundamentally different architectures and training paradigms.</p>
<h4 id="ebgan-energy-based-discrimination">EBGAN: Energy-Based Discrimination</h4>
<p><a href="https://arxiv.org/abs/1609.03126">Energy-Based GAN</a> replaces the discriminator with an <strong>autoencoder</strong>.</p>
<p><strong>Key insight</strong>: Use reconstruction error as the discrimination signal:</p>
<ul>
<li>Good data → Low reconstruction error</li>
<li>Poor data → High reconstruction error</li>
</ul>
<p><strong>Architecture</strong>:</p>
<ol>
<li>Train autoencoder on real data</li>
<li>Generator creates samples</li>
<li>Poor generated samples have high reconstruction loss</li>
<li>This loss drives generator improvement</li>
</ol>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Fast and stable training</li>
<li>Robust to hyperparameter changes</li>
<li>No need to balance discriminator/generator</li>
</ul>
<h4 id="began-boundary-equilibrium">BEGAN: Boundary Equilibrium</h4>
<p><a href="https://arxiv.org/abs/1703.10717">BEGAN</a> combines EBGAN&rsquo;s autoencoder approach with WGAN-style loss functions.</p>
<p><strong>Innovation</strong>: Dynamic equilibrium parameter $k_t$ that balances:</p>
<ul>
<li>Real data reconstruction quality</li>
<li>Generated data reconstruction quality</li>
</ul>
<p><strong>Equilibrium equation</strong>:</p>
<p>$$
L_D = L(x) - k_t L(G(z))
$$</p>
<p>$$
k_{t+1} = k_t + \lambda(\gamma L(x) - L(G(z)))
$$</p>
<h4 id="magan-adaptive-margins">MAGAN: Adaptive Margins</h4>
<p><a href="https://arxiv.org/abs/1704.03817">MAGAN</a> improves EBGAN by making the margin in the hinge loss <strong>adaptive over time</strong>.</p>
<p><strong>Concept</strong>: Start with a large margin, gradually reduce it as training progresses:</p>
<ul>
<li>Early training: Focus on major differences</li>
<li>Later training: Fine-tune subtle details</li>
</ul>
<p><strong>Result</strong>: Better sample quality and training stability.</p>
<h2 id="summary-the-evolution-of-gan-objectives">Summary: The Evolution of GAN Objectives</h2>
<p>The evolution of GAN objective functions reflects the field&rsquo;s progression toward more stable and theoretically grounded training procedures. Each variant addresses specific limitations in earlier approaches.</p>
<h3 id="complete-reference-table">Complete Reference Table</h3>
<table>
  <thead>
      <tr>
          <th><strong>GAN Variant</strong></th>
          <th><strong>Key Innovation</strong></th>
          <th><strong>Main Benefit</strong></th>
          <th><strong>Limitation</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Original GAN</strong></td>
          <td>Jensen-Shannon divergence</td>
          <td>Foundation of adversarial training</td>
          <td>Vanishing gradients, mode collapse</td>
      </tr>
      <tr>
          <td><strong>WGAN</strong></td>
          <td>Earth-Mover distance</td>
          <td>Meaningful loss, better stability</td>
          <td>Weight clipping issues</td>
      </tr>
      <tr>
          <td><strong>WGAN-GP</strong></td>
          <td>Gradient penalty</td>
          <td>Solves weight clipping problems</td>
          <td>Additional hyperparameter tuning</td>
      </tr>
      <tr>
          <td><strong>LSGAN</strong></td>
          <td>Least squares loss</td>
          <td>Better gradients, less saturation</td>
          <td>May converge to non-optimal points</td>
      </tr>
      <tr>
          <td><strong>RWGAN</strong></td>
          <td>Relaxed Wasserstein framework</td>
          <td>General framework for new designs</td>
          <td>Complex theoretical setup</td>
      </tr>
      <tr>
          <td><strong>McGAN</strong></td>
          <td>Mean/covariance matching</td>
          <td>Simple statistical alignment</td>
          <td>Limited by weight clipping</td>
      </tr>
      <tr>
          <td><strong>GMMN</strong></td>
          <td>Maximum mean discrepancy</td>
          <td>No discriminator needed</td>
          <td>Computationally expensive</td>
      </tr>
      <tr>
          <td><strong>MMD GAN</strong></td>
          <td>Adversarial kernels for MMD</td>
          <td>Improved GMMN performance</td>
          <td>Still computationally heavy</td>
      </tr>
      <tr>
          <td><strong>Cramer GAN</strong></td>
          <td>Cramer distance</td>
          <td>Unbiased sample gradients</td>
          <td>Complex implementation</td>
      </tr>
      <tr>
          <td><strong>Fisher GAN</strong></td>
          <td>Chi-square distance</td>
          <td>Self-constraining critic</td>
          <td>Limited empirical validation</td>
      </tr>
      <tr>
          <td><strong>EBGAN</strong></td>
          <td>Autoencoder discriminator</td>
          <td>Fast, stable training</td>
          <td>Requires careful regularization</td>
      </tr>
      <tr>
          <td><strong>BEGAN</strong></td>
          <td>Boundary equilibrium</td>
          <td>Dynamic training balance</td>
          <td>Additional equilibrium parameter</td>
      </tr>
      <tr>
          <td><strong>MAGAN</strong></td>
          <td>Adaptive margin</td>
          <td>Progressive refinement</td>
          <td>Margin scheduling complexity</td>
      </tr>
  </tbody>
</table>
<h3 id="practical-recommendations">Practical Recommendations</h3>
<p>For practitioners, the choice depends on specific requirements and engineering tradeoffs:</p>
<ul>
<li><strong>WGAN-GP</strong>: Best balance of stability and performance for most applications. However, tuning the gradient penalty $\lambda$ can be sensitive in practice.</li>
<li><strong>LSGAN</strong>: Simpler implementation with good empirical results.</li>
<li><strong>EBGAN</strong>: Fast experimentation and prototyping.</li>
<li><strong>Original GAN</strong>: Educational purposes and understanding fundamentals.</li>
</ul>
<p><strong>Real-World Impact:</strong> In my work training VLMs on terabyte-scale multimodal data and forecasting chaotic physical systems, these foundational dynamics still matter. Most generation today runs on diffusion models or autoregressive transformers, but the loss-design and training-stability lessons that came out of GAN research carry over. The choice of objective function shapes generation quality, training stability, and compute cost.</p>
<hr>
<p><strong>Acknowledgments</strong>: This post was inspired by the excellent survey &ldquo;<a href="https://arxiv.org/abs/1711.05914">How Generative Adversarial Networks and Their Variants Work: An Overview of GAN</a>&rdquo;.</p>
]]></content:encoded></item><item><title>Word Embeddings in NLP: An Introduction</title><link>https://hunterheidenreich.com/posts/intro-to-word-embeddings/</link><pubDate>Sun, 05 Aug 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/intro-to-word-embeddings/</guid><description>Learn about word embeddings in NLP: from basic one-hot encoding to contextual models like ELMo. Guide with examples.</description><content:encoded><![CDATA[<h2 id="understanding-word-embeddings">Understanding Word Embeddings</h2>
<p>A word embedding maps words to real-valued vectors:</p>
<p>$$
\text{word} \rightarrow \mathbb{R}^n
$$</p>
<p>where $n$ represents the dimensionality of the embedding space.</p>
<p>The goal is simple: position semantically similar words close together in vector space. This dense representation typically uses hundreds of dimensions, a massive reduction from the millions required by one-hot encoding.</p>
<p>Word embeddings are grounded in <a href="https://en.wikipedia.org/wiki/Distributional_semantics">Zellig Harris&rsquo; distributional hypothesis</a>: words appearing in similar contexts tend to have similar meanings. This forms the foundation of distributional semantics.</p>















<figure class="post-figure center ">
    <img src="/img/distributional_semantics-50.webp"
         alt="Distributional semantics visualization"
         title="Distributional semantics visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Words embedded in three-dimensional space, organized by semantic similarity</figcaption>
    
</figure>

<p>Different embedding algorithms capture various aspects of this distributional principle. This post explores the main methods for creating word embeddings and their applications in natural language processing.</p>
<p>While modern foundation models and large Vision-Language Models rely on subword tokenizers (like BPE) and Transformer embedding layers, the goal is the same: mapping discrete text to a continuous vector space where math can capture meaning. These foundational techniques build the intuition for the embedding layers in today&rsquo;s models.</p>
<h2 id="why-word-embeddings-matter-in-nlp">Why Word Embeddings Matter in NLP</h2>
<p>Computers require numerical representations to apply machine learning algorithms to text. Word embeddings bridge this gap by converting text into dense vectors that preserve semantic and syntactic relationships.</p>
<p><strong>Key advantages:</strong></p>
<ol>
<li><strong>Dense representation</strong>: Hundreds of dimensions provide a compact alternative to vocabulary-sized sparse vectors.</li>
<li><strong>Semantic preservation</strong>: Similar words cluster together in vector space.</li>
<li><strong>Mathematical operations</strong>: Enable analogical reasoning ($\text{king} - \text{man} + \text{woman} \approx \text{queen}$).</li>
<li><strong>Transfer learning</strong>: Pre-trained embeddings work across multiple tasks and domains.</li>
</ol>
<p>Modern deep learning architectures leverage these properties extensively. The development of universal, pre-trained embeddings was a significant step forward. We can use versatile embeddings that generalize across applications, eliminating the need to train task-specific representations from scratch.</p>
<h2 id="word-embedding-approaches">Word Embedding Approaches</h2>
<h3 id="one-hot-encoding-and-count-vectorization">One-Hot Encoding and Count Vectorization</h3>
<p>One-hot encoding represents the simplest approach to word vectorization. Each word gets a unique dimension in a vocabulary-sized vector, marked with 1 for presence and 0 elsewhere. Count vectorization extends this by counting the occurrences of each word in a document.</p>















<figure class="post-figure center ">
    <img src="/img/word_vector_onehot-50.webp"
         alt="One-hot encoding visualization"
         title="One-hot encoding visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">One-hot encoding creates sparse vectors with single active dimensions</figcaption>
    
</figure>

<p><strong>Characteristics:</strong></p>
<ul>
<li><strong>High dimensionality</strong>: Vector length equals vocabulary size.</li>
<li><strong>Extreme sparsity</strong>: Most dimensions contain zeros.</li>
<li><strong>No relationships</strong>: Treats all words as equally distant.</li>
<li><strong>Computational efficiency</strong>: Simple to implement and understand.</li>
</ul>
<p>While lacking semantic information, count vectorization serves as a foundation for more complex methods. Let&rsquo;s look at a practical implementation using scikit-learn&rsquo;s <code>CountVectorizer</code>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.feature_extraction.text <span style="color:#f92672">import</span> CountVectorizer
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Initialize the vectorizer</span>
</span></span><span style="display:flex;"><span>vectorizer <span style="color:#f92672">=</span> CountVectorizer()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Sample text for demonstration</span>
</span></span><span style="display:flex;"><span>sample_text <span style="color:#f92672">=</span> [<span style="color:#e6db74">&#34;One of the most basic ways we can numerically represent words &#34;</span>
</span></span><span style="display:flex;"><span>               <span style="color:#e6db74">&#34;is through the one-hot encoding method (also sometimes called &#34;</span>
</span></span><span style="display:flex;"><span>               <span style="color:#e6db74">&#34;count vectorizing).&#34;</span>]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Fit the vectorizer to our text data</span>
</span></span><span style="display:flex;"><span>vectorizer<span style="color:#f92672">.</span>fit(sample_text)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Examine the vocabulary and word indices</span>
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">&#39;Vocabulary:&#39;</span>)
</span></span><span style="display:flex;"><span>print(vectorizer<span style="color:#f92672">.</span>vocabulary_)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Transform text to vectors</span>
</span></span><span style="display:flex;"><span>vector <span style="color:#f92672">=</span> vectorizer<span style="color:#f92672">.</span>transform(sample_text)
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">&#39;Full vector:&#39;</span>)
</span></span><span style="display:flex;"><span>print(vector<span style="color:#f92672">.</span>toarray())
</span></span></code></pre></div><p>At scale, count vectorization introduces engineering challenges. With millions of documents, the vocabulary grows large, and the sparse matrices become expensive to store and compute on. In these scaling scenarios, practitioners often turn to the <strong>Hashing Trick</strong> (via <code>HashingVectorizer</code>) to bound the dimensionality, or they move entirely to the dense embeddings discussed later in this post.</p>
<p>We can see count vectorization in action with a real dataset, building a simple text classifier for the <a href="https://www.kaggle.com/datasets/crawford/20-newsgroups">20 Newsgroups dataset</a>:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.datasets <span style="color:#f92672">import</span> fetch_20newsgroups
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.feature_extraction.text <span style="color:#f92672">import</span> CountVectorizer
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.naive_bayes <span style="color:#f92672">import</span> MultinomialNB
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn <span style="color:#f92672">import</span> metrics
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Load train and test splits, removing metadata for a cleaner signal</span>
</span></span><span style="display:flex;"><span>newsgroups_train <span style="color:#f92672">=</span> fetch_20newsgroups(subset<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;train&#39;</span>,
</span></span><span style="display:flex;"><span>                                      remove<span style="color:#f92672">=</span>(<span style="color:#e6db74">&#39;headers&#39;</span>, <span style="color:#e6db74">&#39;footers&#39;</span>, <span style="color:#e6db74">&#39;quotes&#39;</span>))
</span></span><span style="display:flex;"><span>newsgroups_test <span style="color:#f92672">=</span> fetch_20newsgroups(subset<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;test&#39;</span>,
</span></span><span style="display:flex;"><span>                                     remove<span style="color:#f92672">=</span>(<span style="color:#e6db74">&#39;headers&#39;</span>, <span style="color:#e6db74">&#39;footers&#39;</span>, <span style="color:#e6db74">&#39;quotes&#39;</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Initialize and fit vectorizer on training data</span>
</span></span><span style="display:flex;"><span>vectorizer <span style="color:#f92672">=</span> CountVectorizer()
</span></span><span style="display:flex;"><span>X_train <span style="color:#f92672">=</span> vectorizer<span style="color:#f92672">.</span>fit_transform(newsgroups_train<span style="color:#f92672">.</span>data)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Build and train classifier</span>
</span></span><span style="display:flex;"><span>classifier <span style="color:#f92672">=</span> MultinomialNB(alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.01</span>)
</span></span><span style="display:flex;"><span>classifier<span style="color:#f92672">.</span>fit(X_train, newsgroups_train<span style="color:#f92672">.</span>target)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Transform test data and make predictions</span>
</span></span><span style="display:flex;"><span>X_test <span style="color:#f92672">=</span> vectorizer<span style="color:#f92672">.</span>transform(newsgroups_test<span style="color:#f92672">.</span>data)
</span></span><span style="display:flex;"><span>y_pred <span style="color:#f92672">=</span> classifier<span style="color:#f92672">.</span>predict(X_test)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Evaluate performance</span>
</span></span><span style="display:flex;"><span>accuracy <span style="color:#f92672">=</span> metrics<span style="color:#f92672">.</span>accuracy_score(newsgroups_test<span style="color:#f92672">.</span>target, y_pred)
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;Accuracy: </span><span style="color:#e6db74">{</span>accuracy<span style="color:#e6db74">:</span><span style="color:#e6db74">.3f</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span></code></pre></div><p>This provides a solid baseline. To capture actual semantic meaning and reduce dimensionality, we must move beyond simple counting.</p>
<h3 id="tf-idf-term-frequency-inverse-document-frequency">TF-IDF (Term Frequency-Inverse Document Frequency)</h3>
<p><a href="https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html">TF-IDF</a> extends one-hot encoding by weighting terms based on their importance across a document collection. TF-IDF combines:</p>
<ul>
<li><strong>Term Frequency (TF)</strong>: How often a word appears in a document</li>
<li><strong>Inverse Document Frequency (IDF)</strong>: How rare a word is across all documents</li>
</ul>
<p>This weighting scheme reduces the impact of common words (like &ldquo;the&rdquo; or &ldquo;and&rdquo;) while emphasizing distinctive terms that appear frequently in specific documents but rarely elsewhere.</p>
<p><strong>Advantages:</strong></p>
<ul>
<li>Captures document-level importance</li>
<li>Reduces impact of stop words</li>
<li>Effective for information retrieval tasks</li>
</ul>
<p><strong>Limitations:</strong></p>
<ul>
<li>Still high-dimensional and sparse</li>
<li>No semantic relationships between terms</li>
<li>Context-independent representation</li>
</ul>
<h3 id="co-occurrence-matrices">Co-Occurrence Matrices</h3>
<p>Co-occurrence matrices capture word relationships by recording which terms appear together within defined contexts (sentences, paragraphs, or fixed windows). The resulting matrix has dimensions equal to vocabulary size squared, with entries showing co-occurrence frequency.</p>















<figure class="post-figure center ">
    <img src="/img/Word_co-occurrence_network_%28range_3_words%29_-_ENG-50.webp"
         alt="Co-occurrence network visualization"
         title="Co-occurrence network visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Co-occurrence relationships within a three-word window</figcaption>
    
</figure>

<p><strong>Key properties:</strong></p>
<ul>
<li><strong>Global statistics</strong>: Captures corpus-wide word relationships</li>
<li><strong>Symmetric relationships</strong>: Mutual co-occurrence patterns</li>
<li><strong>Extreme dimensionality</strong>: Vocabulary size squared creates storage challenges</li>
<li><strong>Sparse representation</strong>: Most word pairs never co-occur</li>
</ul>
<p>While computationally expensive to store and process, co-occurrence matrices form the foundation for advanced methods like GloVe that compress this information into dense representations.</p>
<h2 id="neural-network-based-embeddings">Neural Network-Based Embeddings</h2>
<h3 id="neural-probabilistic-language-models">Neural Probabilistic Language Models</h3>
<p><a href="https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf">Neural probabilistic models</a> pioneered the use of neural networks for learning word embeddings. These models learn dense representations as a byproduct of language modeling, predicting the next word in a sequence.</p>















<figure class="post-figure center ">
    <img src="/img/bengio-npm-50.webp"
         alt="Neural probabilistic model diagram"
         title="Neural probabilistic model diagram"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Architecture of neural probabilistic language models</figcaption>
    
</figure>

<p><strong>Training process:</strong></p>
<ol>
<li>Initialize random dense embeddings for each vocabulary word</li>
<li>Use embeddings as inputs to predict language modeling objectives</li>
<li>Update embeddings through backpropagation based on prediction errors</li>
<li>Resulting embeddings capture patterns useful for the training task</li>
</ol>
<p>This approach demonstrated that task-specific embeddings could be learned jointly with model objectives, establishing the foundation for modern embedding methods.</p>
<h3 id="word2vec">Word2Vec</h3>
<p><a href="https://code.google.com/archive/p/word2vec/">Word2Vec</a> made word embeddings practical at scale by introducing efficient training algorithms for massive corpora. It popularized compelling vector arithmetic properties, enabling analogical reasoning like the famous &ldquo;$\text{king} - \text{man} + \text{woman} \approx \text{queen}$&rdquo; example (a vector-offset regularity first reported by Mikolov, Yih &amp; Zweig (2013) on recurrent-network language-model embeddings).</p>















<figure class="post-figure center ">
    <img src="/img/Word_vector_illustration.webp"
         alt="Word2Vec vector arithmetic visualization"
         title="Word2Vec vector arithmetic visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Word2Vec demonstrates analogical relationships through vector arithmetic</figcaption>
    
</figure>

<p><strong>Two training architectures:</strong></p>
<h4 id="continuous-bag-of-words-cbow">Continuous Bag-of-Words (CBOW)</h4>
<p>Predicts target words from surrounding context words. Given a window of context words, the model learns to predict the central word.</p>
<h4 id="skip-gram">Skip-Gram</h4>
<p>Predicts context words from target words. Given a central word, the model learns to predict surrounding words within a defined window.</p>
<p><strong>Key advantages:</strong></p>
<ul>
<li><strong>Computational efficiency</strong>: Much faster than neural probabilistic models</li>
<li><strong>Scalable training</strong>: Can process billion-word corpora effectively</li>
<li><strong>Quality embeddings</strong>: Captures semantic and syntactic relationships</li>
<li><strong>Flexible context</strong>: Window size controls topical vs. functional similarity</li>
</ul>
<p>The choice of window size significantly impacts learned relationships. Larger windows capture topical associations, while smaller windows focus on syntactic and functional similarities.</p>
<h3 id="glove-global-vectors">GloVe (Global Vectors)</h3>
<p><a href="https://nlp.stanford.edu/projects/glove/">GloVe</a> combines the best aspects of matrix factorization methods (which capture global corpus statistics) and local context window approaches like Word2Vec. Matrix factorization methods excel at global patterns but struggle with analogical reasoning, while Word2Vec captures local relationships but may miss global structure.</p>
<p><strong>Key innovation:</strong>
GloVe trains on a global word-context co-occurrence matrix, incorporating corpus-wide statistical information while maintaining the analogical reasoning capabilities that made Word2Vec successful.</p>
<p><strong>Advantages over Word2Vec:</strong></p>
<ul>
<li><strong>Global optimization</strong>: Leverages entire corpus statistics</li>
<li><strong>Better performance</strong>: Often outperforms Word2Vec on word similarity and analogy tasks</li>
<li><strong>Stable training</strong>: More consistent convergence due to global objective function</li>
</ul>
<p>The result is embeddings that capture both local syntactic patterns and global semantic relationships more effectively.</p>
<h2 id="contextual-embedding-methods">Contextual Embedding Methods</h2>
<h3 id="fasttext">FastText</h3>
<p><a href="https://github.com/facebookresearch/fastText">FastText</a> addresses a critical limitation of previous methods: handling out-of-vocabulary (OOV) words. By incorporating subword information, FastText can generate meaningful representations for previously unseen words.</p>
<p><strong>Subword approach:</strong></p>
<ul>
<li>Decomposes words into character n-grams (typically 3-6 characters)</li>
<li>Represents words as sums of their component n-grams</li>
<li>Trains using skip-gram objective with negative sampling</li>
</ul>
<p><strong>Key advantages:</strong></p>
<ul>
<li><strong>OOV handling</strong>: Can embed unseen words using known subword components</li>
<li><strong>Morphological awareness</strong>: Captures relationships between related word forms</li>
<li><strong>Multilingual support</strong>: Facebook released pre-trained embeddings for 294 languages</li>
<li><strong>Robust performance</strong>: Particularly effective for morphologically rich languages</li>
</ul>
<p>For example, if the model knows &ldquo;navigate,&rdquo; it can provide meaningful representation for &ldquo;circumnavigate&rdquo; by leveraging shared subword components, even if &ldquo;circumnavigate&rdquo; wasn&rsquo;t in the training data.</p>
<h3 id="poincaré-embeddings">Poincaré Embeddings</h3>
<p><a href="https://radimrehurek.com/gensim/models/poincare.html">Poincaré embeddings</a> introduce a novel approach by learning representations in hyperbolic space. This geometric innovation specifically targets hierarchical relationships in data.</p>
<p><strong>Hyperbolic geometry advantages:</strong></p>
<ul>
<li><strong>Natural hierarchy encoding</strong>: Distance represents similarity, while norm encodes hierarchical level</li>
<li><strong>Efficient representation</strong>: Requires fewer dimensions for hierarchical data</li>
<li><strong>Mathematical elegance</strong>: Leverages properties of hyperbolic space for embedding optimization</li>
</ul>
<p><strong>Applications:</strong>
Particularly effective for data with inherent hierarchical structure, such as:</p>
<ul>
<li>WordNet taxonomies</li>
<li>Organizational charts</li>
<li>Computer network topologies</li>
<li>Knowledge graphs</li>
</ul>
<p>The <a href="https://arxiv.org/abs/1705.08039">original paper</a> demonstrates good efficiency in reproducing WordNet relationships with significantly lower dimensionality compared to traditional embedding methods.</p>
<h2 id="contextual-embeddings">Contextual Embeddings</h2>
<h3 id="elmo-embeddings-from-language-models">ELMo (Embeddings from Language Models)</h3>
<p><a href="https://github.com/allenai/allennlp-models">ELMo</a> represents a paradigm shift toward contextual word representations. ELMo generates dynamic representations based on sentence context, adapting to word usage patterns.</p>
<p><strong>Architecture:</strong></p>
<ul>
<li><strong>Bidirectional LSTM</strong>: Processes text in both forward and backward directions</li>
<li><strong>Character-level input</strong>: Handles OOV words and captures morphological patterns</li>
<li><strong>Multi-layer representations</strong>: Combines different abstraction levels</li>
</ul>
<p><strong>Layer specialization:</strong></p>
<ul>
<li><strong>Lower layers</strong>: Excel at syntactic tasks (POS tagging, parsing)</li>
<li><strong>Higher layers</strong>: Capture semantic relationships (word sense disambiguation)</li>
<li><strong>Combined layers</strong>: Weighted combination achieves good performance</li>
</ul>
<p><strong>Key innovation:</strong>
ELMo embeddings vary by context. The word &ldquo;bank&rdquo; receives different representations in &ldquo;river bank&rdquo; versus &ldquo;financial bank,&rdquo; addressing polysemy directly through contextual awareness.</p>
<p>This approach achieved strong performance across numerous NLP tasks by providing context-sensitive representations that adapt to word usage patterns.</p>
<h3 id="probabilistic-fasttext">Probabilistic FastText</h3>
<p><a href="https://github.com/benathi/multisense-prob-fasttext">Probabilistic FastText</a> addresses polysemy (words with multiple meanings) through probabilistic modeling. Traditional embeddings conflate different word senses into single representations, limiting their precision.</p>
<p><strong>The polysemy problem:</strong>
Consider &ldquo;rock&rdquo; which can mean:</p>
<ul>
<li>Rock music (genre)</li>
<li>A stone (geological object)</li>
<li>Rocking motion (verb)</li>
</ul>
<p>Standard embeddings average these meanings, producing representations that may not capture any sense precisely.</p>
<p><strong>Probabilistic approach:</strong>
Probabilistic FastText represents words as Gaussian mixture models: probability distributions that can capture multiple distinct meanings as separate components.</p>
<p><strong>Advantages:</strong></p>
<ul>
<li><strong>Multi-sense representation</strong>: Each word sense gets its own distribution</li>
<li><strong>Context sensitivity</strong>: Can select appropriate sense based on usage context</li>
<li><strong>Uncertainty quantification</strong>: Probabilistic framework captures embedding confidence</li>
</ul>
<p>This approach provides a more nuanced treatment of lexical ambiguity, particularly valuable for words with distinct, context-dependent meanings.</p>
<h2 id="summary-and-future-directions">Summary and Future Directions</h2>
<p>Word embeddings have evolved from simple one-hot encodings to contextual representations that capture nuanced linguistic relationships. Each approach offers distinct advantages:</p>
<p><strong>Static embeddings</strong> (Word2Vec, GloVe, FastText) provide:</p>
<ul>
<li>Computational efficiency for large-scale applications</li>
<li>Pre-trained models available for numerous languages</li>
<li>Clear analogical reasoning capabilities</li>
<li>Good performance on many downstream tasks</li>
</ul>
<p><strong>Contextual embeddings</strong> (ELMo, BERT, GPT) offer:</p>
<ul>
<li>Dynamic representations based on sentence context</li>
<li>Better handling of polysemy and word sense disambiguation</li>
<li>Strong performance on complex NLP tasks</li>
<li>Ability to capture subtle contextual nuances</li>
</ul>
<p><strong>Choosing the right approach</strong> depends on:</p>
<ul>
<li><strong>Task requirements</strong>: Static embeddings for efficiency, contextual for accuracy</li>
<li><strong>Data availability</strong>: Pre-trained models vs. domain-specific training</li>
<li><strong>Computational constraints</strong>: Static embeddings require less processing power</li>
<li><strong>Language coverage</strong>: Consider availability of pre-trained models for target languages</li>
</ul>
<p>The field continues advancing toward more efficient contextual models, better multilingual representations, and embeddings that capture increasingly complex linguistic phenomena.</p>
<p>For a from-scratch Word2Vec implementation in PyTorch (Skip-gram and CBOW, with hierarchical softmax and negative sampling) that takes these concepts further, see the <a href="/projects/modern-word2vec/">PyTorch Word2Vec project</a>.</p>
]]></content:encoded></item></channel></rss>